Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions bench/bench_t1_1_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Benchmark T1-1-7 operators vs torch.

feature_alpha_dropout / mse_loss / flip / fliplr / pixel_unshuffle

python bench/bench_t1_1_7.py
"""

import torch
import torch.nn.functional as F
import triton.testing

import ntops

DEVICE = "cuda"
DTYPE = torch.float32


def _report(name, shape_str, ms_nt, ms_th, nbytes):
bw_nt = nbytes / ms_nt * 1e-6
bw_th = nbytes / ms_th * 1e-6
print(
f" {name:22s} {shape_str:22s} "
f"九齿 {bw_nt:7.0f} GB/s | torch {bw_th:7.0f} GB/s | "
f"speedup {ms_th / ms_nt:.2f}x"
)


def bench_feature_alpha_dropout():
print("\n[feature_alpha_dropout]")
for shape in [(64, 256, 32, 32), (128, 512, 16, 16), (32, 256, 64, 64)]:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
nbytes = x.numel() * x.element_size() * 2
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.feature_alpha_dropout(x, p=0.5, training=True)
)
ms_th = triton.testing.do_bench(
lambda: F.feature_alpha_dropout(x, p=0.5, training=True)
)
_report("feature_alpha_dropout", str(shape), ms_nt, ms_th, nbytes)


def bench_mse_loss():
print("\n[mse_loss]")
for shape in [(4096, 4096), (8192, 8192), (4096 * 4096,)]:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
t = torch.randn(shape, dtype=DTYPE, device=DEVICE)
nbytes = x.numel() * x.element_size() * 2 # 2 reads
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.mse_loss(x, t, reduction="mean")
)
ms_th = triton.testing.do_bench(
lambda: F.mse_loss(x, t, reduction="mean")
)
_report("mse_loss", str(shape), ms_nt, ms_th, nbytes)


def bench_flip():
print("\n[flip]")
cases = [((4096, 4096), (0,)), ((4096, 4096), (1,)), ((8192, 8192), (0, 1))]
for shape, dims in cases:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
nbytes = x.numel() * x.element_size() * 2 # 1 read + 1 write
ms_nt = triton.testing.do_bench(lambda: ntops.torch.flip(x, dims))
ms_th = triton.testing.do_bench(lambda: torch.flip(x, dims))
_report("flip", f"{shape} dims={dims}", ms_nt, ms_th, nbytes)


def bench_fliplr():
print("\n[fliplr]")
for shape in [(4096, 4096), (8192, 8192)]:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
nbytes = x.numel() * x.element_size() * 2
ms_nt = triton.testing.do_bench(lambda: ntops.torch.fliplr(x))
ms_th = triton.testing.do_bench(lambda: torch.fliplr(x))
_report("fliplr", str(shape), ms_nt, ms_th, nbytes)


def bench_pixel_unshuffle():
print("\n[pixel_unshuffle]")
cases = [((32, 64, 112, 112), 2), ((16, 128, 128, 128), 4), ((64, 64, 64, 64), 2)]
for shape, r in cases:
x = torch.randn(shape, dtype=DTYPE, device=DEVICE)
nbytes = x.numel() * x.element_size() * 2
ms_nt = triton.testing.do_bench(
lambda: ntops.torch.pixel_unshuffle(x, r)
)
ms_th = triton.testing.do_bench(lambda: F.pixel_unshuffle(x, r))
_report("pixel_unshuffle", f"{shape} r={r}", ms_nt, ms_th, nbytes)


def main():
print(f"device: {torch.cuda.get_device_name()} dtype: {DTYPE}")
bench_feature_alpha_dropout()
bench_mse_loss()
bench_flip()
bench_fliplr()
bench_pixel_unshuffle()


if __name__ == "__main__":
main()
118 changes: 118 additions & 0 deletions bench/tune_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tune the pinned launch config (block_size / num_warps / num_stages) for
``ntops.torch.flip`` on the current GPU.

Performance evaluation runs with auto-tuning disabled (``max_num_configs=1``),
so the values baked into ``ntops/torch/flip.py`` decide the score. This script
sweeps a small grid under those exact conditions and prints, per shape, the
fastest config plus the speedup over ``torch.flip``.

Usage
-----
python bench/tune_flip.py
"""

import itertools

import torch

import ntops
from ntops.torch.utils import _cached_make

# Shapes that matter for a bandwidth-bound op: the medium case that has not yet
# saturated memory, and the large cases at the bandwidth ceiling. Small shapes
# are launch-overhead bound and not informative for config tuning.
_SHAPES = [
([4096, 4096], (0, 1)),
([4096, 4096], (1,)),
([8192, 8192], (1,)),
([8192, 8192], (0,)),
]

_BLOCK_SIZES = [512, 1024, 2048, 4096, 8192]
_NUM_WARPS = [4, 8, 16]
_NUM_STAGES = [1, 2]

_DTYPES = [torch.float32, torch.float16]


def _time(fn, n_warmup=10, n_repeat=50):
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_repeat):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / n_repeat


def _run_config(input, output, dims, block_size, num_warps, num_stages):
kernel = _cached_make(
ntops.kernels.flip.premake,
input.ndim,
dims,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=1,
)
return lambda: kernel(input, output)


def tune():
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")

for dtype in _DTYPES:
print(f"\n{'='*92}")
print(f"flip config sweep | dtype={dtype} | device={torch.cuda.get_device_name()}")
print("=" * 92)

for shape, dims in _SHAPES:
input = torch.randn(shape, dtype=dtype, device="cuda")
output = torch.empty(shape, dtype=dtype, device="cuda")
num_bytes = input.numel() * input.element_size() * 2

torch_ms = _time(lambda: torch.flip(input, list(dims)))

results = []
for bs, nw, ns in itertools.product(
_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES
):
try:
fn = _run_config(input, output, dims, bs, nw, ns)
ms = _time(fn)
except Exception as exc: # noqa: BLE001
print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}")
continue
results.append((ms, bs, nw, ns))

results.sort()
best_ms, bbs, bnw, bns = results[0]
best_gbps = num_bytes / (best_ms * 1e-3) / 1e9
torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9

print(
f"\nshape={shape} dims={dims} "
f"(torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)"
)
print(
f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} "
f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s "
f"(speedup vs torch {torch_ms / best_ms:.2f})"
)
print(" top 5:")
for ms, bs, nw, ns in results[:5]:
gbps = num_bytes / (ms * 1e-3) / 1e9
print(
f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} "
f"{ms:.4f} ms / {gbps:.0f} GB/s"
)


if __name__ == "__main__":
tune()
170 changes: 170 additions & 0 deletions bench/tune_mse_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Tune the pinned launch configs for ``ntops.torch.mse_loss`` on the current
GPU.

Two kernels are tuned independently:
* the reduction path (``reduction="mean"|"sum"``) -- the defining, perf
critical kernel; partials buffer size depends on ``block_size``;
* the element-wise path (``reduction="none"``).

Performance evaluation runs with auto-tuning disabled (``max_num_configs=1``),
so the values baked into ``ntops/torch/mse_loss.py`` decide the score. This
script sweeps a small grid under those exact conditions and prints, per shape,
the fastest config plus the speedup over ``torch.nn.functional.mse_loss``.

Usage
-----
python bench/tune_mse_loss.py
"""

import itertools
import math

import torch
import torch.nn.functional as F

import ntops
from ntops.torch.utils import _cached_make

# Numbers of elements to tune over (bandwidth-bound regime). Small sizes are
# launch-overhead bound and not informative for config selection.
_NUMELS = [1024 * 1024, 4096 * 4096, 8192 * 8192]

_BLOCK_SIZES = [512, 1024, 2048, 4096, 8192]
_NUM_WARPS = [4, 8, 16]
_NUM_STAGES = [1, 2]

_DTYPES = [torch.float32, torch.float16]


def _time(fn, n_warmup=10, n_repeat=50):
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_repeat):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / n_repeat


def _reduce_runner(flat_in, flat_tg, block_size, num_warps, num_stages):
numel = flat_in.numel()
num_partials = max(1, math.ceil(numel / block_size))
partials = torch.empty(num_partials, dtype=torch.float32, device=flat_in.device)

kernel = _cached_make(
ntops.kernels.mse_loss.reduce_premake,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=1,
)

def run():
kernel(flat_in, flat_tg, partials)
return partials.sum()

return run


def _none_runner(input, target, output, block_size, num_warps, num_stages):
kernel = _cached_make(
ntops.kernels.mse_loss.premake,
input.ndim,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=1,
)
return lambda: kernel(input, target, output)


def _sweep(label, make_runner, num_bytes, torch_ms):
results = []
for bs, nw, ns in itertools.product(_BLOCK_SIZES, _NUM_WARPS, _NUM_STAGES):
try:
ms = _time(make_runner(bs, nw, ns))
except Exception as exc: # noqa: BLE001
print(f" skip bs={bs} nw={nw} ns={ns}: {type(exc).__name__}")
continue
results.append((ms, bs, nw, ns))

results.sort()
best_ms, bbs, bnw, bns = results[0]
best_gbps = num_bytes / (best_ms * 1e-3) / 1e9
torch_gbps = num_bytes / (torch_ms * 1e-3) / 1e9

print(f"\n [{label}] (torch {torch_ms:.4f} ms / {torch_gbps:.0f} GB/s)")
print(
f" BEST block_size={bbs:<5} num_warps={bnw:<3} num_stages={bns} "
f"-> {best_ms:.4f} ms / {best_gbps:.0f} GB/s "
f"(speedup vs torch {torch_ms / best_ms:.2f})"
)
for ms, bs, nw, ns in results[:5]:
gbps = num_bytes / (ms * 1e-3) / 1e9
print(
f" block_size={bs:<5} num_warps={nw:<3} num_stages={ns} "
f"{ms:.4f} ms / {gbps:.0f} GB/s"
)


def _check_reduce_correctness(dtype):
"""Sanity check that the reduction kernel matches F.mse_loss before trusting
any timing numbers."""
x = torch.randn(40000, dtype=dtype, device="cuda")
y = torch.randn(40000, dtype=dtype, device="cuda")
run = _reduce_runner(x, y, 1024, 4, 1)
got = (run() / x.numel()).to(dtype)
ref = F.mse_loss(x, y, reduction="mean")
tol = 1e-3 if dtype == torch.float32 else 1e-2
assert torch.allclose(got, ref, rtol=tol, atol=tol), (got.item(), ref.item())


def tune():
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")

for dtype in _DTYPES:
_check_reduce_correctness(dtype)
itemsize = torch.empty(0, dtype=dtype).element_size()

print(f"\n{'='*92}")
print(
f"mse_loss config sweep | dtype={dtype} | "
f"device={torch.cuda.get_device_name()}"
)
print("=" * 92)

for numel in _NUMELS:
side = int(round(numel**0.5))
input = torch.randn(numel, dtype=dtype, device="cuda")
target = torch.randn(numel, dtype=dtype, device="cuda")
output = torch.empty_like(input)

print(f"\nnumel={numel} (~{side}^2, {numel * itemsize / 1e6:.1f} MB)")

# Reduction path: reads input + target (2x).
torch_ms = _time(lambda: F.mse_loss(input, target, reduction="sum"))
_sweep(
"reduce (sum/mean)",
lambda bs, nw, ns: _reduce_runner(input, target, bs, nw, ns),
numel * itemsize * 2,
torch_ms,
)

# Element-wise path: reads input + target, writes output (3x).
torch_ms = _time(lambda: F.mse_loss(input, target, reduction="none"))
_sweep(
"none (element-wise)",
lambda bs, nw, ns: _none_runner(input, target, output, bs, nw, ns),
numel * itemsize * 3,
torch_ms,
)


if __name__ == "__main__":
tune()
Loading