From 5ae1418037a5818fa134437d6a356ccc9572ed5c Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 21 Jun 2026 15:19:18 +0000 Subject: [PATCH] add tensor_split unflatten moveaxis channel_shuffle im2col operators --- src/ntops/kernels/__init__.py | 10 ++ src/ntops/kernels/channel_shuffle.py | 32 ++++++ src/ntops/kernels/im2col.py | 140 +++++++++++++++++++++++++++ src/ntops/kernels/moveaxis.py | 37 +++++++ src/ntops/kernels/tensor_split.py | 29 ++++++ src/ntops/kernels/unflatten.py | 29 ++++++ src/ntops/torch/__init__.py | 11 ++- src/ntops/torch/channel_shuffle.py | 22 +++++ src/ntops/torch/im2col.py | 79 +++++++++++++++ src/ntops/torch/moveaxis.py | 69 +++++++++++++ src/ntops/torch/tensor_split.py | 103 ++++++++++++++++++++ src/ntops/torch/unflatten.py | 81 ++++++++++++++++ tests/test_channel_shuffle.py | 26 +++++ tests/test_im2col.py | 56 +++++++++++ tests/test_moveaxis.py | 26 +++++ tests/test_tensor_split.py | 33 +++++++ tests/test_unflatten.py | 28 ++++++ 17 files changed, 810 insertions(+), 1 deletion(-) create mode 100644 src/ntops/kernels/channel_shuffle.py create mode 100644 src/ntops/kernels/im2col.py create mode 100644 src/ntops/kernels/moveaxis.py create mode 100644 src/ntops/kernels/tensor_split.py create mode 100644 src/ntops/kernels/unflatten.py create mode 100644 src/ntops/torch/channel_shuffle.py create mode 100644 src/ntops/torch/im2col.py create mode 100644 src/ntops/torch/moveaxis.py create mode 100644 src/ntops/torch/tensor_split.py create mode 100644 src/ntops/torch/unflatten.py create mode 100644 tests/test_channel_shuffle.py create mode 100644 tests/test_im2col.py create mode 100644 tests/test_moveaxis.py create mode 100644 tests/test_tensor_split.py create mode 100644 tests/test_unflatten.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..02bcd52 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -8,6 +8,7 @@ bitwise_or, bmm, clamp, + channel_shuffle, conv2d, cos, div, @@ -17,6 +18,7 @@ ge, gelu, gt, + im2col, isinf, isnan, layer_norm, @@ -24,6 +26,7 @@ lt, max_pool2d, mm, + moveaxis, mul, ne, neg, @@ -39,6 +42,8 @@ softmax, sub, tanh, + tensor_split, + unflatten, ) __all__ = [ @@ -50,6 +55,7 @@ "bitwise_not", "bitwise_or", "bmm", + "channel_shuffle", "clamp", "conv2d", "cos", @@ -60,6 +66,7 @@ "ge", "gelu", "gt", + "im2col", "isinf", "isnan", "layer_norm", @@ -67,6 +74,7 @@ "lt", "max_pool2d", "mm", + "moveaxis", "mul", "ne", "neg", @@ -82,4 +90,6 @@ "softmax", "sub", "tanh", + "tensor_split", + "unflatten", ] diff --git a/src/ntops/kernels/channel_shuffle.py b/src/ntops/kernels/channel_shuffle.py new file mode 100644 index 0000000..892d9cf --- /dev/null +++ b/src/ntops/kernels/channel_shuffle.py @@ -0,0 +1,32 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.flatten().tile((block_size,)) + output = output.flatten().tile((block_size,)) + + return input, output + + +def application(input, output): + output = input + + +def premake(dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor(5, dtype=dtype), + Tensor(4, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/im2col.py b/src/ntops/kernels/im2col.py new file mode 100644 index 0000000..9088f9a --- /dev/null +++ b/src/ntops/kernels/im2col.py @@ -0,0 +1,140 @@ +import functools + +from ninetoothed import Symbol, Tensor + + +BLOCK_SIZE_M = 32 +BLOCK_SIZE_N = 32 + + +def arrangement( + input, + output, + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + block_size_m=None, + block_size_n=None, +): + if kernel_size_h is None: + kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=8) + + if kernel_size_w is None: + kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=8) + + if stride_h is None: + stride_h = Symbol("stride_h", constexpr=True) + + if stride_w is None: + stride_w = Symbol("stride_w", constexpr=True) + + if padding_h is None: + padding_h = Symbol("padding_h", constexpr=True) + + if padding_w is None: + padding_w = Symbol("padding_w", constexpr=True) + + if dilation_h is None: + dilation_h = Symbol("dilation_h", constexpr=True) + + if dilation_w is None: + dilation_w = Symbol("dilation_w", constexpr=True) + + if ceil_mode is None: + ceil_mode = False + + if block_size_m is None: + block_size_m = BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = BLOCK_SIZE_N + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + + input_arranged = input_arranged.tile( + (1, input.shape[1], kernel_size_h, kernel_size_w), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=not ceil_mode, + ) + + input_arranged = input_arranged.squeeze(1) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + + input_arranged = input_arranged.tile((block_size_m, block_size_n)) + + # output: (N * OH * OW, C * KH * KW) + output_arranged = output.tile((block_size_m, block_size_n)) + + return input_arranged, output_arranged + + +def application(input, output): + output = input + + +def premake( + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + dtype=None, + block_size_m=None, + block_size_n=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + ceil_mode=ceil_mode, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + input = Tensor( + 4, + dtype=dtype, + shape_options=( + {"upper_bound": 16}, # N + {"upper_bound": 64}, # C + {"upper_bound": 256}, # H + {"upper_bound": 256}, # W + ), + ) + + output = Tensor( + 2, + dtype=dtype, + shape_options=( + {"upper_bound": 65536}, # N * OH * OW + {"upper_bound": 1024}, # C * KH * KW + ), + ) + + tensors = ( + input, + output, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/moveaxis.py b/src/ntops/kernels/moveaxis.py new file mode 100644 index 0000000..2ca8f8e --- /dev/null +++ b/src/ntops/kernels/moveaxis.py @@ -0,0 +1,37 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, output, permutation, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + assert input.ndim == output.ndim + + input_arranged = input.permute(permutation) + input_arranged = input_arranged.flatten().tile((block_size,)) + + output_arranged = output.flatten().tile((block_size,)) + + return input_arranged, output_arranged + + +def application(input, output): + output = input + + +def premake(ndim, permutation, dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement, + permutation=permutation, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/tensor_split.py b/src/ntops/kernels/tensor_split.py new file mode 100644 index 0000000..1488731 --- /dev/null +++ b/src/ntops/kernels/tensor_split.py @@ -0,0 +1,29 @@ +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + assert input.ndim == output.ndim + + input = input.flatten().tile((block_size,)) + output = output.flatten().tile((block_size,)) + + return input, output + + +def application(input, output): + output = input + + +def premake(ndim): + return ( + arrangement, + application, + ( + Tensor(ndim), + Tensor(ndim), + ), + ) \ No newline at end of file diff --git a/src/ntops/kernels/unflatten.py b/src/ntops/kernels/unflatten.py new file mode 100644 index 0000000..d6d6c64 --- /dev/null +++ b/src/ntops/kernels/unflatten.py @@ -0,0 +1,29 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.flatten().tile((block_size,)) + output = output.flatten().tile((block_size,)) + + return input, output + + +def application(input, output): + output = input + + +def premake(input_ndim, output_ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(input_ndim, dtype=dtype), + Tensor(output_ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..865dc57 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -7,6 +7,7 @@ from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp +from ntops.torch.channel_shuffle import channel_shuffle from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos from ntops.torch.div import div @@ -16,6 +17,7 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.im2col import im2col from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -24,6 +26,7 @@ from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm +from ntops.torch.moveaxis import moveaxis from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg @@ -39,7 +42,8 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh - +from ntops.torch.tensor_split import tensor_split +from ntops.torch.unflatten import unflatten __all__ = [ "abs", "add", @@ -49,6 +53,7 @@ "bitwise_not", "bitwise_or", "bmm", + "channel_shuffle" "clamp", "conv2d", "cos", @@ -59,6 +64,7 @@ "ge", "gelu", "gt", + "im2col", "isinf", "isnan", "layer_norm", @@ -67,6 +73,7 @@ "matmul", "max_pool2d", "mm", + "moveaxis", "mul", "ne", "neg", @@ -82,4 +89,6 @@ "softmax", "sub", "tanh", + "tensor_split", + "unflatten", ] diff --git a/src/ntops/torch/channel_shuffle.py b/src/ntops/torch/channel_shuffle.py new file mode 100644 index 0000000..13f9397 --- /dev/null +++ b/src/ntops/torch/channel_shuffle.py @@ -0,0 +1,22 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def channel_shuffle(input, groups): + n, c, h, w = input.shape + + assert groups > 0 + assert c % groups == 0 + + channels_per_group = c // groups + + input = input.view(n, groups, channels_per_group, h, w) + input = input.transpose(1, 2) + + output = torch.empty((n, c, h, w), dtype=input.dtype, device=input.device) + kernel = _cached_make(ntops.kernels.channel_shuffle.premake) + kernel(input, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/im2col.py b/src/ntops/torch/im2col.py new file mode 100644 index 0000000..aaaec27 --- /dev/null +++ b/src/ntops/torch/im2col.py @@ -0,0 +1,79 @@ +import torch + +import ntops +from ntops.torch.pooling import _calculate_output_size +from ntops.torch.utils import _cached_make + + +def im2col( + input, + kernel_size, + dilation=1, + padding=0, + stride=1, +): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + n, c, h, w = input.shape + + h_ = _calculate_output_size( + h, + kernel_size[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + ) + + w_ = _calculate_output_size( + w, + kernel_size[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + ) + + output_2d = torch.empty( + ( + n * h_ * w_, + c * kernel_size[0] * kernel_size[1], + ), + dtype=input.dtype, + device=input.device, + ) + + kernel = _cached_make( + ntops.kernels.im2col.premake, + kernel_size_h=kernel_size[0], + kernel_size_w=kernel_size[1], + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + ceil_mode=False, + block_size_m=32, + block_size_n=32, + ) + + kernel(input, output_2d) + + output = output_2d.reshape( + n, + h_ * w_, + c * kernel_size[0] * kernel_size[1], + ) + + output = output.permute(0, 2, 1) + + return output \ No newline at end of file diff --git a/src/ntops/torch/moveaxis.py b/src/ntops/torch/moveaxis.py new file mode 100644 index 0000000..612f0b7 --- /dev/null +++ b/src/ntops/torch/moveaxis.py @@ -0,0 +1,69 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _to_tuple(x): + if isinstance(x, int): + return (x,) + + return tuple(x) + + +def _normalize_dim(dim, ndim): + if dim < 0: + dim += ndim + + assert 0 <= dim < ndim + + return dim + + +def _make_permutation(ndim, source, destination): + source = _to_tuple(source) + destination = _to_tuple(destination) + + assert len(source) == len(destination) + + source = tuple(_normalize_dim(dim, ndim) for dim in source) + destination = tuple(_normalize_dim(dim, ndim) for dim in destination) + + assert len(set(source)) == len(source) + assert len(set(destination)) == len(destination) + + permutation = [dim for dim in range(ndim) if dim not in source] + + for dst, src in sorted(zip(destination, source)): + permutation.insert(dst, src) + + return tuple(permutation) + + +def moveaxis(input, source, destination): + ndim = input.ndim + + assert ndim > 0 + + permutation = _make_permutation(ndim, source, destination) + + output_shape = tuple(input.shape[dim] for dim in permutation) + + output = torch.empty( + output_shape, + dtype=input.dtype, + device=input.device, + ) + + if output.numel() == 0: + return output + + kernel = _cached_make( + ntops.kernels.moveaxis.premake, + input.ndim, + permutation, + ) + + kernel(input, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/tensor_split.py b/src/ntops/torch/tensor_split.py new file mode 100644 index 0000000..5dce464 --- /dev/null +++ b/src/ntops/torch/tensor_split.py @@ -0,0 +1,103 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _normalize_dim(dim, ndim): + if dim < 0: + dim += ndim + + assert 0 <= dim < ndim + + return dim + + +def _normalize_index(index, dim_size): + if index < 0: + index += dim_size + + if index < 0: + index = 0 + + if index > dim_size: + index = dim_size + + return index + + +def _split_starts_and_sizes(dim_size, indices_or_sections): + if isinstance(indices_or_sections, int): + sections = indices_or_sections + + assert sections > 0 + + base = dim_size // sections + extra = dim_size % sections + + sizes = [] + for i in range(sections): + if i < extra: + sizes.append(base + 1) + else: + sizes.append(base) + + starts = [] + start = 0 + for size in sizes: + starts.append(start) + start += size + + return starts, sizes + + if isinstance(indices_or_sections, torch.Tensor): + indices = indices_or_sections.detach().cpu().tolist() + else: + indices = list(indices_or_sections) + + indices = [_normalize_index(int(index), dim_size) for index in indices] + + starts = [0] + indices + ends = indices + [dim_size] + + sizes = [] + for start, end in zip(starts, ends): + size = end - start + if size < 0: + size = 0 + sizes.append(size) + + return starts, sizes + + +def tensor_split(input, indices_or_sections, dim=0): + ndim = input.ndim + + assert ndim > 0 + + dim = _normalize_dim(dim, ndim) + + dim_size = input.shape[dim] + starts, sizes = _split_starts_and_sizes(dim_size, indices_or_sections) + + kernel = _cached_make(ntops.kernels.tensor_split.premake, input.ndim) + + outputs = [] + + for start, size in zip(starts, sizes): + output_shape = list(input.shape) + output_shape[dim] = size + + output = torch.empty( + output_shape, + dtype=input.dtype, + device=input.device, + ) + + if output.numel() != 0: + input_slice = input.narrow(dim, start, size) + kernel(input_slice, output) + + outputs.append(output) + + return tuple(outputs) \ No newline at end of file diff --git a/src/ntops/torch/unflatten.py b/src/ntops/torch/unflatten.py new file mode 100644 index 0000000..c7a0796 --- /dev/null +++ b/src/ntops/torch/unflatten.py @@ -0,0 +1,81 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _normalize_dim(dim, ndim): + if dim < 0: + dim += ndim + + assert 0 <= dim < ndim + + return dim + + +def _normalize_sizes(sizes, dim_size): + if isinstance(sizes, int): + sizes = (sizes,) + + sizes = tuple(sizes) + + infer_index = None + known_product = 1 + + for i, size in enumerate(sizes): + size = int(size) + + if size == -1: + assert infer_index is None + infer_index = i + else: + assert size >= 0 + known_product *= size + + sizes = list(sizes) + + if infer_index is not None: + assert known_product != 0 + assert dim_size % known_product == 0 + sizes[infer_index] = dim_size // known_product + else: + assert math.prod(sizes) == dim_size + + return tuple(sizes) + + +def unflatten(input, dim, sizes): + ndim = input.ndim + + assert ndim > 0 + + dim = _normalize_dim(dim, ndim) + + sizes = _normalize_sizes(sizes, input.shape[dim]) + + output_shape = ( + tuple(input.shape[:dim]) + + tuple(sizes) + + tuple(input.shape[dim + 1:]) + ) + + output = torch.empty( + output_shape, + dtype=input.dtype, + device=input.device, + ) + + if output.numel() == 0: + return output + + kernel = _cached_make( + ntops.kernels.unflatten.premake, + input.ndim, + output.ndim, + ) + + kernel(input, output) + + return output \ No newline at end of file diff --git a/tests/test_channel_shuffle.py b/tests/test_channel_shuffle.py new file mode 100644 index 0000000..a515bf0 --- /dev/null +++ b/tests/test_channel_shuffle.py @@ -0,0 +1,26 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-3, 1e-3), + ), +) +@pytest.mark.parametrize("groups", (1, 2, 3, 4, 6)) +@pytest.mark.parametrize("n, c, h, w", ((2, 12, 32, 32),)) +def test_channel_shuffle(n, c, h, w, groups, dtype, device, rtol, atol): + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.channel_shuffle(input, groups) + reference_output = F.channel_shuffle(input, groups) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_im2col.py b/tests/test_im2col.py new file mode 100644 index 0000000..f0bb098 --- /dev/null +++ b/tests/test_im2col.py @@ -0,0 +1,56 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-3, 1e-3), + ), +) +@pytest.mark.parametrize("dilation", (1, 2, (2, 3))) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (1, 2, (2, 3))) +@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_im2col( + n, + c, + h, + w, + kernel_size, + stride, + padding, + dilation, + dtype, + device, + rtol, + atol, +): + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.im2col( + input, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + + reference_output = F.unfold( + input, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + + assert ninetoothed_output.shape == reference_output.shape + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_moveaxis.py b/tests/test_moveaxis.py new file mode 100644 index 0000000..fa66139 --- /dev/null +++ b/tests/test_moveaxis.py @@ -0,0 +1,26 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_moveaxis_single_dim(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + if input.ndim == 0: + pytest.skip("moveaxis does not support scalar input") + + source = random.randint(0, input.ndim - 1) + destination = random.randint(0, input.ndim - 1) + + ninetoothed_output = ntops.torch.moveaxis(input, source, destination) + reference_output = torch.moveaxis(input, source, destination) + + assert ninetoothed_output.shape == reference_output.shape + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_tensor_split.py b/tests/test_tensor_split.py new file mode 100644 index 0000000..5388df4 --- /dev/null +++ b/tests/test_tensor_split.py @@ -0,0 +1,33 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("sections", [1, 2, 3, 4]) +def test_tensor_split_sections(shape, dtype, device, rtol, atol, sections): + input = torch.randn(shape, dtype=dtype, device=device) + + if input.ndim == 0: + pytest.skip("tensor_split does not support scalar input") + + for dim in range(-input.ndim, input.ndim): + ninetoothed_outputs = ntops.torch.tensor_split(input, sections, dim=dim) + reference_outputs = torch.tensor_split(input, sections, dim=dim) + + assert len(ninetoothed_outputs) == len(reference_outputs) + + for ninetoothed_output, reference_output in zip( + ninetoothed_outputs, + reference_outputs, + ): + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_unflatten.py b/tests/test_unflatten.py new file mode 100644 index 0000000..9956eb2 --- /dev/null +++ b/tests/test_unflatten.py @@ -0,0 +1,28 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_unflatten(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + if input.ndim == 0: + pytest.skip("unflatten does not support scalar input") + + dim = random.randint(0, input.ndim - 1) + dim_size = input.shape[dim] + + sizes = (1, dim_size) + + ninetoothed_output = ntops.torch.unflatten(input, dim, sizes) + reference_output = torch.unflatten(input, dim, sizes) + + assert ninetoothed_output.shape == reference_output.shape + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file