diff --git a/bench/bench_frac_scatter_mlml.py b/bench/bench_frac_scatter_mlml.py new file mode 100644 index 0000000..34661c2 --- /dev/null +++ b/bench/bench_frac_scatter_mlml.py @@ -0,0 +1,93 @@ +"""Benchmark frac, scatter_add, multilabel_margin_loss vs torch. + + python bench/bench_frac_scatter_mlml.py +""" + +import torch +import torch.nn.functional as F +import triton.testing + +import ntops + +DEVICE = "cuda" +DTYPE = torch.float32 + + +def _report(name, shape_str, ms_nt, ms_th, nbytes): + bw_nt = nbytes / ms_nt * 1e-6 + bw_th = nbytes / ms_th * 1e-6 + print( + f" {name:22s} {shape_str:18s} " + f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | " + f"speedup {ms_th / ms_nt:.2f}x" + ) + + +def bench_frac(): + print("\n[frac]") + for shape in [(4096 * 4096,), (4096, 4096), (8192, 8192)]: + x = torch.randn(shape, dtype=DTYPE, device=DEVICE) * 5 + nbytes = x.numel() * x.element_size() * 2 # 1 read + 1 write + ms_nt = triton.testing.do_bench(lambda: ntops.torch.frac(x)) + ms_th = triton.testing.do_bench(lambda: torch.frac(x)) + _report("frac", str(shape), ms_nt, ms_th, nbytes) + + +def bench_scatter_add(): + print("\n[scatter_add]") + # (input_shape, dim, k) + cases = [ + ((1024, 256), 1, 128), + ((1024, 256), 1, 256), + ((4096, 512), 1, 256), + ((8192, 256), 0, 4096), + ] + for (shape, dim, k) in cases: + inp = torch.randn(shape, dtype=DTYPE, device=DEVICE) + idx_shape = list(shape); idx_shape[dim] = k + t = shape[dim] + idx = torch.randint(0, t, idx_shape, device=DEVICE) + src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE) + nbytes = (inp.numel() + src.numel()) * inp.element_size() * 2 + label = f"shape={shape} dim={dim} k={k}" + try: + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.scatter_add(inp, dim, idx, src) + ) + except Exception as exc: # noqa: BLE001 + print(f" {'scatter_add':22s} {label:18s} 九齿 SKIP ({type(exc).__name__})") + continue + ms_th = triton.testing.do_bench( + lambda: torch.scatter_add(inp, dim, idx, src) + ) + _report("scatter_add", label, ms_nt, ms_th, nbytes) + + +def bench_mlml(): + print("\n[multilabel_margin_loss]") + cases = [(64, 16), (256, 32), (512, 64), (1024, 32)] + for (n, c) in cases: + x = torch.randn(n, c, dtype=DTYPE, device=DEVICE) + target = torch.full((n, c), -1, dtype=torch.int64, device=DEVICE) + for i in range(n): + num = torch.randint(1, c // 2 + 1, (1,)).item() + target[i, :num] = torch.randperm(c, device=DEVICE)[:num] + nbytes = x.numel() * x.element_size() * 2 + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.multilabel_margin_loss(x, target, reduction="mean") + ) + ms_th = triton.testing.do_bench( + lambda: F.multilabel_margin_loss(x, target, reduction="mean") + ) + _report("mlml", f"N={n} C={c}", ms_nt, ms_th, nbytes) + + +def main(): + print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}") + bench_frac() + bench_scatter_add() + bench_mlml() + + +if __name__ == "__main__": + main() diff --git a/bench/bench_fractional_max_pool.py b/bench/bench_fractional_max_pool.py new file mode 100644 index 0000000..1d6aefd --- /dev/null +++ b/bench/bench_fractional_max_pool.py @@ -0,0 +1,59 @@ +"""Benchmark fractional_max_pool2d/3d (method B) vs torch. + + python bench/bench_fractional_max_pool.py +""" + +import torch +import torch.nn.functional as F +import triton.testing + +import ntops + +DEVICE = "cuda" +DTYPE = torch.float32 + + +def _report(name, shape_str, ms_nt, ms_th): + print(f" {name:10s} {shape_str:28s} 九齿 {ms_nt:8.3f} ms | torch {ms_th:8.3f} ms | speedup {ms_th / ms_nt:.2f}x") + + +def main(): + print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}\n") + + print("[fractional_max_pool2d]") + cases_2d = [ + (8, 16, 16, 16, 2, 2, 12, 12), + (8, 32, 32, 32, 2, 2, 24, 24), + (16, 64, 56, 56, 3, 3, 40, 40), + (32, 64, 112, 112, 2, 2, 80, 80), + ] + for n, c, h, w, kh, kw, oh, ow in cases_2d: + x = torch.randn(n, c, h, w, dtype=DTYPE, device=DEVICE) + s = torch.rand(n, c, 2, dtype=DTYPE, device=DEVICE) + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.fractional_max_pool2d(x, (kh, kw), output_size=(oh, ow), _random_samples=s) + ) + ms_th = triton.testing.do_bench( + lambda: F.fractional_max_pool2d(x, (kh, kw), output_size=(oh, ow), _random_samples=s) + ) + _report("fmp2d", f"({n},{c},{h},{w})->({oh},{ow})", ms_nt, ms_th) + + print("\n[fractional_max_pool3d]") + cases_3d = [ + (4, 16, 16, 16, 16, 2, 2, 2, 12, 12, 12), + (8, 32, 16, 32, 32, 2, 2, 2, 12, 24, 24), + ] + for n, c, d, h, w, kd, kh, kw, od, oh, ow in cases_3d: + x = torch.randn(n, c, d, h, w, dtype=DTYPE, device=DEVICE) + s = torch.rand(n, c, 3, dtype=DTYPE, device=DEVICE) + ms_nt = triton.testing.do_bench( + lambda: ntops.torch.fractional_max_pool3d(x, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=s) + ) + ms_th = triton.testing.do_bench( + lambda: F.fractional_max_pool3d(x, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=s) + ) + _report("fmp3d", f"({n},{c},{d},{h},{w})->({od},{oh},{ow})", ms_nt, ms_th) + + +if __name__ == "__main__": + main() diff --git a/bench/tune_frac.py b/bench/tune_frac.py new file mode 100644 index 0000000..bba6b4a --- /dev/null +++ b/bench/tune_frac.py @@ -0,0 +1,65 @@ +"""Tune (num_warps, block_size) for the frac kernel. + +Evaluation disables auto-tuning (max_num_configs=1), so the winning config +must be passed explicitly into premake. This sweeps at a fixed large shape +under eval-like conditions and reports the best GB/s per dtype. + +Run on each platform (NVIDIA / Iluvatar / MetaX) and bake the winners into a +device-adaptive _launch_config in ntops/torch/frac.py. + + python bench/tune_frac.py +""" + +import torch +import triton.testing + +import ntops +from ntops.torch.utils import _cached_make, set_default_max_num_configs + +DEVICE = "cuda" +SHAPE = (8192, 8192) +BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192) +NUM_WARPS = (4, 8, 16) + + +def _tune(dtype): + input = torch.randn(SHAPE, dtype=dtype, device=DEVICE) * 5 + output = torch.empty_like(input) + + # frac reads 1 tensor and writes 1 → 2 bytes per element. + nbytes = 2 * input.numel() * input.element_size() + best_bw, best_cfg = 0.0, None + + for block_size in BLOCK_SIZES: + for num_warps in NUM_WARPS: + try: + kernel = _cached_make( + ntops.kernels.frac.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + ms = triton.testing.do_bench(lambda k=kernel: k(input, output)) + bw = nbytes / ms * 1e-6 + print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s") + if bw > best_bw: + best_bw, best_cfg = bw, (num_warps, block_size) + except Exception as exc: # noqa: BLE001 + print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})") + + return best_bw, best_cfg + + +def main(): + set_default_max_num_configs(1) + print(f"device: {torch.cuda.get_device_name()}") + + for dtype in (torch.float32, torch.float16): + print(f"\n[{dtype}]") + bw, cfg = _tune(dtype) + print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)") + + +if __name__ == "__main__": + main() diff --git a/bench/tune_fractional_max_pool.py b/bench/tune_fractional_max_pool.py new file mode 100644 index 0000000..93048f9 --- /dev/null +++ b/bench/tune_fractional_max_pool.py @@ -0,0 +1,77 @@ +"""Tune (num_warps, block_size) for the fractional_max_pool2d gather kernel. + +The data-dependent gather is launch/occupancy-sensitive: block_size=1024 gives +too few programs for small/medium M. Sweep to find the config that maximizes +occupancy across sizes. + + python bench/tune_fractional_max_pool.py +""" + +import torch +import triton.testing + +import ntops +from ntops.torch.fractional_max_pool2d import _intervals +from ntops.torch.utils import _cached_make, set_default_max_num_configs + +DEVICE = "cuda" +DTYPE = torch.float32 +BLOCK_SIZES = (64, 128, 256, 512, 1024, 2048) +NUM_WARPS = (1, 2, 4, 8) + +CASES = [ + (8, 16, 16, 16, 2, 2, 12, 12), + (16, 64, 56, 56, 3, 3, 40, 40), + (32, 64, 112, 112, 2, 2, 80, 80), +] + + +def _base_offset(input, start_h, start_w): + n, c, h, w = input.shape + nc_base = (torch.arange(n * c, device=input.device).reshape(n, c)) * (h * w) + bo = nc_base[..., None, None] + (start_h * w)[..., :, None] + start_w[..., None, :] + return bo.reshape(-1).to(torch.int64) + + +def main(): + set_default_max_num_configs(1) + print(f"device: {torch.cuda.get_device_name()}") + + for n, c, h, w, kh, kw, oh, ow in CASES: + x = torch.randn(n, c, h, w, dtype=DTYPE, device=DEVICE).reshape(-1) + xr = x.reshape(n, c, h, w) + s = torch.rand(n, c, 2, dtype=DTYPE, device=DEVICE) + sh = _intervals(s[..., 1], h, oh, kh) + sw = _intervals(s[..., 0], w, ow, kw) + bo = _base_offset(xr, sh, sw) + m = bo.numel() + out = torch.empty((m,), dtype=DTYPE, device=DEVICE) + + ms_th = triton.testing.do_bench( + lambda: torch.nn.functional.fractional_max_pool2d( + xr, (kh, kw), output_size=(oh, ow), _random_samples=s + ) + ) + print(f"\n({n},{c},{h},{w})->({oh},{ow}) M={m} (torch {ms_th:.3f} ms)") + best = (1e9, None) + for block_size in BLOCK_SIZES: + for num_warps in NUM_WARPS: + try: + kernel = _cached_make( + ntops.kernels.fractional_max_pool2d.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + ms = triton.testing.do_bench(lambda k=kernel: k(bo, out, x, w, kh, kw, m)) + if ms < best[0]: + best = (ms, (num_warps, block_size)) + print(f" block={block_size:5d} warps={num_warps:2d} {ms:7.3f} ms ({ms_th/ms:.2f}x)") + except Exception as exc: # noqa: BLE001 + print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})") + print(f" best: num_warps={best[1][0]}, block_size={best[1][1]} ({ms_th/best[0]:.2f}x torch)") + + +if __name__ == "__main__": + main() diff --git a/bench/tune_scatter_add.py b/bench/tune_scatter_add.py new file mode 100644 index 0000000..f96a322 --- /dev/null +++ b/bench/tune_scatter_add.py @@ -0,0 +1,92 @@ +"""Tune scatter_add's atomic-scatter config. + +The atomic scatter is latency-bound (constant ~6 GB/s regardless of size), which +means too few concurrent atomics. The fix is many concurrent threads each doing +~1 atomic: small block_size + enough warps + many programs. This sweeps to find +the config that hides atomic latency. + + python bench/tune_scatter_add.py +""" + +import torch +import triton.testing + +import ntops +from ntops.torch.utils import _cached_make, set_default_max_num_configs + +DEVICE = "cuda" +DTYPE = torch.float32 + +# (input_shape, dim, k) — a few representative scatter shapes. +SHAPES = [ + ((1024, 256), 1, 256), + ((4096, 512), 1, 256), +] +BLOCK_SIZES = (64, 128, 256, 512, 1024, 2048) +NUM_WARPS = (1, 2, 4, 8, 16) + + +def _next_pow2(x): + return 1 << (x - 1).bit_length() + + +def _run(shape, dim, k, block_size, num_warps): + inp = torch.randn(shape, dtype=DTYPE, device=DEVICE) + idx_shape = list(shape) + idx_shape[dim] = k + t = shape[dim] + idx = torch.randint(0, t, idx_shape, device=DEVICE).to(torch.int64) + src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE) + + inp_p = inp.movedim(dim, -1).contiguous() + t_size = inp_p.shape[-1] + rows = inp_p.numel() // t_size + output = inp_p.reshape(rows, t_size).clone() + idx_p = idx.movedim(dim, -1).contiguous().reshape(rows, -1) + k_size = idx_p.shape[-1] + src_p = src.movedim(dim, -1).contiguous().reshape(rows, k_size) + + kernel = _cached_make( + ntops.kernels.scatter_add.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + return triton.testing.do_bench( + lambda: kernel(output, idx_p, src_p, t_size, k_size, rows * k_size) + ) + + +def main(): + set_default_max_num_configs(1) + print(f"device: {torch.cuda.get_device_name()}") + + for shape, dim, k in SHAPES: + inp = torch.randn(shape, dtype=DTYPE, device=DEVICE) + idx_shape = list(shape) + idx_shape[dim] = k + idx = torch.randint(0, shape[dim], idx_shape, device=DEVICE) + src = torch.randn(idx_shape, dtype=DTYPE, device=DEVICE) + ms_th = triton.testing.do_bench(lambda: torch.scatter_add(inp, dim, idx, src)) + + print(f"\nshape={shape} dim={dim} k={k} (torch {ms_th:.3f} ms)") + best = (1e9, None) + for block_size in BLOCK_SIZES: + for num_warps in NUM_WARPS: + try: + ms = _run(shape, dim, k, block_size, num_warps) + print( + f" block={block_size:5d} warps={num_warps:2d} " + f"{ms:8.3f} ms ({ms_th / ms:.2f}x torch)" + ) + if ms < best[0]: + best = (ms, (num_warps, block_size)) + except Exception as exc: # noqa: BLE001 + print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})") + print(f" best: num_warps={best[1][0]}, block_size={best[1][1]} " + f"({ms_th / best[0]:.2f}x torch)") + + +if __name__ == "__main__": + main() diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..0ba6ee3 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -14,6 +14,9 @@ dropout, eq, exp, + frac, + fractional_max_pool2d, + fractional_max_pool3d, ge, gelu, gt, @@ -25,6 +28,7 @@ max_pool2d, mm, mul, + multilabel_margin_loss, ne, neg, pow, @@ -33,6 +37,7 @@ rotary_position_embedding, rsqrt, scaled_dot_product_attention, + scatter_add, sigmoid, silu, sin, @@ -57,6 +62,9 @@ "dropout", "eq", "exp", + "frac", + "fractional_max_pool2d", + "fractional_max_pool3d", "ge", "gelu", "gt", @@ -68,6 +76,7 @@ "max_pool2d", "mm", "mul", + "multilabel_margin_loss", "ne", "neg", "pow", @@ -76,6 +85,7 @@ "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", + "scatter_add", "sigmoid", "silu", "sin", diff --git a/src/ntops/kernels/frac.py b/src/ntops/kernels/frac.py new file mode 100644 index 0000000..e8dded9 --- /dev/null +++ b/src/ntops/kernels/frac.py @@ -0,0 +1,27 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + # frac(x) = x - trunc(x); trunc rounds toward zero. floor only accepts + # fp32/fp64, so compute in fp32 and cast back (ceil(x) = -floor(-x) avoids + # the int-range issues of an int cast). + # frac(x) = x - trunc(x); trunc rounds toward zero. floor only accepts + # fp32/fp64, so compute in fp32 and cast back (ceil(x) = -floor(-x) avoids + # the int-range issues of an int cast). element_wise is single-level so + # `output.dtype` is the element type. + x = ntl.cast(input, ntl.float32) + trunc = ntl.where(x >= 0, ntl.floor(x), -ntl.floor(-x)) + output = ntl.cast(x - trunc, output.dtype) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/fractional_max_pool2d.py b/src/ntops/kernels/fractional_max_pool2d.py new file mode 100644 index 0000000..b3a7366 --- /dev/null +++ b/src/ntops/kernels/fractional_max_pool2d.py @@ -0,0 +1,94 @@ +"""fractional_max_pool2d — fully fused (method Z). + +Everything is in one kernel: from the raw per-(n,c) random samples, each program +computes the window start (PyTorch's interval formula), the input base offset, +then gathers the kH*kW window via data-dependent loads and maxes. The host does +no per-element work (no interval tensors, no base_offset materialization), so the +whole op is a single launch — competitive with torch's fused kernel. + +Earlier methods: B (torch gather + ninetoothed max) was ~0.01x (materialized +windows); C (host base_offset + in-kernel gather) had a fast kernel but ~16 host +torch ops dominated (~0.37ms fixed). Z removes the host glue entirely. +""" + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(output, input, samples, alpha_h, alpha_w, h, w, oh, ow, kh, kw, m, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + output_arranged = output.flatten().tile((block_size,)) + + return output_arranged, input, samples, alpha_h, alpha_w, h, w, oh, ow, kh, kw, m + + +def application(output, input, samples, alpha_h, alpha_w, h, w, oh, ow, kh, kw, m): + p = output.offsets() # flat output index per element + mask = p < m + + ohw = oh * ow + nc = p // ohw + rem = p - nc * ohw + oh_i = rem // ow + ow_i = rem - oh_i * ow + + # samples is (N*C, 2): [..., 1] -> H sample, [..., 0] -> W sample. + sample_h = ntl.cast( + ntl.load(samples.data_ptr() + nc * 2 + 1, mask=mask, other=0.0), ntl.float32 + ) + sample_w = ntl.cast( + ntl.load(samples.data_ptr() + nc * 2 + 0, mask=mask, other=0.0), ntl.float32 + ) + + oh_f = ntl.cast(oh_i, ntl.float32) + ow_f = ntl.cast(ow_i, ntl.float32) + + # start = floor((idx + u)*alpha) - floor(u*alpha); last index forced to in-pool. + sh = ntl.cast( + ntl.floor((oh_f + sample_h) * alpha_h) - ntl.floor(sample_h * alpha_h), ntl.int32 + ) + sw = ntl.cast( + ntl.floor((ow_f + sample_w) * alpha_w) - ntl.floor(sample_w * alpha_w), ntl.int32 + ) + sh = ntl.where(oh_i == oh - 1, h - kh, sh) + sw = ntl.where(ow_i == ow - 1, w - kw, sw) + + base = (nc * h + sh) * w + sw + + acc = ntl.load(input.data_ptr() + base, mask=mask, other=float("-inf")) + for dh in range(kh): + for dw in range(kw): + acc = ntl.maximum( + acc, + ntl.load( + input.data_ptr() + base + (dh * w + dw), mask=mask, other=float("-inf") + ), + ) + + output = acc # noqa: F841 + + +def premake(block_size=None): + tensors = ( + Tensor(1), # output (M,) + Tensor(1, shape_options={"constexpr": True}), # input flat (source) + Tensor(1, shape_options={"constexpr": True}), # samples flat (N*C*2) (source) + Tensor(0, constexpr=True), # alpha_h + Tensor(0, constexpr=True), # alpha_w + Tensor(0, dtype=int, constexpr=True), # h + Tensor(0, dtype=int, constexpr=True), # w + Tensor(0, dtype=int, constexpr=True), # oh + Tensor(0, dtype=int, constexpr=True), # ow + Tensor(0, dtype=int, constexpr=True), # kh + Tensor(0, dtype=int, constexpr=True), # kw + Tensor(0, dtype=int, constexpr=True), # m + ) + + arrangement_ = functools.partial(arrangement, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/fractional_max_pool3d.py b/src/ntops/kernels/fractional_max_pool3d.py new file mode 100644 index 0000000..57d7dbe --- /dev/null +++ b/src/ntops/kernels/fractional_max_pool3d.py @@ -0,0 +1,97 @@ +"""fractional_max_pool3d — fully fused (method Z). 3D analog of the 2d kernel.""" + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement( + output, input, samples, alpha_d, alpha_h, alpha_w, + d, h, w, od, oh, ow, kd, kh, kw, m, block_size=None, +): + if block_size is None: + block_size = ninetoothed.block_size() + + output_arranged = output.flatten().tile((block_size,)) + + return ( + output_arranged, input, samples, alpha_d, alpha_h, alpha_w, + d, h, w, od, oh, ow, kd, kh, kw, m, + ) + + +def application( + output, input, samples, alpha_d, alpha_h, alpha_w, + d, h, w, od, oh, ow, kd, kh, kw, m, +): + p = output.offsets() + mask = p < m + + ohw = oh * ow + odhw = od * ohw + nc = p // odhw + rem = p - nc * odhw + od_i = rem // ohw + rem2 = rem - od_i * ohw + oh_i = rem2 // ow + ow_i = rem2 - oh_i * ow + + # samples is (N*C, 3): [0] -> D, [1] -> H, [2] -> W. + sample_d = ntl.cast(ntl.load(samples.data_ptr() + nc * 3 + 0, mask=mask, other=0.0), ntl.float32) + sample_h = ntl.cast(ntl.load(samples.data_ptr() + nc * 3 + 1, mask=mask, other=0.0), ntl.float32) + sample_w = ntl.cast(ntl.load(samples.data_ptr() + nc * 3 + 2, mask=mask, other=0.0), ntl.float32) + + od_f = ntl.cast(od_i, ntl.float32) + oh_f = ntl.cast(oh_i, ntl.float32) + ow_f = ntl.cast(ow_i, ntl.float32) + + sd = ntl.cast(ntl.floor((od_f + sample_d) * alpha_d) - ntl.floor(sample_d * alpha_d), ntl.int32) + sh = ntl.cast(ntl.floor((oh_f + sample_h) * alpha_h) - ntl.floor(sample_h * alpha_h), ntl.int32) + sw = ntl.cast(ntl.floor((ow_f + sample_w) * alpha_w) - ntl.floor(sample_w * alpha_w), ntl.int32) + sd = ntl.where(od_i == od - 1, d - kd, sd) + sh = ntl.where(oh_i == oh - 1, h - kh, sh) + sw = ntl.where(ow_i == ow - 1, w - kw, sw) + + base = ((nc * d + sd) * h + sh) * w + sw + + acc = ntl.load(input.data_ptr() + base, mask=mask, other=float("-inf")) + for dd in range(kd): + for dh in range(kh): + for dw in range(kw): + acc = ntl.maximum( + acc, + ntl.load( + input.data_ptr() + base + (dd * h * w + dh * w + dw), + mask=mask, + other=float("-inf"), + ), + ) + + output = acc # noqa: F841 + + +def premake(block_size=None): + tensors = ( + Tensor(1), # output (M,) + Tensor(1, shape_options={"constexpr": True}), # input flat (source) + Tensor(1, shape_options={"constexpr": True}), # samples flat (N*C*3) (source) + Tensor(0, constexpr=True), # alpha_d + Tensor(0, constexpr=True), # alpha_h + Tensor(0, constexpr=True), # alpha_w + Tensor(0, dtype=int, constexpr=True), # d + Tensor(0, dtype=int, constexpr=True), # h + Tensor(0, dtype=int, constexpr=True), # w + Tensor(0, dtype=int, constexpr=True), # od + Tensor(0, dtype=int, constexpr=True), # oh + Tensor(0, dtype=int, constexpr=True), # ow + Tensor(0, dtype=int, constexpr=True), # kd + Tensor(0, dtype=int, constexpr=True), # kh + Tensor(0, dtype=int, constexpr=True), # kw + Tensor(0, dtype=int, constexpr=True), # m + ) + + arrangement_ = functools.partial(arrangement, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/multilabel_margin_loss.py b/src/ntops/kernels/multilabel_margin_loss.py new file mode 100644 index 0000000..4fd9917 --- /dev/null +++ b/src/ntops/kernels/multilabel_margin_loss.py @@ -0,0 +1,76 @@ +"""multilabel_margin_loss via one-hot gather + masked reduction. + +DRAFT — needs GPU iteration. One program per sample (row). Per sample: + + raw = sum_{j in valid targets} sum_{i not a target} max(0, 1 - (x[t_j] - x[i])) + +where valid targets are the contiguous non-negative prefix of `target`. The +gather x[target[j]] and the "is i a target" test are done with one-hot masked +reductions (affine addressing). The wrapper divides by C and reduces over N. + +Padding to C = next_pow2: x -> -inf (so padded classes contribute max(0, -inf) += 0 and are never targets), target -> -1 (excluded from the valid prefix). +""" + +import functools + +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + + +def arrangement(x, target, output, c, block_size=None): + x_arranged = x.tile((1, c)) + x_arranged.dtype = x_arranged.dtype.squeeze(0) + + target_arranged = target.tile((1, c)) + target_arranged.dtype = target_arranged.dtype.squeeze(0) + + output_arranged = output.tile((1,)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + return x_arranged, target_arranged, output_arranged + + +def application(x, target, output): + # x, target: (C,) for one sample (C a power of 2; padded x=-inf, target=-1). + cls = ntl.arange(0, x.shape[0]) # class indices + x_f = ntl.cast(x, ntl.float32) + + # valid target positions = contiguous non-negative prefix (before first -1). + neg_pos = ntl.where(target == -1, cls, x.shape[0]) + first_neg = ntl.min(neg_pos) + valid_j = cls < first_neg # (Cj,) + + tgt_col = ntl.expand_dims(target, 1) # (Cj, 1) + cls_row = ntl.expand_dims(cls, 0) # (1, Ci) + onehot = tgt_col == cls_row # (Cj, Ci): target[j] == i + x_row = ntl.expand_dims(x_f, 0) # (1, Ci) + + # xt[j] = x[target[j]] + xt = ntl.sum(ntl.where(onehot, x_row, 0.0), axis=1) # (Cj,) + + valid_col = ntl.expand_dims(valid_j, 1) # (Cj, 1) + # in_T_count[i] = #valid j with target[j] == i -> i is a target iff > 0 + in_t_count = ntl.sum(ntl.where(valid_col & onehot, 1.0, 0.0), axis=0) # (Ci,) + not_in_t = ntl.expand_dims(in_t_count == 0, 0) # (1, Ci) + + xt_col = ntl.expand_dims(xt, 1) # (Cj, 1) + margin = 1.0 - (xt_col - x_row) # (Cj, Ci) + term = ntl.where(margin > 0, margin, 0.0) # relu + + mask = valid_col & not_in_t # (Cj, Ci) + output = ntl.sum(ntl.sum(ntl.where(mask, term, 0.0), axis=1), axis=0) # noqa: F841 + + +def premake(block_size=None): + c = Symbol("c", constexpr=True) + + tensors = ( + Tensor(2, other=float("-inf")), # x (N, C), pad -inf + Tensor(2, other=-1), # target (N, C), pad -1 + Tensor(1), # output (N,) raw per-sample loss (pre /C) + ) + + arrangement_ = functools.partial(arrangement, c=c, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/scatter_add.py b/src/ntops/kernels/scatter_add.py new file mode 100644 index 0000000..a76c40c --- /dev/null +++ b/src/ntops/kernels/scatter_add.py @@ -0,0 +1,58 @@ +"""scatter_add via data-dependent atomic scatter (O(N), like torch). + +DRAFT — needs GPU iteration. The one-hot version was O(T*K) and lost / OOM'd, so +this uses `ntl.atomic_add` into data-dependent output positions (histc's +primitive, PR #60), extended to a *data-dependent* target out_ptr + (r*T + idx). + +The wrapper moves `dim` to the last axis and flattens to output (R, T) (a clone +of input, the accumulation base) and src/index (R, K). One program owns a block +of flat source elements p; row r = p // K; it atomic-adds src[p] into +output[r, index[p]]. + +GRID NOTE: ninetoothed sets grid = prod(first arg's shape). `src` is first so the +grid = number of source blocks (O(N)). `output` is passed through un-arranged +(data_ptr() needs a source tensor); if it were first the grid would be R*T — a +huge oversized grid that kills perf and faults out of bounds. +""" + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(src, index, output, t, k, n, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + src_arranged = src.flatten().tile((block_size,)) + index_arranged = index.flatten().tile((block_size,)) + + return src_arranged, index_arranged, output, t, k, n + + +def application(src, index, output, t, k, n): + p = src.offsets() # flat source positions of this block + row = p // k + target = row * t + index # data-dependent flat offset into output + mask = p < n # exclude tile padding past the real source length + + ntl.atomic_add(output.data_ptr() + target, src, mask=mask) + + +def premake(block_size=None): + tensors = ( + Tensor(2, other=0), # src (R, K) + Tensor(2, other=0), # index (R, K) + # output (R, T): clone of input, accumulated in place. constexpr shape so + # the autotuner sees concrete sizes (it can't bound an un-arranged source). + Tensor(2, shape_options={"constexpr": True}), + Tensor(0, dtype=int, constexpr=True), # t = T (output row length) + Tensor(0, dtype=int, constexpr=True), # k = K (src/index row length) + Tensor(0, dtype=int, constexpr=True), # n = R*K (real source length) + ) + + arrangement_ = functools.partial(arrangement, block_size=block_size) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..af722fd 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -13,6 +13,9 @@ from ntops.torch.dropout import dropout from ntops.torch.eq import eq from ntops.torch.exp import exp +from ntops.torch.frac import frac +from ntops.torch.fractional_max_pool2d import fractional_max_pool2d +from ntops.torch.fractional_max_pool3d import fractional_max_pool3d from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt @@ -25,6 +28,7 @@ from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul +from ntops.torch.multilabel_margin_loss import multilabel_margin_loss from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow @@ -33,6 +37,7 @@ from ntops.torch.rotary_position_embedding import rotary_position_embedding from ntops.torch.rsqrt import rsqrt from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention +from ntops.torch.scatter_add import scatter_add from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin @@ -56,6 +61,9 @@ "dropout", "eq", "exp", + "frac", + "fractional_max_pool2d", + "fractional_max_pool3d", "ge", "gelu", "gt", @@ -68,6 +76,7 @@ "max_pool2d", "mm", "mul", + "multilabel_margin_loss", "ne", "neg", "pow", @@ -76,6 +85,7 @@ "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", + "scatter_add", "sigmoid", "silu", "sin", diff --git a/src/ntops/torch/frac.py b/src/ntops/torch/frac.py new file mode 100644 index 0000000..c40cc97 --- /dev/null +++ b/src/ntops/torch/frac.py @@ -0,0 +1,33 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# (num_warps, block_size) tuned per platform at [8192, 8192]; see +# bench/tune_frac.py. Both fp32 and fp16 are within ~4% of their respective +# peaks with these configs. +_CONFIGS = { + "nvidia": (8, 512), + "iluvatar": (4, 2048), + "metax": (4, 1024), + "default": (4, 2048), +} + + +def frac(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + num_warps, block_size = _CONFIGS[_device_key()] + + kernel = _cached_make( + ntops.kernels.frac.premake, + input.ndim, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + ) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/fractional_max_pool2d.py b/src/ntops/torch/fractional_max_pool2d.py new file mode 100644 index 0000000..131c8ab --- /dev/null +++ b/src/ntops/torch/fractional_max_pool2d.py @@ -0,0 +1,78 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# (num_warps, block_size) tuned per platform; see bench/tune_fractional_max_pool.py. +_CONFIGS = { + "nvidia": (4, 256), + "iluvatar": (8, 512), + "metax": (4, 256), + "default": (4, 256), +} + + +def fractional_max_pool2d( + input, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, +): + assert not return_indices, "`return_indices` is not supported yet." + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + kernel_h, kernel_w = kernel_size + + input = input.contiguous() + n, c, h, w = input.shape + + if output_size is not None: + if isinstance(output_size, int): + output_size = (output_size, output_size) + out_h, out_w = output_size + else: + assert output_ratio is not None, "Either output_size or output_ratio is required." + if isinstance(output_ratio, (int, float)): + output_ratio = (output_ratio, output_ratio) + out_h = int(h * output_ratio[0]) + out_w = int(w * output_ratio[1]) + + if _random_samples is None: + _random_samples = torch.rand(n, c, 2, dtype=input.dtype, device=input.device) + + # alpha = (in - pool) / (out - 1), guarded for out == 1 (the only output is + # forced to in-pool in-kernel, so the value is unused there). float32 matches + # torch's scalar_t interval arithmetic. + alpha_h = float(h - kernel_h) / (out_h - 1) if out_h > 1 else 0.0 + alpha_w = float(w - kernel_w) / (out_w - 1) if out_w > 1 else 0.0 + + m = n * c * out_h * out_w + output = torch.empty((m,), dtype=input.dtype, device=input.device) + + num_warps, block_size = _CONFIGS[_device_key()] + kernel = _cached_make( + ntops.kernels.fractional_max_pool2d.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + kernel( + output, + input.reshape(-1), + _random_samples.reshape(-1), + alpha_h, + alpha_w, + h, + w, + out_h, + out_w, + kernel_h, + kernel_w, + m, + ) + + return output.reshape(n, c, out_h, out_w) diff --git a/src/ntops/torch/fractional_max_pool3d.py b/src/ntops/torch/fractional_max_pool3d.py new file mode 100644 index 0000000..56d6dfc --- /dev/null +++ b/src/ntops/torch/fractional_max_pool3d.py @@ -0,0 +1,80 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +_CONFIGS = { + "nvidia": (4, 256), + "iluvatar": (8, 512), + "metax": (4, 256), + "default": (4, 256), +} + + +def fractional_max_pool3d( + input, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, +): + assert not return_indices, "`return_indices` is not supported yet." + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + kernel_d, kernel_h, kernel_w = kernel_size + + input = input.contiguous() + n, c, d, h, w = input.shape + + if output_size is not None: + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + out_d, out_h, out_w = output_size + else: + assert output_ratio is not None, "Either output_size or output_ratio is required." + if isinstance(output_ratio, (int, float)): + output_ratio = (output_ratio, output_ratio, output_ratio) + out_d = int(d * output_ratio[0]) + out_h = int(h * output_ratio[1]) + out_w = int(w * output_ratio[2]) + + if _random_samples is None: + _random_samples = torch.rand(n, c, 3, dtype=input.dtype, device=input.device) + + alpha_d = float(d - kernel_d) / (out_d - 1) if out_d > 1 else 0.0 + alpha_h = float(h - kernel_h) / (out_h - 1) if out_h > 1 else 0.0 + alpha_w = float(w - kernel_w) / (out_w - 1) if out_w > 1 else 0.0 + + m = n * c * out_d * out_h * out_w + output = torch.empty((m,), dtype=input.dtype, device=input.device) + + num_warps, block_size = _CONFIGS[_device_key()] + kernel = _cached_make( + ntops.kernels.fractional_max_pool3d.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + kernel( + output, + input.reshape(-1), + _random_samples.reshape(-1), + alpha_d, + alpha_h, + alpha_w, + d, + h, + w, + out_d, + out_h, + out_w, + kernel_d, + kernel_h, + kernel_w, + m, + ) + + return output.reshape(n, c, out_d, out_h, out_w) diff --git a/src/ntops/torch/multilabel_margin_loss.py b/src/ntops/torch/multilabel_margin_loss.py new file mode 100644 index 0000000..47e7cff --- /dev/null +++ b/src/ntops/torch/multilabel_margin_loss.py @@ -0,0 +1,37 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _next_pow2(x): + return 1 << (x - 1).bit_length() + + +def multilabel_margin_loss(input, target, reduction="mean"): + # One program per sample computes the raw (pre-/C) margin sum; the wrapper + # divides by C and reduces over N. + orig_ndim = input.ndim + if orig_ndim == 1: + input = input.unsqueeze(0) + target = target.unsqueeze(0) + + n, c = input.shape + out = torch.empty((n,), dtype=torch.float32, device=input.device) + + kernel = _cached_make(ntops.kernels.multilabel_margin_loss.premake) + kernel(input, target.to(torch.int64), out, c=_next_pow2(c)) + + out = out / c + + if orig_ndim == 1: + out = out.squeeze(0) + + if reduction == "mean": + result = out.mean() + elif reduction == "sum": + result = out.sum() + else: + result = out + + return result.to(input.dtype) diff --git a/src/ntops/torch/scatter_add.py b/src/ntops/torch/scatter_add.py new file mode 100644 index 0000000..938511c --- /dev/null +++ b/src/ntops/torch/scatter_add.py @@ -0,0 +1,49 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _device_key + +# (num_warps, block_size) tuned per platform; see bench/tune_scatter_add.py. +_CONFIGS = { + "nvidia": (8, 256), + "iluvatar": (2, 1024), + "metax": (2, 1024), + "default": (8, 256), +} + + +def scatter_add(input, dim, index, src): + # Move `dim` to the last axis and flatten to output (R, T) / index,src (R, K). + # output starts as a clone of input (the accumulation base); the kernel + # atomic-adds each src element into output[r, index]. Supports the common + # case where index/src match input in the non-`dim` dims (asserted). + dim = dim % input.ndim + + inp = input.movedim(dim, -1).contiguous() + permuted_shape = inp.shape + t_size = permuted_shape[-1] + rows = inp.numel() // t_size + + idx = index.movedim(dim, -1).contiguous() + k_size = idx.shape[-1] + assert idx.numel() // k_size == rows, ( + "scatter_add currently requires index/src to match input in non-`dim` dims." + ) + + output = inp.reshape(rows, t_size).clone() + idx = idx.reshape(rows, k_size).to(torch.int64) + src = src.movedim(dim, -1).contiguous().reshape(rows, k_size) + + # Force a single config (eval-aligned): the un-arranged `output` source has + # no autotunable size bounds, so the autotuner can't run over it. + num_warps, block_size = _CONFIGS[_device_key()] + kernel = _cached_make( + ntops.kernels.scatter_add.premake, + block_size=block_size, + num_warps=num_warps, + num_stages=1, + max_num_configs=1, + ) + kernel(src, idx, output, t_size, k_size, rows * k_size) + + return output.reshape(permuted_shape).movedim(-1, dim) diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..d138dd7 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -68,3 +68,21 @@ def _get_matmul_input_precision(): return ntops.kernels.mm.InputPrecisionVariant.IEEE return ntops.kernels.mm.InputPrecisionVariant.TF32 + + +@functools.cache +def _device_key(): + # Select a launch config by hardware only (not input size / op name), as + # required by the rules. Cached so the device name is queried once. + name = torch.cuda.get_device_name().lower() if torch.cuda.is_available() else "" + + if "metax" in name: + return "metax" + + if "iluvatar" in name: + return "iluvatar" + + if "nvidia" in name: + return "nvidia" + + return "default" diff --git a/tests/test_frac.py b/tests/test_frac.py new file mode 100644 index 0000000..a0d059a --- /dev/null +++ b/tests/test_frac.py @@ -0,0 +1,18 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_frac(shape, dtype, device, rtol, atol): + # Spread values so integer parts are nontrivial (exercises trunc). + input = torch.randn(shape, dtype=dtype, device=device) * 5 + + ninetoothed_output = ntops.torch.frac(input) + reference_output = torch.frac(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_fractional_max_pool2d.py b/tests/test_fractional_max_pool2d.py new file mode 100644 index 0000000..e795555 --- /dev/null +++ b/tests/test_fractional_max_pool2d.py @@ -0,0 +1,32 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_CASES = [ + # (N, C, H, W, kH, kW, oH, oW) + (1, 1, 16, 16, 2, 2, 8, 8), + (2, 3, 16, 16, 2, 2, 12, 12), + (2, 3, 32, 32, 3, 3, 20, 20), + (1, 4, 24, 24, 2, 2, 16, 16), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("n, c, h, w, kh, kw, oh, ow", _CASES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_fractional_max_pool2d(n, c, h, w, kh, kw, oh, ow, dtype): + device = "cuda" + input = torch.randn(n, c, h, w, dtype=dtype, device=device) + samples = torch.rand(n, c, 2, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.fractional_max_pool2d( + input, (kh, kw), output_size=(oh, ow), _random_samples=samples + ) + reference_output = F.fractional_max_pool2d( + input, (kh, kw), output_size=(oh, ow), _random_samples=samples + ) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_fractional_max_pool3d.py b/tests/test_fractional_max_pool3d.py new file mode 100644 index 0000000..efbbb72 --- /dev/null +++ b/tests/test_fractional_max_pool3d.py @@ -0,0 +1,31 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_CASES = [ + # (N, C, D, H, W, kD, kH, kW, oD, oH, oW) + (1, 1, 8, 8, 8, 2, 2, 2, 4, 4, 4), + (2, 3, 8, 12, 16, 2, 2, 2, 6, 9, 12), + (1, 2, 12, 12, 12, 3, 3, 3, 8, 8, 8), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("n, c, d, h, w, kd, kh, kw, od, oh, ow", _CASES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_fractional_max_pool3d(n, c, d, h, w, kd, kh, kw, od, oh, ow, dtype): + device = "cuda" + input = torch.randn(n, c, d, h, w, dtype=dtype, device=device) + samples = torch.rand(n, c, 3, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.fractional_max_pool3d( + input, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=samples + ) + reference_output = F.fractional_max_pool3d( + input, (kd, kh, kw), output_size=(od, oh, ow), _random_samples=samples + ) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_multilabel_margin_loss.py b/tests/test_multilabel_margin_loss.py new file mode 100644 index 0000000..eb85786 --- /dev/null +++ b/tests/test_multilabel_margin_loss.py @@ -0,0 +1,51 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _random_target(shape, c, device): + # A contiguous non-negative prefix of label indices, then -1 padding. + target = torch.full(shape, -1, dtype=torch.int64, device=device) + n = shape[0] if len(shape) == 2 else 1 + rows = target.reshape(n, c) + for r in range(n): + num = torch.randint(1, c + 1, (1,)).item() + labels = torch.randperm(c, device=device)[:num] + rows[r, :num] = labels + return target + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", [(4, 8), (2, 16), (8, 32)]) +@pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) +def test_multilabel_margin_loss(shape, reduction): + device = "cuda" + c = shape[-1] + input = torch.randn(shape, dtype=torch.float32, device=device) + target = _random_target(shape, c, device) + + ninetoothed_output = ntops.torch.multilabel_margin_loss( + input, target, reduction=reduction + ) + reference_output = F.multilabel_margin_loss(input, target, reduction=reduction) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=1e-3, atol=1e-3) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) +def test_multilabel_margin_loss_1d(reduction): + device = "cuda" + c = 8 + input = torch.randn((c,), dtype=torch.float32, device=device) + target = _random_target((c,), c, device) + + ninetoothed_output = ntops.torch.multilabel_margin_loss( + input, target, reduction=reduction + ) + reference_output = F.multilabel_margin_loss(input, target, reduction=reduction) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=1e-3, atol=1e-3) diff --git a/tests/test_scatter_add.py b/tests/test_scatter_add.py new file mode 100644 index 0000000..4125fa9 --- /dev/null +++ b/tests/test_scatter_add.py @@ -0,0 +1,35 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + +_CASES = [ + # (input_shape, dim, k) — index/src share input's non-dim dims, dim-size = k + ((8,), 0, 5), + ((8,), 0, 8), + ((4, 6), 0, 4), + ((4, 6), 1, 6), + ((4, 6), 1, 3), + ((2, 3, 5), 2, 5), + ((2, 3, 5), 0, 2), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dim, k", _CASES) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +def test_scatter_add(shape, dim, k, dtype): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + index_shape = list(shape) + index_shape[dim] = k + t = shape[dim] + index = torch.randint(0, t, index_shape, device=device) + src = torch.randn(index_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.scatter_add(input, dim, index, src) + reference_output = torch.scatter_add(input, dim, index, src) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=1e-2, atol=1e-2)