diff --git a/bench/bench_t1_1_8.py b/bench/bench_t1_1_8.py new file mode 100644 index 0000000..ef5a944 --- /dev/null +++ b/bench/bench_t1_1_8.py @@ -0,0 +1,121 @@ +"""Benchmark T1-1-8 operators vs torch. + + kl_div / count_nonzero / narrow / corrcoef / combinations + + python bench/bench_t1_1_8.py +""" + +import torch +import torch.nn.functional as F +import triton.testing + +import ntops + +DEVICE = "cuda" +DTYPE = torch.float32 + + +def _report_bw(name, shape_str, ms_nt, ms_th, nbytes): + bw_nt = nbytes / ms_nt * 1e-6 + bw_th = nbytes / ms_th * 1e-6 + print( + f" {name:16s} {shape_str:24s} " + f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | " + f"speedup {ms_th / ms_nt:.2f}x" + ) + + +def _report_ms(name, shape_str, ms_nt, ms_th): + print( + f" {name:16s} {shape_str:24s} " + f"九齿 {ms_nt:8.3f} ms | torch {ms_th:8.3f} ms | " + f"speedup {ms_th / ms_nt:.2f}x" + ) + + +def bench_kl_div(): + print("\n[kl_div]") + for shape in [(4096, 4096), (8192, 8192), (1024, 8192)]: + x = F.log_softmax(torch.randn(shape, dtype=DTYPE, device=DEVICE), dim=-1) + t = F.softmax(torch.randn(shape, dtype=DTYPE, device=DEVICE), dim=-1) + nbytes = x.numel() * x.element_size() * 2 # 2 reads + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.kl_div(x, t, reduction="mean") + ) + ms_th = triton.testing.do_bench( + lambda: F.kl_div(x, t, reduction="mean") + ) + _report_bw("kl_div", str(shape), ms_nt, ms_th, nbytes) + + +def bench_count_nonzero(): + print("\n[count_nonzero]") + cases = [((4096, 4096), None), ((8192, 8192), None), ((8192, 8192), 1)] + for shape, dim in cases: + x = torch.randint(0, 2, shape, device=DEVICE).to(DTYPE) + nbytes = x.numel() * x.element_size() # 1 read + ms_nt = triton.testing.do_bench(lambda: ntops.torch.count_nonzero(x, dim)) + ms_th = triton.testing.do_bench(lambda: torch.count_nonzero(x, dim)) + _report_bw("count_nonzero", f"{shape} dim={dim}", ms_nt, ms_th, nbytes) + + +def bench_narrow(): + print("\n[narrow]") + # (shape, dim, start, length) + cases = [ + ((8192, 8192), 0, 1024, 4096), + ((8192, 8192), 1, 1024, 4096), + ((4096, 4096), 0, 0, 2048), + ] + for shape, dim, start, length in cases: + x = torch.randn(shape, dtype=DTYPE, device=DEVICE) + out_numel = (length if dim == 0 else shape[0]) * ( + length if dim == 1 else shape[1] + ) + nbytes = out_numel * x.element_size() * 2 # read slice + write + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.narrow(x, dim, start, length) + ) + # torch.narrow returns a zero-copy view (O(1) metadata, no memory + # traffic), so comparing our materializing copy against it is apples to + # oranges -- the "148437 GB/s" it reports is bytes / ~0 time, not real + # bandwidth. Add .contiguous() so torch also materializes the slice: the + # fair, same-work comparison (matches benchmark_narrow in test_narrow.py). + ms_th = triton.testing.do_bench( + lambda: torch.narrow(x, dim, start, length).contiguous() + ) + _report_bw("narrow", f"{shape} d={dim} l={length}", ms_nt, ms_th, nbytes) + + +def bench_corrcoef(): + print("\n[corrcoef]") + cases = [(64, 4096), (128, 8192), (256, 16384)] + for m, n in cases: + x = torch.randn(m, n, dtype=DTYPE, device=DEVICE) + ms_nt = triton.testing.do_bench(lambda: ntops.torch.corrcoef(x)) + ms_th = triton.testing.do_bench(lambda: torch.corrcoef(x)) + _report_ms("corrcoef", f"({m}, {n})", ms_nt, ms_th) + + +def bench_combinations(): + print("\n[combinations]") + for n, r in [(64, 2), (128, 2), (256, 2)]: + x = torch.randn(n, dtype=DTYPE, device=DEVICE) + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.combinations(x, r=r) + ) + ms_th = triton.testing.do_bench(lambda: torch.combinations(x, r=r)) + _report_ms("combinations", f"n={n} r={r}", ms_nt, ms_th) + + +def main(): + print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}") + bench_kl_div() + bench_count_nonzero() + bench_narrow() + bench_corrcoef() + bench_combinations() + + +if __name__ == "__main__": + main() diff --git a/bench/tune_count_nonzero.py b/bench/tune_count_nonzero.py new file mode 100644 index 0000000..2bea2eb --- /dev/null +++ b/bench/tune_count_nonzero.py @@ -0,0 +1,183 @@ +"""Tune the pinned launch configs for ``ntops.torch.count_nonzero`` on the +current GPU. + +Two kernels are tuned independently: + * the global path (``dim=None``) -- flattens, one partial per block; + * the dim path (``dim`` given) -- reshapes to ``(M, N)``, one partial per + ``(row, block)``. + +Both are memory-bound partial-sum reductions reading the input once. +Performance evaluation runs with auto-tuning disabled (``max_num_configs=1``), +so the values baked into ``ntops/torch/count_nonzero.py`` decide the score; the +block size also sizes the partials buffer host-side. This sweeps +``block_size x num_warps x num_stages`` under those conditions and prints, per +shape, the fastest config plus the speedup over ``torch.count_nonzero``. + +Usage +----- + python bench/tune_count_nonzero.py +""" + +import itertools +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + +_NUMELS = [1024 * 1024, 4096 * 4096, 8192 * 8192] + +_BLOCK_SIZES = [512, 1024, 2048, 4096, 8192] +_NUM_WARPS = [4, 8, 16] +_NUM_STAGES = [1, 2] + +_DTYPES = [torch.float32, torch.float16] + + +def _time(fn, n_warmup=10, n_repeat=50): + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_repeat): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / n_repeat + + +def _global_runner(flat, block_size, num_warps, num_stages): + numel = flat.numel() + num_partials = max(1, math.ceil(numel / block_size)) + partials = torch.empty(num_partials, dtype=torch.int64, device=flat.device) + + kernel = _cached_make( + ntops.kernels.count_nonzero.global_premake, + block_size=block_size, + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=1, + ) + + def run(): + kernel(flat, partials) + return partials.sum() + + return run + + +def _dim_runner(x2d, block_size, num_warps, num_stages): + m, n = x2d.shape + num_blocks = max(1, math.ceil(n / block_size)) + partials = torch.empty((m, num_blocks), dtype=torch.int64, device=x2d.device) + + kernel = _cached_make( + ntops.kernels.count_nonzero.dim_premake, + block_size=block_size, + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=1, + ) + + def run(): + kernel(x2d, partials) + return partials.sum(dim=1) + + return run + + +def _sweep(label, make_runner, num_bytes, torch_ms): + results = [] + for bs, nw, ns in itertools.product(_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES): + try: + ms = _time(make_runner(bs, nw, ns)) + except Exception as exc: # noqa: BLE001 + print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}") + continue + results.append((ms, bs, nw, ns)) + + results.sort() + best_ms, bbs, bnw, bns = results[0] + best_gbps = num_bytes / (best_ms * 1e-3) / 1e9 + torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9 + + print(f"\n [{label}] (torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)") + print( + f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} " + f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s " + f"(speedup vs torch {torch_ms / best_ms:.2f})" + ) + for ms, bs, nw, ns in results[:5]: + gbps = num_bytes / (ms * 1e-3) / 1e9 + print( + f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} " + f"{ms:.4f} ms / {gbps:.0f} GB/s" + ) + + +def _make_input(numel, dtype): + x = torch.randn(numel, dtype=dtype, device="cuda") + return torch.where(x.abs() < 0.5, torch.zeros_like(x), x) + + +def _check_correctness(dtype): + x = _make_input(40000, dtype) + got = _global_runner(x, 1024, 4, 1)() + assert got.item() == torch.count_nonzero(x).item(), (got, torch.count_nonzero(x)) + + x2 = x.reshape(200, 200) + got2 = _dim_runner(x2, 1024, 4, 1)() + assert torch.equal(got2, torch.count_nonzero(x2, dim=1)) + + # Leading (dim=0) coalesced path, exercised end-to-end through the wrapper. + got3 = ntops.torch.count_nonzero(x2, dim=0) + assert torch.equal(got3, torch.count_nonzero(x2, dim=0)) + + +def tune(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + for dtype in _DTYPES: + _check_correctness(dtype) + itemsize = torch.empty(0, dtype=dtype).element_size() + + print(f"\n{'='*92}") + print( + f"count_nonzero config sweep | dtype={dtype} | " + f"device={torch.cuda.get_device_name()}" + ) + print("=" * 92) + + for numel in _NUMELS: + side = int(round(numel**0.5)) + x = _make_input(numel, dtype) + + print(f"\nnumel={numel} (~{side}^2, {numel * itemsize / 1e6:.1f} MB)") + + # Global path: reads the whole input once. + torch_ms = _time(lambda: torch.count_nonzero(x)) + _sweep( + "global (dim=None)", + lambda bs, nw, ns: _global_runner(x, bs, nw, ns), + numel * itemsize, + torch_ms, + ) + + # Dim path: reduce the last dim of a squarish (side, side) view. + x2d = x[: side * side].reshape(side, side) + torch_ms = _time(lambda: torch.count_nonzero(x2d, dim=1)) + _sweep( + "dim=1", + lambda bs, nw, ns: _dim_runner(x2d, bs, nw, ns), + side * side * itemsize, + torch_ms, + ) + + +if __name__ == "__main__": + tune() diff --git a/bench/tune_kl_div.py b/bench/tune_kl_div.py new file mode 100644 index 0000000..4d18807 --- /dev/null +++ b/bench/tune_kl_div.py @@ -0,0 +1,173 @@ +"""Tune the pinned launch configs for ``ntops.torch.kl_div`` on the current GPU. + +Two kernels are tuned independently: + * the reduction path (``reduction="sum"|"mean"|"batchmean"``) -- the defining, + perf critical kernel; partials buffer size depends on ``block_size``; + * the element-wise path (``reduction="none"``). + +Performance evaluation runs with auto-tuning disabled (``max_num_configs=1``), +so the values baked into ``ntops/torch/kl_div.py`` decide the score. This script +sweeps a small grid under those exact conditions and prints, per shape, the +fastest config plus the speedup over ``torch.nn.functional.kl_div``. The default +``log_target=False`` path (extra ``log`` + ``where``) is the heavier one and is +what gets tuned here. + +Usage +----- + python bench/tune_kl_div.py +""" + +import itertools +import math + +import torch +import torch.nn.functional as F + +import ntops +from ntops.torch.utils import _cached_make + +# Numbers of elements to tune over (bandwidth-bound regime). Small sizes are +# launch-overhead bound and not informative for config selection. +_NUMELS = [1024 * 1024, 4096 * 4096, 8192 * 8192] + +_BLOCK_SIZES = [512, 1024, 2048, 4096, 8192] +_NUM_WARPS = [4, 8, 16] +_NUM_STAGES = [1, 2] + +_DTYPES = [torch.float32, torch.float16] + + +def _time(fn, n_warmup=10, n_repeat=50): + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_repeat): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / n_repeat + + +def _reduce_runner(flat_in, flat_tg, block_size, num_warps, num_stages): + numel = flat_in.numel() + num_partials = max(1, math.ceil(numel / block_size)) + partials = torch.empty(num_partials, dtype=torch.float32, device=flat_in.device) + + kernel = _cached_make( + ntops.kernels.kl_div.reduce_premake, + log_target=False, + block_size=block_size, + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=1, + ) + + def run(): + kernel(flat_in, flat_tg, partials) + return partials.sum() + + return run + + +def _none_runner(input, target, output, block_size, num_warps, num_stages): + kernel = _cached_make( + ntops.kernels.kl_div.premake, + input.ndim, + log_target=False, + block_size=block_size, + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=1, + ) + return lambda: kernel(input, target, output) + + +def _sweep(label, make_runner, num_bytes, torch_ms): + results = [] + for bs, nw, ns in itertools.product(_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES): + try: + ms = _time(make_runner(bs, nw, ns)) + except Exception as exc: # noqa: BLE001 + print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}") + continue + results.append((ms, bs, nw, ns)) + + results.sort() + best_ms, bbs, bnw, bns = results[0] + best_gbps = num_bytes / (best_ms * 1e-3) / 1e9 + torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9 + + print(f"\n [{label}] (torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)") + print( + f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} " + f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s " + f"(speedup vs torch {torch_ms / best_ms:.2f})" + ) + for ms, bs, nw, ns in results[:5]: + gbps = num_bytes / (ms * 1e-3) / 1e9 + print( + f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} " + f"{ms:.4f} ms / {gbps:.0f} GB/s" + ) + + +def _check_reduce_correctness(dtype): + """Sanity check that the reduction kernel matches F.kl_div before trusting + any timing numbers.""" + x = torch.randn(40000, dtype=dtype, device="cuda").log_softmax(dim=-1) + y = torch.randn(40000, dtype=dtype, device="cuda").softmax(dim=-1) + run = _reduce_runner(x, y, 1024, 4, 1) + got = (run() / x.numel()).to(dtype) + ref = F.kl_div(x, y, reduction="mean") + tol = 1e-3 if dtype == torch.float32 else 1e-2 + assert torch.allclose(got, ref, rtol=tol, atol=tol), (got.item(), ref.item()) + + +def tune(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + for dtype in _DTYPES: + _check_reduce_correctness(dtype) + itemsize = torch.empty(0, dtype=dtype).element_size() + + print(f"\n{'='*92}") + print( + f"kl_div config sweep | dtype={dtype} | " + f"device={torch.cuda.get_device_name()}" + ) + print("=" * 92) + + for numel in _NUMELS: + side = int(round(numel**0.5)) + input = torch.randn(numel, dtype=dtype, device="cuda").log_softmax(dim=-1) + target = torch.randn(numel, dtype=dtype, device="cuda").softmax(dim=-1) + output = torch.empty_like(input) + + print(f"\nnumel={numel} (~{side}^2, {numel * itemsize / 1e6:.1f} MB)") + + # Reduction path: reads input + target (2x). + torch_ms = _time(lambda: F.kl_div(input, target, reduction="sum")) + _sweep( + "reduce (sum/mean/batchmean)", + lambda bs, nw, ns: _reduce_runner(input, target, bs, nw, ns), + numel * itemsize * 2, + torch_ms, + ) + + # Element-wise path: reads input + target, writes output (3x). + torch_ms = _time(lambda: F.kl_div(input, target, reduction="none")) + _sweep( + "none (element-wise)", + lambda bs, nw, ns: _none_runner(input, target, output, bs, nw, ns), + numel * itemsize * 3, + torch_ms, + ) + + +if __name__ == "__main__": + tune() diff --git a/bench/tune_narrow.py b/bench/tune_narrow.py new file mode 100644 index 0000000..80d7fc6 --- /dev/null +++ b/bench/tune_narrow.py @@ -0,0 +1,150 @@ +"""Tune the pinned launch config for ``ntops.torch.narrow`` on the current GPU. + +narrow materializes a strided slice with a copy kernel. Performance evaluation +runs with auto-tuning disabled (``max_num_configs=1``), so the value baked into +``ntops/torch/narrow.py`` decides the score. This sweeps +``block_size x num_warps x num_stages`` on two representative slices -- a strided +inner-dim slice (the harder memory pattern) and a contiguous leading-dim slice +-- and prints, per shape, the fastest config plus the speedup over +``torch.narrow(...).contiguous()``. ``_launch_config`` must key on hardware only, +so pick a config that is good across both patterns. + +Usage +----- + python bench/tune_narrow.py +""" + +import itertools + +import torch + +import ntops +from ntops.torch.utils import _cached_make + +_NUMELS = [1024 * 1024, 4096 * 4096, 8192 * 8192] + +_BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192] +_NUM_WARPS = [4, 8, 16] +_NUM_STAGES = [1, 2] + +_DTYPES = [torch.float32, torch.float16] + + +def _time(fn, n_warmup=10, n_repeat=50): + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_repeat): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / n_repeat + + +def _runner(input, dim, length, block_size, num_warps, num_stages): + src = input.narrow(dim, 0, length) + output = torch.empty(src.shape, dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.narrow.premake, + src.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=1, + ) + + return lambda: kernel(src, output) + + +def _sweep(label, make_runner, num_bytes, torch_ms): + results = [] + for bs, nw, ns in itertools.product(_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES): + try: + ms = _time(make_runner(bs, nw, ns)) + except Exception as exc: # noqa: BLE001 + print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}") + continue + results.append((ms, bs, nw, ns)) + + results.sort() + best_ms, bbs, bnw, bns = results[0] + best_gbps = num_bytes / (best_ms * 1e-3) / 1e9 + torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9 + + print(f"\n [{label}] (torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)") + print( + f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} " + f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s " + f"(speedup vs torch {torch_ms / best_ms:.2f})" + ) + for ms, bs, nw, ns in results[:5]: + gbps = num_bytes / (ms * 1e-3) / 1e9 + print( + f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} " + f"{ms:.4f} ms / {gbps:.0f} GB/s" + ) + + +def _check_correctness(dtype): + x = torch.randn(64, 1000, dtype=dtype, device="cuda") + src = x.narrow(1, 0, 500) + out = torch.empty(src.shape, dtype=dtype, device="cuda") + kernel = _cached_make( + ntops.kernels.narrow.premake, + src.ndim, + block_size=1024, + num_warps=4, + num_stages=1, + max_num_configs=1, + ) + kernel(src, out) + assert torch.equal(out, src.contiguous()) + + +def tune(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + for dtype in _DTYPES: + _check_correctness(dtype) + itemsize = torch.empty(0, dtype=dtype).element_size() + + print(f"\n{'='*92}") + print( + f"narrow config sweep | dtype={dtype} | " + f"device={torch.cuda.get_device_name()}" + ) + print("=" * 92) + + for numel in _NUMELS: + side = int(round(numel**0.5)) + full = torch.randn(side, side, dtype=dtype, device="cuda") + length = side // 2 + out_bytes = side * length * itemsize * 2 # read slice + write output + + print(f"\nnumel~{numel} ({side}x{side}, {numel * itemsize / 1e6:.1f} MB)") + + torch_ms = _time(lambda: full.narrow(1, 0, length).contiguous()) + _sweep( + "dim=1 (strided)", + lambda bs, nw, ns: _runner(full, 1, length, bs, nw, ns), + out_bytes, + torch_ms, + ) + + torch_ms = _time(lambda: full.narrow(0, 0, length).contiguous()) + _sweep( + "dim=0 (contiguous)", + lambda bs, nw, ns: _runner(full, 0, length, bs, nw, ns), + out_bytes, + torch_ms, + ) + + +if __name__ == "__main__": + tune() diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..4c1baa8 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -8,8 +8,10 @@ bitwise_or, bmm, clamp, + combinations, conv2d, cos, + count_nonzero, div, dropout, eq, @@ -19,12 +21,14 @@ gt, isinf, isnan, + kl_div, layer_norm, le, lt, max_pool2d, mm, mul, + narrow, ne, neg, pow, @@ -51,8 +55,10 @@ "bitwise_or", "bmm", "clamp", + "combinations", "conv2d", "cos", + "count_nonzero", "div", "dropout", "eq", @@ -62,12 +68,14 @@ "gt", "isinf", "isnan", + "kl_div", "layer_norm", "le", "lt", "max_pool2d", "mm", "mul", + "narrow", "ne", "neg", "pow", diff --git a/src/ntops/kernels/combinations.py b/src/ntops/kernels/combinations.py new file mode 100644 index 0000000..1e0aad3 --- /dev/null +++ b/src/ntops/kernels/combinations.py @@ -0,0 +1,45 @@ +"""combinations gather kernel: output[m] = input[index[m]]. + +combinations is a combinatorial gather. The wrapper enumerates the flat gather +indices on the host (cheaply, e.g. triu_indices for r=2 — avoiding torch's +n**r meshgrid), and this kernel does the data-dependent gather via +``ntl.load(input.data_ptr() + index)`` (the read-side of scatter_add's atomic +trick). Affine addressing throughout; the index value comparison is allowed. +""" + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(index, output, input, m, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + index_arranged = index.flatten().tile((block_size,)) + output_arranged = output.flatten().tile((block_size,)) + + # index first → grid; input un-arranged source for data_ptr(). + return index_arranged, output_arranged, input, m + + +def application(index, output, input, m): + p = index.offsets() + mask = p < m + + output = ntl.load(input.data_ptr() + index, mask=mask, other=0) # noqa: F841 + + +def premake(block_size=None): + tensors = ( + Tensor(1, dtype=int), # index (M,) flat gather indices + Tensor(1), # output (M,) + Tensor(1, shape_options={"constexpr": True}), # input flat (source) + Tensor(0, dtype=int, constexpr=True), # m + ) + + arrangement_ = functools.partial(arrangement, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/count_nonzero.py b/src/ntops/kernels/count_nonzero.py new file mode 100644 index 0000000..4682968 --- /dev/null +++ b/src/ntops/kernels/count_nonzero.py @@ -0,0 +1,107 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +# --------------------------------------------------------------------------- +# count_nonzero = sum(input != 0), reduced either globally (dim=None) or over a +# set of dims. Both paths are single-pass partial-sum kernels: each program +# counts the nonzeros in one block and writes an int64 partial; the host sums +# the partials. int64 (not float32) partials keep the count exact for large +# inputs. ``other=0`` pads the trailing block -- 0 is counted as zero, so +# padding never inflates the count. +# --------------------------------------------------------------------------- + + +def _global_arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.flatten().tile((block_size,)) + output_arranged = output.flatten().tile((1,)) + + return input_arranged, output_arranged + + +def global_application(input, output): + output = ntl.sum(ntl.where(input != 0, 1, 0)) # noqa: F841 + + +def global_premake(input_dtype=None, block_size=None): + arrangement_ = functools.partial(_global_arrangement, block_size=block_size) + + tensors = ( + Tensor(1, other=0, dtype=input_dtype), + Tensor(1, dtype=ninetoothed.int64), + ) + + return arrangement_, global_application, tensors + + +# Dim path: the host reshapes to (M, N) with the reduced dims trailing. Each +# (1, block_size) tile becomes one program writing a partial into the +# (M, num_blocks) buffer; the host then sums along the blocks per row. +def _dim_arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.tile((1, block_size)) + output_arranged = output.tile((1, 1)) + + return input_arranged, output_arranged + + +def dim_application(input, output): + output = ntl.sum(ntl.where(input != 0, 1, 0)) # noqa: F841 + + +def dim_premake(input_dtype=None, block_size=None): + arrangement_ = functools.partial(_dim_arrangement, block_size=block_size) + + tensors = ( + Tensor(2, other=0, dtype=input_dtype), + Tensor(2, dtype=ninetoothed.int64), + ) + + return arrangement_, dim_application, tensors + + +# Leading path: reduce a contiguous block of *leading* dims, viewed host-side as +# ``(R, inner)`` with ``inner`` the contiguous trailing dims. Reducing axis 0 +# directly (instead of permuting it to the back, which would materialize a +# transpose) is done with a ``(reduce_block, block_size)`` tile: the ``block_size`` +# columns are read coalesced while ``ntl.sum(..., axis=0)`` reduces the rows. One +# partial per ``(row-block, column)`` is written; the host sums the row-blocks. +# This mirrors avg_pool2d's ``ntl.sum(axis=-1)`` + output squeeze, transposed to +# axis 0. ``other=0`` pads both ragged edges and is counted as zero. +def _leading_arrangement(input, output, block_size=None, reduce_block=None): + if block_size is None: + block_size = ninetoothed.block_size() + + if reduce_block is None: + reduce_block = 32 + + input_arranged = input.tile((reduce_block, block_size)) + + output_arranged = output.tile((1, block_size)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + return input_arranged, output_arranged + + +def leading_application(input, output): + output = ntl.sum(ntl.where(input != 0, 1, 0), axis=0) # noqa: F841 + + +def leading_premake(input_dtype=None, block_size=None, reduce_block=None): + arrangement_ = functools.partial( + _leading_arrangement, block_size=block_size, reduce_block=reduce_block + ) + + tensors = ( + Tensor(2, other=0, dtype=input_dtype), + Tensor(2, dtype=ninetoothed.int64), + ) + + return arrangement_, leading_application, tensors diff --git a/src/ntops/kernels/kl_div.py b/src/ntops/kernels/kl_div.py new file mode 100644 index 0000000..10e5fe9 --- /dev/null +++ b/src/ntops/kernels/kl_div.py @@ -0,0 +1,97 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement as _element_wise_arrangement + + +# --------------------------------------------------------------------------- +# Element-wise path (reduction="none"): output keeps the input shape. +# +# Pointwise KL divergence (matching torch.nn.functional.kl_div): +# * log_target=False: target * (log(target) - input), with the +# 0 * log(0) = 0 convention -> contributions where target <= 0 are zeroed +# via ``where`` (this also masks the log(0) = -inf produced by padding). +# * log_target=True: exp(target) * (target - input). +# The math runs in float32 for accuracy (log/exp lose precision in fp16); the +# store casts back to the output dtype, mirroring the silu/softmax kernels. +# --------------------------------------------------------------------------- + + +def application(input, target, output): + input_fp32 = ntl.cast(input, ntl.float32) + target_fp32 = ntl.cast(target, ntl.float32) + pointwise = target_fp32 * (ntl.log(target_fp32) - input_fp32) + output = ntl.where(target_fp32 > 0, pointwise, 0) # noqa: F841 + + +def log_target_application(input, target, output): + input_fp32 = ntl.cast(input, ntl.float32) + target_fp32 = ntl.cast(target, ntl.float32) + output = ntl.exp(target_fp32) * (target_fp32 - input_fp32) # noqa: F841 + + +def premake(ndim, log_target=False, dtype=None, block_size=None): + arrangement_ = functools.partial(_element_wise_arrangement, block_size=block_size) + + application_ = log_target_application if log_target else application + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application_, tensors + + +# --------------------------------------------------------------------------- +# Reduction path (reduction="sum"/"mean"/"batchmean"): a single-pass partial-sum +# kernel emits one float32 partial per block; the host sums the partials and +# applies the reduction scaling. ~2N memory traffic, no full intermediate +# tensor. ``other=0`` pads the trailing block: target=0 yields a zeroed +# contribution under both formulas, so padding never perturbs the sum. +# --------------------------------------------------------------------------- + + +def _reduce_arrangement(input, target, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.flatten().tile((block_size,)) + target_arranged = target.flatten().tile((block_size,)) + output_arranged = output.flatten().tile((1,)) + + return input_arranged, target_arranged, output_arranged + + +def reduce_application(input, target, output): + input_fp32 = ntl.cast(input, ntl.float32) + target_fp32 = ntl.cast(target, ntl.float32) + pointwise = target_fp32 * (ntl.log(target_fp32) - input_fp32) + masked = ntl.where(target_fp32 > 0, pointwise, 0) + output = ntl.sum(masked) # noqa: F841 + + +def reduce_log_target_application(input, target, output): + input_fp32 = ntl.cast(input, ntl.float32) + target_fp32 = ntl.cast(target, ntl.float32) + output = ntl.sum(ntl.exp(target_fp32) * (target_fp32 - input_fp32)) # noqa: F841 + + +def reduce_premake(input_dtype=None, log_target=False, block_size=None): + arrangement_ = functools.partial(_reduce_arrangement, block_size=block_size) + + application_ = ( + reduce_log_target_application if log_target else reduce_application + ) + + tensors = ( + Tensor(1, other=0, dtype=input_dtype), + Tensor(1, other=0, dtype=input_dtype), + Tensor(1, dtype=ninetoothed.float32), + ) + + return arrangement_, application_, tensors diff --git a/src/ntops/kernels/narrow.py b/src/ntops/kernels/narrow.py new file mode 100644 index 0000000..24d842b --- /dev/null +++ b/src/ntops/kernels/narrow.py @@ -0,0 +1,22 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +# narrow returns the slice ``input[..., start:start+length, ...]``. The torch +# wrapper builds that as a strided view (a regular ``offset + stride`` pattern, +# unlike a data-dependent gather), and this trivial copy kernel materializes it +# into a contiguous output -- the same "strided view + copy" approach as +# pixel_unshuffle. +def application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..a3af707 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -7,8 +7,11 @@ from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp +from ntops.torch.combinations import combinations from ntops.torch.conv2d import conv2d +from ntops.torch.corrcoef import corrcoef from ntops.torch.cos import cos +from ntops.torch.count_nonzero import count_nonzero from ntops.torch.div import div from ntops.torch.dropout import dropout from ntops.torch.eq import eq @@ -18,6 +21,7 @@ from ntops.torch.gt import gt from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan +from ntops.torch.kl_div import kl_div from ntops.torch.layer_norm import layer_norm from ntops.torch.le import le from ntops.torch.lt import lt @@ -25,6 +29,7 @@ from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul +from ntops.torch.narrow import narrow from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow @@ -50,8 +55,11 @@ "bitwise_or", "bmm", "clamp", + "combinations", "conv2d", + "corrcoef", "cos", + "count_nonzero", "div", "dropout", "eq", @@ -61,6 +69,7 @@ "gt", "isinf", "isnan", + "kl_div", "layer_norm", "le", "lt", @@ -68,6 +77,7 @@ "max_pool2d", "mm", "mul", + "narrow", "ne", "neg", "pow", diff --git a/src/ntops/torch/combinations.py b/src/ntops/torch/combinations.py new file mode 100644 index 0000000..851737e --- /dev/null +++ b/src/ntops/torch/combinations.py @@ -0,0 +1,72 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# combinations = combinatorial gather, output[k, j] = input[idx[k, j]]. +# +# r == 2 (the default/common case): the index tuples are torch.triu_indices +# (O(num_comb), avoiding torch's n*n value meshgrid), and a ninetoothed kernel +# does the data-dependent gather → multi-x speedup over torch.combinations. +# +# r >= 3: torch's own algorithm (value meshgrid + triu mask). It is rarely used +# and its output C(n, r) is already huge for moderate n; a separate index pass +# for a ninetoothed gather would only add work, so we match torch here. +_CONFIGS = { + "nvidia": (4, 256), + "iluvatar": (8, 512), + "metax": (4, 256), + "default": (4, 256), +} + + +def combinations(input, r=2, with_replacement=False): + if input.dim() != 1: + raise RuntimeError(f"Expect a 1D vector, but got shape {list(input.shape)}") + + if r < 0: + raise RuntimeError(f"Expect a non-negative number, but got {r}") + + if r == 0: + return torch.empty(0, dtype=input.dtype, device=input.device) + + n = input.size(0) + + if r == 1: + return input.reshape(n, 1).clone() + + if r == 2: + offset = 0 if with_replacement else 1 + ij = torch.triu_indices(n, n, offset=offset, device=input.device) # (2, num_comb) + num_comb = ij.shape[1] + index = ij.t().contiguous().reshape(-1).to(torch.int64) + + m = index.numel() + output = torch.empty((m,), dtype=input.dtype, device=input.device) + + num_warps, block_size = _CONFIGS[_device_key()] + kernel = _cached_make( + ntops.kernels.combinations.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + kernel(index, output, input, m) + + return output.reshape(num_comb, 2) + + # r >= 3: torch's value-meshgrid algorithm (matches torch.combinations). The + # triu mask compares INDICES (arange grids), not values — input values are + # unsorted, so value comparison would pick the wrong combinations/order. + rng = torch.arange(n, device=input.device) + index_grids = torch.meshgrid(*([rng] * r), indexing="ij") + mask = torch.ones(index_grids[0].shape, dtype=torch.bool, device=input.device) + for i in range(r - 1): + if with_replacement: + mask = mask & (index_grids[i] <= index_grids[i + 1]) + else: + mask = mask & (index_grids[i] < index_grids[i + 1]) + + value_grids = torch.meshgrid(*([input] * r), indexing="ij") + return torch.stack([g.masked_select(mask) for g in value_grids], dim=1) diff --git a/src/ntops/torch/corrcoef.py b/src/ntops/torch/corrcoef.py new file mode 100644 index 0000000..2943538 --- /dev/null +++ b/src/ntops/torch/corrcoef.py @@ -0,0 +1,48 @@ +import torch + +import ntops + +# --------------------------------------------------------------------------- +# corrcoef normalizes the covariance matrix. The dominant cost is the +# covariance matmul ``C @ C.T`` (O(D^2 N)), which is delegated to the +# ninetoothed ``mm`` kernel -- the same delegation style as ``matmul -> mm``. +# The cheap O(D N) / O(D^2) glue (row mean, centering, std normalization, clamp) +# stays in torch. Output matches torch.corrcoef: a (D, D) matrix for a 2-D +# input, a scalar 1.0 for a 1-D input; integer/bool inputs promote to float. +# --------------------------------------------------------------------------- + + +def corrcoef(input): + if input.dim() > 2: + raise RuntimeError( + f"corrcoef(): expected input to have two or fewer dimensions but got " + f"{input.dim()}" + ) + + was_1d = input.dim() < 2 + if was_1d: + input = input.reshape(1, -1) + + # Integer / bool inputs promote to the default floating dtype (matching torch). + if not torch.is_floating_point(input) and not torch.is_complex(input): + input = input.to(torch.get_default_dtype()) + + _, n = input.shape + + # Center each variable (row) across its observations. + mean = input.mean(dim=1, keepdim=True) + centered = (input - mean).contiguous() + + # Covariance via the ninetoothed matmul (the dominant O(D^2 N) work). + cov = ntops.torch.mm(centered, centered.t().contiguous()) / (n - 1) + + # Normalize by the outer product of the standard deviations, then clamp. + stddev = cov.diagonal().sqrt() + cov = cov / stddev.unsqueeze(1) + cov = cov / stddev.unsqueeze(0) + cov = cov.clamp(-1, 1) + + if was_1d: + return cov.squeeze() + + return cov diff --git a/src/ntops/torch/count_nonzero.py b/src/ntops/torch/count_nonzero.py new file mode 100644 index 0000000..c21e11d --- /dev/null +++ b/src/ntops/torch/count_nonzero.py @@ -0,0 +1,135 @@ +import functools +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +@functools.lru_cache(maxsize=None) +def _launch_config(): + """Pick ``(num_warps, block_size)`` for the count reduction on this GPU. + + Performance evaluation disables auto-tuning (``max_num_configs=1``), so + explicit values are required; the block size also sets the partials buffer + length and so must be known host-side. Tuned with + ``bench/tune_count_nonzero.py``: the global (dim=None) path strongly prefers + the largest 8192 block, which the dim path also accepts, so both share it. + MetaX wants 4 warps, Iluvatar 8, NVIDIA 16. ``num_stages`` is a no-op here + and stays 1. Keys on the hardware name only, never on input shapes. + """ + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return 4, 8192 + if "iluvatar" in name: + return 8, 8192 + return 16, 8192 + + +def _normalize_dims(dim, ndim): + dims = (dim,) if isinstance(dim, int) else tuple(dim) + dims = tuple(d if d >= 0 else d + ndim for d in dims) + + for d in dims: + if not (0 <= d < ndim): + raise IndexError( + f"Dimension out of range (expected to be in range of " + f"[{-ndim}, {ndim - 1}])" + ) + + if len(set(dims)) != len(dims): + raise RuntimeError(f"dim {dim} appears multiple times") + + return dims + + +def count_nonzero(input, dim=None): + num_warps, block_size = _launch_config() + + if dim is None: + flat = input.reshape(-1) + numel = flat.numel() + + if numel == 0: + return torch.zeros((), dtype=torch.int64, device=input.device) + + num_partials = max(1, math.ceil(numel / block_size)) + partials = torch.empty(num_partials, dtype=torch.int64, device=input.device) + + kernel = _cached_make( + ntops.kernels.count_nonzero.global_premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + kernel(flat, partials) + + return partials.sum() + + dims = _normalize_dims(dim, input.dim()) + + kept = tuple(i for i in range(input.dim()) if i not in dims) + kept_shape = tuple(input.shape[i] for i in kept) + + # Fast path: reducing a consecutive block of *leading* dims (with trailing + # dims kept) is a coalesced column reduction on the contiguous ``(R, inner)`` + # view -- no transpose. Trailing-dim reductions already need no real + # transpose (the permute below is identity) and stay on the general path. + sorted_dims = sorted(dims) + is_consecutive = list(sorted_dims) == list( + range(sorted_dims[0], sorted_dims[-1] + 1) + ) + if is_consecutive and sorted_dims[0] == 0 and sorted_dims[-1] < input.dim() - 1: + b = sorted_dims[-1] + r = math.prod(input.shape[: b + 1]) + inner = math.prod(input.shape[b + 1 :]) + + if r == 0 or inner == 0: + return torch.zeros(kept_shape, dtype=torch.int64, device=input.device) + + x2d = input.reshape(r, inner) + + # Small column block keeps the (reduce_block, col_block) tile sized + # sanely while staying coalesced; not as heavily tuned as the global + # block size (which would make the 2-D tile far too large). + col_block = 256 + reduce_block = 32 + num_row_blocks = max(1, math.ceil(r / reduce_block)) + partials = torch.empty( + (num_row_blocks, inner), dtype=torch.int64, device=input.device + ) + + kernel = _cached_make( + ntops.kernels.count_nonzero.leading_premake, + block_size=col_block, + reduce_block=reduce_block, + num_warps=num_warps, + num_stages=1, + ) + kernel(x2d, partials) + + return partials.sum(dim=0).reshape(kept_shape) + + # General path: move reduced dims to the back and collapse to (M, N). + permuted = input.permute(kept + dims).contiguous() + m = math.prod(kept_shape) if kept_shape else 1 + n = permuted.numel() // m if m else 0 + x2d = permuted.reshape(m, n) + + if n == 0: + return torch.zeros(kept_shape, dtype=torch.int64, device=input.device) + + num_blocks = max(1, math.ceil(n / block_size)) + partials = torch.empty((m, num_blocks), dtype=torch.int64, device=input.device) + + kernel = _cached_make( + ntops.kernels.count_nonzero.dim_premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + kernel(x2d, partials) + + return partials.sum(dim=1).reshape(kept_shape) diff --git a/src/ntops/torch/kl_div.py b/src/ntops/torch/kl_div.py new file mode 100644 index 0000000..44b85c3 --- /dev/null +++ b/src/ntops/torch/kl_div.py @@ -0,0 +1,91 @@ +import functools +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +@functools.lru_cache(maxsize=None) +def _launch_config(): + """Pick ``(num_warps, reduce_block_size, none_block_size)`` for this GPU. + + Performance evaluation disables auto-tuning (``max_num_configs=1``), so + explicit values are required; the reduction block size additionally sets the + partial-sums buffer length and so must be known host-side. ``num_stages`` is + a no-op here (one block per program, no inner loop) and stays 1. + + The reduction path reads ``input`` + ``target`` (~2N traffic) just like + ``mse_loss``, so the same tuned configs are used as a starting point; the + intra-block ``ntl.sum`` favors more warps on Iluvatar (16) but fewer on + MetaX (4), while NVIDIA is warp-insensitive and wants the larger 8192 block + (also the MetaX optimum), so NVIDIA and unmeasured devices fall through to + the 8-warp / 8192 default. Refine with ``bench/tune_kl_div.py``. This keys on + the hardware name only, never on input shapes/names. + """ + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return 4, 8192, 1024 + if "iluvatar" in name: + return 16, 4096, 1024 + return 16, 8192, 512 + + +def kl_div(input, target, reduction="mean", log_target=False): + if reduction not in ("none", "mean", "sum", "batchmean"): + raise ValueError(f"unsupported reduction: {reduction!r}") + + # ``batchmean`` divides the summed loss by the original input's batch size + # (torch uses ``input.size(0)``, and skips the division for 0-dim input). + # Capture it before any broadcasting reshapes the leading dimension. + batch_size = input.shape[0] if input.dim() != 0 else 1 + + if input.shape != target.shape: + input, target = torch.broadcast_tensors(input, target) + input = input.contiguous() + target = target.contiguous() + + num_warps, reduce_block_size, none_block_size = _launch_config() + + if reduction == "none": + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.kl_div.premake, + input.ndim, + log_target=log_target, + block_size=none_block_size, + num_warps=num_warps, + num_stages=1, + ) + kernel(input, target, output) + + return output + + flat_input = input.reshape(-1) + flat_target = target.reshape(-1) + + numel = flat_input.numel() + num_partials = max(1, math.ceil(numel / reduce_block_size)) + + partials = torch.empty(num_partials, dtype=torch.float32, device=input.device) + + kernel = _cached_make( + ntops.kernels.kl_div.reduce_premake, + log_target=log_target, + block_size=reduce_block_size, + num_warps=num_warps, + num_stages=1, + ) + kernel(flat_input, flat_target, partials) + + total = partials.sum() + + if reduction == "mean": + total = total / numel + elif reduction == "batchmean": + total = total / batch_size + + return total.to(input.dtype) diff --git a/src/ntops/torch/narrow.py b/src/ntops/torch/narrow.py new file mode 100644 index 0000000..9af9cc8 --- /dev/null +++ b/src/ntops/torch/narrow.py @@ -0,0 +1,79 @@ +import functools + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +@functools.lru_cache(maxsize=None) +def _launch_config(): + """Pick ``(block_size, num_warps)`` for the strided-copy kernel on this GPU. + + Performance evaluation disables auto-tuning (``max_num_configs=1``), so + explicit values are required. Tuned with ``bench/tune_narrow.py`` on the + asymptotic (large, HBM-bound) cases where configs differ: NVIDIA peaks at a + small block with many warps, while the domestic cards prefer larger blocks + with 4 warps. ``num_stages`` is a no-op (one block per program, no loop). + Keys on the hardware name only, never on input shapes; unmeasured devices + (e.g. Moore) fall through to the domestic-style default. + """ + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return 4096, 4 + if "iluvatar" in name: + return 2048, 4 + if "nvidia" in name: + return 512, 16 + return 2048, 4 + + +def narrow(input, dim, start, length): + if input.dim() == 0: + raise RuntimeError("narrow() cannot be applied to a 0-dim tensor.") + + ndim = input.dim() + + if not (-ndim <= dim < ndim): + raise IndexError( + f"Dimension out of range (expected to be in range of " + f"[{-ndim}, {ndim - 1}], but got {dim})" + ) + + dim = dim if dim >= 0 else dim + ndim + + size = input.shape[dim] + + # torch accepts a 0-dim tensor start; normalize to a Python int. + if torch.is_tensor(start): + start = int(start.item()) + + if start < 0: + start += size + + if start < 0 or start + length > size: + raise RuntimeError( + f"start ({start}) + length ({length}) exceeds dimension size ({size})." + ) + + # A strided view (no copy) of the requested slice; the kernel materializes it. + src = input.narrow(dim, start, length) + + output = torch.empty(src.shape, dtype=input.dtype, device=input.device) + + if output.numel() == 0: + return output + + block_size, num_warps = _launch_config() + + kernel = _cached_make( + ntops.kernels.narrow.premake, + src.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + kernel(src, output) + + return output diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..d138dd7 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -68,3 +68,21 @@ def _get_matmul_input_precision(): return ntops.kernels.mm.InputPrecisionVariant.IEEE return ntops.kernels.mm.InputPrecisionVariant.TF32 + + +@functools.cache +def _device_key(): + # Select a launch config by hardware only (not input size / op name), as + # required by the rules. Cached so the device name is queried once. + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return "metax" + + if "iluvatar" in name: + return "iluvatar" + + if "nvidia" in name: + return "nvidia" + + return "default" diff --git a/tests/test_combinations.py b/tests/test_combinations.py new file mode 100644 index 0000000..e621256 --- /dev/null +++ b/tests/test_combinations.py @@ -0,0 +1,200 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +# --------------------------------------------------------------------------- +# Correctness tests (compared against torch.combinations). +# +# combinations is a pure gather/selection, so the output must be bit-identical +# to torch's, including ordering, dtype, device, and shape -- ``torch.equal``. +# --------------------------------------------------------------------------- + +_NS = [1, 2, 3, 5, 8] +_RS = [0, 1, 2, 3, 4] + + +def _make_input(n, dtype, device): + if dtype == torch.int64: + return torch.randint(-1000, 1000, (n,), dtype=dtype, device=device) + return torch.randn(n, dtype=dtype, device=device) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int64]) +@pytest.mark.parametrize("with_replacement", [False, True]) +@pytest.mark.parametrize("r", _RS) +@pytest.mark.parametrize("n", _NS) +def test_combinations(n, r, with_replacement, dtype): + device = "cuda" + input = _make_input(n, dtype, device) + + output = ntops.torch.combinations(input, r=r, with_replacement=with_replacement) + expected = torch.combinations(input, r=r, with_replacement=with_replacement) + + assert output.shape == expected.shape + assert output.dtype == expected.dtype + assert output.device == expected.device + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_combinations_default_args(): + """Default r=2, with_replacement=False.""" + device = "cuda" + input = torch.randn(6, device=device) + + output = ntops.torch.combinations(input) + expected = torch.combinations(input) + + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_combinations_r_greater_than_n(): + """r > n without replacement yields an empty (0, r) result.""" + device = "cuda" + input = torch.randn(3, device=device) + + output = ntops.torch.combinations(input, r=5) + expected = torch.combinations(input, r=5) + + assert output.shape == expected.shape # (0, 5) + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_combinations_empty_input(): + device = "cuda" + input = torch.randn(0, device=device) + + for r in (0, 1, 2): + output = ntops.torch.combinations(input, r=r) + expected = torch.combinations(input, r=r) + assert output.shape == expected.shape + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_combinations_invalid_ndim(): + device = "cuda" + input = torch.randn(3, 4, device=device) + + with pytest.raises(RuntimeError): + ntops.torch.combinations(input) + + +@skip_if_cuda_not_available +def test_combinations_negative_r(): + device = "cuda" + input = torch.randn(5, device=device) + + with pytest.raises(RuntimeError): + ntops.torch.combinations(input, r=-1) + + +# --------------------------------------------------------------------------- +# Performance benchmark interface +# --------------------------------------------------------------------------- + +def benchmark_combinations( + n, + r=2, + with_replacement=False, + dtype=torch.float32, + device="cuda", + n_warmup=10, + n_repeat=100, +): + """Compare ntops.torch.combinations vs torch.combinations. + + Returns timing (ms) for both plus the speedup ratio. combinations has no + ninetoothed kernel (the gather is not expressible in ninetoothed); this + interface exists for parity with the rest of the suite. + + Example + ------- + >>> results = benchmark_combinations(256, r=2) + >>> print(results) + """ + if not torch.cuda.is_available() and device == "cuda": + raise RuntimeError("CUDA not available") + + input = torch.randn(n, dtype=dtype, device=device) + + def run_ntops(): + ntops.torch.combinations(input, r=r, with_replacement=with_replacement) + + def run_torch(): + torch.combinations(input, r=r, with_replacement=with_replacement) + + for _ in range(n_warmup): + run_ntops() + run_torch() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_repeat): + run_ntops() + end.record() + torch.cuda.synchronize() + ntops_ms = start.elapsed_time(end) / n_repeat + + start.record() + for _ in range(n_repeat): + run_torch() + end.record() + torch.cuda.synchronize() + torch_ms = start.elapsed_time(end) / n_repeat + + return { + "n": n, + "r": r, + "with_replacement": with_replacement, + "dtype": str(dtype), + "ntops_time_ms": ntops_ms, + "torch_time_ms": torch_ms, + "speedup": torch_ms / ntops_ms, + } + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("r", [2, 3]) +def test_benchmark_sweep(r): + """Sweep input sizes. Run with: + pytest tests/test_combinations.py::test_benchmark_sweep -v -s + """ + header = ( + f"{'n':>8} {'r':>4} {'num_comb':>12} " + f"{'ntops(ms)':>11} {'torch(ms)':>11} {'speedup':>9}" + ) + print(f"\n{'='*len(header)}") + print(f"combinations sweep | r={r}") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + ns = [64, 128, 256] if r == 2 else [16, 32, 48] + for n in ns: + res = benchmark_combinations(n, r=r) + num_comb = torch.combinations( + torch.arange(n), r=r + ).shape[0] + print( + f"{n:>8} {r:>4} {num_comb:>12} " + f"{res['ntops_time_ms']:>11.4f} {res['torch_time_ms']:>11.4f} " + f"{res['speedup']:>9.2f}" + ) + + print("=" * len(header)) + + +@skip_if_cuda_not_available +def test_benchmark_interface(): + """Smoke-test that the benchmark interface runs without error.""" + results = benchmark_combinations(64, r=2, n_warmup=2, n_repeat=5) + assert results["ntops_time_ms"] > 0 diff --git a/tests/test_corrcoef.py b/tests/test_corrcoef.py new file mode 100644 index 0000000..a615c59 --- /dev/null +++ b/tests/test_corrcoef.py @@ -0,0 +1,171 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +# --------------------------------------------------------------------------- +# Correctness tests (compared against torch.corrcoef). Tolerances are loose +# because the covariance goes through the ninetoothed mm kernel (tf32 by +# default on NVIDIA) and float accumulation differs from torch's. +# --------------------------------------------------------------------------- + +_SHAPES = [ + [2, 50], + [4, 100], + [8, 256], + [16, 1024], + [3, 4097], # observations not a multiple of the mm tiling +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype, rtol, atol", [ + # Loose: the covariance matmul uses tf32 by default on NVIDIA (~1e-3 rel). + (torch.float32, 1e-2, 1e-2), +]) +@pytest.mark.parametrize("shape", _SHAPES) +def test_corrcoef_2d(shape, dtype, rtol, atol): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + output = ntops.torch.corrcoef(input) + expected = torch.corrcoef(input) + + assert output.shape == expected.shape # (D, D) + assert output.dtype == expected.dtype + assert torch.allclose(output, expected, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +def test_corrcoef_1d(): + """A 1-D input yields the scalar 1.0.""" + device = "cuda" + input = torch.randn(100, device=device) + + output = ntops.torch.corrcoef(input) + expected = torch.corrcoef(input) + + assert output.shape == expected.shape # () + assert torch.allclose(output, expected, rtol=1e-3, atol=1e-3) + + +@skip_if_cuda_not_available +def test_corrcoef_diagonal_is_one(): + device = "cuda" + input = torch.randn(6, 200, device=device) + + output = ntops.torch.corrcoef(input) + + assert torch.allclose( + output.diagonal(), torch.ones(6, device=device), rtol=1e-3, atol=1e-3 + ) + assert output.abs().max().item() <= 1.0 + 1e-5 # clamped to [-1, 1] + + +@skip_if_cuda_not_available +def test_corrcoef_integer_input_promotes(): + device = "cuda" + input = torch.randint(0, 10, (4, 100), device=device) + + output = ntops.torch.corrcoef(input) + expected = torch.corrcoef(input) + + assert output.dtype == expected.dtype # float32 + assert torch.allclose(output, expected, rtol=1e-2, atol=1e-2) + + +@skip_if_cuda_not_available +def test_corrcoef_too_many_dims(): + device = "cuda" + input = torch.randn(2, 3, 4, device=device) + + with pytest.raises(RuntimeError): + ntops.torch.corrcoef(input) + + +# --------------------------------------------------------------------------- +# Performance benchmark interface +# --------------------------------------------------------------------------- + +def benchmark_corrcoef( + shape, + dtype=torch.float32, + device="cuda", + n_warmup=10, + n_repeat=100, +): + """Compare ntops.torch.corrcoef vs torch.corrcoef. + + Example + ------- + >>> results = benchmark_corrcoef([64, 4096]) + """ + if not torch.cuda.is_available() and device == "cuda": + raise RuntimeError("CUDA not available") + + input = torch.randn(shape, dtype=dtype, device=device) + + for _ in range(n_warmup): + ntops.torch.corrcoef(input) + torch.corrcoef(input) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_repeat): + ntops.torch.corrcoef(input) + end.record() + torch.cuda.synchronize() + ntops_ms = start.elapsed_time(end) / n_repeat + + start.record() + for _ in range(n_repeat): + torch.corrcoef(input) + end.record() + torch.cuda.synchronize() + torch_ms = start.elapsed_time(end) / n_repeat + + return { + "shape": shape, + "dtype": str(dtype), + "ntops_time_ms": ntops_ms, + "torch_time_ms": torch_ms, + "speedup": torch_ms / ntops_ms, + } + + +_SWEEP_SHAPES = [[32, 4096], [128, 8192], [512, 8192]] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_benchmark_sweep(dtype): + """Sweep sizes. Run with: + pytest tests/test_corrcoef.py::test_benchmark_sweep -v -s + """ + header = ( + f"{'shape':>16} {'ntops(ms)':>11} {'torch(ms)':>11} {'speedup':>9}" + ) + print(f"\n{'='*len(header)}") + print(f"corrcoef sweep | dtype={dtype}") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for shape in _SWEEP_SHAPES: + res = benchmark_corrcoef(shape, dtype=dtype) + print( + f"{str(shape):>16} {res['ntops_time_ms']:>11.4f} " + f"{res['torch_time_ms']:>11.4f} {res['speedup']:>9.2f}" + ) + + print("=" * len(header)) + + +@skip_if_cuda_not_available +def test_benchmark_interface(): + results = benchmark_corrcoef([32, 512], n_warmup=2, n_repeat=5) + assert results["ntops_time_ms"] > 0 diff --git a/tests/test_count_nonzero.py b/tests/test_count_nonzero.py new file mode 100644 index 0000000..24ec3c7 --- /dev/null +++ b/tests/test_count_nonzero.py @@ -0,0 +1,197 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +# --------------------------------------------------------------------------- +# Correctness tests (compared against torch.count_nonzero). The count is exact, +# so torch.equal; output is always int64. +# --------------------------------------------------------------------------- + +_SHAPES = [ + [16], + [1024], + [4097], # not a multiple of the reduction block size + [32, 64], + [8, 7, 5], + [4, 3, 16, 16], +] + + +def _make_input(shape, dtype, device): + """A tensor with a healthy mix of zeros and nonzeros.""" + if dtype == torch.int64: + x = torch.randint(-2, 3, shape, dtype=dtype, device=device) + else: + x = torch.randn(shape, dtype=dtype, device=device) + x = torch.where(x.abs() < 0.5, torch.zeros_like(x), x) + return x + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int64]) +@pytest.mark.parametrize("shape", _SHAPES) +def test_count_nonzero_global(shape, dtype): + device = "cuda" + input = _make_input(shape, dtype, device) + + output = ntops.torch.count_nonzero(input) + expected = torch.count_nonzero(input) + + assert output.shape == expected.shape # scalar () + assert output.dtype == expected.dtype # int64 + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +@pytest.mark.parametrize("shape, dim", [ + ([32, 64], 0), + ([32, 64], 1), + ([32, 64], -1), + ([8, 7, 5], 0), + ([8, 7, 5], 1), + ([8, 7, 5], 2), + ([8, 7, 5], (0, 1)), + ([8, 7, 5], (1, 2)), + ([8, 7, 5], (0, 1, 2)), + ([4, 3, 16, 16], (2, 3)), +]) +def test_count_nonzero_dim(shape, dim, dtype): + device = "cuda" + input = _make_input(shape, dtype, device) + + output = ntops.torch.count_nonzero(input, dim=dim) + expected = torch.count_nonzero(input, dim=dim) + + assert output.shape == expected.shape + assert output.dtype == expected.dtype + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_count_nonzero_all_zeros_and_all_nonzeros(): + device = "cuda" + zeros = torch.zeros(100, device=device) + ones = torch.ones(100, device=device) + + assert ntops.torch.count_nonzero(zeros).item() == 0 + assert ntops.torch.count_nonzero(ones).item() == 100 + + +@skip_if_cuda_not_available +def test_count_nonzero_empty(): + device = "cuda" + input = torch.randn(0, device=device) + + output = ntops.torch.count_nonzero(input) + expected = torch.count_nonzero(input) + + assert torch.equal(output, expected) + + +# --------------------------------------------------------------------------- +# Performance benchmark interface +# --------------------------------------------------------------------------- + +def benchmark_count_nonzero( + shape, + dim=None, + dtype=torch.float32, + device="cuda", + n_warmup=10, + n_repeat=100, +): + """Compare ntops.torch.count_nonzero vs torch.count_nonzero. Bandwidth + assumes the input is read once. + + Example + ------- + >>> results = benchmark_count_nonzero([4096, 4096]) + """ + if not torch.cuda.is_available() and device == "cuda": + raise RuntimeError("CUDA not available") + + input = torch.randn(shape, dtype=dtype, device=device) + input = torch.where(input.abs() < 0.5, torch.zeros_like(input), input) + + def run_ntops(): + ntops.torch.count_nonzero(input, dim=dim) + + def run_torch(): + torch.count_nonzero(input, dim=dim) + + for _ in range(n_warmup): + run_ntops() + run_torch() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_repeat): + run_ntops() + end.record() + torch.cuda.synchronize() + ntops_ms = start.elapsed_time(end) / n_repeat + + start.record() + for _ in range(n_repeat): + run_torch() + end.record() + torch.cuda.synchronize() + torch_ms = start.elapsed_time(end) / n_repeat + + num_bytes = input.numel() * input.element_size() + + return { + "shape": shape, + "dim": dim, + "dtype": str(dtype), + "ntops_time_ms": ntops_ms, + "torch_time_ms": torch_ms, + "ntops_bandwidth_GBs": num_bytes / (ntops_ms * 1e-3) / 1e9, + "torch_bandwidth_GBs": num_bytes / (torch_ms * 1e-3) / 1e9, + "speedup": torch_ms / ntops_ms, + } + + +_SWEEP_SHAPES = [[1024, 1024], [4096, 4096], [8192, 8192]] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dim", [None, 0, 1]) +def test_benchmark_sweep(dim, dtype): + """Sweep sizes. Run with: + pytest tests/test_count_nonzero.py::test_benchmark_sweep -v -s + """ + header = ( + f"{'shape':>16} {'dim':>6} " + f"{'ntops(ms)':>11} {'torch(ms)':>11} " + f"{'ntops(GB/s)':>13} {'torch(GB/s)':>13} {'speedup':>9}" + ) + print(f"\n{'='*len(header)}") + print(f"count_nonzero sweep | dim={dim} | dtype={dtype}") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for shape in _SWEEP_SHAPES: + res = benchmark_count_nonzero(shape, dim=dim, dtype=dtype) + print( + f"{str(shape):>16} {str(dim):>6} " + f"{res['ntops_time_ms']:>11.4f} {res['torch_time_ms']:>11.4f} " + f"{res['ntops_bandwidth_GBs']:>13.1f} {res['torch_bandwidth_GBs']:>13.1f} " + f"{res['speedup']:>9.2f}" + ) + + print("=" * len(header)) + + +@skip_if_cuda_not_available +def test_benchmark_interface(): + results = benchmark_count_nonzero([512, 512], n_warmup=2, n_repeat=5) + assert results["ntops_time_ms"] > 0 diff --git a/tests/test_kl_div.py b/tests/test_kl_div.py new file mode 100644 index 0000000..87984e3 --- /dev/null +++ b/tests/test_kl_div.py @@ -0,0 +1,248 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + +pytestmark = pytest.mark.filterwarnings( + "ignore:.*batchmean.*:UserWarning" +) + +# --------------------------------------------------------------------------- +# Correctness tests (compared against torch.nn.functional.kl_div) +# +# ``kl_div`` expects ``input`` to be log-probabilities and ``target`` to be +# probabilities (or log-probabilities when ``log_target=True``); the helpers +# below build valid inputs via log_softmax / softmax over the last dim. +# --------------------------------------------------------------------------- + +_SHAPES = [ + [16], + [1024], + [4097], # not a multiple of the reduction block size + [32, 64], + [8, 7, 5], + [4, 3, 16, 16], + [1], # single element +] + + +def _make_inputs(shape, dtype, device, log_target): + input = torch.randn(shape, dtype=dtype, device=device).log_softmax(dim=-1) + + if log_target: + target = torch.randn(shape, dtype=dtype, device=device).log_softmax(dim=-1) + else: + target = torch.randn(shape, dtype=dtype, device=device).softmax(dim=-1) + + return input, target + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype, rtol, atol", [ + (torch.float32, 1e-3, 1e-3), + (torch.float16, 1e-2, 1e-2), +]) +@pytest.mark.parametrize("log_target", [False, True]) +@pytest.mark.parametrize("reduction", ["none", "sum", "mean", "batchmean"]) +@pytest.mark.parametrize("shape", _SHAPES) +def test_kl_div(shape, reduction, log_target, dtype, rtol, atol): + device = "cuda" + input, target = _make_inputs(shape, dtype, device, log_target) + + output = ntops.torch.kl_div( + input, target, reduction=reduction, log_target=log_target + ) + expected = F.kl_div( + input, target, reduction=reduction, log_target=log_target + ) + + assert output.shape == expected.shape + assert output.dtype == expected.dtype + assert torch.allclose(output, expected, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("reduction", ["none", "sum", "mean", "batchmean"]) +def test_kl_div_target_with_zeros(reduction): + """target == 0 must contribute 0 (the 0*log(0)=0 convention), not NaN.""" + device = "cuda" + input = torch.randn(8, 16, device=device).log_softmax(dim=-1) + target = torch.rand(8, 16, device=device) + target[target < 0.3] = 0.0 # inject exact zeros + + output = ntops.torch.kl_div(input, target, reduction=reduction) + expected = F.kl_div(input, target, reduction=reduction) + + assert not torch.isnan(output).any() + assert torch.allclose(output, expected, rtol=1e-3, atol=1e-3) + + +@skip_if_cuda_not_available +def test_kl_div_default_reduction_is_mean(): + device = "cuda" + input = torch.randn(2, 3, 4, device=device).log_softmax(dim=-1) + target = torch.randn(2, 3, 4, device=device).softmax(dim=-1) + + output = ntops.torch.kl_div(input, target) + expected = F.kl_div(input, target, reduction="mean") + + assert torch.allclose(output, expected, rtol=1e-3, atol=1e-3) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("log_target", [False, True]) +def test_kl_div_broadcast(log_target): + device = "cuda" + input = torch.randn(4, 3, 8, device=device).log_softmax(dim=-1) + if log_target: + target = torch.randn(3, 8, device=device).log_softmax(dim=-1) + else: + target = torch.randn(3, 8, device=device).softmax(dim=-1) + + output = ntops.torch.kl_div( + input, target, reduction="sum", log_target=log_target + ) + expected = F.kl_div( + input, + target.expand_as(input).contiguous(), + reduction="sum", + log_target=log_target, + ) + + assert torch.allclose(output, expected, rtol=1e-3, atol=1e-3) + + +@skip_if_cuda_not_available +def test_kl_div_invalid_reduction(): + device = "cuda" + input = torch.randn(8, device=device).log_softmax(dim=-1) + target = torch.randn(8, device=device).softmax(dim=-1) + + with pytest.raises(ValueError): + ntops.torch.kl_div(input, target, reduction="median") + + +# --------------------------------------------------------------------------- +# Performance benchmark interface +# --------------------------------------------------------------------------- + +def benchmark_kl_div( + shape, + reduction="mean", + log_target=False, + dtype=torch.float32, + device="cuda", + n_warmup=10, + n_repeat=100, +): + """Compare ntops.torch.kl_div vs F.kl_div. + + Returns timing (ms) and effective memory bandwidth (GB/s) for both, + plus the speedup ratio. Bandwidth assumes both ``input`` and ``target`` + are read once (2x input bytes), which is the lower bound for the op. + + Example + ------- + >>> results = benchmark_kl_div([4096, 4096], "mean") + >>> print(results) + """ + if not torch.cuda.is_available() and device == "cuda": + raise RuntimeError("CUDA not available") + + input = torch.randn(shape, dtype=dtype, device=device).log_softmax(dim=-1) + target = torch.randn(shape, dtype=dtype, device=device).softmax(dim=-1) + + def run_ntops(): + ntops.torch.kl_div(input, target, reduction=reduction, log_target=log_target) + + def run_torch(): + F.kl_div(input, target, reduction=reduction, log_target=log_target) + + for _ in range(n_warmup): + run_ntops() + run_torch() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_repeat): + run_ntops() + end.record() + torch.cuda.synchronize() + ntops_ms = start.elapsed_time(end) / n_repeat + + start.record() + for _ in range(n_repeat): + run_torch() + end.record() + torch.cuda.synchronize() + torch_ms = start.elapsed_time(end) / n_repeat + + num_bytes = input.numel() * input.element_size() * 2 + ntops_gbps = num_bytes / (ntops_ms * 1e-3) / 1e9 + torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9 + + return { + "shape": shape, + "reduction": reduction, + "log_target": log_target, + "dtype": str(dtype), + "ntops_time_ms": ntops_ms, + "torch_time_ms": torch_ms, + "ntops_bandwidth_GBs": ntops_gbps, + "torch_bandwidth_GBs": torch_gbps, + "speedup": torch_ms / ntops_ms, + } + + +_SWEEP_SHAPES = [ + [1024, 1024], # 4 MB + [4096, 4096], # 64 MB + [8192, 8192], # 256 MB +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("reduction", ["none", "mean", "sum", "batchmean"]) +def test_benchmark_sweep(reduction, dtype): + """Sweep tensor sizes. Run with: + pytest tests/test_kl_div.py::test_benchmark_sweep -v -s + """ + header = ( + f"{'shape':>16} {'MB':>8} " + f"{'ntops(ms)':>11} {'torch(ms)':>11} " + f"{'ntops(GB/s)':>13} {'torch(GB/s)':>13} {'speedup':>9}" + ) + print(f"\n{'='*len(header)}") + print(f"kl_div sweep | reduction={reduction} | dtype={dtype}") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for shape in _SWEEP_SHAPES: + res = benchmark_kl_div(shape, reduction=reduction, dtype=dtype) + mb = ( + torch.empty(shape, dtype=dtype).numel() + * torch.empty(0, dtype=dtype).element_size() + ) / 1e6 + print( + f"{str(shape):>16} {mb:>8.1f} " + f"{res['ntops_time_ms']:>11.4f} {res['torch_time_ms']:>11.4f} " + f"{res['ntops_bandwidth_GBs']:>13.1f} {res['torch_bandwidth_GBs']:>13.1f} " + f"{res['speedup']:>9.2f}" + ) + + print("=" * len(header)) + + +@skip_if_cuda_not_available +def test_benchmark_interface(): + """Smoke-test that the benchmark interface runs without error.""" + results = benchmark_kl_div([512, 512], n_warmup=2, n_repeat=5) + assert results["ntops_time_ms"] > 0 + assert results["ntops_bandwidth_GBs"] > 0 diff --git a/tests/test_narrow.py b/tests/test_narrow.py new file mode 100644 index 0000000..ac6189e --- /dev/null +++ b/tests/test_narrow.py @@ -0,0 +1,181 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +# --------------------------------------------------------------------------- +# Correctness tests (compared against torch.narrow). narrow is a pure copy of a +# strided slice, so the output must match bit-for-bit -- torch.equal. +# --------------------------------------------------------------------------- + +_CASES = [ + # (shape, dim, start, length) + ([16], 0, 3, 8), + ([16], 0, 0, 16), # full + ([16], -1, 2, 4), # negative dim + ([16], 0, -5, 3), # negative start + ([8, 7], 0, 1, 4), + ([8, 7], 1, 2, 3), + ([8, 7], -1, -4, 2), + ([4, 5, 6], 1, 1, 3), + ([4, 5, 6], 2, 0, 6), + ([4, 5, 6], 0, 2, 0), # zero length + ([3, 4097], 1, 5, 4090), # not a multiple of the copy block size +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int64]) +@pytest.mark.parametrize("shape, dim, start, length", _CASES) +def test_narrow(shape, dim, start, length, dtype): + device = "cuda" + if dtype == torch.int64: + input = torch.randint(-1000, 1000, shape, dtype=dtype, device=device) + else: + input = torch.randn(shape, dtype=dtype, device=device) + + output = ntops.torch.narrow(input, dim, start, length) + expected = torch.narrow(input, dim, start, length) + + assert output.shape == expected.shape + assert output.dtype == expected.dtype + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_narrow_tensor_start(): + """torch accepts a 0-dim tensor start.""" + device = "cuda" + input = torch.randn(10, device=device) + + output = ntops.torch.narrow(input, 0, torch.tensor(3), 4) + expected = torch.narrow(input, 0, 3, 4) + + assert torch.equal(output, expected) + + +@skip_if_cuda_not_available +def test_narrow_out_of_range_dim(): + device = "cuda" + input = torch.randn(4, 5, device=device) + + with pytest.raises(IndexError): + ntops.torch.narrow(input, 2, 0, 1) + + +@skip_if_cuda_not_available +def test_narrow_length_too_large(): + device = "cuda" + input = torch.randn(8, device=device) + + with pytest.raises(RuntimeError): + ntops.torch.narrow(input, 0, 5, 10) + + +# --------------------------------------------------------------------------- +# Performance benchmark interface +# --------------------------------------------------------------------------- + +def benchmark_narrow( + shape, + dim=0, + fraction=0.5, + dtype=torch.float32, + device="cuda", + n_warmup=10, + n_repeat=100, +): + """Compare ntops.torch.narrow (materializing copy) vs torch.narrow (view) + + .contiguous(). Bandwidth assumes the slice is read and written once. + + Example + ------- + >>> results = benchmark_narrow([4096, 4096], dim=1) + """ + if not torch.cuda.is_available() and device == "cuda": + raise RuntimeError("CUDA not available") + + input = torch.randn(shape, dtype=dtype, device=device) + length = max(1, int(shape[dim] * fraction)) + + def run_ntops(): + ntops.torch.narrow(input, dim, 0, length) + + def run_torch(): + torch.narrow(input, dim, 0, length).contiguous() + + for _ in range(n_warmup): + run_ntops() + run_torch() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_repeat): + run_ntops() + end.record() + torch.cuda.synchronize() + ntops_ms = start.elapsed_time(end) / n_repeat + + start.record() + for _ in range(n_repeat): + run_torch() + end.record() + torch.cuda.synchronize() + torch_ms = start.elapsed_time(end) / n_repeat + + out = ntops.torch.narrow(input, dim, 0, length) + num_bytes = out.numel() * out.element_size() * 2 + + return { + "shape": shape, + "dim": dim, + "dtype": str(dtype), + "ntops_time_ms": ntops_ms, + "torch_time_ms": torch_ms, + "ntops_bandwidth_GBs": num_bytes / (ntops_ms * 1e-3) / 1e9, + "torch_bandwidth_GBs": num_bytes / (torch_ms * 1e-3) / 1e9, + "speedup": torch_ms / ntops_ms, + } + + +_SWEEP_SHAPES = [[1024, 1024], [4096, 4096], [8192, 8192]] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dim", [0, 1]) +def test_benchmark_sweep(dim, dtype): + """Sweep sizes. Run with: + pytest tests/test_narrow.py::test_benchmark_sweep -v -s + """ + header = ( + f"{'shape':>16} {'dim':>4} " + f"{'ntops(ms)':>11} {'torch(ms)':>11} " + f"{'ntops(GB/s)':>13} {'torch(GB/s)':>13} {'speedup':>9}" + ) + print(f"\n{'='*len(header)}") + print(f"narrow sweep | dim={dim} | dtype={dtype}") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for shape in _SWEEP_SHAPES: + res = benchmark_narrow(shape, dim=dim, dtype=dtype) + print( + f"{str(shape):>16} {dim:>4} " + f"{res['ntops_time_ms']:>11.4f} {res['torch_time_ms']:>11.4f} " + f"{res['ntops_bandwidth_GBs']:>13.1f} {res['torch_bandwidth_GBs']:>13.1f} " + f"{res['speedup']:>9.2f}" + ) + + print("=" * len(header)) + + +@skip_if_cuda_not_available +def test_benchmark_interface(): + results = benchmark_narrow([512, 512], n_warmup=2, n_repeat=5) + assert results["ntops_time_ms"] > 0