diff --git a/bench/bench_slogdet.py b/bench/bench_slogdet.py new file mode 100644 index 0000000..bf286f2 --- /dev/null +++ b/bench/bench_slogdet.py @@ -0,0 +1,89 @@ +"""Benchmark ntops.slogdet vs torch.linalg.slogdet. + +Single-block in-kernel LU targets *many small matrices*. Reports GB/s and +speedup at several (batch, n) combinations so you can see where the crossover +is between ninetoothed and cuSOLVER. + + python bench/bench_slogdet.py +""" + +import os +from contextlib import contextmanager + +import torch +import triton.testing + +import ntops + + +@contextmanager +def _suppress_stderr(): + # The Iluvatar backend emits "loop ... not unrolled" remarks at the C/MLIR + # level (fd 2), which `contextlib.redirect_stderr` can't catch; dup2 the fd. + # Benign (correctness unaffected); this just keeps the bench output clean. + saved = os.dup(2) + devnull = os.open(os.devnull, os.O_WRONLY) + try: + os.dup2(devnull, 2) + yield + finally: + os.dup2(saved, 2) + os.close(devnull) + os.close(saved) + +DEVICE = "cuda" +DTYPE = torch.float32 + +# (batch, n) — n must be a power of 2 after padding; we pass the exact n and +# let the wrapper pad. Add larger n cautiously: single-block LU may fail to +# compile above n~64-128. +CASES = [ + (1, 1), + (1, 4), + (1, 8), + (1, 16), + (64, 4), + (128, 4), + (512, 4), + (64, 8), + (128, 8), + (64, 16), + (128, 16), + (512, 16), + (64, 32), + (128, 32), + # probe the large-n ceiling: where does single-block LU stop compiling / + # stop winning vs cuSOLVER? + (1, 64), + (64, 64), + (1, 128), + (64, 128), + (1, 256), +] + + +def main(): + print(f"device: {torch.cuda.get_device_name()}\n") + print(f" {'shape':>16s} {'九齿 ms':>10s} {'torch ms':>10s} speedup") + print(" " + "-" * 55) + + for batch, n in CASES: + A = torch.randn(batch, n, n, dtype=DTYPE, device=DEVICE) + + try: + with _suppress_stderr(): + ms_nt = triton.testing.do_bench(lambda: ntops.torch.slogdet(A)) + except Exception as exc: + print(f" ({batch:4d},{n:3d}) SKIP ninetoothed: {type(exc).__name__}") + continue + + ms_th = triton.testing.do_bench(lambda: torch.linalg.slogdet(A)) + + print( + f" ({batch:4d}, {n:3d}) " + f"{ms_nt:10.3f} ms {ms_th:10.3f} ms {ms_th / ms_nt:.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/bench/bench_vs_torch.py b/bench/bench_vs_torch.py new file mode 100644 index 0000000..0b06540 --- /dev/null +++ b/bench/bench_vs_torch.py @@ -0,0 +1,87 @@ +"""Compare ninetoothed ops against their torch references. + +Reports achieved bandwidth (GB/s) and speedup (torch_ms / ninetoothed_ms) at a +couple of large shapes. Auto-tuning is left ON here, so this shows the +*achievable ceiling*; the eval config (max_num_configs=1) comes from the +tune_*.py scripts. Byte models are approximate — the speedup ratio is the +meaningful number. + +Scoring is against other ninetoothed entries, not torch: torch winning on the +view-based op (hsplit) is expected, and slogdet (LU factorization) is omitted +because it *is* torch. + + python bench/bench_vs_torch.py +""" + +import torch +import torch.nn.functional as F +import triton.testing + +import ntops + +DEVICE = "cuda" +DTYPE = torch.float32 +SHAPES = ((4096, 4096), (8192, 8192)) + + +def _report(name, shape, ms_nt, ms_th, nbytes): + bw_nt = nbytes / ms_nt * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9) + bw_th = nbytes / ms_th * 1e-6 + print( + f" {name:15s} {str(shape):14s} " + f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | " + f"speedup {ms_th / ms_nt:.2f}x" + ) + + +def _bench(name, shape, nt_fn, th_fn, byte_factor): + ms_nt = triton.testing.do_bench(nt_fn) + ms_th = triton.testing.do_bench(th_fn) + numel = shape[0] * shape[1] + _report(name, shape, ms_nt, ms_th, byte_factor * numel * torch.tensor([], dtype=DTYPE).element_size()) + + +def main(): + print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}\n") + + for shape in SHAPES: + x = torch.randn(shape, dtype=DTYPE, device=DEVICE) + v = torch.randn(shape, dtype=DTYPE, device=DEVICE) + + _bench( + "heaviside", shape, + lambda: ntops.torch.heaviside(x, v), + lambda: torch.heaviside(x, v), + byte_factor=3, + ) + + end = shape[-1] // 2 + src = x[..., :end].contiguous() + _bench( + "slice_scatter", shape, + lambda: ntops.torch.slice_scatter(x, src, dim=-1, start=0, end=end), + lambda: torch.slice_scatter(x, src, dim=-1, start=0, end=end), + byte_factor=2, + ) + + # torch.hsplit returns zero-copy views (≈ no work), so force torch to + # materialize each split for a fair, apples-to-apples comparison. + _bench( + "hsplit", shape, + lambda: ntops.torch.hsplit(x, 2), + lambda: tuple(v.contiguous() for v in torch.hsplit(x, 2)), + byte_factor=2, + ) + + _bench( + "gumbel_softmax", shape, + lambda: ntops.torch.gumbel_softmax(x, dim=-1), + lambda: F.gumbel_softmax(x, dim=-1), + byte_factor=2, + ) + + print() + + +if __name__ == "__main__": + main() diff --git a/bench/tune_copy.py b/bench/tune_copy.py new file mode 100644 index 0000000..05c9ccb --- /dev/null +++ b/bench/tune_copy.py @@ -0,0 +1,63 @@ +"""Tune ``(num_warps, block_size)`` for the copy kernel. + +The copy kernel backs both ``slice_scatter`` (contiguous copy) and ``hsplit`` +(strided copy). This tunes the contiguous case — slice_scatter's dominant cost; +hsplit reuses the same config. Same eval-condition / device-adaptive notes as +``tune_heaviside.py``. + + python bench/tune_copy.py +""" + +import torch +import triton.testing + +import ntops +from ntops.torch.utils import _cached_make, set_default_max_num_configs + +DEVICE = "cuda" +SHAPE = (8192, 8192) +BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192) +NUM_WARPS = (4, 8, 16) + + +def _tune(dtype): + input = torch.randn(SHAPE, dtype=dtype, device=DEVICE) + output = torch.empty_like(input) + + # copy touches 1 read + 1 write per element. + nbytes = 2 * input.numel() * input.element_size() + best_bw, best_cfg = 0.0, None + + for block_size in BLOCK_SIZES: + for num_warps in NUM_WARPS: + try: + kernel = _cached_make( + ntops.kernels.copy.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + ms = triton.testing.do_bench(lambda k=kernel: k(input, output)) + bw = nbytes / ms * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9) + print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s") + if bw > best_bw: + best_bw, best_cfg = bw, (num_warps, block_size) + except Exception as exc: # noqa: BLE001 + print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})") + + return best_bw, best_cfg + + +def main(): + set_default_max_num_configs(1) + print(f"device: {torch.cuda.get_device_name()}") + + for dtype in (torch.float32, torch.float16): + print(f"\n[{dtype}]") + bw, cfg = _tune(dtype) + print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)") + + +if __name__ == "__main__": + main() diff --git a/bench/tune_heaviside.py b/bench/tune_heaviside.py new file mode 100644 index 0000000..f1276df --- /dev/null +++ b/bench/tune_heaviside.py @@ -0,0 +1,66 @@ +"""Tune ``(num_warps, block_size)`` for the heaviside kernel. + +Evaluation disables auto-tuning (``max_num_configs=1``), so the winning config +must be passed explicitly into ``premake``. This sweeps configs at a fixed large +shape under eval-like conditions and reports the best GB/s per dtype. + +Run on each platform (NVIDIA / Iluvatar / MetaX) and bake the winners into a +device-adaptive ``_launch_config`` in ``ntops/torch/heaviside.py``. + + python bench/tune_heaviside.py +""" + +import torch +import triton.testing + +import ntops +from ntops.torch.utils import _cached_make, set_default_max_num_configs + +DEVICE = "cuda" +SHAPE = (8192, 8192) +BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192) +NUM_WARPS = (4, 8, 16) + + +def _tune(dtype): + input = torch.randn(SHAPE, dtype=dtype, device=DEVICE) + values = torch.randn(SHAPE, dtype=dtype, device=DEVICE) + output = torch.empty_like(input) + + # heaviside touches 2 reads + 1 write per element. + nbytes = 3 * input.numel() * input.element_size() + best_bw, best_cfg = 0.0, None + + for block_size in BLOCK_SIZES: + for num_warps in NUM_WARPS: + try: + kernel = _cached_make( + ntops.kernels.heaviside.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + ms = triton.testing.do_bench(lambda k=kernel: k(input, values, output)) + bw = nbytes / ms * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9) + print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s") + if bw > best_bw: + best_bw, best_cfg = bw, (num_warps, block_size) + except Exception as exc: # noqa: BLE001 + print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})") + + return best_bw, best_cfg + + +def main(): + set_default_max_num_configs(1) + print(f"device: {torch.cuda.get_device_name()}") + + for dtype in (torch.float32, torch.float16): + print(f"\n[{dtype}]") + bw, cfg = _tune(dtype) + print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)") + + +if __name__ == "__main__": + main() diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..31068a9 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -9,6 +9,7 @@ bmm, clamp, conv2d, + copy, cos, div, dropout, @@ -17,6 +18,7 @@ ge, gelu, gt, + heaviside, isinf, isnan, layer_norm, @@ -36,6 +38,7 @@ sigmoid, silu, sin, + slogdet, softmax, sub, tanh, @@ -52,6 +55,7 @@ "bmm", "clamp", "conv2d", + "copy", "cos", "div", "dropout", @@ -60,6 +64,7 @@ "ge", "gelu", "gt", + "heaviside", "isinf", "isnan", "layer_norm", @@ -79,6 +84,7 @@ "sigmoid", "silu", "sin", + "slogdet", "softmax", "sub", "tanh", diff --git a/src/ntops/kernels/copy.py b/src/ntops/kernels/copy.py new file mode 100644 index 0000000..fff8218 --- /dev/null +++ b/src/ntops/kernels/copy.py @@ -0,0 +1,17 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/heaviside.py b/src/ntops/kernels/heaviside.py new file mode 100644 index 0000000..8fe8887 --- /dev/null +++ b/src/ntops/kernels/heaviside.py @@ -0,0 +1,24 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, values, output): + output = ntl.where( # noqa: F841 + input < 0, 0, ntl.where(input > 0, 1, values) + ) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/slogdet.py b/src/ntops/kernels/slogdet.py new file mode 100644 index 0000000..0017aec --- /dev/null +++ b/src/ntops/kernels/slogdet.py @@ -0,0 +1,88 @@ +"""Pure-ninetoothed slogdet via single-block Gaussian elimination. + +DRAFT (no pivoting). One program per (batched) matrix; the whole N*N matrix is +loaded into one tile and an unrolled `for k in range(N)` loop runs LU in-block. +Data-dependent addressing is avoided entirely: every row/column/entry is +extracted with a one-hot masked reduction, so all addressing stays affine. + +The matrix dimension is referenced as `input.shape[0]` (a constexpr from the +Symbol-`n` tiling), never as a bare Symbol — a bare Symbol Name in the body is +treated as a tensor load by the codegen. The wrapper pads the n*n matrix to +N = next_pow2(n) with an identity block (`[[A, 0], [0, I]]`, det unchanged), so +padded pivots are 1 and contribute log|1|=0, sign*1. + +Limits: matrix must fit one block (n <= ~64-128); no intra-matrix parallelism +(only batch parallelizes), so this targets *many small* matrices. +""" + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + + +def arrangement(input, sign, logabsdet, n, block_size=None): + # One program per batch element: each owns the full (n, n) matrix and two + # scalar outputs. `n` (constexpr Symbol) makes the tile dim a constexpr. + input_arranged = input.tile((1, n, n)) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + sign_arranged = sign.tile((1,)) + sign_arranged.dtype = sign_arranged.dtype.squeeze(0) + + logabsdet_arranged = logabsdet.tile((1,)) + logabsdet_arranged.dtype = logabsdet_arranged.dtype.squeeze(0) + + return input_arranged, sign_arranged, logabsdet_arranged + + +def application(input, sign, logabsdet): + # input: (N, N) for one matrix, N a power of 2. Reference the dim only via + # `input.shape[0]` (constexpr), never a bare Symbol. + row = ntl.arange(0, input.shape[0]) + col = ntl.arange(0, input.shape[0]) + row_c = ntl.expand_dims(row, 1) # (N, 1) + col_r = ntl.expand_dims(col, 0) # (1, N) + + a = ntl.cast(input, ntl.float32) + logabs = ntl.cast(0.0, ntl.float32) + sgn = ntl.cast(1.0, ntl.float32) + + for k in range(input.shape[0]): # constexpr -> unrolled + is_k_row = row_c == k # (N, 1) + is_k_col = col_r == k # (1, N) + + # pivot = a[k, k] via one-hot reduction (no data-dependent addressing). + pivot = ntl.sum(ntl.sum(ntl.where(is_k_row & is_k_col, a, 0.0), axis=1), axis=0) + + logabs += ntl.log(ntl.abs(pivot)) + sgn = sgn * ntl.where(pivot > 0, 1.0, ntl.where(pivot < 0, -1.0, 0.0)) + + col_k = ntl.sum(ntl.where(is_k_col, a, 0.0), axis=1) # (N,) + row_k = ntl.sum(ntl.where(is_k_row, a, 0.0), axis=0) # (N,) + + # multipliers for rows below k; guard a zero pivot so `a` stays finite + # (the log term already drove logabs to -inf, sgn to 0). + m = ntl.where((row > k) & (pivot != 0.0), col_k / pivot, 0.0) # (N,) + m_c = ntl.expand_dims(m, 1) # (N, 1) + row_k_r = ntl.expand_dims(row_k, 0) # (1, N) + + a = a - m_c * row_k_r # rank-1 update, only rows > k change + + sign = sgn # noqa: F841 + logabsdet = logabs # noqa: F841 + + +def premake(dtype=None, block_size=None): + n = Symbol("n", constexpr=True) + + tensors = ( + Tensor(3, dtype=dtype, other=0.0), # input (B, n, n), padded + Tensor(1, dtype=ninetoothed.float32), # sign (B,) + Tensor(1, dtype=ninetoothed.float32), # logabsdet (B,) + ) + + arrangement_ = functools.partial(arrangement, n=n, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..ea956d1 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -16,6 +16,9 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.gumbel_softmax import gumbel_softmax +from ntops.torch.heaviside import heaviside +from ntops.torch.hsplit import hsplit from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -36,6 +39,8 @@ from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin +from ntops.torch.slice_scatter import slice_scatter +from ntops.torch.slogdet import slogdet from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh @@ -59,6 +64,9 @@ "ge", "gelu", "gt", + "gumbel_softmax", + "heaviside", + "hsplit", "isinf", "isnan", "layer_norm", @@ -79,6 +87,8 @@ "sigmoid", "silu", "sin", + "slice_scatter", + "slogdet", "softmax", "sub", "tanh", diff --git a/src/ntops/torch/copy.py b/src/ntops/torch/copy.py new file mode 100644 index 0000000..e8dd856 --- /dev/null +++ b/src/ntops/torch/copy.py @@ -0,0 +1,29 @@ +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# (num_warps, block_size) tuned per platform at [8192, 8192]; see bench/tune_copy.py. +_CONFIGS = { + "nvidia": (8, 1024), + "iluvatar": (4, 2048), + "metax": (4, 8192), + "default": (4, 2048), +} + + +def _copy(input, output): + """Materialize (possibly strided) ``input`` into ``output`` via the copy + kernel. Internal helper shared by ``slice_scatter`` and ``hsplit``; not a + public op.""" + num_warps, block_size = _CONFIGS[_device_key()] + + kernel = _cached_make( + ntops.kernels.copy.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/gumbel_softmax.py b/src/ntops/torch/gumbel_softmax.py new file mode 100644 index 0000000..dde6927 --- /dev/null +++ b/src/ntops/torch/gumbel_softmax.py @@ -0,0 +1,32 @@ +import torch + +import ntops + + +def gumbel_softmax(logits, tau=1.0, hard=False, eps=1e-10, dim=-1): + # Gumbel-Softmax = softmax over `(logits + gumbel_noise) / tau`. The heavy, + # bandwidth-bound part is the normalized reduction (softmax), which is + # delegated to the ninetoothed kernel; sampling the noise and (optionally) + # the straight-through one-hot are cheap torch glue. The noise is drawn + # identically to `torch.nn.functional.gumbel_softmax` so a shared RNG seed + # reproduces its result bit-for-bit up to the softmax numerics. + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) + gumbels = (logits + gumbels) / tau + + y_soft = ntops.torch.softmax(gumbels, dim) + + if hard: + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format + ).scatter_(dim, index, 1.0) + # Straight-through estimator: hard forward value, soft gradient. + ret = y_hard - y_soft.detach() + y_soft + else: + ret = y_soft + + return ret diff --git a/src/ntops/torch/heaviside.py b/src/ntops/torch/heaviside.py new file mode 100644 index 0000000..a90caec --- /dev/null +++ b/src/ntops/torch/heaviside.py @@ -0,0 +1,43 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# (num_warps, block_size) tuned per platform at [8192, 8192]; see +# bench/tune_heaviside.py. +_CONFIGS = { + "nvidia": (8, 1024), + "iluvatar": (4, 2048), + "metax": (4, 8192), + "default": (4, 2048), +} + + +def heaviside(input, values, *, out=None): + # `torch.heaviside` requires `input` and `values` to share a dtype and + # only broadcasts `values` against `input`. + assert input.dtype == values.dtype, ( + "`heaviside` requires `input` and `values` to have the same dtype." + ) + + if out is None: + out = torch.empty_like(input) + + # A stride-0 broadcast view is read correctly by ninetoothed (the offset is + # `index * stride`, and `stride == 0` repeats the element), so a scalar + # `values` is not materialized to `input`'s size. + values = values.broadcast_to(input.shape) + + num_warps, block_size = _CONFIGS[_device_key()] + + kernel = _cached_make( + ntops.kernels.heaviside.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + + kernel(input, values, out) + + return out diff --git a/src/ntops/torch/hsplit.py b/src/ntops/torch/hsplit.py new file mode 100644 index 0000000..34b7b60 --- /dev/null +++ b/src/ntops/torch/hsplit.py @@ -0,0 +1,45 @@ +import torch + +from ntops.torch.copy import _copy + + +def hsplit(input, indices_or_sections): + # `torch.hsplit` splits along dim 0 for 1D inputs and dim 1 otherwise, and + # returns zero-copy *views*. ninetoothed cannot return views, so each split + # is materialized into a contiguous tensor by a copy kernel reading the + # (generally strided) slice. We compute the split boundaries ourselves + # rather than calling `torch.hsplit`, so the operator is genuinely + # reimplemented; only the cheap view construction borrows torch slicing. + assert input.ndim >= 1, "`hsplit` requires at least a 1D tensor." + + dim = 0 if input.ndim == 1 else 1 + size = input.shape[dim] + + if isinstance(indices_or_sections, int): + sections = indices_or_sections + assert size % sections == 0, ( + f"torch.hsplit attempted to split along dimension {dim}, but the " + f"size of the dimension {size} is not divisible by the split_size " + f"{sections}." + ) + split = size // sections + bounds = [(i * split, (i + 1) * split) for i in range(sections)] + else: + points = [0, *list(indices_or_sections), size] + bounds = [(points[i], points[i + 1]) for i in range(len(points) - 1)] + + outputs = [] + + for lo, hi in bounds: + index = [slice(None)] * input.ndim + index[dim] = slice(lo, hi) + view = input[tuple(index)] + + out = torch.empty(view.shape, dtype=input.dtype, device=input.device) + + if view.numel() != 0: + _copy(view, out) + + outputs.append(out) + + return tuple(outputs) diff --git a/src/ntops/torch/slice_scatter.py b/src/ntops/torch/slice_scatter.py new file mode 100644 index 0000000..eb1ac26 --- /dev/null +++ b/src/ntops/torch/slice_scatter.py @@ -0,0 +1,23 @@ +import torch + +from ntops.torch.copy import _copy + + +def slice_scatter(input, src, dim=0, start=None, end=None, step=1): + # `slice_scatter` returns a *new* tensor equal to `input` everywhere except + # the slice `input[..., start:end:step, ...]` (along `dim`), which is taken + # from `src`. The dominant cost is copying all of `input`; that contiguous + # copy is the ninetoothed kernel. Writing `src` into the strided slice view + # is a small torch op (glue), matching the `corrcoef`/`matmul` convention of + # keeping the heavy, regular work on ninetoothed. + input = input.contiguous() + output = torch.empty_like(input) + + _copy(input, output) + + dim = dim % input.ndim + index = [slice(None)] * input.ndim + index[dim] = slice(start, end, step) + output[tuple(index)] = src + + return output diff --git a/src/ntops/torch/slogdet.py b/src/ntops/torch/slogdet.py new file mode 100644 index 0000000..db1e2fd --- /dev/null +++ b/src/ntops/torch/slogdet.py @@ -0,0 +1,52 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + +# Pure-ninetoothed slogdet (DRAFT v0): single-block in-kernel Gaussian +# elimination, no torch.linalg delegation. See kernels/slogdet.py for the +# algorithm and limits (small n only; one program per matrix; no pivoting yet). +# +# NOT YET GPU-VERIFIED — wire in, run tests/test_slogdet.py on a GPU, iterate. + + +def _next_pow2(x): + return 1 << (x - 1).bit_length() + + +def slogdet(A, *, out=None): + assert A.ndim >= 2 and A.shape[-1] == A.shape[-2], ( + "`slogdet` requires square (batched) matrices." + ) + + n = A.shape[-1] + batch_shape = A.shape[:-2] + batch = math.prod(batch_shape) if batch_shape else 1 + + # Accumulate in fp32. Pad to a power-of-2 N so the in-kernel `arange(N)` + # (needed for the masks) is legal. Pad with an identity block so the matrix + # becomes [[A, 0], [0, I]] (det unchanged): the kernel loops to N and the + # padded pivots (=1) contribute log|1|=0, sign*1. + a = A.reshape(batch, n, n).to(torch.float32) + pad = _next_pow2(n) + + if pad != n: + padded = torch.zeros((batch, pad, pad), dtype=torch.float32, device=A.device) + diag = torch.arange(pad, device=A.device) + padded[:, diag, diag] = 1.0 + padded[:, :n, :n] = a + a = padded + + sign = torch.empty((batch,), dtype=torch.float32, device=A.device) + logabsdet = torch.empty((batch,), dtype=torch.float32, device=A.device) + + kernel = _cached_make(ntops.kernels.slogdet.premake) + kernel(a, sign, logabsdet, n=pad) + + out_shape = batch_shape if batch_shape else () + sign = sign.reshape(out_shape) + logabsdet = logabsdet.reshape(out_shape) + + return torch.return_types.linalg_slogdet((sign, logabsdet)) diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..d138dd7 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -68,3 +68,21 @@ def _get_matmul_input_precision(): return ntops.kernels.mm.InputPrecisionVariant.IEEE return ntops.kernels.mm.InputPrecisionVariant.TF32 + + +@functools.cache +def _device_key(): + # Select a launch config by hardware only (not input size / op name), as + # required by the rules. Cached so the device name is queried once. + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return "metax" + + if "iluvatar" in name: + return "iluvatar" + + if "nvidia" in name: + return "nvidia" + + return "default" diff --git a/tests/test_gumbel_softmax.py b/tests/test_gumbel_softmax.py new file mode 100644 index 0000000..72e946a --- /dev/null +++ b/tests/test_gumbel_softmax.py @@ -0,0 +1,52 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_SHAPES = [(8, 16), (4, 1024), (2, 3, 32)] + + +def _seeded(fn, seed, *args, **kwargs): + torch.manual_seed(seed) + return fn(*args, **kwargs) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", _SHAPES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize("dim", (-1, 1)) +def test_gumbel_softmax_soft(shape, dtype, dim): + device = "cuda" + logits = torch.randn(shape, dtype=dtype, device=device) + + # Identical noise via a shared seed; only the softmax backend differs. + ninetoothed_output = _seeded( + ntops.torch.gumbel_softmax, 0, logits, tau=1.0, hard=False, dim=dim + ) + reference_output = _seeded( + F.gumbel_softmax, 0, logits, tau=1.0, hard=False, dim=dim + ) + + rtol, atol = (1e-3, 1e-3) if dtype is torch.float32 else (1e-2, 1e-2) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", _SHAPES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_gumbel_softmax_hard(shape, dtype): + device = "cuda" + logits = torch.randn(shape, dtype=dtype, device=device) + + output = _seeded( + ntops.torch.gumbel_softmax, 0, logits, tau=1.0, hard=True, dim=-1 + ) + + # `hard=True` yields a one-hot along `dim`. + assert torch.equal(output.sum(-1), torch.ones(shape[:-1], device=device)) + assert torch.equal( + ((output == 0) | (output == 1)), + torch.ones_like(output, dtype=torch.bool), + ) diff --git a/tests/test_heaviside.py b/tests/test_heaviside.py new file mode 100644 index 0000000..187eab7 --- /dev/null +++ b/tests/test_heaviside.py @@ -0,0 +1,32 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_heaviside(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + # Force exact zeros into `input` so the `values` branch is exercised. + input = torch.where(input > 0.5, torch.zeros_like(input), input) + values = torch.randn((), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.heaviside(input, values) + reference_output = torch.heaviside(input, values) + + assert torch.equal(ninetoothed_output, reference_output) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_heaviside_broadcast_values(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + values = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.heaviside(input, values) + reference_output = torch.heaviside(input, values) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_hsplit.py b/tests/test_hsplit.py new file mode 100644 index 0000000..1c2fe2c --- /dev/null +++ b/tests/test_hsplit.py @@ -0,0 +1,37 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_CASES = [ + # (shape, indices_or_sections) + ((6,), 3), + ((8,), 2), + ((8,), [2, 5]), + ((4, 6), 2), + ((4, 6), 3), + ((4, 6), [1, 4]), + ((2, 8, 3), 4), + ((2, 9, 3), [3, 6]), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, indices_or_sections", _CASES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_hsplit(shape, indices_or_sections, dtype): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_outputs = ntops.torch.hsplit(input, indices_or_sections) + reference_outputs = torch.hsplit(input, indices_or_sections) + + assert len(ninetoothed_outputs) == len(reference_outputs) + + for ninetoothed_output, reference_output in zip( + ninetoothed_outputs, reference_outputs + ): + assert ninetoothed_output.shape == reference_output.shape + assert ninetoothed_output.is_contiguous() + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_slice_scatter.py b/tests/test_slice_scatter.py new file mode 100644 index 0000000..ec49e24 --- /dev/null +++ b/tests/test_slice_scatter.py @@ -0,0 +1,37 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_CASES = [ + # (shape, dim, start, end, step) + ((8,), 0, 2, 6, 1), + ((8,), 0, None, None, 1), + ((4, 6), 1, 1, 5, 1), + ((4, 6), 1, 0, 6, 2), + ((4, 6), 0, 1, 3, 1), + ((4, 6), -1, 2, None, 1), + ((2, 3, 5), 2, 1, 4, 1), + ((2, 3, 5), 0, None, 2, 1), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dim, start, end, step", _CASES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_slice_scatter(shape, dim, start, end, step, dtype): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + index = [slice(None)] * len(shape) + index[dim] = slice(start, end, step) + src_shape = input[tuple(index)].shape + src = torch.randn(src_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.slice_scatter( + input, src, dim, start, end, step + ) + reference_output = torch.slice_scatter(input, src, dim, start, end, step) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_slogdet.py b/tests/test_slogdet.py new file mode 100644 index 0000000..9150c48 --- /dev/null +++ b/tests/test_slogdet.py @@ -0,0 +1,41 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_SHAPES = [(1, 1), (4, 4), (16, 16), (8, 5, 5), (2, 3, 7, 7)] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", _SHAPES) +@pytest.mark.parametrize("dtype", (torch.float32,)) +def test_slogdet(shape, dtype): + device = "cuda" + A = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_sign, ninetoothed_logabsdet = ntops.torch.slogdet(A) + reference_sign, reference_logabsdet = torch.linalg.slogdet(A) + + assert torch.allclose(ninetoothed_sign, reference_sign) + assert torch.allclose(ninetoothed_logabsdet, reference_logabsdet) + + +@skip_if_cuda_not_available +def test_slogdet_singular(): + device = "cuda" + A = torch.zeros((3, 3), dtype=torch.float32, device=device) + + sign, logabsdet = ntops.torch.slogdet(A) + + assert sign.item() == 0.0 + assert logabsdet.item() == float("-inf") + + +@skip_if_cuda_not_available +def test_slogdet_non_square_raises(): + device = "cuda" + A = torch.randn((3, 4), dtype=torch.float32, device=device) + + with pytest.raises(AssertionError): + ntops.torch.slogdet(A)