Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions bench/bench_frac_scatter_mlml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Benchmark frac, scatter_add, multilabel_margin_loss vs torch.

python bench/bench_frac_scatter_mlml.py
"""

import torch
import torch.nn.functional as F
import triton.testing

import ntops

DEVICE = "cuda"
DTYPE = torch.float32


def _report(name, shape_str, ms_nt, ms_th, nbytes):
bw_nt = nbytes / ms_nt * 1e-6
bw_th = nbytes / ms_th * 1e-6
print(
f" {name:22s} {shape_str:18s} "
f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | "
f"speedup {ms_th / ms_nt:.2f}x"
)


def bench_frac():
print("\n[frac]")
for shape in [(4096 * 4096,), (4096, 4096), (8192, 8192)]:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE) * 5
nbytes = x.numel() * x.element_size() * 2 # 1 read + 1 write
ms_nt = triton.testing.do_bench(lambda: ntops.torch.frac(x))
ms_th = triton.testing.do_bench(lambda: torch.frac(x))
_report("frac", str(shape), ms_nt, ms_th, nbytes)


def bench_scatter_add():
print("\n[scatter_add]")
# (input_shape, dim, k)
cases = [
((1024, 256), 1, 128),
((1024, 256), 1, 256),
((4096, 512), 1, 256),
((8192, 256), 0, 4096),
]
for (shape, dim, k) in cases:
inp = torch.randn(shape, dtype=DTYPE, device=DEVICE)
idx_shape = list(shape); idx_shape[dim] = k
t = shape[dim]
idx = torch.randint(0, t, idx_shape, device=DEVICE)
src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE)
nbytes = (inp.numel() + src.numel()) * inp.element_size() * 2
label = f"shape={shape} dim={dim} k={k}"
try:
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.scatter_add(inp, dim, idx, src)
)
except Exception as exc: # noqa: BLE001
print(f" {'scatter_add':22s} {label:18s} 九齿 SKIP ({type(exc).__name__})")
continue
ms_th = triton.testing.do_bench(
lambda: torch.scatter_add(inp, dim, idx, src)
)
_report("scatter_add", label, ms_nt, ms_th, nbytes)


def bench_mlml():
print("\n[multilabel_margin_loss]")
cases = [(64, 16), (256, 32), (512, 64), (1024, 32)]
for (n, c) in cases:
x = torch.randn(n, c, dtype=DTYPE, device=DEVICE)
target = torch.full((n, c), -1, dtype=torch.int64, device=DEVICE)
for i in range(n):
num = torch.randint(1, c // 2 + 1, (1,)).item()
target[i, :num] = torch.randperm(c, device=DEVICE)[:num]
nbytes = x.numel() * x.element_size() * 2
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.multilabel_margin_loss(x, target, reduction="mean")
)
ms_th = triton.testing.do_bench(
lambda: F.multilabel_margin_loss(x, target, reduction="mean")
)
_report("mlml", f"N={n} C={c}", ms_nt, ms_th, nbytes)


def main():
print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}")
bench_frac()
bench_scatter_add()
bench_mlml()


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions bench/bench_fractional_max_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Benchmark fractional_max_pool2d/3d (method B) vs torch.

python bench/bench_fractional_max_pool.py
"""

import torch
import torch.nn.functional as F
import triton.testing

import ntops

DEVICE = "cuda"
DTYPE = torch.float32


def _report(name, shape_str, ms_nt, ms_th):
print(f" {name:10s} {shape_str:28s} 九齿 {ms_nt:8.3f} ms | torch {ms_th:8.3f} ms | speedup {ms_th / ms_nt:.2f}x")


def main():
print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}\n")

print("[fractional_max_pool2d]")
cases_2d = [
(8, 16, 16, 16, 2, 2, 12, 12),
(8, 32, 32, 32, 2, 2, 24, 24),
(16, 64, 56, 56, 3, 3, 40, 40),
(32, 64, 112, 112, 2, 2, 80, 80),
]
for n, c, h, w, kh, kw, oh, ow in cases_2d:
x = torch.randn(n, c, h, w, dtype=DTYPE, device=DEVICE)
s = torch.rand(n, c, 2, dtype=DTYPE, device=DEVICE)
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.fractional_max_pool2d(x, (kh, kw), output_size=(oh, ow), _random_samples=s)
)
ms_th = triton.testing.do_bench(
lambda: F.fractional_max_pool2d(x, (kh, kw), output_size=(oh, ow), _random_samples=s)
)
_report("fmp2d", f"({n},{c},{h},{w})->({oh},{ow})", ms_nt, ms_th)

print("\n[fractional_max_pool3d]")
cases_3d = [
(4, 16, 16, 16, 16, 2, 2, 2, 12, 12, 12),
(8, 32, 16, 32, 32, 2, 2, 2, 12, 24, 24),
]
for n, c, d, h, w, kd, kh, kw, od, oh, ow in cases_3d:
x = torch.randn(n, c, d, h, w, dtype=DTYPE, device=DEVICE)
s = torch.rand(n, c, 3, dtype=DTYPE, device=DEVICE)
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.fractional_max_pool3d(x, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=s)
)
ms_th = triton.testing.do_bench(
lambda: F.fractional_max_pool3d(x, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=s)
)
_report("fmp3d", f"({n},{c},{d},{h},{w})->({od},{oh},{ow})", ms_nt, ms_th)


if __name__ == "__main__":
main()
65 changes: 65 additions & 0 deletions bench/tune_frac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tune (num_warps, block_size) for the frac kernel.

Evaluation disables auto-tuning (max_num_configs=1), so the winning config
must be passed explicitly into premake. This sweeps at a fixed large shape
under eval-like conditions and reports the best GB/s per dtype.

Run on each platform (NVIDIA / Iluvatar / MetaX) and bake the winners into a
device-adaptive _launch_config in ntops/torch/frac.py.

python bench/tune_frac.py
"""

import torch
import triton.testing

import ntops
from ntops.torch.utils import _cached_make, set_default_max_num_configs

DEVICE = "cuda"
SHAPE = (8192, 8192)
BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192)
NUM_WARPS = (4, 8, 16)


def _tune(dtype):
input = torch.randn(SHAPE, dtype=dtype, device=DEVICE) * 5
output = torch.empty_like(input)

# frac reads 1 tensor and writes 1 → 2 bytes per element.
nbytes = 2 * input.numel() * input.element_size()
best_bw, best_cfg = 0.0, None

for block_size in BLOCK_SIZES:
for num_warps in NUM_WARPS:
try:
kernel = _cached_make(
ntops.kernels.frac.premake,
input.ndim,
block_size=block_size,
num_warps=num_warps,
num_stages=1,
)
ms = triton.testing.do_bench(lambda k=kernel: k(input, output))
bw = nbytes / ms * 1e-6
print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s")
if bw > best_bw:
best_bw, best_cfg = bw, (num_warps, block_size)
except Exception as exc: # noqa: BLE001
print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})")

return best_bw, best_cfg


def main():
set_default_max_num_configs(1)
print(f"device: {torch.cuda.get_device_name()}")

for dtype in (torch.float32, torch.float16):
print(f"\n[{dtype}]")
bw, cfg = _tune(dtype)
print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)")


if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions bench/tune_fractional_max_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tune (num_warps, block_size) for the fractional_max_pool2d gather kernel.

The data-dependent gather is launch/occupancy-sensitive: block_size=1024 gives
too few programs for small/medium M. Sweep to find the config that maximizes
occupancy across sizes.

python bench/tune_fractional_max_pool.py
"""

import torch
import triton.testing

import ntops
from ntops.torch.fractional_max_pool2d import _intervals
from ntops.torch.utils import _cached_make, set_default_max_num_configs

DEVICE = "cuda"
DTYPE = torch.float32
BLOCK_SIZES = (64, 128, 256, 512, 1024, 2048)
NUM_WARPS = (1, 2, 4, 8)

CASES = [
(8, 16, 16, 16, 2, 2, 12, 12),
(16, 64, 56, 56, 3, 3, 40, 40),
(32, 64, 112, 112, 2, 2, 80, 80),
]


def _base_offset(input, start_h, start_w):
n, c, h, w = input.shape
nc_base = (torch.arange(n * c, device=input.device).reshape(n, c)) * (h * w)
bo = nc_base[..., None, None] + (start_h * w)[..., :, None] + start_w[..., None, :]
return bo.reshape(-1).to(torch.int64)


def main():
set_default_max_num_configs(1)
print(f"device: {torch.cuda.get_device_name()}")

for n, c, h, w, kh, kw, oh, ow in CASES:
x = torch.randn(n, c, h, w, dtype=DTYPE, device=DEVICE).reshape(-1)
xr = x.reshape(n, c, h, w)
s = torch.rand(n, c, 2, dtype=DTYPE, device=DEVICE)
sh = _intervals(s[..., 1], h, oh, kh)
sw = _intervals(s[..., 0], w, ow, kw)
bo = _base_offset(xr, sh, sw)
m = bo.numel()
out = torch.empty((m,), dtype=DTYPE, device=DEVICE)

ms_th = triton.testing.do_bench(
lambda: torch.nn.functional.fractional_max_pool2d(
xr, (kh, kw), output_size=(oh, ow), _random_samples=s
)
)
print(f"\n({n},{c},{h},{w})->({oh},{ow}) M={m} (torch {ms_th:.3f} ms)")
best = (1e9, None)
for block_size in BLOCK_SIZES:
for num_warps in NUM_WARPS:
try:
kernel = _cached_make(
ntops.kernels.fractional_max_pool2d.premake,
block_size=block_size,
num_warps=num_warps,
num_stages=1,
max_num_configs=1,
)
ms = triton.testing.do_bench(lambda k=kernel: k(bo, out, x, w, kh, kw, m))
if ms < best[0]:
best = (ms, (num_warps, block_size))
print(f" block={block_size:5d} warps={num_warps:2d} {ms:7.3f} ms ({ms_th/ms:.2f}x)")
except Exception as exc: # noqa: BLE001
print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})")
print(f" best: num_warps={best[1][0]}, block_size={best[1][1]} ({ms_th/best[0]:.2f}x torch)")


if __name__ == "__main__":
main()
92 changes: 92 additions & 0 deletions bench/tune_scatter_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Tune scatter_add's atomic-scatter config.

The atomic scatter is latency-bound (constant ~6 GB/s regardless of size), which
means too few concurrent atomics. The fix is many concurrent threads each doing
~1 atomic: small block_size + enough warps + many programs. This sweeps to find
the config that hides atomic latency.

python bench/tune_scatter_add.py
"""

import torch
import triton.testing

import ntops
from ntops.torch.utils import _cached_make, set_default_max_num_configs

DEVICE = "cuda"
DTYPE = torch.float32

# (input_shape, dim, k) — a few representative scatter shapes.
SHAPES = [
((1024, 256), 1, 256),
((4096, 512), 1, 256),
]
BLOCK_SIZES = (64, 128, 256, 512, 1024, 2048)
NUM_WARPS = (1, 2, 4, 8, 16)


def _next_pow2(x):
return 1 << (x - 1).bit_length()


def _run(shape, dim, k, block_size, num_warps):
inp = torch.randn(shape, dtype=DTYPE, device=DEVICE)
idx_shape = list(shape)
idx_shape[dim] = k
t = shape[dim]
idx = torch.randint(0, t, idx_shape, device=DEVICE).to(torch.int64)
src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE)

inp_p = inp.movedim(dim, -1).contiguous()
t_size = inp_p.shape[-1]
rows = inp_p.numel() // t_size
output = inp_p.reshape(rows, t_size).clone()
idx_p = idx.movedim(dim, -1).contiguous().reshape(rows, -1)
k_size = idx_p.shape[-1]
src_p = src.movedim(dim, -1).contiguous().reshape(rows, k_size)

kernel = _cached_make(
ntops.kernels.scatter_add.premake,
block_size=block_size,
num_warps=num_warps,
num_stages=1,
max_num_configs=1,
)
return triton.testing.do_bench(
lambda: kernel(output, idx_p, src_p, t_size, k_size, rows * k_size)
)


def main():
set_default_max_num_configs(1)
print(f"device: {torch.cuda.get_device_name()}")

for shape, dim, k in SHAPES:
inp = torch.randn(shape, dtype=DTYPE, device=DEVICE)
idx_shape = list(shape)
idx_shape[dim] = k
idx = torch.randint(0, shape[dim], idx_shape, device=DEVICE)
src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE)
ms_th = triton.testing.do_bench(lambda: torch.scatter_add(inp, dim, idx, src))

print(f"\nshape={shape} dim={dim} k={k} (torch {ms_th:.3f} ms)")
best = (1e9, None)
for block_size in BLOCK_SIZES:
for num_warps in NUM_WARPS:
try:
ms = _run(shape, dim, k, block_size, num_warps)
print(
f" block={block_size:5d} warps={num_warps:2d} "
f"{ms:8.3f} ms ({ms_th / ms:.2f}x torch)"
)
if ms < best[0]:
best = (ms, (num_warps, block_size))
except Exception as exc: # noqa: BLE001
print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})")
print(f" best: num_warps={best[1][0]}, block_size={best[1][1]} "
f"({ms_th / best[0]:.2f}x torch)")


if __name__ == "__main__":
main()
Loading