Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions bench/bench_slogdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Benchmark ntops.slogdet vs torch.linalg.slogdet.

Single-block in-kernel LU targets *many small matrices*. Reports GB/s and
speedup at several (batch, n) combinations so you can see where the crossover
is between ninetoothed and cuSOLVER.

python bench/bench_slogdet.py
"""

import os
from contextlib import contextmanager

import torch
import triton.testing

import ntops


@contextmanager
def _suppress_stderr():
# The Iluvatar backend emits "loop ... not unrolled" remarks at the C/MLIR
# level (fd 2), which `contextlib.redirect_stderr` can't catch; dup2 the fd.
# Benign (correctness unaffected); this just keeps the bench output clean.
saved = os.dup(2)
devnull = os.open(os.devnull, os.O_WRONLY)
try:
os.dup2(devnull, 2)
yield
finally:
os.dup2(saved, 2)
os.close(devnull)
os.close(saved)

DEVICE = "cuda"
DTYPE = torch.float32

# (batch, n) — n must be a power of 2 after padding; we pass the exact n and
# let the wrapper pad. Add larger n cautiously: single-block LU may fail to
# compile above n~64-128.
CASES = [
(1, 1),
(1, 4),
(1, 8),
(1, 16),
(64, 4),
(128, 4),
(512, 4),
(64, 8),
(128, 8),
(64, 16),
(128, 16),
(512, 16),
(64, 32),
(128, 32),
# probe the large-n ceiling: where does single-block LU stop compiling /
# stop winning vs cuSOLVER?
(1, 64),
(64, 64),
(1, 128),
(64, 128),
(1, 256),
]


def main():
print(f"device: {torch.cuda.get_device_name()}\n")
print(f" {'shape':>16s} {'九齿 ms':>10s} {'torch ms':>10s} speedup")
print(" " + "-" * 55)

for batch, n in CASES:
A = torch.randn(batch, n, n, dtype=DTYPE, device=DEVICE)

try:
with _suppress_stderr():
ms_nt = triton.testing.do_bench(lambda: ntops.torch.slogdet(A))
except Exception as exc:
print(f" ({batch:4d},{n:3d}) SKIP ninetoothed: {type(exc).__name__}")
continue

ms_th = triton.testing.do_bench(lambda: torch.linalg.slogdet(A))

print(
f" ({batch:4d}, {n:3d}) "
f"{ms_nt:10.3f} ms {ms_th:10.3f} ms {ms_th / ms_nt:.2f}x"
)


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions bench/bench_vs_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Compare ninetoothed ops against their torch references.

Reports achieved bandwidth (GB/s) and speedup (torch_ms / ninetoothed_ms) at a
couple of large shapes. Auto-tuning is left ON here, so this shows the
*achievable ceiling*; the eval config (max_num_configs=1) comes from the
tune_*.py scripts. Byte models are approximate — the speedup ratio is the
meaningful number.

Scoring is against other ninetoothed entries, not torch: torch winning on the
view-based op (hsplit) is expected, and slogdet (LU factorization) is omitted
because it *is* torch.

python bench/bench_vs_torch.py
"""

import torch
import torch.nn.functional as F
import triton.testing

import ntops

DEVICE = "cuda"
DTYPE = torch.float32
SHAPES = ((4096, 4096), (8192, 8192))


def _report(name, shape, ms_nt, ms_th, nbytes):
bw_nt = nbytes / ms_nt * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9)
bw_th = nbytes / ms_th * 1e-6
print(
f" {name:15s} {str(shape):14s} "
f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | "
f"speedup {ms_th / ms_nt:.2f}x"
)


def _bench(name, shape, nt_fn, th_fn, byte_factor):
ms_nt = triton.testing.do_bench(nt_fn)
ms_th = triton.testing.do_bench(th_fn)
numel = shape[0] * shape[1]
_report(name, shape, ms_nt, ms_th, byte_factor * numel * torch.tensor([], dtype=DTYPE).element_size())


def main():
print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}\n")

for shape in SHAPES:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
v = torch.randn(shape, dtype=DTYPE, device=DEVICE)

_bench(
"heaviside", shape,
lambda: ntops.torch.heaviside(x, v),
lambda: torch.heaviside(x, v),
byte_factor=3,
)

end = shape[-1] // 2
src = x[..., :end].contiguous()
_bench(
"slice_scatter", shape,
lambda: ntops.torch.slice_scatter(x, src, dim=-1, start=0, end=end),
lambda: torch.slice_scatter(x, src, dim=-1, start=0, end=end),
byte_factor=2,
)

# torch.hsplit returns zero-copy views (≈ no work), so force torch to
# materialize each split for a fair, apples-to-apples comparison.
_bench(
"hsplit", shape,
lambda: ntops.torch.hsplit(x, 2),
lambda: tuple(v.contiguous() for v in torch.hsplit(x, 2)),
byte_factor=2,
)

_bench(
"gumbel_softmax", shape,
lambda: ntops.torch.gumbel_softmax(x, dim=-1),
lambda: F.gumbel_softmax(x, dim=-1),
byte_factor=2,
)

print()


if __name__ == "__main__":
main()
63 changes: 63 additions & 0 deletions bench/tune_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tune ``(num_warps, block_size)`` for the copy kernel.

The copy kernel backs both ``slice_scatter`` (contiguous copy) and ``hsplit``
(strided copy). This tunes the contiguous case — slice_scatter's dominant cost;
hsplit reuses the same config. Same eval-condition / device-adaptive notes as
``tune_heaviside.py``.

python bench/tune_copy.py
"""

import torch
import triton.testing

import ntops
from ntops.torch.utils import _cached_make, set_default_max_num_configs

DEVICE = "cuda"
SHAPE = (8192, 8192)
BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192)
NUM_WARPS = (4, 8, 16)


def _tune(dtype):
input = torch.randn(SHAPE, dtype=dtype, device=DEVICE)
output = torch.empty_like(input)

# copy touches 1 read + 1 write per element.
nbytes = 2 * input.numel() * input.element_size()
best_bw, best_cfg = 0.0, None

for block_size in BLOCK_SIZES:
for num_warps in NUM_WARPS:
try:
kernel = _cached_make(
ntops.kernels.copy.premake,
input.ndim,
block_size=block_size,
num_warps=num_warps,
num_stages=1,
)
ms = triton.testing.do_bench(lambda k=kernel: k(input, output))
bw = nbytes / ms * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9)
print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s")
if bw > best_bw:
best_bw, best_cfg = bw, (num_warps, block_size)
except Exception as exc: # noqa: BLE001
print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})")

return best_bw, best_cfg


def main():
set_default_max_num_configs(1)
print(f"device: {torch.cuda.get_device_name()}")

for dtype in (torch.float32, torch.float16):
print(f"\n[{dtype}]")
bw, cfg = _tune(dtype)
print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)")


if __name__ == "__main__":
main()
66 changes: 66 additions & 0 deletions bench/tune_heaviside.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Tune ``(num_warps, block_size)`` for the heaviside kernel.

Evaluation disables auto-tuning (``max_num_configs=1``), so the winning config
must be passed explicitly into ``premake``. This sweeps configs at a fixed large
shape under eval-like conditions and reports the best GB/s per dtype.

Run on each platform (NVIDIA / Iluvatar / MetaX) and bake the winners into a
device-adaptive ``_launch_config`` in ``ntops/torch/heaviside.py``.

python bench/tune_heaviside.py
"""

import torch
import triton.testing

import ntops
from ntops.torch.utils import _cached_make, set_default_max_num_configs

DEVICE = "cuda"
SHAPE = (8192, 8192)
BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192)
NUM_WARPS = (4, 8, 16)


def _tune(dtype):
input = torch.randn(SHAPE, dtype=dtype, device=DEVICE)
values = torch.randn(SHAPE, dtype=dtype, device=DEVICE)
output = torch.empty_like(input)

# heaviside touches 2 reads + 1 write per element.
nbytes = 3 * input.numel() * input.element_size()
best_bw, best_cfg = 0.0, None

for block_size in BLOCK_SIZES:
for num_warps in NUM_WARPS:
try:
kernel = _cached_make(
ntops.kernels.heaviside.premake,
input.ndim,
block_size=block_size,
num_warps=num_warps,
num_stages=1,
)
ms = triton.testing.do_bench(lambda k=kernel: k(input, values, output))
bw = nbytes / ms * 1e-6 # ms -> s (1e3) and bytes -> GB (1e-9)
print(f" block={block_size:5d} warps={num_warps:2d} {bw:7.0f} GB/s")
if bw > best_bw:
best_bw, best_cfg = bw, (num_warps, block_size)
except Exception as exc: # noqa: BLE001
print(f" block={block_size:5d} warps={num_warps:2d} SKIP ({type(exc).__name__})")

return best_bw, best_cfg


def main():
set_default_max_num_configs(1)
print(f"device: {torch.cuda.get_device_name()}")

for dtype in (torch.float32, torch.float16):
print(f"\n[{dtype}]")
bw, cfg = _tune(dtype)
print(f" best: num_warps={cfg[0]}, block_size={cfg[1]} ({bw:.0f} GB/s)")


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
bmm,
clamp,
conv2d,
copy,
cos,
div,
dropout,
Expand All @@ -17,6 +18,7 @@
ge,
gelu,
gt,
heaviside,
isinf,
isnan,
layer_norm,
Expand All @@ -36,6 +38,7 @@
sigmoid,
silu,
sin,
slogdet,
softmax,
sub,
tanh,
Expand All @@ -52,6 +55,7 @@
"bmm",
"clamp",
"conv2d",
"copy",
"cos",
"div",
"dropout",
Expand All @@ -60,6 +64,7 @@
"ge",
"gelu",
"gt",
"heaviside",
"isinf",
"isnan",
"layer_norm",
Expand All @@ -79,6 +84,7 @@
"sigmoid",
"silu",
"sin",
"slogdet",
"softmax",
"sub",
"tanh",
Expand Down
17 changes: 17 additions & 0 deletions src/ntops/kernels/copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import functools

from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = input # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
24 changes: 24 additions & 0 deletions src/ntops/kernels/heaviside.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, values, output):
output = ntl.where( # noqa: F841
input < 0, 0, ntl.where(input > 0, 1, values)
)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
Loading