diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..4e2511d 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -14,6 +14,9 @@ dropout, eq, exp, + feature_alpha_dropout, + flip, + fliplr, ge, gelu, gt, @@ -25,8 +28,10 @@ max_pool2d, mm, mul, + mse_loss, ne, neg, + pixel_unshuffle, pow, relu, rms_norm, @@ -57,6 +62,9 @@ "dropout", "eq", "exp", + "feature_alpha_dropout", + "flip", + "fliplr", "ge", "gelu", "gt", @@ -68,8 +76,10 @@ "max_pool2d", "mm", "mul", + "mse_loss", "ne", "neg", + "pixel_unshuffle", "pow", "relu", "rms_norm", diff --git a/src/ntops/kernels/feature_alpha_dropout.py b/src/ntops/kernels/feature_alpha_dropout.py new file mode 100644 index 0000000..502db4d --- /dev/null +++ b/src/ntops/kernels/feature_alpha_dropout.py @@ -0,0 +1,36 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, p, seed, a, b, output): + alpha_prime = -1.7580993408473766 + + keep = ntl.rand(seed, input.offsets()) > p + + dropped = ntl.where( + keep, + input, + alpha_prime, + ) + + output = dropped * a + b # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), # input + Tensor(0, dtype=ninetoothed.float64), # p + Tensor(0, dtype=ninetoothed.int64), # seed + Tensor(0, dtype=ninetoothed.float64), # a + Tensor(0, dtype=ninetoothed.float64), # b + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/flip.py b/src/ntops/kernels/flip.py new file mode 100644 index 0000000..bfeae76 --- /dev/null +++ b/src/ntops/kernels/flip.py @@ -0,0 +1,78 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def application(input, output): + output = input # noqa: F841 + + +def _normalize_dims(dims, ndim): + if isinstance(dims, int): + dims = (dims,) + + dims = tuple(dim if dim >= 0 else dim + ndim for dim in dims) + + assert all(0 <= dim < ndim for dim in dims), "`dims` out of range." + + result = [] + for dim in dims: + if dim not in result: + result.append(dim) + + return tuple(result) + + +def arrangement( + input, + output, + dims, + block_size=None, +): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + dims = _normalize_dims(dims, ndim) + slices = tuple( + slice(None, None, -1) if dim in dims else slice(None) + for dim in range(ndim) + ) + + input_arranged = input[slices] + input_arranged = input_arranged.flatten().tile((block_size,)) + + output_arranged = output.flatten().tile((block_size,)) + + return input_arranged, output_arranged + + +def premake( + ndim, + dims, + dtype=None, + block_size=None, +): + dims = _normalize_dims(dims, ndim) + + arrangement_ = functools.partial( + arrangement, + dims=dims, + block_size=block_size, + ) + + tensors = ( + Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ), + Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/fliplr.py b/src/ntops/kernels/fliplr.py new file mode 100644 index 0000000..e6483bf --- /dev/null +++ b/src/ntops/kernels/fliplr.py @@ -0,0 +1,59 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def application(input, output): + output = input # noqa: F841 + + +def arrangement( + input, + output, + block_size=None, +): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + assert ndim >= 2, "`fliplr` requires input with ndim >= 2." + slices = tuple( + slice(None, None, -1) if dim == 1 else slice(None) + for dim in range(ndim) + ) + + input_arranged = input[slices] + input_arranged = input_arranged.flatten().tile((block_size,)) + + output_arranged = output.flatten().tile((block_size,)) + + return input_arranged, output_arranged + + +def premake( + ndim, + dtype=None, + block_size=None, +): + assert ndim >= 2, "`fliplr` requires input with ndim >= 2." + + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ), + Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/mse_loss.py b/src/ntops/kernels/mse_loss.py new file mode 100644 index 0000000..aed30a2 --- /dev/null +++ b/src/ntops/kernels/mse_loss.py @@ -0,0 +1,131 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement as element_wise_arrangement +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +REDUCTION_NONE = 0 +REDUCTION_MEAN = 1 +REDUCTION_SUM = 2 + + +def reduction_all_arrangement(input, target, output, inv_numel=None, block_size=None): + input_arranged, target_arranged = reduction_arrangement( + input, + target, + dim=tuple(range(input.ndim)), + block_size=block_size, + ) + + output_arranged = output.tile((1,)) + + if inv_numel is None: + return input_arranged, target_arranged, output_arranged + + return input_arranged, target_arranged, output_arranged, inv_numel + + +def application_none(input, target, output): + diff = input - target + output = diff * diff # noqa: F841 + + +def application_sum(input, target, output): + dtype = output.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + + acc = ntl.cast(0, acc_dtype) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i] - target[i], acc_dtype) + acc += ntl.sum(diff * diff) + + output = ntl.cast(acc, dtype) # noqa: F841 + + +def application_mean(input, target, output, inv_numel): + dtype = output.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + + acc = ntl.cast(0, acc_dtype) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i] - target[i], acc_dtype) + acc += ntl.sum(diff * diff) + + output = ntl.cast(acc * inv_numel, dtype) # noqa: F841 + + +def premake( + ndim, + reduction=REDUCTION_MEAN, + dtype=None, + block_size=None, +): + if reduction == REDUCTION_NONE: + arrangement_ = functools.partial( + element_wise_arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application_none, tensors + + assert reduction in ( + REDUCTION_MEAN, + REDUCTION_SUM, + ), "`reduction` must be 0, 1, or 2." + + arrangement_ = functools.partial( + reduction_all_arrangement, + block_size=block_size, + ) + + input = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + target = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + output = Tensor( + 1, + dtype=dtype, + shape_options=( + {"constexpr": True, "upper_bound": 1}, + ), + ) + output.shape = (1,) + + if reduction == REDUCTION_SUM: + tensors = ( + input, + target, + output, + ) + return arrangement_, application_sum, tensors + + inv_numel = Tensor(0, dtype=ninetoothed.float64) + + tensors = ( + input, + target, + output, + inv_numel, + ) + + return arrangement_, application_mean, tensors \ No newline at end of file diff --git a/src/ntops/kernels/pixel_unshuffle.py b/src/ntops/kernels/pixel_unshuffle.py new file mode 100644 index 0000000..b82e477 --- /dev/null +++ b/src/ntops/kernels/pixel_unshuffle.py @@ -0,0 +1,83 @@ +import functools + +import ninetoothed +from ninetoothed import Symbol, Tensor + + +def application(input, output): + output = input # noqa: F841 + + +def arrangement( + input, + output, + downscale_factor=None, + block_size=None, +): + if downscale_factor is None: + downscale_factor = Symbol( + "downscale_factor", + constexpr=True, + upper_bound=16, + ) + + if block_size is None: + block_size = ninetoothed.block_size() + + factor2 = downscale_factor * downscale_factor + + # input: [N, C, H * r, W * r] + # arranged: [N, C, H, W, r * r] + input_arranged = input.tile( + (1, 1, downscale_factor, downscale_factor), + strides=(-1, -1, downscale_factor, downscale_factor), + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + # output: [N, C * r * r, H, W] + # arranged: [N, C, H, W, r * r] + output_arranged = output.tile( + (1, factor2, 1, 1), + strides=(-1, factor2, -1, -1), + ) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + return input_arranged, output_arranged + + +def premake( + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor( + 4, + dtype=dtype, + shape_options=( + None, + None, + None, + {"constexpr": True, "upper_bound": 8192}, + ), + ), + Tensor( + 4, + dtype=dtype, + shape_options=( + None, + None, + None, + {"constexpr": True, "upper_bound": 8192}, + ), + ), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..4b2da62 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -13,6 +13,9 @@ from ntops.torch.dropout import dropout from ntops.torch.eq import eq from ntops.torch.exp import exp +from ntops.torch.feature_alpha_dropout import feature_alpha_dropout +from ntops.torch.flip import flip +from ntops.torch.fliplr import fliplr from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt @@ -25,8 +28,10 @@ from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul +from .mse_loss import mse_loss from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.pixel_unshuffle import pixel_unshuffle from ntops.torch.pow import pow from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm @@ -56,6 +61,9 @@ "dropout", "eq", "exp", + "feature_alpha_dropout", + "flip", + "fliplr", "ge", "gelu", "gt", @@ -68,8 +76,10 @@ "max_pool2d", "mm", "mul", + "mse_loss", "ne", "neg", + "pixel_unshuffle", "pow", "relu", "rms_norm", diff --git a/src/ntops/torch/feature_alpha_dropout.py b/src/ntops/torch/feature_alpha_dropout.py new file mode 100644 index 0000000..c6a3851 --- /dev/null +++ b/src/ntops/torch/feature_alpha_dropout.py @@ -0,0 +1,48 @@ +import math +import random + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): + if p < 0.0 or p >= 1.0: + raise ValueError( + f"dropout probability has to satisfy 0 <= p < 1, but got {p}" + ) + + assert input.ndim >= 2, "Feature dropout requires at least 2 dimensions in the input" + + if not training or p == 0: + return input + + seed = random.randrange(0, 2**31) + + if inplace: + output = input + else: + output = torch.empty_like(input) + + alpha_prime = -1.7580993408473766 + + q = 1.0 - float(p) + a = 1.0 / math.sqrt(q * (1.0 + float(p) * alpha_prime * alpha_prime)) + b = -a * float(p) * alpha_prime + + kernel = _cached_make( + ntops.kernels.feature_alpha_dropout.premake, + input.ndim, + ) + + kernel( + input, + float(p), + seed, + float(a), + float(b), + output, + ) + + return output \ No newline at end of file diff --git a/src/ntops/torch/flip.py b/src/ntops/torch/flip.py new file mode 100644 index 0000000..b9f376e --- /dev/null +++ b/src/ntops/torch/flip.py @@ -0,0 +1,39 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def _normalize_dims(dims, ndim): + if isinstance(dims, int): + dims = (dims,) + + dims = tuple(dim if dim >= 0 else dim + ndim for dim in dims) + + assert all(0 <= dim < ndim for dim in dims), "`dims` out of range." + + result = [] + for dim in dims: + if dim not in result: + result.append(dim) + + return tuple(result) + + +def flip(input, dims): + dims = _normalize_dims(dims, input.ndim) + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.flip.premake, + input.ndim, + dims, + ) + + kernel( + input, + output, + ) + + return output \ No newline at end of file diff --git a/src/ntops/torch/fliplr.py b/src/ntops/torch/fliplr.py new file mode 100644 index 0000000..f12129a --- /dev/null +++ b/src/ntops/torch/fliplr.py @@ -0,0 +1,23 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def fliplr(input): + assert input.ndim >= 2, "`fliplr` requires input with ndim >= 2." + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.flip.premake, + input.ndim, + (1,), + ) + + kernel( + input, + output, + ) + + return output \ No newline at end of file diff --git a/src/ntops/torch/mse_loss.py b/src/ntops/torch/mse_loss.py new file mode 100644 index 0000000..12ae481 --- /dev/null +++ b/src/ntops/torch/mse_loss.py @@ -0,0 +1,45 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make +_REDUCTION_NONE = 0 +_REDUCTION_MEAN = 1 +_REDUCTION_SUM = 2 +def _get_reduction_enum(reduction): + if reduction == "none": + return _REDUCTION_NONE + if reduction == "mean": + return _REDUCTION_MEAN + if reduction == "sum": + return _REDUCTION_SUM + raise ValueError("`reduction` must be one of 'none', 'mean', or 'sum'.") +def _as_scalar(output): + if hasattr(output, "reshape"): + return output.reshape(()) + if hasattr(output, "view"): + return output.view(()) + return output +def mse_loss(input, target, reduction="mean"): + assert input.shape == target.shape, "`input` and `target` must have the same shape." + + reduction_enum = _get_reduction_enum(reduction) + + if reduction_enum == _REDUCTION_NONE: + output = torch.empty_like(input) + else: + output = torch.empty((1,), dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.mse_loss.premake, + input.ndim, + reduction_enum, + ) + if reduction_enum == _REDUCTION_MEAN: + kernel(input, target, output, float(1.0 / input.numel())) + else: + kernel(input, target, output) + + if reduction_enum == _REDUCTION_NONE: + return output + + return _as_scalar(output) \ No newline at end of file diff --git a/src/ntops/torch/pixel_unshuffle.py b/src/ntops/torch/pixel_unshuffle.py new file mode 100644 index 0000000..176e40a --- /dev/null +++ b/src/ntops/torch/pixel_unshuffle.py @@ -0,0 +1,34 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def pixel_unshuffle(input, downscale_factor): + assert input.ndim == 4, "`pixel_unshuffle` only supports 4D NCHW input." + + n, c, h, w = input.shape + r = downscale_factor + + assert isinstance(r, int), "`downscale_factor` must be int." + assert r > 0, "`downscale_factor` must be positive." + assert h % r == 0, "input height must be divisible by downscale_factor." + assert w % r == 0, "input width must be divisible by downscale_factor." + + output = torch.empty( + (n, c * r * r, h // r, w // r), + dtype=input.dtype, + device=input.device, + ) + + kernel = _cached_make( + ntops.kernels.pixel_unshuffle.premake, + ) + + kernel( + input, + output, + downscale_factor=r, + ) + + return output \ No newline at end of file diff --git a/tests/test_feature_alpha_dropout.py b/tests/test_feature_alpha_dropout.py new file mode 100644 index 0000000..0a33fab --- /dev/null +++ b/tests/test_feature_alpha_dropout.py @@ -0,0 +1,89 @@ +import math +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_feature_alpha_dropout(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + p = random.uniform(0.05, 0.9) + + # PyTorch feature_alpha_dropout 不支持 1D 输入 + if input.ndim < 2: + with pytest.raises(RuntimeError): + F.feature_alpha_dropout( + input, + p=p, + training=True, + ) + + with pytest.raises(AssertionError): + ntops.torch.feature_alpha_dropout( + input, + p=p, + training=True, + ) + + return + + ninetoothed_output = ntops.torch.feature_alpha_dropout( + input, + p=p, + training=True, + ) + + reference_output = F.feature_alpha_dropout( + input, + p=p, + training=True, + ) + + assert ninetoothed_output.shape == reference_output.shape + + alpha_prime = -1.7580993408473766 + q = 1.0 - p + a = 1.0 / math.sqrt(q * (1.0 + p * alpha_prime * alpha_prime)) + b = -a * p * alpha_prime + drop_value = alpha_prime * a + b + + drop_value = torch.tensor( + drop_value, + dtype=ninetoothed_output.dtype, + device=ninetoothed_output.device, + ) + + ninetoothed_dropped = torch.isclose( + ninetoothed_output, + drop_value, + rtol=rtol, + atol=max(atol, 1e-3), + ) + + reference_dropped = torch.isclose( + reference_output, + drop_value, + rtol=rtol, + atol=max(atol, 1e-3), + ) + + ninetoothed_drop_ratio = ninetoothed_dropped.sum().item() / input.numel() + reference_drop_ratio = reference_dropped.sum().item() / input.numel() + + assert abs(ninetoothed_drop_ratio - reference_drop_ratio) < 0.1 + + kept = ~ninetoothed_dropped + + assert torch.allclose( + ninetoothed_output[kept], + input[kept] * a + b, + rtol=rtol, + atol=max(atol, 1e-3), + ) \ No newline at end of file diff --git a/tests/test_flip.py b/tests/test_flip.py new file mode 100644 index 0000000..b2a9631 --- /dev/null +++ b/tests/test_flip.py @@ -0,0 +1,54 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize( + "input_shape,dims,input_strides", + ( + ((13, 4), (0,), None), + ((8, 16), (1,), (128, 1)), + ((2, 3, 4), (2,), None), + ((4, 5, 6), (0, 2), None), + ((16, 64), (0, 1), None), + ((2, 2, 3, 4), (1, 3), None), + ), +) +def test_flip(shape, dtype, device, rtol, atol, input_shape, dims, input_strides): + del shape + + if input_strides is None: + input = torch.randn(input_shape, dtype=dtype, device=device) + else: + storage_size = 1 + for size, stride in zip(input_shape, input_strides): + storage_size += (size - 1) * stride + + base = torch.randn((storage_size,), dtype=dtype, device=device) + input = torch.as_strided( + base, + size=input_shape, + stride=input_strides, + ) + + ninetoothed_output = ntops.torch.flip( + input, + dims, + ) + + reference_output = torch.flip( + input, + dims, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_fliplr.py b/tests/test_fliplr.py new file mode 100644 index 0000000..24ccbd4 --- /dev/null +++ b/tests/test_fliplr.py @@ -0,0 +1,49 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize( + "input_shape,input_strides", + ( + ((13, 4), None), + ((8, 16), (128, 1)), + ((2, 3, 4), None), + ((4, 5, 6), None), + ((16, 64), None), + ((2, 2, 3, 4), None), + ((2, 3, 4, 5), (60, 20, 5, 1)), + ), +) +def test_fliplr(shape, dtype, device, rtol, atol, input_shape, input_strides): + del shape + + if input_strides is None: + input = torch.randn(input_shape, dtype=dtype, device=device) + else: + storage_size = 1 + for size, stride in zip(input_shape, input_strides): + storage_size += (size - 1) * stride + + base = torch.randn((storage_size,), dtype=dtype, device=device) + input = torch.as_strided( + base, + size=input_shape, + stride=input_strides, + ) + + ninetoothed_output = ntops.torch.fliplr(input) + + reference_output = torch.fliplr(input) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_mse_loss.py b/tests/test_mse_loss.py new file mode 100644 index 0000000..e563ffc --- /dev/null +++ b/tests/test_mse_loss.py @@ -0,0 +1,37 @@ +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_mse_loss(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + target = torch.randn(shape, dtype=dtype, device=device) + + reduction = random.choice(("none", "mean", "sum")) + + ninetoothed_output = ntops.torch.mse_loss( + input, + target, + reduction=reduction, + ) + + reference_output = F.mse_loss( + input, + target, + reduction=reduction, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_pixel_unshuffle.py b/tests/test_pixel_unshuffle.py new file mode 100644 index 0000000..e2e2113 --- /dev/null +++ b/tests/test_pixel_unshuffle.py @@ -0,0 +1,41 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("downscale_factor", (1, 2, 4)) +@pytest.mark.parametrize( + "input_shape", + ( + (13, 1, 8, 8), + (2, 3, 16, 24), + (1, 4, 32, 32), + ), +) +def test_pixel_unshuffle(shape, dtype, device, rtol, atol, downscale_factor, input_shape): + del shape + + input = torch.randn(input_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.pixel_unshuffle( + input, + downscale_factor, + ) + + reference_output = F.pixel_unshuffle( + input, + downscale_factor, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file