diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..12c495a 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -9,7 +9,10 @@ bmm, clamp, conv2d, + combinations, + corrcoef, cos, + count_nonzero, div, dropout, eq, @@ -19,12 +22,14 @@ gt, isinf, isnan, + kl_div, layer_norm, le, lt, max_pool2d, mm, mul, + narrow, ne, neg, pow, @@ -52,7 +57,10 @@ "bmm", "clamp", "conv2d", + "combinations", + "corrcoef", "cos", + "count_nonzero", "div", "dropout", "eq", @@ -62,12 +70,14 @@ "gt", "isinf", "isnan", + "kl_div", "layer_norm", "le", "lt", "max_pool2d", "mm", "mul", + "narrow", "ne", "neg", "pow", diff --git a/src/ntops/kernels/combinations.py b/src/ntops/kernels/combinations.py new file mode 100644 index 0000000..b69e934 --- /dev/null +++ b/src/ntops/kernels/combinations.py @@ -0,0 +1,128 @@ +import functools +import math + +import ninetoothed +from ninetoothed import Tensor + + +def _num_combinations(n, r, with_replacement): + if r < 0: + raise ValueError("r must be non-negative") + + if r == 0: + return 1 + + if n == 0: + return 0 + + if with_replacement: + return math.comb(n + r - 1, r) + + if r > n: + return 0 + + return math.comb(n, r) + + +def arrangement(input, output, input_size, r, with_replacement, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + return input, output, input_size, r, with_replacement + + +def application(input, output, input_size, r, with_replacement): + if r == 0: + return + + if r == 1: + row = 0 + + for i in range(input_size): + output[row, 0] = input[i] # noqa: F841 + row += 1 + + elif r == 2: + row = 0 + + if with_replacement: + for i in range(input_size): + for j in range(i, input_size): + output[row, 0] = input[i] # noqa: F841 + output[row, 1] = input[j] # noqa: F841 + row += 1 + else: + for i in range(input_size): + for j in range(i + 1, input_size): + output[row, 0] = input[i] # noqa: F841 + output[row, 1] = input[j] # noqa: F841 + row += 1 + + elif r == 3: + row = 0 + + if with_replacement: + for i in range(input_size): + for j in range(i, input_size): + for k in range(j, input_size): + output[row, 0] = input[i] # noqa: F841 + output[row, 1] = input[j] # noqa: F841 + output[row, 2] = input[k] # noqa: F841 + row += 1 + else: + for i in range(input_size): + for j in range(i + 1, input_size): + for k in range(j + 1, input_size): + output[row, 0] = input[i] # noqa: F841 + output[row, 1] = input[j] # noqa: F841 + output[row, 2] = input[k] # noqa: F841 + row += 1 + + +def premake( + input_size, + r=2, + with_replacement=False, + dtype=None, + block_size=None, +): + input_size = int(input_size) + r = int(r) + with_replacement = bool(with_replacement) + + if r < 0: + raise ValueError("r must be non-negative") + + if r > 3: + raise NotImplementedError("combinations currently only supports r <= 3") + + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + input = Tensor(1, dtype=dtype) + output = Tensor(2, dtype=dtype) + + input_size_tensor = Tensor(0, constexpr=True, value=input_size) + r_tensor = Tensor(0, constexpr=True, value=r) + with_replacement_tensor = Tensor(0, constexpr=True, value=with_replacement) + + num_rows = _num_combinations( + n=input_size, + r=r, + with_replacement=with_replacement, + ) + + input.shape = (input_size,) + output.shape = (num_rows, r) + + tensors = ( + input, + output, + input_size_tensor, + r_tensor, + with_replacement_tensor, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/corrcoef.py b/src/ntops/kernels/corrcoef.py new file mode 100644 index 0000000..d70956e --- /dev/null +++ b/src/ntops/kernels/corrcoef.py @@ -0,0 +1,162 @@ +import functools + +from ninetoothed import Tensor +import ninetoothed.language as ntl + + +def arrangement_1d(input, output, input_size, block_size=None): + return input, output, input_size + + +def arrangement_2d_single_row(input, output, num_cols, block_size=None): + return input, output, num_cols + + +def arrangement_2d(input, output, num_rows, num_cols, block_size=None): + return input, output, num_rows, num_cols + + +def application_1d(input, output, input_size): + # torch.corrcoef(1D tensor) returns scalar 1. + # Use output shape (1,) to avoid 0-dim output pointer issue. + output[0] = 1.0 # noqa: F841 + + +def application_2d_single_row(input, output, num_cols): + # torch.corrcoef(input) for input shape (1, N) returns scalar 1. + # Kernel still writes shape (1,), wrapper squeezes it to scalar. + output[0] = 1.0 # noqa: F841 + + +def application_2d(input, output, num_rows, num_cols): + correction = num_cols - 1 + + for i in range(num_rows): + for j in range(num_rows): + if num_cols <= 1: + output[i, j] = float("nan") # noqa: F841 + + elif i == j: + output[i, j] = 1.0 # noqa: F841 + + elif j < i: + output[i, j] = output[j, i] # noqa: F841 + + else: + mean_i = ntl.zeros((), dtype=ntl.float32) + mean_j = ntl.zeros((), dtype=ntl.float32) + + for k in range(num_cols): + mean_i += input[i, k].to(ntl.float32) + mean_j += input[j, k].to(ntl.float32) + + mean_i = mean_i / num_cols + mean_j = mean_j / num_cols + + cov = ntl.zeros((), dtype=ntl.float32) + var_i = ntl.zeros((), dtype=ntl.float32) + var_j = ntl.zeros((), dtype=ntl.float32) + + for k in range(num_cols): + xi = input[i, k].to(ntl.float32) - mean_i + xj = input[j, k].to(ntl.float32) - mean_j + + cov += xi * xj + var_i += xi * xi + var_j += xj * xj + + cov = cov / correction + var_i = var_i / correction + var_j = var_j / correction + + denom = var_i * var_j + + inv_sqrt = ntl.rsqrt(denom) + + # Improve rsqrt precision. + inv_sqrt = inv_sqrt * (1.5 - 0.5 * denom * inv_sqrt * inv_sqrt) + inv_sqrt = inv_sqrt * (1.5 - 0.5 * denom * inv_sqrt * inv_sqrt) + + corr = cov * inv_sqrt + corr = ntl.minimum(ntl.maximum(corr, -1.0), 1.0) + + output[i, j] = corr # noqa: F841 + + +def premake( + input_shape, + dtype=None, + block_size=None, +): + input_shape = tuple(input_shape) + + if len(input_shape) == 1: + input_size_value = int(input_shape[0]) + + arrangement = functools.partial( + arrangement_1d, + block_size=block_size, + ) + + input = Tensor(1, dtype=dtype) + output = Tensor(1, dtype=dtype) + + input.shape = input_shape + output.shape = (1,) + + input_size = Tensor(0, constexpr=True, value=input_size_value) + + tensors = ( + input, + output, + input_size, + ) + + return arrangement, application_1d, tensors + + if len(input_shape) == 2: + num_rows_value = int(input_shape[0]) + num_cols_value = int(input_shape[1]) + + input = Tensor(2, dtype=dtype) + input.shape = input_shape + + if num_rows_value == 1: + arrangement = functools.partial( + arrangement_2d_single_row, + block_size=block_size, + ) + + output = Tensor(1, dtype=dtype) + output.shape = (1,) + + num_cols = Tensor(0, constexpr=True, value=num_cols_value) + + tensors = ( + input, + output, + num_cols, + ) + + return arrangement, application_2d_single_row, tensors + + arrangement = functools.partial( + arrangement_2d, + block_size=block_size, + ) + + output = Tensor(2, dtype=dtype) + + output.shape = (num_rows_value, num_rows_value) + + num_rows = Tensor(0, constexpr=True, value=num_rows_value) + num_cols = Tensor(0, constexpr=True, value=num_cols_value) + + tensors = ( + input, + output, + num_rows, + num_cols, + ) + + return arrangement, application_2d, tensors \ No newline at end of file diff --git a/src/ntops/kernels/count_nonzero.py b/src/ntops/kernels/count_nonzero.py new file mode 100644 index 0000000..efe2cff --- /dev/null +++ b/src/ntops/kernels/count_nonzero.py @@ -0,0 +1,162 @@ +import functools +import hashlib +import linecache + +import ninetoothed +from ninetoothed import Tensor +import ninetoothed.language as ntl + + +def _normalize_dims(dim, ndim): + if dim is None: + return tuple(range(ndim)) + + if isinstance(dim, int): + dims = (dim,) + else: + dims = tuple(dim) + + normalized = [] + for d in dims: + d = int(d) + if d < 0: + d += ndim + if d < 0 or d >= ndim: + raise IndexError("dim out of range") + if d in normalized: + raise ValueError("dim contains duplicate values") + normalized.append(d) + + return tuple(normalized) + + +def _output_shape(input_shape, reduce_dims): + return tuple( + size + for axis, size in enumerate(input_shape) + if axis not in reduce_dims + ) + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + return input, output + + +def _index_expr(vars_): + if len(vars_) == 1: + return vars_[0] + return ", ".join(vars_) + + +def _indent(level): + return " " * level + + +def _make_application(input_shape, reduce_dims): + ndim = len(input_shape) + input_shape = tuple(int(x) for x in input_shape) + reduce_dims = tuple(int(x) for x in reduce_dims) + keep_dims = tuple(axis for axis in range(ndim) if axis not in reduce_dims) + + axis_vars = [f"i{axis}" for axis in range(ndim)] + + lines = [] + lines.append("def application(input, output):") + + level = 1 + + # 外层循环:非归约维度,对应 output 的每个元素 + for axis in keep_dims: + lines.append(f"{_indent(level)}for i{axis} in range({input_shape[axis]}):") + level += 1 + + lines.append(f"{_indent(level)}acc = ntl.zeros((), dtype=ntl.int64)") + + # 内层循环:归约维度 + for axis in reduce_dims: + lines.append(f"{_indent(level)}for i{axis} in range({input_shape[axis]}):") + level += 1 + + input_index = _index_expr(axis_vars) + + if len(reduce_dims) == 0: + # dim=() 这种情况:不归约,每个元素输出 0/1 + lines.append( + f"{_indent(level)}acc = ntl.where(input[{input_index}] != 0, 1, 0).to(ntl.int64)" + ) + else: + lines.append( + f"{_indent(level)}acc += ntl.where(input[{input_index}] != 0, 1, 0).to(ntl.int64)" + ) + + # 写 output + if len(keep_dims) == 0: + output_index = "0" + else: + output_index = _index_expr([f"i{axis}" for axis in keep_dims]) + + # 回到外层循环之后写 output + write_level = 1 + len(keep_dims) + lines.append(f"{_indent(write_level)}output[{output_index}] = acc # noqa: F841") + + source = "\n".join(lines) + "\n" + + digest = hashlib.sha1(source.encode("utf-8")).hexdigest() + filename = f"" + linecache.cache[filename] = ( + len(source), + None, + source.splitlines(True), + filename, + ) + + namespace = { + "ntl": ntl, + } + code = compile(source, filename, "exec") + exec(code, namespace) + + return namespace["application"] + + +def premake( + input_shape, + dim=None, + dtype=None, + block_size=None, +): + input_shape = tuple(int(x) for x in input_shape) + ndim = len(input_shape) + + reduce_dims = _normalize_dims(dim, ndim) + output_shape = _output_shape(input_shape, reduce_dims) + + # ninetoothed 对 0-dim output pointer 处理不稳定,所以 scalar 用 shape (1,)。 + actual_output_shape = output_shape if len(output_shape) > 0 else (1,) + output_ndim = len(actual_output_shape) + + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + application = _make_application( + input_shape=input_shape, + reduce_dims=reduce_dims, + ) + + input = Tensor(ndim, dtype=dtype) + output = Tensor(output_ndim, dtype=ninetoothed.int64) + + input.shape = input_shape + output.shape = actual_output_shape + + tensors = ( + input, + output, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/kl_div.py b/src/ntops/kernels/kl_div.py new file mode 100644 index 0000000..bfcf567 --- /dev/null +++ b/src/ntops/kernels/kl_div.py @@ -0,0 +1,125 @@ +import functools + +from ninetoothed import Tensor +import ninetoothed.language as ntl + + +REDUCTION_SUM = 0 +REDUCTION_BATCHMEAN = 1 +REDUCTION_MEAN = 2 + + +def arrangement( + input, + target, + output, + batch_size, + feature_size, + reduction, + log_target, + block_size=None, +): + return ( + input, + target, + output, + batch_size, + feature_size, + reduction, + log_target, + ) + + +def application( + input, + target, + output, + batch_size, + feature_size, + reduction, + log_target, +): + acc = ntl.zeros((), dtype=ntl.float32) + + for i in range(batch_size): + for j in range(feature_size): + x = input[i, j].to(ntl.float32) + t = target[i, j].to(ntl.float32) + + if log_target: + # PyTorch: + # loss = exp(target) * (target - input) + loss = ntl.exp(t) * (t - x) + else: + # PyTorch kl_div(log_target=False) uses xlogy(target, target) + # semantics: + # + # target == 0 -> target * log(target) is treated as 0 + # + # Directly computing t * log(t) gives NaN when t == 0, + # because 0 * -inf = NaN. + is_zero = t == 0 + safe_t = ntl.where(is_zero, 1.0, t) + + loss = ntl.where( + is_zero, + 0.0, + t * (ntl.log(safe_t) - x), + ) + + acc += loss + + # 0: sum + # 1: batchmean + # 2: mean + if reduction == 1: + acc = acc / batch_size + elif reduction == 2: + acc = acc / (batch_size * feature_size) + + output[0] = acc # noqa: F841 + + +def premake( + input_shape, + reduction=REDUCTION_BATCHMEAN, + log_target=False, + dtype=None, + block_size=None, +): + assert len(input_shape) == 2, "kl_div currently only supports 2-D input" + + batch_size_value = int(input_shape[0]) + feature_size_value = int(input_shape[1]) + reduction_value = int(reduction) + log_target_value = bool(log_target) + + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + input = Tensor(2, dtype=dtype) + target = Tensor(2, dtype=dtype) + output = Tensor(1, dtype=dtype) + + input.shape = tuple(input_shape) + target.shape = tuple(input_shape) + output.shape = (1,) + + batch_size = Tensor(0, constexpr=True, value=batch_size_value) + feature_size = Tensor(0, constexpr=True, value=feature_size_value) + reduction = Tensor(0, constexpr=True, value=reduction_value) + log_target = Tensor(0, constexpr=True, value=log_target_value) + + tensors = ( + input, + target, + output, + batch_size, + feature_size, + reduction, + log_target, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/narrow.py b/src/ntops/kernels/narrow.py new file mode 100644 index 0000000..3836a30 --- /dev/null +++ b/src/ntops/kernels/narrow.py @@ -0,0 +1,35 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement as element_wise_arrangement + + +def arrangement(input, output, dim=None, start=None, length=None, block_size=None): + slices = [slice(None)] * input.ndim + slices[dim] = slice(start, start + length) + + input = input[tuple(slices)] + + return element_wise_arrangement(input, output, block_size=block_size) + + +def application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dim, start, length, dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement, + dim=dim, + start=start, + length=length, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..c10bff1 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -7,8 +7,11 @@ from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp +from ntops.torch.combinations import combinations from ntops.torch.conv2d import conv2d +from ntops.torch.corrcoef import corrcoef from ntops.torch.cos import cos +from ntops.torch.count_nonzero import count_nonzero from ntops.torch.div import div from ntops.torch.dropout import dropout from ntops.torch.eq import eq @@ -18,6 +21,7 @@ from ntops.torch.gt import gt from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan +from ntops.torch.kl_div import kl_div from ntops.torch.layer_norm import layer_norm from ntops.torch.le import le from ntops.torch.lt import lt @@ -25,6 +29,7 @@ from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul +from ntops.torch.narrow import narrow from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow @@ -50,8 +55,11 @@ "bitwise_or", "bmm", "clamp", + "combinations", "conv2d", + "corrcoef", "cos", + "count_nonzero", "div", "dropout", "eq", @@ -61,6 +69,7 @@ "gt", "isinf", "isnan", + "kl_div", "layer_norm", "le", "lt", @@ -68,6 +77,7 @@ "max_pool2d", "mm", "mul", + "narrow", "ne", "neg", "pow", @@ -82,4 +92,4 @@ "softmax", "sub", "tanh", -] +] \ No newline at end of file diff --git a/src/ntops/torch/combinations.py b/src/ntops/torch/combinations.py new file mode 100644 index 0000000..2e2c768 --- /dev/null +++ b/src/ntops/torch/combinations.py @@ -0,0 +1,70 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _num_combinations(n, r, with_replacement): + if r < 0: + raise ValueError("r must be non-negative") + + if r == 0: + return 1 + + if n == 0: + return 0 + + if with_replacement: + return math.comb(n + r - 1, r) + + if r > n: + return 0 + + return math.comb(n, r) + + +def combinations(input, r=2, with_replacement=False, *, out=None): + assert input.ndim == 1, "combinations only supports 1-D input" + + r = int(r) + with_replacement = bool(with_replacement) + + assert r >= 0, "r must be non-negative" + assert r <= 3, "combinations currently only supports r <= 3" + + input_size = input.shape[0] + + num_rows = _num_combinations( + n=input_size, + r=r, + with_replacement=with_replacement, + ) + + if out is None: + out = torch.empty( + (num_rows, r), + dtype=input.dtype, + device=input.device, + ) + else: + assert tuple(out.shape) == (num_rows, r), ( + f"invalid out shape, expected {(num_rows, r)}, got {tuple(out.shape)}" + ) + assert out.dtype == input.dtype, "out dtype must match input dtype" + assert out.device == input.device, "out device must match input device" + + if out.numel() == 0: + return out + + kernel = _cached_make( + ntops.kernels.combinations.premake, + input_size, + r, + with_replacement, + ) + + kernel(input, out, input_size, r, with_replacement) + + return out \ No newline at end of file diff --git a/src/ntops/torch/corrcoef.py b/src/ntops/torch/corrcoef.py new file mode 100644 index 0000000..04f36dc --- /dev/null +++ b/src/ntops/torch/corrcoef.py @@ -0,0 +1,58 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def corrcoef(input): + assert input.ndim in (1, 2), "corrcoef only supports 1-D or 2-D input" + + if input.ndim == 1: + output = torch.empty((1,), dtype=input.dtype, device=input.device) + elif input.shape[0] == 1: + # torch.corrcoef for shape (1, N) returns scalar. + output = torch.empty((1,), dtype=input.dtype, device=input.device) + else: + output = torch.empty( + (input.shape[0], input.shape[0]), + dtype=input.dtype, + device=input.device, + ) + + kernel = _cached_make( + ntops.kernels.corrcoef.premake, + tuple(input.shape), + ) + + if input.ndim == 1: + kernel( + input, + output, + input.shape[0], + ) + + if hasattr(output, "reshape"): + return output.reshape(()) + + return output + + if input.shape[0] == 1: + kernel( + input, + output, + input.shape[1], + ) + + if hasattr(output, "reshape"): + return output.reshape(()) + + return output + + kernel( + input, + output, + input.shape[0], + input.shape[1], + ) + + return output \ No newline at end of file diff --git a/src/ntops/torch/count_nonzero.py b/src/ntops/torch/count_nonzero.py new file mode 100644 index 0000000..26a020c --- /dev/null +++ b/src/ntops/torch/count_nonzero.py @@ -0,0 +1,77 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _normalize_dims(dim, ndim): + if dim is None: + return tuple(range(ndim)) + + if isinstance(dim, int): + dims = (dim,) + else: + dims = tuple(dim) + + normalized = [] + for d in dims: + d = int(d) + if d < 0: + d += ndim + if d < 0 or d >= ndim: + raise IndexError("dim out of range") + if d in normalized: + raise ValueError("dim contains duplicate values") + normalized.append(d) + + return tuple(normalized) + + +def _output_shape(input_shape, reduce_dims): + return tuple( + size + for axis, size in enumerate(input_shape) + if axis not in reduce_dims + ) + + +def _dim_cache_key(dim): + if dim is None: + return None + + if isinstance(dim, int): + return (int(dim),) + + # 关键修复:list -> tuple,避免 _cached_make 报 unhashable type: 'list' + return tuple(int(d) for d in dim) + + +def count_nonzero(input, dim=None): + reduce_dims = _normalize_dims(dim, input.ndim) + output_shape = _output_shape(tuple(input.shape), reduce_dims) + + actual_output_shape = output_shape if len(output_shape) > 0 else (1,) + + output = torch.empty( + actual_output_shape, + dtype=torch.int64, + device=input.device, + ) + + dim_key = _dim_cache_key(dim) + + kernel = _cached_make( + ntops.kernels.count_nonzero.premake, + tuple(input.shape), + dim_key, + ) + + kernel(input, output) + + if len(output_shape) == 0: + if hasattr(output, "reshape"): + return output.reshape(()) + + return output + + return output \ No newline at end of file diff --git a/src/ntops/torch/kl_div.py b/src/ntops/torch/kl_div.py new file mode 100644 index 0000000..3968df9 --- /dev/null +++ b/src/ntops/torch/kl_div.py @@ -0,0 +1,65 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +REDUCTION_SUM = 0 +REDUCTION_BATCHMEAN = 1 +REDUCTION_MEAN = 2 + + +def _reduction_to_code(reduction): + if reduction is None: + reduction = "mean" + + if reduction == "sum": + return REDUCTION_SUM + + if reduction == "batchmean": + return REDUCTION_BATCHMEAN + + if reduction == "mean": + return REDUCTION_MEAN + + raise NotImplementedError( + "kl_div currently only supports reduction='sum', 'batchmean', or 'mean'" + ) + + +def kl_div(input, target, reduction="mean", log_target=False): + assert input.ndim == 2, "kl_div currently only supports 2-D input" + assert target.shape == input.shape, "target shape must match input shape" + assert target.dtype == input.dtype, "target dtype must match input dtype" + assert target.device == input.device, "target device must match input device" + + reduction_code = _reduction_to_code(reduction) + log_target = bool(log_target) + + # kernel 输出用 shape (1,),不要用 0-dim。 + # ninetoothed/Triton 对 0-dim output Tensor 容易生成不了 output pointer。 + output = torch.empty((1,), dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.kl_div.premake, + tuple(input.shape), + reduction_code, + log_target, + ) + + kernel( + input, + target, + output, + input.shape[0], + input.shape[1], + reduction_code, + log_target, + ) + + # 如果是原生 torch.Tensor,可以 reshape 成 PyTorch 一致的 scalar。 + # 如果是 infinicore.Tensor,它没有 reshape,直接返回 (1,),由 InfiniCore wrapper squeeze。 + if hasattr(output, "reshape"): + return output.reshape(()) + + return output \ No newline at end of file diff --git a/src/ntops/torch/narrow.py b/src/ntops/torch/narrow.py new file mode 100644 index 0000000..bebaf2e --- /dev/null +++ b/src/ntops/torch/narrow.py @@ -0,0 +1,25 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def narrow(input, dim, start, length, *, out=None): + dim = dim % input.ndim + + if out is None: + shape = list(input.shape) + shape[dim] = length + out = torch.empty(shape, dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.narrow.premake, + input.ndim, + dim, + start, + length, + ) + + kernel(input, out) + + return out \ No newline at end of file diff --git a/tests/test_combinations.py b/tests/test_combinations.py new file mode 100644 index 0000000..93b065b --- /dev/null +++ b/tests/test_combinations.py @@ -0,0 +1,48 @@ +import pytest +import torch + +import ntops +from tests.utils import generate_arguments + + +_COMBINATIONS_TEST_CASES = [ + (5, 1, False), + (5, 2, False), + (6, 3, False), + (7, 2, False), + (8, 3, False), + (3, 2, False), +] + + +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("size,r,with_replacement", _COMBINATIONS_TEST_CASES) +def test_combinations(shape, dtype, device, rtol, atol, size, r, with_replacement): + _ = (shape, rtol, atol) + + if dtype != torch.float32: + return + + input = torch.randn( + (size,), + dtype=dtype, + device=device, + ) + + output = ntops.torch.combinations( + input, + r=r, + with_replacement=with_replacement, + ) + + reference = torch.combinations( + input, + r=r, + with_replacement=with_replacement, + ) + + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert output.device == reference.device + + assert torch.equal(output, reference) \ No newline at end of file diff --git a/tests/test_corrcoef.py b/tests/test_corrcoef.py new file mode 100644 index 0000000..09963cd --- /dev/null +++ b/tests/test_corrcoef.py @@ -0,0 +1,58 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +_CORRCOEF_TEST_CASES = [ + ((5,), None), + ((3, 5), None), + ((4, 4), None), + ((2, 8), None), + ((6, 6), None), + ((1, 7), None), +] + + +def _generate_float32_arguments(): + arg_names, arg_values = generate_arguments() + + filtered_arg_values = [ + args for args in arg_values + if args[1] == torch.float32 + ] + + return arg_names, filtered_arg_values + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*_generate_float32_arguments()) +@pytest.mark.parametrize("case_shape,case_strides", _CORRCOEF_TEST_CASES) +def test_corrcoef(shape, dtype, device, rtol, atol, case_shape, case_strides): + rtol, atol = 5e-4, 5e-4 + + base = torch.randn(case_shape, dtype=dtype, device=device) + + if case_strides is not None: + input = torch.empty_strided( + case_shape, + case_strides, + dtype=dtype, + device=device, + ) + input.copy_(base) + else: + input = base + + ninetoothed_output = ntops.torch.corrcoef(input) + reference_output = torch.corrcoef(input) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + equal_nan=True, + ) \ No newline at end of file diff --git a/tests/test_count_nonzero.py b/tests/test_count_nonzero.py new file mode 100644 index 0000000..b1e68c9 --- /dev/null +++ b/tests/test_count_nonzero.py @@ -0,0 +1,69 @@ +import pytest +import torch + +import ntops +from tests.utils import generate_arguments + + +_COUNT_NONZERO_TEST_CASES = [ + ((8, 8), None, None), + ((8, 8), (16, 1), 1), + ((2, 3, 4), None, 0), + ((1, 8), None, (0,)), + ((16, 64), (128, 1), None), + ((4, 5, 6), (60, 12, 2), 2), +] + +_SUPPORTED_DTYPES = ( + torch.int32, + torch.float32, + torch.uint8, +) + + +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("input_shape,strides,dim", _COUNT_NONZERO_TEST_CASES) +def test_count_nonzero(shape, dtype, device, rtol, atol, input_shape, strides, dim): + _ = (shape, rtol, atol) + + if dtype not in _SUPPORTED_DTYPES: + return + + if dtype == torch.uint8: + base = torch.randint( + 0, + 3, + input_shape, + dtype=dtype, + device=device, + ) + else: + base = torch.randint( + -2, + 3, + input_shape, + device=device, + ).to(dtype) + + if strides is not None: + input = torch.empty_strided( + input_shape, + strides, + dtype=dtype, + device=device, + ) + input.copy_(base) + else: + input = base + + if dim is None: + output = ntops.torch.count_nonzero(input) + reference = torch.count_nonzero(input) + else: + output = ntops.torch.count_nonzero(input, dim=dim) + reference = torch.count_nonzero(input, dim=dim) + + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert output.device == reference.device + assert torch.equal(output, reference) \ No newline at end of file diff --git a/tests/test_kl_div.py b/tests/test_kl_div.py new file mode 100644 index 0000000..433c9c9 --- /dev/null +++ b/tests/test_kl_div.py @@ -0,0 +1,60 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +_KL_DIV_TEST_CASES = [ + ((4, 5), "batchmean", False), + ((8, 8), "sum", False), + ((1, 10), "batchmean", True), + ((16, 100), "batchmean", False), + ((3, 7), "batchmean", False), + ((2, 2), "sum", False), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("case_shape,reduction,log_target", _KL_DIV_TEST_CASES) +def test_kl_div(shape, dtype, device, rtol, atol, case_shape, reduction, log_target): + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + pytest.skip("kl_div test only covers float32, float16, and bfloat16") + + if dtype == torch.float32: + rtol, atol = 1e-4, 1e-5 + elif dtype == torch.float16: + rtol, atol = 1e-1, 1e-2 + elif dtype == torch.bfloat16: + rtol, atol = 5e-2, 1e-2 + + if log_target: + input = torch.randn(case_shape, dtype=dtype, device=device) + target = torch.randn(case_shape, dtype=dtype, device=device) + else: + input = torch.randn(case_shape, dtype=dtype, device=device) + target = torch.rand(case_shape, dtype=dtype, device=device) + 0.1 + + ninetoothed_output = ntops.torch.kl_div( + input, + target, + reduction=reduction, + log_target=log_target, + ) + + reference_output = F.kl_div( + input, + target, + reduction=reduction, + log_target=log_target, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_narrow.py b/tests/test_narrow.py new file mode 100644 index 0000000..0167ed4 --- /dev/null +++ b/tests/test_narrow.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_narrow(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + dim = input.ndim - 1 + size = input.shape[dim] + + start = size // 3 + length = size - start + + ninetoothed_output = ntops.torch.narrow(input, dim, start, length) + reference_output = torch.narrow(input, dim, start, length) + + assert torch.equal(ninetoothed_output, reference_output) \ No newline at end of file