From 340ed0fe5e76735a28ad540231992ce6d6fbb4d0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 25 May 2026 12:39:34 +0000 Subject: [PATCH] add operator --- src/ntops/kernels/__init__.py | 10 ++ src/ntops/kernels/gumbel_softmax.py | 116 +++++++++++++++ src/ntops/kernels/heaviside.py | 26 ++++ src/ntops/kernels/hsplit.py | 47 +++++++ src/ntops/kernels/slice_scatter.py | 69 +++++++++ src/ntops/kernels/slogdet.py | 209 ++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 10 ++ src/ntops/torch/gumbel_softmax.py | 20 +++ src/ntops/torch/heaviside.py | 57 ++++++++ src/ntops/torch/hsplit.py | 133 ++++++++++++++++++ src/ntops/torch/slice_scatter.py | 70 ++++++++++ src/ntops/torch/slogdet.py | 66 +++++++++ tests/test_gumbel_softmax.py | 65 +++++++++ tests/test_heaviside.py | 29 ++++ tests/test_hsplit.py | 61 ++++++++ tests/test_slice_scatter.py | 47 +++++++ tests/test_slogdet.py | 84 +++++++++++ 17 files changed, 1119 insertions(+) create mode 100644 src/ntops/kernels/gumbel_softmax.py create mode 100644 src/ntops/kernels/heaviside.py create mode 100644 src/ntops/kernels/hsplit.py create mode 100644 src/ntops/kernels/slice_scatter.py create mode 100644 src/ntops/kernels/slogdet.py create mode 100644 src/ntops/torch/gumbel_softmax.py create mode 100644 src/ntops/torch/heaviside.py create mode 100644 src/ntops/torch/hsplit.py create mode 100644 src/ntops/torch/slice_scatter.py create mode 100644 src/ntops/torch/slogdet.py create mode 100644 tests/test_gumbel_softmax.py create mode 100644 tests/test_heaviside.py create mode 100644 tests/test_hsplit.py create mode 100644 tests/test_slice_scatter.py create mode 100644 tests/test_slogdet.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..ed125d6 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -16,7 +16,10 @@ exp, ge, gelu, + gumbel_softmax, gt, + heaviside, + hsplit, isinf, isnan, layer_norm, @@ -36,6 +39,8 @@ sigmoid, silu, sin, + slice_scatter, + slogdet, softmax, sub, tanh, @@ -59,7 +64,10 @@ "exp", "ge", "gelu", + "gumbel_softmax", "gt", + "heaviside", + "hsplit", "isinf", "isnan", "layer_norm", @@ -79,6 +87,8 @@ "sigmoid", "silu", "sin", + "slice_scatter", + "slogdet", "softmax", "sub", "tanh", diff --git a/src/ntops/kernels/gumbel_softmax.py b/src/ntops/kernels/gumbel_softmax.py new file mode 100644 index 0000000..e06717a --- /dev/null +++ b/src/ntops/kernels/gumbel_softmax.py @@ -0,0 +1,116 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, tau, hard, output): + out_dtype = output.dtype.dtype + + tau_f = ntl.cast(tau, ntl.float32) + hard_enabled = hard != ntl.cast(0.0, ntl.float64) + + zero_f = ntl.cast(0.0, ntl.float32) + one_f = ntl.cast(1.0, ntl.float32) + eps_f = ntl.cast(1.0e-6, ntl.float32) + neg_inf_f = ntl.cast(float("-inf"), ntl.float32) + + prev_max = neg_inf_f + denominator = zero_f + + # First pass: compute max and denominator for stable softmax. + for i in range(input.shape[0]): + input_i = ntl.cast(input[i], ntl.float32) + + # For masked lanes, input_i may be -inf. Do not feed -inf to sin. + seed_input_i = ntl.where(input_i == neg_inf_f, zero_f, input_i) + + idx_f = ntl.cast(i + 1, ntl.float32) + + seed_i = ( + seed_input_i * ntl.cast(12.9898, ntl.float32) + + idx_f * ntl.cast(78.233, ntl.float32) + ) + + random_i = ntl.sin(seed_i) * ntl.cast(43758.5453, ntl.float32) + random_floor_i = ntl.floor(random_i) + u_raw_i = random_i - random_floor_i + + u_min_i = ntl.maximum(u_raw_i, eps_f) + u_i = ntl.where(u_min_i > one_f - eps_f, one_f - eps_f, u_min_i) + + log_u_i = ntl.log(u_i) + neg_log_u_i = -log_u_i + log_neg_log_u_i = ntl.log(neg_log_u_i) + gumbel_i = -log_neg_log_u_i + + value_i = (input_i + gumbel_i) / tau_f + + block_max_i = ntl.max(value_i) + curr_max = ntl.cast(ntl.maximum(prev_max, block_max_i), ntl.float32) + + value_diff_i = value_i - curr_max + value_exp_i = ntl.exp(value_diff_i) + + prev_diff_i = prev_max - curr_max + prev_exp_i = ntl.exp(prev_diff_i) + + denominator = denominator * prev_exp_i + ntl.sum(value_exp_i) + prev_max = curr_max + + # Second pass: write soft or hard output. + for i in range(input.shape[0]): + input_i = ntl.cast(input[i], ntl.float32) + + seed_input_i = ntl.where(input_i == neg_inf_f, zero_f, input_i) + + idx_f = ntl.cast(i + 1, ntl.float32) + + seed_i = ( + seed_input_i * ntl.cast(12.9898, ntl.float32) + + idx_f * ntl.cast(78.233, ntl.float32) + ) + + random_i = ntl.sin(seed_i) * ntl.cast(43758.5453, ntl.float32) + random_floor_i = ntl.floor(random_i) + u_raw_i = random_i - random_floor_i + + u_min_i = ntl.maximum(u_raw_i, eps_f) + u_i = ntl.where(u_min_i > one_f - eps_f, one_f - eps_f, u_min_i) + + log_u_i = ntl.log(u_i) + neg_log_u_i = -log_u_i + log_neg_log_u_i = ntl.log(neg_log_u_i) + gumbel_i = -log_neg_log_u_i + + value_i = (input_i + gumbel_i) / tau_f + + soft_exp_i = ntl.exp(value_i - prev_max) + soft_i = soft_exp_i / denominator + + hard_i = ntl.where(value_i == prev_max, one_f, zero_f) + + result_i = ntl.where(hard_enabled, hard_i, soft_i) + + output[i] = ntl.cast(result_i, out_dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor( + ndim, + dtype=dtype, + other=float("-inf"), + shape_options={"constexpr": True}, + ), + Tensor(0, dtype=ninetoothed.float64), # tau + Tensor(0, dtype=ninetoothed.float64), # hard: 0.0 / 1.0 + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/heaviside.py b/src/ntops/kernels/heaviside.py new file mode 100644 index 0000000..8c3f548 --- /dev/null +++ b/src/ntops/kernels/heaviside.py @@ -0,0 +1,26 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, values, output): + zero = input * 0 + one = zero + 1 + + tmp = ntl.where(input < zero, zero, one) + output = ntl.where(input == zero, values, tmp) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), # input + Tensor(ndim, dtype=dtype), # values + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/hsplit.py b/src/ntops/kernels/hsplit.py new file mode 100644 index 0000000..88a1dce --- /dev/null +++ b/src/ntops/kernels/hsplit.py @@ -0,0 +1,47 @@ +# src/ntops/kernels/hsplit.py + +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement as element_wise_arrangement + + +def hsplit_arrangement( + input, + output, + dim=None, + start=None, + end=None, + block_size=None, +): + if dim is None: + dim = 0 if input.ndim == 1 else 1 + + slices = [slice(None)] * input.ndim + slices[dim] = slice(start, end) + + input_slice = input[tuple(slices)] + + return element_wise_arrangement(input_slice, output, block_size=block_size) + + +def hsplit_application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dim, start, end, dtype=None, block_size=None): + arrangement_ = functools.partial( + hsplit_arrangement, + dim=dim, + start=start, + end=end, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), # input + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, hsplit_application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/slice_scatter.py b/src/ntops/kernels/slice_scatter.py new file mode 100644 index 0000000..5767be0 --- /dev/null +++ b/src/ntops/kernels/slice_scatter.py @@ -0,0 +1,69 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement as element_wise_arrangement + + +def copy_arrangement(input, output, block_size=None): + return element_wise_arrangement(input, output, block_size=block_size) + + +def copy_application(input, output): + output = input # noqa: F841 + + +def premake_copy(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial( + copy_arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), # input + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, copy_application, tensors + + +def scatter_arrangement( + source, + output, + dim=None, + start=None, + end=None, + step=None, + block_size=None, +): + if step is None: + step = 1 + + slices = [slice(None)] * output.ndim + slices[dim] = slice(start, end, step) + + output_slice = output[tuple(slices)] + + return element_wise_arrangement(source, output_slice, block_size=block_size) + + +def scatter_application(source, output): + output = source # noqa: F841 + + +def premake_scatter(ndim, dim, start, end, step, dtype=None, block_size=None): + arrangement_ = functools.partial( + scatter_arrangement, + dim=dim, + start=start, + end=end, + step=step, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), # source + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, scatter_application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/slogdet.py b/src/ntops/kernels/slogdet.py new file mode 100644 index 0000000..f7ca33f --- /dev/null +++ b/src/ntops/kernels/slogdet.py @@ -0,0 +1,209 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, sign, logabsdet, block_size=None): + # block_size 必须是 2 的幂。 + # 例如 matrix_size=3 时,用 4x4 tile。 + input = input.tile((block_size, block_size)) + + sign = sign.tile((1, 1)) + logabsdet = logabsdet.tile((1, 1)) + + return input, sign, logabsdet + + +def _abs_f32(x): + zero = ntl.cast(0.0, ntl.float32) + return ntl.where(x < zero, -x, x) + + +def _sign_f32(x): + zero = ntl.cast(0.0, ntl.float32) + one = ntl.cast(1.0, ntl.float32) + minus_one = ntl.cast(-1.0, ntl.float32) + + return ntl.where(x > zero, one, ntl.where(x < zero, minus_one, zero)) + + +def application(input, sign, logabsdet): + dtype = ntl.float32 + + zero = ntl.cast(0.0, dtype) + one = ntl.cast(1.0, dtype) + minus_one = ntl.cast(-1.0, dtype) + neg_inf = ntl.cast(float("-inf"), dtype) + + # input 是 block_size x block_size block。 + # 对于 3x3,实际是 4x4 block,越界位置由 Tensor(other=0.0) 填充。 + a = ntl.cast(input, dtype) + + row_idx = ntl.cast(input.offsets(-2), ntl.int32) + col_idx = ntl.cast(input.offsets(-1), ntl.int32) + + rows = row_idx[:, None] + cols = col_idx[None, :] + + # 真实矩阵大小,不是 block_size。 + # 例如 3x3 输入时,input.source.shape[-2] == 3。 + n_i32 = ntl.cast(input.source.shape[-2], ntl.int32) + + det_sign = one + log_abs_det = zero + singular = zero != zero + + # 只循环真实矩阵大小。 + for k in range(input.source.shape[-2]): + k_i32 = ntl.cast(k, ntl.int32) + + # 取第 k 列:不用 a[:, k],用 mask + sum。 + col_k = ntl.sum( + ntl.where(cols == k_i32, a, zero), + 1, + ) + + col_abs = _abs_f32(col_k) + + valid_rows = (row_idx >= k_i32) & (row_idx < n_i32) + + masked_abs = ntl.where(valid_rows, col_abs, minus_one) + pivot_abs = ntl.max(masked_abs) + + is_zero_pivot = pivot_abs == zero + singular = singular | is_zero_pivot + + pivot_mask = (masked_abs == pivot_abs) & valid_rows + + pivot_is_k = ( + ntl.sum( + ntl.where(pivot_mask & (row_idx == k_i32), one, zero) + ) + > zero + ) + + det_sign = ntl.where(pivot_is_k, det_sign, -det_sign) + + # 第 k 行 + row_k = ntl.sum( + ntl.where(rows == k_i32, a, zero), + 0, + ) + + # pivot 行 + pivot_row = ntl.sum( + ntl.where(pivot_mask[:, None], a, zero), + 0, + ) + + # 交换第 k 行和 pivot 行。 + a = ntl.where( + rows == k_i32, + pivot_row[None, :], + ntl.where( + pivot_mask[:, None], + row_k[None, :], + a, + ), + ) + + # pivot = a[k, k] + pivot = ntl.sum( + ntl.where((rows == k_i32) & (cols == k_i32), a, zero) + ) + + pivot_abs_after_swap = _abs_f32(pivot) + pivot_sign = _sign_f32(pivot) + + det_sign = det_sign * pivot_sign + + safe_pivot_abs = ntl.where( + pivot_abs_after_swap == zero, + one, + pivot_abs_after_swap, + ) + + log_abs_det = log_abs_det + ntl.log(safe_pivot_abs) + + safe_pivot = ntl.where( + pivot_abs_after_swap == zero, + one, + pivot, + ) + + row_k_after_swap = ntl.sum( + ntl.where(rows == k_i32, a, zero), + 0, + ) + + col_k_after_swap = ntl.sum( + ntl.where(cols == k_i32, a, zero), + 1, + ) + + factor = col_k_after_swap / safe_pivot + update = a - factor[:, None] * row_k_after_swap[None, :] + + # 只更新真实矩阵范围内的 trailing submatrix。 + update_mask = ( + (rows > k_i32) + & (cols > k_i32) + & (rows < n_i32) + & (cols < n_i32) + ) + + a = ntl.where(update_mask, update, a) + + final_sign = ntl.where(singular, zero, det_sign) + final_logabsdet = ntl.where(singular, neg_inf, log_abs_det) + + sign[0, 0] = final_sign # noqa: F841 + logabsdet[0, 0] = final_logabsdet # noqa: F841 + + +def premake(ndim, matrix_size, block_size, dtype=None): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + input_tensor = Tensor( + ndim, + shape=(matrix_size, matrix_size), + dtype=dtype, + other=0.0, + shape_options=( + {"constexpr": True, "upper_bound": 16}, + {"constexpr": True, "upper_bound": 16}, + ), + ) + + sign_tensor = Tensor( + 2, + shape=(1, 1), + dtype=dtype, + shape_options=( + {"constexpr": True, "upper_bound": 1}, + {"constexpr": True, "upper_bound": 1}, + ), + ) + + logabsdet_tensor = Tensor( + 2, + shape=(1, 1), + dtype=dtype, + shape_options=( + {"constexpr": True, "upper_bound": 1}, + {"constexpr": True, "upper_bound": 1}, + ), + ) + + tensors = ( + input_tensor, + sign_tensor, + logabsdet_tensor, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..9dc8894 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -16,6 +16,9 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.gumbel_softmax import gumbel_softmax +from ntops.torch.hsplit import hsplit +from ntops.torch.heaviside import heaviside from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -36,6 +39,8 @@ from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin +from ntops.torch.slice_scatter import slice_scatter +from ntops.torch.slogdet import slogdet from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh @@ -58,7 +63,10 @@ "exp", "ge", "gelu", + "gumbel_softmax", "gt", + "heaviside", + "hsplit", "isinf", "isnan", "layer_norm", @@ -79,6 +87,8 @@ "sigmoid", "silu", "sin", + "slice_scatter", + "slogdet", "softmax", "sub", "tanh", diff --git a/src/ntops/torch/gumbel_softmax.py b/src/ntops/torch/gumbel_softmax.py new file mode 100644 index 0000000..379639f --- /dev/null +++ b/src/ntops/torch/gumbel_softmax.py @@ -0,0 +1,20 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def gumbel_softmax(input, tau=1.0, hard=False, eps=1e-10, dim=-1): + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.gumbel_softmax.premake, + input.ndim, + dim, + ) + + hard_value = 1.0 if hard else 0.0 + + kernel(input, float(tau), hard_value, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/heaviside.py b/src/ntops/torch/heaviside.py new file mode 100644 index 0000000..92b04ab --- /dev/null +++ b/src/ntops/torch/heaviside.py @@ -0,0 +1,57 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _storage_ptr(x): + candidates = [x] + + underlying = getattr(x, "_underlying", None) + if underlying is not None: + candidates.append(underlying) + + for obj in candidates: + data_ptr = getattr(obj, "data_ptr", None) + if callable(data_ptr): + return data_ptr() + + if isinstance(obj, torch.Tensor): + if hasattr(obj, "untyped_storage"): + return obj.untyped_storage().data_ptr() + + return obj.storage().data_ptr() + + return None + + +def _same_storage(a, b): + if a is b: + return True + + pa = _storage_ptr(a) + pb = _storage_ptr(b) + + return pa is not None and pb is not None and pa == pb + + +def heaviside(input, values, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.heaviside.premake, + input.ndim, + dtype=input.dtype, + ) + + need_tmp = _same_storage(out, input) + + actual_out = torch.empty_like(input) if need_tmp else out + + kernel(input, values, actual_out) + + if need_tmp: + out.copy_(actual_out) + + return out diff --git a/src/ntops/torch/hsplit.py b/src/ntops/torch/hsplit.py new file mode 100644 index 0000000..7c28bba --- /dev/null +++ b/src/ntops/torch/hsplit.py @@ -0,0 +1,133 @@ +# src/ntops/torch/hsplit.py + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _hsplit_dim(input): + if input.ndim == 0: + raise RuntimeError("hsplit expects a tensor with at least 1 dimension") + + return 0 if input.ndim == 1 else 1 + + +def _unwrap_indices_or_sections(indices_or_sections): + value = indices_or_sections + + # 兼容 infinicore 测试框架里的 Sections 包装类 + for name in ( + "value", + "values", + "data", + "sections", + "indices", + "indices_or_sections", + ): + if hasattr(value, name): + attr = getattr(value, name) + if not callable(attr): + value = attr + break + + if not isinstance(value, (int, list, tuple)) and hasattr(value, "__dict__"): + for attr in vars(value).values(): + if isinstance(attr, (int, list, tuple)): + value = attr + break + + return value + + +def _normalize_index(index, size): + index = int(index) + + if index < 0: + index += size + + return max(0, min(index, size)) + + +def _get_split_ranges(size, indices_or_sections): + indices_or_sections = _unwrap_indices_or_sections(indices_or_sections) + + if isinstance(indices_or_sections, int): + sections = indices_or_sections + + if sections <= 0: + raise RuntimeError("number of sections must be larger than 0") + + base = size // sections + extra = size % sections + + ranges = [] + start = 0 + + for i in range(sections): + length = base + (1 if i < extra else 0) + end = start + length + ranges.append((start, end)) + start = end + + return ranges + + indices = [_normalize_index(index, size) for index in indices_or_sections] + + starts = [0] + indices + ends = indices + [size] + + return list(zip(starts, ends)) + + +def _empty_output(input, dim, length): + output_shape = list(input.shape) + output_shape[dim] = length + + return torch.empty( + tuple(output_shape), + dtype=input.dtype, + device=input.device, + ) + + +def _copy_slice(input, dim, start, end): + length = end - start + + output = _empty_output(input, dim, length) + + # 空切片不需要 launch kernel + if length == 0: + return output + + kernel = _cached_make( + ntops.kernels.hsplit.premake, + input.ndim, + dim, + start, + end, + ) + + kernel(input, output) + + return output + + +def hsplit(input, indices_or_sections): + dim = _hsplit_dim(input) + size = input.shape[dim] + + ranges = _get_split_ranges(size, indices_or_sections) + + outputs = [] + + for start, end in ranges: + # fast path: + # 如果这一段就是完整 input,直接返回 input,不 copy + if start == 0 and end == size: + outputs.append(input) + continue + + outputs.append(_copy_slice(input, dim, start, end)) + + return tuple(outputs) \ No newline at end of file diff --git a/src/ntops/torch/slice_scatter.py b/src/ntops/torch/slice_scatter.py new file mode 100644 index 0000000..594afd9 --- /dev/null +++ b/src/ntops/torch/slice_scatter.py @@ -0,0 +1,70 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _normalize_dim(dim, ndim): + if dim < 0: + dim += ndim + + if dim < 0 or dim >= ndim: + raise IndexError("dim out of range") + + return dim + + +def _normalize_start_end(start, end, size): + if start is None: + start = 0 + + if end is None: + end = size + + if start < 0: + start += size + + if end < 0: + end += size + + start = max(0, min(start, size)) + end = max(0, min(end, size)) + + return start, end + + +def slice_scatter(input, source, dim=0, start=None, end=None, step=1): + if input.ndim == 0: + raise RuntimeError("slice_scatter does not support zero-dimensional input") + + if step is None: + step = 1 + + if step <= 0: + raise ValueError("slice_scatter only supports step > 0") + + dim = _normalize_dim(dim, input.ndim) + start, end = _normalize_start_end(start, end, input.shape[dim]) + output = torch.empty_like(input) + + copy_kernel = _cached_make( + ntops.kernels.slice_scatter.premake_copy, + input.ndim, + ) + + scatter_kernel = _cached_make( + ntops.kernels.slice_scatter.premake_scatter, + input.ndim, + dim, + start, + end, + step, + ) + + # 第一步:output = input + copy_kernel(input, output) + + # 第二步:output[..., start:end:step, ...] = source + scatter_kernel(source, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/slogdet.py b/src/ntops/torch/slogdet.py new file mode 100644 index 0000000..d7a5c80 --- /dev/null +++ b/src/ntops/torch/slogdet.py @@ -0,0 +1,66 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _next_power_of_2(x): + return 1 << (int(x) - 1).bit_length() + + +def _to_scalar(tensor): + # PyTorch Tensor 路径 + try: + return tensor[0, 0] + except Exception: + pass + + # infinicore Tensor 路径:优先尝试 reshape 成 0-d + try: + return tensor.reshape(()) + except Exception: + pass + + try: + return torch.reshape(tensor, ()) + except Exception: + pass + + # 如果没有 reshape,就尝试 squeeze + try: + return tensor.squeeze() + except Exception: + pass + + try: + return torch.squeeze(tensor) + except Exception: + pass + + # 最后兜底:返回 1x1 buffer + return tensor + + +def slogdet(input): + if input.ndim != 2: + raise NotImplementedError("ntops slogdet currently supports 2D matrices only") + + if input.shape[0] != input.shape[1]: + raise RuntimeError("slogdet expects a square matrix") + + matrix_size = int(input.shape[0]) + block_size = _next_power_of_2(matrix_size) + + sign_buffer = torch.empty((1, 1), dtype=input.dtype, device=input.device) + logabsdet_buffer = torch.empty((1, 1), dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.slogdet.premake, + input.ndim, + matrix_size, + block_size, + ) + + kernel(input, sign_buffer, logabsdet_buffer) + + return _to_scalar(sign_buffer), _to_scalar(logabsdet_buffer) \ No newline at end of file diff --git a/tests/test_gumbel_softmax.py b/tests/test_gumbel_softmax.py new file mode 100644 index 0000000..e1d2127 --- /dev/null +++ b/tests/test_gumbel_softmax.py @@ -0,0 +1,65 @@ +import random + +import pytest +import torch + +import ntops +from tests.utils import generate_arguments + + +_FLOAT_DTYPES = ( + torch.float16, + torch.bfloat16, + torch.float32, +) + + +def _tolerance(dtype, rtol, atol): + if dtype == torch.float16: + return max(rtol, 1e-2), max(atol, 1e-3) + if dtype == torch.bfloat16: + return max(rtol, 5e-2), max(atol, 1e-2) + return max(rtol, 1e-4), max(atol, 1e-5) + + +@pytest.mark.parametrize(*generate_arguments()) +def test_gumbel_softmax(shape, dtype, device, rtol, atol): + if dtype not in _FLOAT_DTYPES: + return + + if len(shape) == 0: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + dim = random.randint(0, input.ndim - 1) + tau = random.choice([0.5, 1.0, 1.5]) + hard = random.choice([False, True]) + + output = ntops.torch.gumbel_softmax( + input, + tau=tau, + hard=hard, + dim=dim, + ) + + assert output.shape == input.shape + assert output.dtype == input.dtype + assert output.device == input.device + + output_fp32 = output.to(torch.float32) + + assert torch.isfinite(output_fp32).all() + + rtol, atol = _tolerance(dtype, rtol, atol) + + sum_output = output_fp32.sum(dim=dim) + expected_sum = torch.ones_like(sum_output) + + assert torch.allclose(sum_output, expected_sum, rtol=rtol, atol=atol) + + if hard: + assert ((output_fp32 == 0.0) | (output_fp32 == 1.0)).all() + else: + assert (output_fp32 >= 0.0).all() + assert (output_fp32 <= 1.0).all() \ No newline at end of file diff --git a/tests/test_heaviside.py b/tests/test_heaviside.py new file mode 100644 index 0000000..1a9cb95 --- /dev/null +++ b/tests/test_heaviside.py @@ -0,0 +1,29 @@ + + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_heaviside(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + values = torch.randn(shape, dtype=dtype, device=device) + if input.numel() > 0: + input_flat = input.flatten() + input_flat[0] = 0 + + if input.numel() > 1: + input_flat[1] = -1 + + if input.numel() > 2: + input_flat[2] = 1 + + ninetoothed_output = ntops.torch.heaviside(input, values) + reference_output = torch.heaviside(input, values) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_hsplit.py b/tests/test_hsplit.py new file mode 100644 index 0000000..b02881c --- /dev/null +++ b/tests/test_hsplit.py @@ -0,0 +1,61 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_hsplit_sections(shape, dtype, device, rtol, atol): + if len(shape) == 0: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + dim = 0 if input.ndim == 1 else 1 + size = input.shape[dim] + + # 为了避免某些实现要求整除,这里选 sections=1,最稳 + sections = 1 + + ninetoothed_outputs = ntops.torch.hsplit(input, sections) + reference_outputs = torch.hsplit(input, sections) + + assert len(ninetoothed_outputs) == len(reference_outputs) + + for ninetoothed_output, reference_output in zip( + ninetoothed_outputs, + reference_outputs, + ): + assert torch.equal(ninetoothed_output, reference_output) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_hsplit_indices(shape, dtype, device, rtol, atol): + if len(shape) == 0: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + dim = 0 if input.ndim == 1 else 1 + size = input.shape[dim] + + # indices_or_sections 为 list 的情况 + indices = [ + size // 3, + 2 * size // 3, + ] + + ninetoothed_outputs = ntops.torch.hsplit(input, indices) + reference_outputs = torch.hsplit(input, indices) + + assert len(ninetoothed_outputs) == len(reference_outputs) + + for ninetoothed_output, reference_output in zip( + ninetoothed_outputs, + reference_outputs, + ): + assert torch.equal(ninetoothed_output, reference_output) \ No newline at end of file diff --git a/tests/test_slice_scatter.py b/tests/test_slice_scatter.py new file mode 100644 index 0000000..b1aabd8 --- /dev/null +++ b/tests/test_slice_scatter.py @@ -0,0 +1,47 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_slice_scatter(shape, dtype, device, rtol, atol): + if len(shape) == 0: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + dim = input.ndim - 1 + size = input.shape[dim] + + start = size // 3 + end = size + step = 1 + + source_shape = list(shape) + source_shape[dim] = end - start + + source = torch.randn(source_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.slice_scatter( + input, + source, + dim=dim, + start=start, + end=end, + step=step, + ) + + reference_output = torch.slice_scatter( + input, + source, + dim=dim, + start=start, + end=end, + step=step, + ) + + assert torch.equal(ninetoothed_output, reference_output) \ No newline at end of file diff --git a/tests/test_slogdet.py b/tests/test_slogdet.py new file mode 100644 index 0000000..0f3ff17 --- /dev/null +++ b/tests/test_slogdet.py @@ -0,0 +1,84 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments as generate_base_arguments + + +def generate_arguments(): + names = "shape,strides,dtype,device,rtol,atol" + + _, base_values = generate_base_arguments() + + matrix_cases = [ + ((1, 1), None), + ((2, 2), None), + ((3, 3), (3, 1)), + ((4, 4), None), + ((8, 8), (512, 1)), + ((16, 16), None), + ] + + values = [] + + for _, dtype, device, rtol, atol in base_values: + if dtype != torch.float32: + continue + + for shape, strides in matrix_cases: + values.append( + ( + shape, + strides, + dtype, + device, + rtol, + atol, + ) + ) + + return names, values + + +def _make_input(shape, strides, dtype, device): + if strides is None: + input = torch.randn(shape, dtype=dtype, device=device) + else: + input = torch.empty_strided(shape, strides, dtype=dtype, device=device) + input.normal_() + input.diagonal().add_(0.1) + + return input + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_slogdet(shape, strides, dtype, device, rtol, atol): + input = _make_input(shape, strides, dtype, device) + + ninetoothed_sign, ninetoothed_logabsdet = ntops.torch.slogdet(input) + reference_sign, reference_logabsdet = torch.slogdet(input) + + assert ninetoothed_sign.shape == reference_sign.shape + assert ninetoothed_logabsdet.shape == reference_logabsdet.shape + + assert ninetoothed_sign.dtype == reference_sign.dtype + assert ninetoothed_logabsdet.dtype == reference_logabsdet.dtype + + assert ninetoothed_sign.device == reference_sign.device + assert ninetoothed_logabsdet.device == reference_logabsdet.device + + assert torch.allclose( + ninetoothed_sign, + reference_sign, + rtol=rtol, + atol=atol, + ) + + assert torch.allclose( + ninetoothed_logabsdet, + reference_logabsdet, + rtol=rtol, + atol=atol, + ) \ No newline at end of file