diff --git a/HONOR_CODE.md b/HONOR_CODE.md new file mode 100644 index 0000000..5d22ad7 --- /dev/null +++ b/HONOR_CODE.md @@ -0,0 +1,71 @@ +# 2026 春季启元人工智能大赛诚信守则(Honor Code) + + +本人作为 2026 春季启元人工智能大赛(以下简称“比赛”)的参赛选手,郑重承诺严格遵守比赛规则及本诚信守则,秉持诚信、公正、廉洁的参赛原则,自觉维护比赛的公平性与严肃性。本人充分理解并认可,违反本准则将导致参赛资格被取消、比赛成绩作废等相应后果,且愿意承担由此产生的一切责任。 + +## 一、参赛诚信承诺 + +1. 本人保证所提交的赛题PR(Pull Request)中包含的算子实现代码及相关文档,均为本人(及参赛团队,如为团队参赛)在比赛期间独立完成或在明确标注参考来源的基础上进行开发,不存在任何欺诈、抄袭、作弊行为。 + +2. 本人承诺主动、全面、真实地披露赛题实现过程中所有参考的外部资源,尤其是开源代码资源,不隐瞒任何可能影响比赛公平性的信息。 + +3. 本人保证不采用任何不正当手段获取比赛优势,包括但不限于窃取其他参赛选手的代码成果、利用非比赛允许的工具或技术、与他人串通作弊等。 + +## 二、参考资源说明 + +本人确认已按比赛要求,将本次赛题实现过程中涉及的参考资源信息单独撰写至`REFERENCE.md`文件中,该文件将与本诚信守则一同作为PR附件提交。`REFERENCE.md`需根据实际参考情况,按以下要求完整填写,信息不完整或虚假填写将视为违反本准则: + +**情况1:无参考外部开源代码及核心实现思路** + +`REFERENCE.md`中需明确声明:“本次赛题提交的算子代码、核心算法逻辑及实现方案均为本人(及参赛团队)独立设计与开发,未参考任何外部开源项目、技术文档中的核心代码片段或实现思路,未接受任何第三方的技术指导或代码支持。” + +**情况2:有参考外部开源代码及相关资源** + +对每个参考资源提供以下信息陈述: +1. 参考开源项目/资源名称 + +2. 参考资源链接(GitHub/Gitee/论文/技术文档等) + +3. 参考的具体内容(请明确说明参考的代码片段、算法逻辑、实现思路等,需标注对应资源的具体位置,如文件路径、代码行数等) + +4. 本人对参考内容的修改与优化说明:(请详细说明在参考基础上,本人所做的独立开发、修改、优化工作,体现自身技术贡献) + +5. 若是开源项目,提供参考资源的开源协议类型:(如MIT、Apache 2.0、GPL等) + +6. 其他需要补充说明的信息 + + +## 三、禁止行为确认 + +本人明确知晓并承诺避免以下违反比赛公平性的行为,若存在以下任一情况,自愿接受比赛组委会的相应处罚: + +1. 未经授权复制、抄袭他人(包括其他参赛选手、开源项目、商业代码)的代码、算法或技术方案,且未进行明确标注; + +2. 隐瞒或虚假披露参考资源信息,包括遗漏重要参考来源、伪造参考内容说明等; + +3. 与其他参赛选手或第三方串通,进行代码共享、成果交换等违规协作; + +4. 利用比赛平台漏洞、技术缺陷或非比赛允许的工具获取不正当利益; + +5. 伪造比赛相关证明材料、提交虚假信息; + +6. 其他违反比赛规则及公序良俗的不诚信行为。 + + +## 四、责任与确认 + +1. 本人充分理解,比赛组委会将对所有提交的PR进行代码溯源、参考信息核查等公平性审查,若发现本人存在违反本准则的行为,有权随时取消本人的参赛资格、作废比赛成绩,情节严重的将在比赛相关平台进行公示。 + +2. 若因本人违反本准则导致比赛争议或第三方权益受损(如开源协议侵权等),本人将独立承担全部法律责任及相关损失,与比赛组委会无关。 + +3. 本人确认已仔细阅读并完全理解本诚信守则的全部内容,自愿签署本准则,接受比赛组委会的监督与审查。 + +## 五、签署信息 + +参赛选手姓名(团队参赛需填写所有成员姓名) + + 李浩坤 + +签署日期 + +___2026___年__6__月__17__日 \ No newline at end of file diff --git a/README.md b/README.md index 77a300f..189354f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,15 @@ # ntops + NineToothed operators for LLMs. + +## src/ntops/kernels + +这个目录下写的是九齿算子,九齿是一种自研的DSL,application函数中写算子逻辑 + +## src/ntops/torch + +这个目录下为九齿算子提供pytorch的包装层 + +## tests + +test目录下完成九齿算子和torch算子的正确性的对比测试 \ No newline at end of file diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..3240e9b 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -7,6 +7,7 @@ bitwise_not, bitwise_or, bmm, + chunk, clamp, conv2d, cos, @@ -14,6 +15,8 @@ dropout, eq, exp, + eye, + flatten, ge, gelu, gt, @@ -29,6 +32,7 @@ neg, pow, relu, + repeat, rms_norm, rotary_position_embedding, rsqrt, @@ -39,6 +43,7 @@ softmax, sub, tanh, + unbind, ) __all__ = [ @@ -50,6 +55,7 @@ "bitwise_not", "bitwise_or", "bmm", + "chunk", "clamp", "conv2d", "cos", @@ -57,6 +63,8 @@ "dropout", "eq", "exp", + "eye", + "flatten", "ge", "gelu", "gt", @@ -72,6 +80,7 @@ "neg", "pow", "relu", + "repeat", "rms_norm", "rotary_position_embedding", "rsqrt", @@ -82,4 +91,5 @@ "softmax", "sub", "tanh", + "unbind", ] diff --git a/src/ntops/kernels/chunk.py b/src/ntops/kernels/chunk.py new file mode 100644 index 0000000..a2c4744 --- /dev/null +++ b/src/ntops/kernels/chunk.py @@ -0,0 +1,26 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +# The slicing is done by the torch wrapper (input.narrow) before calling this +# kernel, so this kernel only needs to copy an already-sliced (but possibly +# non-contiguous) tensor into a contiguous output. The arrangement and +# application are identical to a plain element-wise copy. +# +# Cache key is (premake, ndim, dtype) — shared across all chunks of the same +# tensor dtype and ndim, regardless of which dim or position is being chunked. +# Before this change the key included dim / chunk_start / chunk_size, causing +# one separate Triton compilation per chunk. + + +def application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + return arrangement_, application, tensors diff --git a/src/ntops/kernels/eye.py b/src/ntops/kernels/eye.py new file mode 100644 index 0000000..833506a --- /dev/null +++ b/src/ntops/kernels/eye.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(rows, cols, output): + output = ntl.where(rows == cols, 1.0, 0.0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/flatten.py b/src/ntops/kernels/flatten.py new file mode 100644 index 0000000..f6b82fa --- /dev/null +++ b/src/ntops/kernels/flatten.py @@ -0,0 +1,18 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.where(input >= 0, input, input) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/repeat.py b/src/ntops/kernels/repeat.py new file mode 100644 index 0000000..aab55fc --- /dev/null +++ b/src/ntops/kernels/repeat.py @@ -0,0 +1,18 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.where(input >= input, input, input) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/unbind.py b/src/ntops/kernels/unbind.py new file mode 100644 index 0000000..cc5d07d --- /dev/null +++ b/src/ntops/kernels/unbind.py @@ -0,0 +1,21 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +# Plain element-wise copy: input → output. +# The torch wrapper calls this with `moved = input.movedim(dim, 0)` as the +# source and a fresh contiguous tensor as the destination. By the time this +# kernel runs, the "unbind axis" has already been moved to dim-0 via a +# zero-cost view, so a single kernel invocation copies all slices in parallel +# instead of launching one kernel per slice. +def application(input, output): + output = input # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..a8f2903 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -6,6 +6,7 @@ from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm +from ntops.torch.chunk import chunk from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos @@ -13,6 +14,8 @@ from ntops.torch.dropout import dropout from ntops.torch.eq import eq from ntops.torch.exp import exp +from ntops.torch.eye import eye +from ntops.torch.flatten import flatten from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt @@ -29,6 +32,7 @@ from ntops.torch.neg import neg from ntops.torch.pow import pow from ntops.torch.relu import relu +from ntops.torch.repeat import repeat from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding from ntops.torch.rsqrt import rsqrt @@ -39,6 +43,7 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.unbind import unbind __all__ = [ "abs", @@ -49,6 +54,7 @@ "bitwise_not", "bitwise_or", "bmm", + "chunk", "clamp", "conv2d", "cos", @@ -56,6 +62,8 @@ "dropout", "eq", "exp", + "eye", + "flatten", "ge", "gelu", "gt", @@ -72,6 +80,7 @@ "neg", "pow", "relu", + "repeat", "rms_norm", "rotary_position_embedding", "rsqrt", @@ -82,4 +91,5 @@ "softmax", "sub", "tanh", + "unbind", ] diff --git a/src/ntops/torch/chunk.py b/src/ntops/torch/chunk.py new file mode 100644 index 0000000..a0df7a0 --- /dev/null +++ b/src/ntops/torch/chunk.py @@ -0,0 +1,65 @@ +import torch + +import ninetoothed +import ntops +from ntops.torch.utils import _cached_make + +_DTYPE_MAP = { + torch.float16: ninetoothed.float16, + torch.bfloat16: ninetoothed.bfloat16, + torch.float32: ninetoothed.float32, + torch.float64: ninetoothed.float64, + torch.int8: ninetoothed.int8, + torch.int16: ninetoothed.int16, + torch.int32: ninetoothed.int32, + torch.int64: ninetoothed.int64, +} + + +def chunk(input, chunks, dim=0): + if dim < 0: + dim = input.ndim + dim + + dim_size = input.shape[dim] + chunk_size = (dim_size + chunks - 1) // chunks + + # Fast path: contiguous input — every narrow() along any dim produces a + # contiguous view when the tensor is contiguous (dim=0) or when the sliced + # dim is the leading dimension of a contiguous tensor. For the most common + # case (dim=0, contiguous input) all slices are contiguous, so we can + # return views directly with zero kernel launches. + if input.is_contiguous() and dim == 0: + return tuple( + input.narrow(0, i * chunk_size, min(chunk_size, dim_size - i * chunk_size)) + for i in range(chunks) + if i * chunk_size < dim_size + ) + + # General path: slice in Python then decide per-chunk whether a kernel + # copy is needed. All chunks share one compiled kernel (cache key is + # (premake, ndim, dtype) only — dim/start/size are no longer part of it). + kernel = _cached_make( + ntops.kernels.chunk.premake, + input.ndim, + dtype=_DTYPE_MAP.get(input.dtype), + ) + + outputs = [] + for i in range(chunks): + start = i * chunk_size + if start >= dim_size: + break + + actual_size = min(chunk_size, dim_size - start) + chunk_view = input.narrow(dim, start, actual_size) + + if chunk_view.is_contiguous(): + outputs.append(chunk_view) + else: + out_chunk = torch.empty( + chunk_view.shape, dtype=input.dtype, device=input.device + ) + kernel(chunk_view, out_chunk) + outputs.append(out_chunk) + + return tuple(outputs) diff --git a/src/ntops/torch/eye.py b/src/ntops/torch/eye.py new file mode 100644 index 0000000..3bcf0e0 --- /dev/null +++ b/src/ntops/torch/eye.py @@ -0,0 +1,27 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def eye(n, m=None, *, dtype=None, device=None, out=None): + if m is None: + m = n + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if dtype is None: + dtype = torch.float32 + + rows = torch.arange(n, device=device).reshape(n, 1).expand(n, m) + cols = torch.arange(m, device=device).reshape(1, m).expand(n, m) + + if out is None: + out = torch.empty(n, m, dtype=dtype, device=device) + + kernel = _cached_make(ntops.kernels.eye.premake, 2) + + kernel(rows, cols, out) + + return out diff --git a/src/ntops/torch/flatten.py b/src/ntops/torch/flatten.py new file mode 100644 index 0000000..2a6b59c --- /dev/null +++ b/src/ntops/torch/flatten.py @@ -0,0 +1,30 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def flatten(input, start_dim=0, end_dim=-1): + if end_dim < 0: + end_dim = input.ndim + end_dim + + if start_dim < 0: + start_dim = input.ndim + start_dim + + flattened_numel = 1 + + for dim in range(start_dim, end_dim + 1): + flattened_numel *= input.shape[dim] + + out_shape = input.shape[:start_dim] + (flattened_numel,) + input.shape[end_dim + 1 :] + + out = torch.empty(out_shape, dtype=input.dtype, device=input.device) + + # Reshape input to match output ndim so the kernel can process both uniformly. + reshaped_input = input.reshape(out_shape) + + kernel = _cached_make(ntops.kernels.flatten.premake, out.ndim) + + kernel(reshaped_input, out) + + return out diff --git a/src/ntops/torch/repeat.py b/src/ntops/torch/repeat.py new file mode 100644 index 0000000..f85ebbf --- /dev/null +++ b/src/ntops/torch/repeat.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def repeat(input, *sizes): + if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): + sizes = tuple(sizes[0]) + + repeated = input.repeat(*sizes) + out = torch.empty_like(repeated) + + kernel = _cached_make(ntops.kernels.repeat.premake, repeated.ndim) + + kernel(repeated, out) + + return out diff --git a/src/ntops/torch/unbind.py b/src/ntops/torch/unbind.py new file mode 100644 index 0000000..dc54044 --- /dev/null +++ b/src/ntops/torch/unbind.py @@ -0,0 +1,47 @@ +import torch + +import ninetoothed +import ntops +from ntops.torch.utils import _cached_make + +_DTYPE_MAP = { + torch.float16: ninetoothed.float16, + torch.bfloat16: ninetoothed.bfloat16, + torch.float32: ninetoothed.float32, + torch.float64: ninetoothed.float64, + torch.int8: ninetoothed.int8, + torch.int16: ninetoothed.int16, + torch.int32: ninetoothed.int32, + torch.int64: ninetoothed.int64, +} + + +def unbind(input, dim=0): + if dim < 0: + dim = input.ndim + dim + + # movedim is a zero-cost view: (d0,..,dim,..,dk) → (dim_size, d0,..,dk). + # After this, every "slice" is simply moved[i], and the copy problem + # reduces to a single contiguous-output kernel regardless of which dim + # was originally requested. + moved = input.movedim(dim, 0) + n_slices = moved.shape[0] + + # Fast path: moved is already contiguous (happens when dim=0 and input is + # contiguous). Return views directly — zero kernel launches. + if moved.is_contiguous(): + return tuple(moved[i] for i in range(n_slices)) + + # General path: ONE kernel launch copies the entire non-contiguous `moved` + # into a contiguous output buffer. Previously this was n_slices separate + # launches (one per slice), each suffering its own launch overhead. + output = torch.empty_like(moved, memory_format=torch.contiguous_format) + kernel = _cached_make( + ntops.kernels.unbind.premake, + moved.ndim, + dtype=_DTYPE_MAP.get(input.dtype), + ) + kernel(moved, output) + + # output[i] is a contiguous view into the output buffer. + return tuple(output[i] for i in range(n_slices)) diff --git a/tests/test_chunk.py b/tests/test_chunk.py new file mode 100644 index 0000000..b2f5682 --- /dev/null +++ b/tests/test_chunk.py @@ -0,0 +1,45 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_chunk(shape, dtype, device, rtol, atol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + input = torch.randn(shape, dtype=dtype, device=device) + chunks = max(1, input.shape[0] // 2) + + ninetoothed_output = ntops.torch.chunk(input, chunks) + reference_output = torch.chunk(input, chunks) + + assert len(ninetoothed_output) == len(reference_output) + + for ninetoothed_chunk, reference_chunk in zip(ninetoothed_output, reference_output): + assert torch.allclose(ninetoothed_chunk, reference_chunk, rtol=rtol, atol=atol) + assert ninetoothed_chunk.shape == reference_chunk.shape + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ndim", [1, 2, 3, 4]) +def test_chunk_dims(ndim): + shape = tuple(range(3, ndim + 3)) + input = torch.randn(shape, device="cuda") + + for dim in range(ndim): + chunks = max(1, input.shape[dim] // 2) + + ninetoothed_output = ntops.torch.chunk(input, chunks, dim) + reference_output = torch.chunk(input, chunks, dim) + + assert len(ninetoothed_output) == len(reference_output) + + for ninetoothed_chunk, reference_chunk in zip(ninetoothed_output, reference_output): + assert torch.allclose(ninetoothed_chunk, reference_chunk) + assert ninetoothed_chunk.shape == reference_chunk.shape diff --git a/tests/test_eye.py b/tests/test_eye.py new file mode 100644 index 0000000..c1af77d --- /dev/null +++ b/tests/test_eye.py @@ -0,0 +1,39 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestEye: + @pytest.mark.parametrize("n", [1, 3, 5, 10]) + def test_square(self, n): + ninetoothed_output = ntops.torch.eye(n, device="cuda") + reference_output = torch.eye(n, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.device.type == "cuda" + assert ninetoothed_output.shape == (n, n) + + @pytest.mark.parametrize("n, m", [(3, 5), (5, 3), (1, 10), (10, 1)]) + def test_rectangular(self, n, m): + ninetoothed_output = ntops.torch.eye(n, m, device="cuda") + reference_output = torch.eye(n, m, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (n, m) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + n = 4 + ninetoothed_output = ntops.torch.eye(n, dtype=dtype, device="cuda") + reference_output = torch.eye(n, dtype=dtype, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.dtype == dtype + + def test_default_device(self): + if torch.cuda.is_available(): + result = ntops.torch.eye(3) + assert result.device.type == "cuda" diff --git a/tests/test_flatten.py b/tests/test_flatten.py new file mode 100644 index 0000000..6946cc2 --- /dev/null +++ b/tests/test_flatten.py @@ -0,0 +1,37 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_flatten_default(shape, dtype, device, rtol, atol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.flatten(input) + reference_output = torch.flatten(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) + assert ninetoothed_output.shape == reference_output.shape + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ndim", [2, 3, 4]) +def test_flatten_partial(ndim): + shape = tuple(range(2, ndim + 2)) + input = torch.randn(shape, device="cuda") + + for start_dim in range(ndim): + for end_dim in range(start_dim, ndim): + ninetoothed_output = ntops.torch.flatten(input, start_dim, end_dim) + reference_output = torch.flatten(input, start_dim, end_dim) + + assert torch.allclose(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == reference_output.shape diff --git a/tests/test_repeat.py b/tests/test_repeat.py new file mode 100644 index 0000000..ac2177c --- /dev/null +++ b/tests/test_repeat.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_repeat_same(shape, dtype, device, rtol, atol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + input = torch.randn(shape, dtype=dtype, device=device) + repeats = tuple(2 for _ in range(input.ndim)) + + ninetoothed_output = ntops.torch.repeat(input, *repeats) + reference_output = input.repeat(*repeats) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) + assert ninetoothed_output.shape == reference_output.shape + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ndim", [1, 2, 3, 4]) +def test_repeat_various(ndim): + shape = tuple(range(2, ndim + 2)) + input = torch.randn(shape, device="cuda") + + repeat_specs = [ + tuple(1 for _ in range(ndim)), + tuple(3 for _ in range(ndim)), + (1,) * (ndim - 1) + (4,), + (2, 3) if ndim >= 2 else (1,), + ] + + for repeats in repeat_specs: + if len(repeats) != ndim: + continue + + ninetoothed_output = ntops.torch.repeat(input, *repeats) + reference_output = input.repeat(*repeats) + + assert torch.allclose(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == reference_output.shape + + +@skip_if_cuda_not_available +def test_repeat_list_input(): + input = torch.tensor([[1, 2], [3, 4]], device="cuda") + sizes = (2, 3) + + # Test with *sizes unpacking + ninetoothed_output = ntops.torch.repeat(input, *sizes) + reference_output = input.repeat(*sizes) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_unbind.py b/tests/test_unbind.py new file mode 100644 index 0000000..d93564e --- /dev/null +++ b/tests/test_unbind.py @@ -0,0 +1,53 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_unbind(shape, dtype, device, rtol, atol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.unbind(input) + reference_output = torch.unbind(input) + + assert len(ninetoothed_output) == len(reference_output) + + for ninetoothed_tensor, reference_tensor in zip(ninetoothed_output, reference_output): + assert torch.allclose(ninetoothed_tensor, reference_tensor, rtol=rtol, atol=atol) + assert ninetoothed_tensor.shape == reference_tensor.shape + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ndim", [1, 2, 3, 4]) +def test_unbind_dims(ndim): + shape = tuple(range(2, ndim + 2)) + input = torch.randn(shape, device="cuda") + + for dim in range(ndim): + ninetoothed_output = ntops.torch.unbind(input, dim) + reference_output = torch.unbind(input, dim) + + assert len(ninetoothed_output) == len(reference_output) + + for ninetoothed_tensor, reference_tensor in zip(ninetoothed_output, reference_output): + assert torch.allclose(ninetoothed_tensor, reference_tensor) + assert ninetoothed_tensor.shape == reference_tensor.shape + + +@skip_if_cuda_not_available +def test_unbind_concatenation(): + input = torch.randn(3, 4, 5, device="cuda") + + for dim in range(3): + ninetoothed_output = ntops.torch.unbind(input, dim) + stacked = torch.stack(ninetoothed_output, dim) + + assert torch.allclose(stacked, input)