diff --git a/HONOR_CODE.md b/HONOR_CODE.md new file mode 100644 index 0000000..5d22ad7 --- /dev/null +++ b/HONOR_CODE.md @@ -0,0 +1,71 @@ +# 2026 春季启元人工智能大赛诚信守则(Honor Code) + + +本人作为 2026 春季启元人工智能大赛(以下简称“比赛”)的参赛选手,郑重承诺严格遵守比赛规则及本诚信守则,秉持诚信、公正、廉洁的参赛原则,自觉维护比赛的公平性与严肃性。本人充分理解并认可,违反本准则将导致参赛资格被取消、比赛成绩作废等相应后果,且愿意承担由此产生的一切责任。 + +## 一、参赛诚信承诺 + +1. 本人保证所提交的赛题PR(Pull Request)中包含的算子实现代码及相关文档,均为本人(及参赛团队,如为团队参赛)在比赛期间独立完成或在明确标注参考来源的基础上进行开发,不存在任何欺诈、抄袭、作弊行为。 + +2. 本人承诺主动、全面、真实地披露赛题实现过程中所有参考的外部资源,尤其是开源代码资源,不隐瞒任何可能影响比赛公平性的信息。 + +3. 本人保证不采用任何不正当手段获取比赛优势,包括但不限于窃取其他参赛选手的代码成果、利用非比赛允许的工具或技术、与他人串通作弊等。 + +## 二、参考资源说明 + +本人确认已按比赛要求,将本次赛题实现过程中涉及的参考资源信息单独撰写至`REFERENCE.md`文件中,该文件将与本诚信守则一同作为PR附件提交。`REFERENCE.md`需根据实际参考情况,按以下要求完整填写,信息不完整或虚假填写将视为违反本准则: + +**情况1:无参考外部开源代码及核心实现思路** + +`REFERENCE.md`中需明确声明:“本次赛题提交的算子代码、核心算法逻辑及实现方案均为本人(及参赛团队)独立设计与开发,未参考任何外部开源项目、技术文档中的核心代码片段或实现思路,未接受任何第三方的技术指导或代码支持。” + +**情况2:有参考外部开源代码及相关资源** + +对每个参考资源提供以下信息陈述: +1. 参考开源项目/资源名称 + +2. 参考资源链接(GitHub/Gitee/论文/技术文档等) + +3. 参考的具体内容(请明确说明参考的代码片段、算法逻辑、实现思路等,需标注对应资源的具体位置,如文件路径、代码行数等) + +4. 本人对参考内容的修改与优化说明:(请详细说明在参考基础上,本人所做的独立开发、修改、优化工作,体现自身技术贡献) + +5. 若是开源项目,提供参考资源的开源协议类型:(如MIT、Apache 2.0、GPL等) + +6. 其他需要补充说明的信息 + + +## 三、禁止行为确认 + +本人明确知晓并承诺避免以下违反比赛公平性的行为,若存在以下任一情况,自愿接受比赛组委会的相应处罚: + +1. 未经授权复制、抄袭他人(包括其他参赛选手、开源项目、商业代码)的代码、算法或技术方案,且未进行明确标注; + +2. 隐瞒或虚假披露参考资源信息,包括遗漏重要参考来源、伪造参考内容说明等; + +3. 与其他参赛选手或第三方串通,进行代码共享、成果交换等违规协作; + +4. 利用比赛平台漏洞、技术缺陷或非比赛允许的工具获取不正当利益; + +5. 伪造比赛相关证明材料、提交虚假信息; + +6. 其他违反比赛规则及公序良俗的不诚信行为。 + + +## 四、责任与确认 + +1. 本人充分理解,比赛组委会将对所有提交的PR进行代码溯源、参考信息核查等公平性审查,若发现本人存在违反本准则的行为,有权随时取消本人的参赛资格、作废比赛成绩,情节严重的将在比赛相关平台进行公示。 + +2. 若因本人违反本准则导致比赛争议或第三方权益受损(如开源协议侵权等),本人将独立承担全部法律责任及相关损失,与比赛组委会无关。 + +3. 本人确认已仔细阅读并完全理解本诚信守则的全部内容,自愿签署本准则,接受比赛组委会的监督与审查。 + +## 五、签署信息 + +参赛选手姓名(团队参赛需填写所有成员姓名) + + 李浩坤 + +签署日期 + +___2026___年__6__月__17__日 \ No newline at end of file diff --git a/README.md b/README.md index 77a300f..f1ae839 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,14 @@ # ntops NineToothed operators for LLMs. + +## src/ntops/kernels + +这个目录下写的是九齿算子,九齿是一种自研的DSL,application函数中写算子逻辑 + +## src/ntops/torch + +这个目录下为九齿算子提供pytorch的包装层 + +## tests + +test目录下完成九齿算子和torch算子的正确性的对比测试 \ No newline at end of file diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..bc8bdbc 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -27,6 +27,7 @@ mul, ne, neg, + outer, pow, relu, rms_norm, @@ -39,6 +40,10 @@ softmax, sub, tanh, + trace, + tril, + triu, + triu_indices, ) __all__ = [ @@ -70,6 +75,7 @@ "mul", "ne", "neg", + "outer", "pow", "relu", "rms_norm", @@ -82,4 +88,8 @@ "softmax", "sub", "tanh", + "trace", + "tril", + "triu", + "triu_indices", ] diff --git a/src/ntops/kernels/outer.py b/src/ntops/kernels/outer.py new file mode 100644 index 0000000..cc066b7 --- /dev/null +++ b/src/ntops/kernels/outer.py @@ -0,0 +1,43 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, other, output, block_size_m=None, block_size_n=None): + if block_size_m is None: + block_size_m = ninetoothed.block_size() + + if block_size_n is None: + block_size_n = ninetoothed.block_size() + + output_arranged = output.tile((block_size_m, block_size_n)) + + input_arranged = input.tile((block_size_m, 1)) + input_arranged = input_arranged.expand((-1, output_arranged.shape[1])) + + other_arranged = other.tile((1, block_size_n)) + other_arranged = other_arranged.expand((output_arranged.shape[0], -1)) + + return input_arranged, other_arranged, output_arranged + + +def application(input, other, output): + output = input * other # noqa: F841 + + +def premake(dtype=None, block_size_m=None, block_size_n=None): + arrangement_ = functools.partial( + arrangement, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + tensors = ( + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/trace.py b/src/ntops/kernels/trace.py new file mode 100644 index 0000000..cb46e9a --- /dev/null +++ b/src/ntops/kernels/trace.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + sum_val = ntl.sum(input) + output = sum_val + ntl.cast(0, ntl.float32) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(1, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/tril.py b/src/ntops/kernels/tril.py new file mode 100644 index 0000000..fb13189 --- /dev/null +++ b/src/ntops/kernels/tril.py @@ -0,0 +1,23 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(rows, cols, input, output): + output = ntl.where(cols <= rows, input, 0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/triu.py b/src/ntops/kernels/triu.py new file mode 100644 index 0000000..6ecba08 --- /dev/null +++ b/src/ntops/kernels/triu.py @@ -0,0 +1,23 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(rows, cols, input, output): + output = ntl.where(cols >= rows, input, 0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/triu_indices.py b/src/ntops/kernels/triu_indices.py new file mode 100644 index 0000000..a93e6dc --- /dev/null +++ b/src/ntops/kernels/triu_indices.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(rows, cols, output): + output = ntl.where(cols >= rows, 1, 0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..249032c 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -27,6 +27,7 @@ from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.outer import outer from ntops.torch.pow import pow from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm @@ -39,6 +40,10 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.trace import trace +from ntops.torch.tril import tril +from ntops.torch.triu import triu +from ntops.torch.triu_indices import triu_indices __all__ = [ "abs", @@ -70,6 +75,7 @@ "mul", "ne", "neg", + "outer", "pow", "relu", "rms_norm", @@ -82,4 +88,8 @@ "softmax", "sub", "tanh", + "trace", + "tril", + "triu", + "triu_indices", ] diff --git a/src/ntops/torch/outer.py b/src/ntops/torch/outer.py new file mode 100644 index 0000000..baaad1b --- /dev/null +++ b/src/ntops/torch/outer.py @@ -0,0 +1,21 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def outer(input, other, *, out=None): + m, n = input.size(0), other.size(0) + + if out is None: + out = torch.empty(m, n, dtype=input.dtype, device=input.device) + else: + out = out.view(m, n) + + input_2d = input.unsqueeze(1) + other_2d = other.unsqueeze(0) + + kernel = _cached_make(ntops.kernels.outer.premake) + kernel(input_2d, other_2d, out) + + return out diff --git a/src/ntops/torch/trace.py b/src/ntops/torch/trace.py new file mode 100644 index 0000000..4875aae --- /dev/null +++ b/src/ntops/torch/trace.py @@ -0,0 +1,14 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def trace(input): + diagonal = torch.diagonal(input) + output = torch.zeros(1, dtype=diagonal.dtype, device=diagonal.device) + + kernel = _cached_make(ntops.kernels.trace.premake, diagonal.ndim) + kernel(diagonal, output) + + return output.squeeze(0) diff --git a/src/ntops/torch/tril.py b/src/ntops/torch/tril.py new file mode 100644 index 0000000..5c9630e --- /dev/null +++ b/src/ntops/torch/tril.py @@ -0,0 +1,28 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def tril(input, diagonal=0, *, out=None): + n, m = input.shape[-2], input.shape[-1] + + rows = torch.arange(n, device=input.device).reshape(n, 1).expand(n, m) + cols = torch.arange(m, device=input.device).reshape(1, m).expand(n, m) + + # Adjust column comparison for diagonal offset: cols <= rows + diagonal + if diagonal != 0: + cols = cols - diagonal + + # Expand rows and cols to match input's ndim + rows = rows.reshape((1,) * (input.ndim - 2) + (n, m)).expand_as(input) + cols = cols.reshape((1,) * (input.ndim - 2) + (n, m)).expand_as(input) + + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.tril.premake, input.ndim) + + kernel(rows, cols, input, out) + + return out diff --git a/src/ntops/torch/triu.py b/src/ntops/torch/triu.py new file mode 100644 index 0000000..f23cea3 --- /dev/null +++ b/src/ntops/torch/triu.py @@ -0,0 +1,28 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def triu(input, diagonal=0, *, out=None): + n, m = input.shape[-2], input.shape[-1] + + rows = torch.arange(n, device=input.device).reshape(n, 1).expand(n, m) + cols = torch.arange(m, device=input.device).reshape(1, m).expand(n, m) + + # Adjust column comparison for diagonal offset: cols >= rows + diagonal + if diagonal != 0: + cols = cols - diagonal + + # Expand rows and cols to match input's ndim + rows = rows.reshape((1,) * (input.ndim - 2) + (n, m)).expand_as(input) + cols = cols.reshape((1,) * (input.ndim - 2) + (n, m)).expand_as(input) + + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.triu.premake, input.ndim) + + kernel(rows, cols, input, out) + + return out diff --git a/src/ntops/torch/triu_indices.py b/src/ntops/torch/triu_indices.py new file mode 100644 index 0000000..42e0a7f --- /dev/null +++ b/src/ntops/torch/triu_indices.py @@ -0,0 +1,26 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def triu_indices(n, m=None, offset=0, *, device=None): + if m is None: + m = n + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + rows = torch.arange(n, device=device).reshape(n, 1).expand(n, m) + cols = torch.arange(m, device=device).reshape(1, m).expand(n, m) + cols = cols - offset + + mask = torch.empty(n, m, dtype=torch.int32, device=device) + + kernel = _cached_make(ntops.kernels.triu_indices.premake, 2) + kernel(rows, cols, mask) + + indices = torch.nonzero(mask) + indices = indices.T.contiguous() + + return indices diff --git a/test_all.py b/test_all.py new file mode 100644 index 0000000..ae4d609 --- /dev/null +++ b/test_all.py @@ -0,0 +1,52 @@ +# 测试T-1-3这个组的所有算子 + +import torch +import sys +sys.path.insert(0, 'src') +import ntops +import inspect + +print('=== Testing triu_indices first (no other ops) ===') +print('sig:', inspect.signature(ntops.torch.triu_indices)) +r = ntops.torch.triu_indices(5, device='cuda') +ref = torch.triu_indices(5, 5, device='cuda') +assert torch.equal(r, ref), 'triu_indices alone failed' +print('triu_indices alone: OK') + +print('\n=== Now test all together ===') +x = torch.randn(5, 5, device='cuda') + +# tril +out = ntops.torch.tril(x) +ref = torch.tril(x) +assert torch.equal(out, ref), 'tril failed' +print('tril: OK') + +# triu +out = ntops.torch.triu(x) +ref = torch.triu(x) +assert torch.equal(out, ref), 'triu failed' +print('triu: OK') + +# trace +out = ntops.torch.trace(x) +ref = torch.trace(x) +assert torch.allclose(out, ref), 'trace failed' +print('trace: OK') + +# outer +a = torch.randn(5, device='cuda') +b = torch.randn(7, device='cuda') +out = ntops.torch.outer(a, b) +ref = torch.outer(a, b) +assert torch.allclose(out, ref), 'outer failed' +print('outer: OK') + +# triu_indices again +print('sig after other ops:', inspect.signature(ntops.torch.triu_indices)) +r = ntops.torch.triu_indices(5, device='cuda') +ref = torch.triu_indices(5, 5, device='cuda') +assert torch.equal(r, ref), 'triu_indices after tril failed' +print('triu_indices after tril: OK') + +print('\nALL PASSED') diff --git a/test_trace_debug.py b/test_trace_debug.py new file mode 100644 index 0000000..1533a90 --- /dev/null +++ b/test_trace_debug.py @@ -0,0 +1,25 @@ +"""调试: trace DSL - 1D diagonal + element_wise + ntl.sum""" +import torch, sys; sys.path.insert(0, 'src') +import ninetoothed.language as ntl +from ninetoothed import Tensor, make +from ntops.kernels.element_wise import arrangement + +def application(input, output): + sum_val = ntl.sum(input) + output = sum_val + ntl.cast(0, ntl.float32) + +def premake(ndim, dtype=None, block_size=None): + arr = functools.partial(arrangement, block_size=block_size) + tensors = (Tensor(ndim, dtype=dtype), Tensor(1, dtype=dtype)) + return arr, application, tensors + +import functools +N = 5 +x = torch.arange(N, dtype=torch.float32, device='cuda') +out = torch.zeros(1, device='cuda') + +k = make(arrangement, application, (Tensor(1), Tensor(1))) +k(x, out) +print(f"sum: {out.item()}") +print(f"expected: {x.sum().item()}") +print(f"correct: {abs(out.item() - x.sum().item()) < 1e-5}") diff --git a/tests/test_outer.py b/tests/test_outer.py new file mode 100644 index 0000000..3c9f7a0 --- /dev/null +++ b/tests/test_outer.py @@ -0,0 +1,40 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestOuter: + @pytest.mark.parametrize("m, n", [(1, 1), (3, 5), (5, 3), (10, 1), (1, 10)]) + def test_shapes(self, m, n): + input = torch.randn(m, device="cuda") + other = torch.randn(n, device="cuda") + + ninetoothed_output = ntops.torch.outer(input, other) + reference_output = torch.outer(input, other) + + assert torch.allclose(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (m, n) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) + def test_dtype(self, dtype): + input = torch.randn(5, device="cuda").to(dtype) + other = torch.randn(7, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.outer(input, other) + reference_output = torch.outer(input, other) + + assert torch.allclose(ninetoothed_output, reference_output, atol=0.01, rtol=0.01) + + def test_known_values(self): + input = torch.tensor([1.0, 2.0, 3.0], device="cuda") + other = torch.tensor([4.0, 5.0], device="cuda") + + ninetoothed_output = ntops.torch.outer(input, other) + reference_output = torch.outer(input, other) + + expected = torch.tensor([[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]], device="cuda") + assert torch.equal(ninetoothed_output, expected) + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_trace.py b/tests/test_trace.py new file mode 100644 index 0000000..31316ca --- /dev/null +++ b/tests/test_trace.py @@ -0,0 +1,45 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestTrace: + @pytest.mark.parametrize("n", [1, 3, 5, 10]) + def test_square(self, n): + input = torch.randn(n, n, device="cuda") + + ninetoothed_output = ntops.torch.trace(input) + reference_output = torch.trace(input) + + assert torch.allclose(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == () + + @pytest.mark.parametrize("n, m", [(3, 5), (5, 3)]) + def test_rectangular(self, n, m): + input = torch.randn(n, m, device="cuda") + + ninetoothed_output = ntops.torch.trace(input) + reference_output = torch.trace(input) + + assert torch.allclose(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == () + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + input = torch.randn(4, 4, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.trace(input) + reference_output = torch.trace(input) + + assert torch.allclose(ninetoothed_output, reference_output.to(dtype)) + + def test_zero_diagonal(self): + input = torch.zeros(0, 0, device="cuda") + + ninetoothed_output = ntops.torch.trace(input) + reference_output = torch.trace(input) + + assert ninetoothed_output.item() == 0.0 diff --git a/tests/test_tril.py b/tests/test_tril.py new file mode 100644 index 0000000..cfb084a --- /dev/null +++ b/tests/test_tril.py @@ -0,0 +1,53 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestTril: + @pytest.mark.parametrize("n", [1, 3, 5, 10]) + def test_square(self, n): + input = torch.randn(n, n, device="cuda") + + ninetoothed_output = ntops.torch.tril(input) + reference_output = torch.tril(input) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == input.shape + + @pytest.mark.parametrize("n, m", [(3, 5), (5, 3), (1, 10), (10, 1)]) + def test_rectangular(self, n, m): + input = torch.randn(n, m, device="cuda") + + ninetoothed_output = ntops.torch.tril(input) + reference_output = torch.tril(input) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("diagonal", [-2, -1, 0, 1, 2]) + def test_diagonal(self, diagonal): + input = torch.randn(5, 5, device="cuda") + + ninetoothed_output = ntops.torch.tril(input, diagonal) + reference_output = torch.tril(input, diagonal) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + input = torch.randn(4, 4, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.tril(input) + reference_output = torch.tril(input) + + assert torch.equal(ninetoothed_output, reference_output) + + def test_3d(self): + input = torch.randn(2, 3, 3, device="cuda") + + ninetoothed_output = ntops.torch.tril(input) + reference_output = torch.tril(input) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_triu.py b/tests/test_triu.py new file mode 100644 index 0000000..e8b539c --- /dev/null +++ b/tests/test_triu.py @@ -0,0 +1,53 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestTriu: + @pytest.mark.parametrize("n", [1, 3, 5, 10]) + def test_square(self, n): + input = torch.randn(n, n, device="cuda") + + ninetoothed_output = ntops.torch.triu(input) + reference_output = torch.triu(input) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == input.shape + + @pytest.mark.parametrize("n, m", [(3, 5), (5, 3), (1, 10), (10, 1)]) + def test_rectangular(self, n, m): + input = torch.randn(n, m, device="cuda") + + ninetoothed_output = ntops.torch.triu(input) + reference_output = torch.triu(input) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("diagonal", [-2, -1, 0, 1, 2]) + def test_diagonal(self, diagonal): + input = torch.randn(5, 5, device="cuda") + + ninetoothed_output = ntops.torch.triu(input, diagonal) + reference_output = torch.triu(input, diagonal) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + input = torch.randn(4, 4, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.triu(input) + reference_output = torch.triu(input) + + assert torch.equal(ninetoothed_output, reference_output) + + def test_3d(self): + input = torch.randn(2, 3, 3, device="cuda") + + ninetoothed_output = ntops.torch.triu(input) + reference_output = torch.triu(input) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_triu_indices.py b/tests/test_triu_indices.py new file mode 100644 index 0000000..a743dd8 --- /dev/null +++ b/tests/test_triu_indices.py @@ -0,0 +1,38 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestTriuIndices: + @pytest.mark.parametrize("n", [1, 3, 5, 10]) + def test_square(self, n): + ninetoothed_output = ntops.torch.triu_indices(n, device="cuda") + reference_output = torch.triu_indices(n, n, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == reference_output.shape + + @pytest.mark.parametrize("n, m", [(3, 5), (5, 3), (1, 10), (10, 1)]) + def test_rectangular(self, n, m): + ninetoothed_output = ntops.torch.triu_indices(n, m, device="cuda") + reference_output = torch.triu_indices(n, m, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == reference_output.shape + + @pytest.mark.parametrize("offset", [-3, -1, 0, 1, 3]) + def test_offset(self, offset): + n, m = 5, 5 + + ninetoothed_output = ntops.torch.triu_indices(n, m, offset, device="cuda") + reference_output = torch.triu_indices(n, m, offset, device="cuda") + + assert torch.equal(ninetoothed_output, reference_output) + + def test_default_device(self): + if torch.cuda.is_available(): + result = ntops.torch.triu_indices(3, device="cuda") + assert result.device.type == "cuda"