Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions bench/bench_t1_1_8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Benchmark T1-1-8 operators vs torch.

kl_div / count_nonzero / narrow / corrcoef / combinations

python bench/bench_t1_1_8.py
"""

import torch
import torch.nn.functional as F
import triton.testing

import ntops

DEVICE = "cuda"
DTYPE = torch.float32


def _report_bw(name, shape_str, ms_nt, ms_th, nbytes):
bw_nt = nbytes / ms_nt * 1e-6
bw_th = nbytes / ms_th * 1e-6
print(
f" {name:16s} {shape_str:24s} "
f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | "
f"speedup {ms_th / ms_nt:.2f}x"
)


def _report_ms(name, shape_str, ms_nt, ms_th):
print(
f" {name:16s} {shape_str:24s} "
f"九齿 {ms_nt:8.3f} ms | torch {ms_th:8.3f} ms | "
f"speedup {ms_th / ms_nt:.2f}x"
)


def bench_kl_div():
print("\n[kl_div]")
for shape in [(4096, 4096), (8192, 8192), (1024, 8192)]:
x = F.log_softmax(torch.randn(shape, dtype=DTYPE, device=DEVICE), dim=-1)
t = F.softmax(torch.randn(shape, dtype=DTYPE, device=DEVICE), dim=-1)
nbytes = x.numel() * x.element_size() * 2 # 2 reads
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.kl_div(x, t, reduction="mean")
)
ms_th = triton.testing.do_bench(
lambda: F.kl_div(x, t, reduction="mean")
)
_report_bw("kl_div", str(shape), ms_nt, ms_th, nbytes)


def bench_count_nonzero():
print("\n[count_nonzero]")
cases = [((4096, 4096), None), ((8192, 8192), None), ((8192, 8192), 1)]
for shape, dim in cases:
x = torch.randint(0, 2, shape, device=DEVICE).to(DTYPE)
nbytes = x.numel() * x.element_size() # 1 read
ms_nt = triton.testing.do_bench(lambda: ntops.torch.count_nonzero(x, dim))
ms_th = triton.testing.do_bench(lambda: torch.count_nonzero(x, dim))
_report_bw("count_nonzero", f"{shape} dim={dim}", ms_nt, ms_th, nbytes)


def bench_narrow():
print("\n[narrow]")
# (shape, dim, start, length)
cases = [
((8192, 8192), 0, 1024, 4096),
((8192, 8192), 1, 1024, 4096),
((4096, 4096), 0, 0, 2048),
]
for shape, dim, start, length in cases:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
out_numel = (length if dim == 0 else shape[0]) * (
length if dim == 1 else shape[1]
)
nbytes = out_numel * x.element_size() * 2 # read slice + write
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.narrow(x, dim, start, length)
)
# torch.narrow returns a zero-copy view (O(1) metadata, no memory
# traffic), so comparing our materializing copy against it is apples to
# oranges -- the "148437 GB/s" it reports is bytes / ~0 time, not real
# bandwidth. Add .contiguous() so torch also materializes the slice: the
# fair, same-work comparison (matches benchmark_narrow in test_narrow.py).
ms_th = triton.testing.do_bench(
lambda: torch.narrow(x, dim, start, length).contiguous()
)
_report_bw("narrow", f"{shape} d={dim} l={length}", ms_nt, ms_th, nbytes)


def bench_corrcoef():
print("\n[corrcoef]")
cases = [(64, 4096), (128, 8192), (256, 16384)]
for m, n in cases:
x = torch.randn(m, n, dtype=DTYPE, device=DEVICE)
ms_nt = triton.testing.do_bench(lambda: ntops.torch.corrcoef(x))
ms_th = triton.testing.do_bench(lambda: torch.corrcoef(x))
_report_ms("corrcoef", f"({m}, {n})", ms_nt, ms_th)


def bench_combinations():
print("\n[combinations]")
for n, r in [(64, 2), (128, 2), (256, 2)]:
x = torch.randn(n, dtype=DTYPE, device=DEVICE)
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.combinations(x, r=r)
)
ms_th = triton.testing.do_bench(lambda: torch.combinations(x, r=r))
_report_ms("combinations", f"n={n} r={r}", ms_nt, ms_th)


def main():
print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}")
bench_kl_div()
bench_count_nonzero()
bench_narrow()
bench_corrcoef()
bench_combinations()


if __name__ == "__main__":
main()
183 changes: 183 additions & 0 deletions bench/tune_count_nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Tune the pinned launch configs for ``ntops.torch.count_nonzero`` on the
current GPU.

Two kernels are tuned independently:
* the global path (``dim=None``) -- flattens, one partial per block;
* the dim path (``dim`` given) -- reshapes to ``(M, N)``, one partial per
``(row, block)``.

Both are memory-bound partial-sum reductions reading the input once.
Performance evaluation runs with auto-tuning disabled (``max_num_configs=1``),
so the values baked into ``ntops/torch/count_nonzero.py`` decide the score; the
block size also sizes the partials buffer host-side. This sweeps
``block_size x num_warps x num_stages`` under those conditions and prints, per
shape, the fastest config plus the speedup over ``torch.count_nonzero``.

Usage
-----
python bench/tune_count_nonzero.py
"""

import itertools
import math

import torch

import ntops
from ntops.torch.utils import _cached_make

_NUMELS = [1024 * 1024, 4096 * 4096, 8192 * 8192]

_BLOCK_SIZES = [512, 1024, 2048, 4096, 8192]
_NUM_WARPS = [4, 8, 16]
_NUM_STAGES = [1, 2]

_DTYPES = [torch.float32, torch.float16]


def _time(fn, n_warmup=10, n_repeat=50):
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_repeat):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / n_repeat


def _global_runner(flat, block_size, num_warps, num_stages):
numel = flat.numel()
num_partials = max(1, math.ceil(numel / block_size))
partials = torch.empty(num_partials, dtype=torch.int64, device=flat.device)

kernel = _cached_make(
ntops.kernels.count_nonzero.global_premake,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=1,
)

def run():
kernel(flat, partials)
return partials.sum()

return run


def _dim_runner(x2d, block_size, num_warps, num_stages):
m, n = x2d.shape
num_blocks = max(1, math.ceil(n / block_size))
partials = torch.empty((m, num_blocks), dtype=torch.int64, device=x2d.device)

kernel = _cached_make(
ntops.kernels.count_nonzero.dim_premake,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=1,
)

def run():
kernel(x2d, partials)
return partials.sum(dim=1)

return run


def _sweep(label, make_runner, num_bytes, torch_ms):
results = []
for bs, nw, ns in itertools.product(_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES):
try:
ms = _time(make_runner(bs, nw, ns))
except Exception as exc: # noqa: BLE001
print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}")
continue
results.append((ms, bs, nw, ns))

results.sort()
best_ms, bbs, bnw, bns = results[0]
best_gbps = num_bytes / (best_ms * 1e-3) / 1e9
torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9

print(f"\n [{label}] (torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)")
print(
f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} "
f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s "
f"(speedup vs torch {torch_ms / best_ms:.2f})"
)
for ms, bs, nw, ns in results[:5]:
gbps = num_bytes / (ms * 1e-3) / 1e9
print(
f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} "
f"{ms:.4f} ms / {gbps:.0f} GB/s"
)


def _make_input(numel, dtype):
x = torch.randn(numel, dtype=dtype, device="cuda")
return torch.where(x.abs() < 0.5, torch.zeros_like(x), x)


def _check_correctness(dtype):
x = _make_input(40000, dtype)
got = _global_runner(x, 1024, 4, 1)()
assert got.item() == torch.count_nonzero(x).item(), (got, torch.count_nonzero(x))

x2 = x.reshape(200, 200)
got2 = _dim_runner(x2, 1024, 4, 1)()
assert torch.equal(got2, torch.count_nonzero(x2, dim=1))

# Leading (dim=0) coalesced path, exercised end-to-end through the wrapper.
got3 = ntops.torch.count_nonzero(x2, dim=0)
assert torch.equal(got3, torch.count_nonzero(x2, dim=0))


def tune():
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")

for dtype in _DTYPES:
_check_correctness(dtype)
itemsize = torch.empty(0, dtype=dtype).element_size()

print(f"\n{'='*92}")
print(
f"count_nonzero config sweep | dtype={dtype} | "
f"device={torch.cuda.get_device_name()}"
)
print("=" * 92)

for numel in _NUMELS:
side = int(round(numel**0.5))
x = _make_input(numel, dtype)

print(f"\nnumel={numel} (~{side}^2, {numel * itemsize / 1e6:.1f} MB)")

# Global path: reads the whole input once.
torch_ms = _time(lambda: torch.count_nonzero(x))
_sweep(
"global (dim=None)",
lambda bs, nw, ns: _global_runner(x, bs, nw, ns),
numel * itemsize,
torch_ms,
)

# Dim path: reduce the last dim of a squarish (side, side) view.
x2d = x[: side * side].reshape(side, side)
torch_ms = _time(lambda: torch.count_nonzero(x2d, dim=1))
_sweep(
"dim=1",
lambda bs, nw, ns: _dim_runner(x2d, bs, nw, ns),
side * side * itemsize,
torch_ms,
)


if __name__ == "__main__":
tune()
Loading