diff --git a/HONOR_CODE.md b/HONOR_CODE.md new file mode 100644 index 0000000..5d22ad7 --- /dev/null +++ b/HONOR_CODE.md @@ -0,0 +1,71 @@ +# 2026 春季启元人工智能大赛诚信守则(Honor Code) + + +本人作为 2026 春季启元人工智能大赛(以下简称“比赛”)的参赛选手,郑重承诺严格遵守比赛规则及本诚信守则,秉持诚信、公正、廉洁的参赛原则,自觉维护比赛的公平性与严肃性。本人充分理解并认可,违反本准则将导致参赛资格被取消、比赛成绩作废等相应后果,且愿意承担由此产生的一切责任。 + +## 一、参赛诚信承诺 + +1. 本人保证所提交的赛题PR(Pull Request)中包含的算子实现代码及相关文档,均为本人(及参赛团队,如为团队参赛)在比赛期间独立完成或在明确标注参考来源的基础上进行开发,不存在任何欺诈、抄袭、作弊行为。 + +2. 本人承诺主动、全面、真实地披露赛题实现过程中所有参考的外部资源,尤其是开源代码资源,不隐瞒任何可能影响比赛公平性的信息。 + +3. 本人保证不采用任何不正当手段获取比赛优势,包括但不限于窃取其他参赛选手的代码成果、利用非比赛允许的工具或技术、与他人串通作弊等。 + +## 二、参考资源说明 + +本人确认已按比赛要求,将本次赛题实现过程中涉及的参考资源信息单独撰写至`REFERENCE.md`文件中,该文件将与本诚信守则一同作为PR附件提交。`REFERENCE.md`需根据实际参考情况,按以下要求完整填写,信息不完整或虚假填写将视为违反本准则: + +**情况1:无参考外部开源代码及核心实现思路** + +`REFERENCE.md`中需明确声明:“本次赛题提交的算子代码、核心算法逻辑及实现方案均为本人(及参赛团队)独立设计与开发,未参考任何外部开源项目、技术文档中的核心代码片段或实现思路,未接受任何第三方的技术指导或代码支持。” + +**情况2:有参考外部开源代码及相关资源** + +对每个参考资源提供以下信息陈述: +1. 参考开源项目/资源名称 + +2. 参考资源链接(GitHub/Gitee/论文/技术文档等) + +3. 参考的具体内容(请明确说明参考的代码片段、算法逻辑、实现思路等,需标注对应资源的具体位置,如文件路径、代码行数等) + +4. 本人对参考内容的修改与优化说明:(请详细说明在参考基础上,本人所做的独立开发、修改、优化工作,体现自身技术贡献) + +5. 若是开源项目,提供参考资源的开源协议类型:(如MIT、Apache 2.0、GPL等) + +6. 其他需要补充说明的信息 + + +## 三、禁止行为确认 + +本人明确知晓并承诺避免以下违反比赛公平性的行为,若存在以下任一情况,自愿接受比赛组委会的相应处罚: + +1. 未经授权复制、抄袭他人(包括其他参赛选手、开源项目、商业代码)的代码、算法或技术方案,且未进行明确标注; + +2. 隐瞒或虚假披露参考资源信息,包括遗漏重要参考来源、伪造参考内容说明等; + +3. 与其他参赛选手或第三方串通,进行代码共享、成果交换等违规协作; + +4. 利用比赛平台漏洞、技术缺陷或非比赛允许的工具获取不正当利益; + +5. 伪造比赛相关证明材料、提交虚假信息; + +6. 其他违反比赛规则及公序良俗的不诚信行为。 + + +## 四、责任与确认 + +1. 本人充分理解,比赛组委会将对所有提交的PR进行代码溯源、参考信息核查等公平性审查,若发现本人存在违反本准则的行为,有权随时取消本人的参赛资格、作废比赛成绩,情节严重的将在比赛相关平台进行公示。 + +2. 若因本人违反本准则导致比赛争议或第三方权益受损(如开源协议侵权等),本人将独立承担全部法律责任及相关损失,与比赛组委会无关。 + +3. 本人确认已仔细阅读并完全理解本诚信守则的全部内容,自愿签署本准则,接受比赛组委会的监督与审查。 + +## 五、签署信息 + +参赛选手姓名(团队参赛需填写所有成员姓名) + + 李浩坤 + +签署日期 + +___2026___年__6__月__17__日 \ No newline at end of file diff --git a/README.md b/README.md index 77a300f..f1ae839 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,14 @@ # ntops NineToothed operators for LLMs. + +## src/ntops/kernels + +这个目录下写的是九齿算子,九齿是一种自研的DSL,application函数中写算子逻辑 + +## src/ntops/torch + +这个目录下为九齿算子提供pytorch的包装层 + +## tests + +test目录下完成九齿算子和torch算子的正确性的对比测试 \ No newline at end of file diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..6396e03 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -7,7 +7,9 @@ bitwise_not, bitwise_or, bmm, + cartesian_prod, clamp, + column_stack, conv2d, cos, div, @@ -23,13 +25,16 @@ le, lt, max_pool2d, + meshgrid, mm, + mode, mul, ne, neg, pow, relu, rms_norm, + roll, rotary_position_embedding, rsqrt, scaled_dot_product_attention, @@ -50,7 +55,9 @@ "bitwise_not", "bitwise_or", "bmm", + "cartesian_prod", "clamp", + "column_stack", "conv2d", "cos", "div", @@ -66,13 +73,16 @@ "le", "lt", "max_pool2d", + "meshgrid", "mm", + "mode", "mul", "ne", "neg", "pow", "relu", "rms_norm", + "roll", "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", diff --git a/src/ntops/kernels/cartesian_prod.py b/src/ntops/kernels/cartesian_prod.py new file mode 100644 index 0000000..3ba824c --- /dev/null +++ b/src/ntops/kernels/cartesian_prod.py @@ -0,0 +1,37 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(a, b, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + a_arranged = a.tile((block_size, 1)) + a_arranged = a_arranged.expand((-1, 2)) + + b_arranged = b.tile((block_size, 1)) + b_arranged = b_arranged.expand((-1, 2)) + + output_arranged = output.tile((block_size, 2)) + + return a_arranged, b_arranged, output_arranged + + +def application(a, b, output): + col_idx = ntl.arange(0, output.shape[1]) + output = ntl.where(col_idx == 0, a, b) # noqa: F841 + + +def premake(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/column_stack.py b/src/ntops/kernels/column_stack.py new file mode 100644 index 0000000..3ba824c --- /dev/null +++ b/src/ntops/kernels/column_stack.py @@ -0,0 +1,37 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(a, b, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + a_arranged = a.tile((block_size, 1)) + a_arranged = a_arranged.expand((-1, 2)) + + b_arranged = b.tile((block_size, 1)) + b_arranged = b_arranged.expand((-1, 2)) + + output_arranged = output.tile((block_size, 2)) + + return a_arranged, b_arranged, output_arranged + + +def application(a, b, output): + col_idx = ntl.arange(0, output.shape[1]) + output = ntl.where(col_idx == 0, a, b) # noqa: F841 + + +def premake(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + Tensor(2, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/meshgrid.py b/src/ntops/kernels/meshgrid.py new file mode 100644 index 0000000..ea62f5c --- /dev/null +++ b/src/ntops/kernels/meshgrid.py @@ -0,0 +1,24 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(x, y, X, Y): + X = ntl.where(x >= x, x, ntl.cast(0, ntl.float32)) # noqa: F841 + Y = ntl.where(y >= y, y, ntl.cast(0, ntl.float32)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/mode.py b/src/ntops/kernels/mode.py new file mode 100644 index 0000000..38956e1 --- /dev/null +++ b/src/ntops/kernels/mode.py @@ -0,0 +1,46 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor, block_size as _block_size + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = _block_size() + + input_arranged = input.tile((1, block_size)) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + output_arranged = output.tile((1,)) + + return input_arranged, output_arranged + + +def application(input, output): + n = input.shape[0] + best_val = ntl.cast(0, ntl.float32) + best_count = ntl.cast(-1, ntl.int32) + for i in range(n): + val = input[i] + ok = val == val + count = ntl.cast(0, ntl.int32) + for j in range(n): + vj = input[j] + same = ntl.where(vj == val, ntl.cast(1, ntl.int32), ntl.cast(0, ntl.int32)) + same = ntl.where(vj != vj, ntl.cast(0, ntl.int32), same) + count = count + same + better = (count >= best_count) & ok + best_val = ntl.where(better, val, best_val) + best_count = ntl.where(better, count, best_count) + output = best_val + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=float("nan")), + Tensor(1, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/roll.py b/src/ntops/kernels/roll.py new file mode 100644 index 0000000..f07871f --- /dev/null +++ b/src/ntops/kernels/roll.py @@ -0,0 +1,42 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, size, shift, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.tile((1, block_size)) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + output_arranged = output.tile((1, block_size)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + size_arranged = size.tile((1,)) + shift_arranged = shift.tile((1,)) + + return input_arranged, size_arranged, shift_arranged, output_arranged + + +def application(input, size, shift, output): + n = ntl.cast(size[0], ntl.int32) + s = ntl.cast(shift[0], ntl.int32) + for i in range(output.shape[0]): + src_idx = (i + n - s) % n + output[i] = input[src_idx] + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(1, dtype=dtype), + Tensor(1, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..fbd184a 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -6,7 +6,9 @@ from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm +from ntops.torch.cartesian_prod import cartesian_prod from ntops.torch.clamp import clamp +from ntops.torch.column_stack import column_stack from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos from ntops.torch.div import div @@ -23,12 +25,15 @@ from ntops.torch.lt import lt from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d +from ntops.torch.meshgrid import meshgrid from ntops.torch.mm import mm +from ntops.torch.mode import mode from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow from ntops.torch.relu import relu +from ntops.torch.roll import roll from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding from ntops.torch.rsqrt import rsqrt @@ -49,7 +54,9 @@ "bitwise_not", "bitwise_or", "bmm", + "cartesian_prod", "clamp", + "column_stack", "conv2d", "cos", "div", @@ -66,12 +73,15 @@ "lt", "matmul", "max_pool2d", + "meshgrid", "mm", + "mode", "mul", "ne", "neg", "pow", "relu", + "roll", "rms_norm", "rotary_position_embedding", "rsqrt", diff --git a/src/ntops/torch/cartesian_prod.py b/src/ntops/torch/cartesian_prod.py new file mode 100644 index 0000000..c1dba5a --- /dev/null +++ b/src/ntops/torch/cartesian_prod.py @@ -0,0 +1,22 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def cartesian_prod(*tensors): + if len(tensors) == 2: + a, b = tensors[0], tensors[1] + n, m = a.size(0), b.size(0) + + a_exp = a.repeat_interleave(m).unsqueeze(1) + b_exp = b.repeat(n).unsqueeze(1) + + output = torch.empty(n * m, 2, dtype=a.dtype, device=a.device) + + kernel = _cached_make(ntops.kernels.cartesian_prod.premake) + kernel(a_exp, b_exp, output) + + return output + + return torch.cartesian_prod(*tensors) diff --git a/src/ntops/torch/column_stack.py b/src/ntops/torch/column_stack.py new file mode 100644 index 0000000..77db93b --- /dev/null +++ b/src/ntops/torch/column_stack.py @@ -0,0 +1,22 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def column_stack(tensors): + if len(tensors) == 2 and tensors[0].ndim == 1 and tensors[1].ndim == 1: + a, b = tensors + N = a.size(0) + + a_2d = a.unsqueeze(1) + b_2d = b.unsqueeze(1) + + output = torch.empty(N, 2, dtype=a.dtype, device=a.device) + + kernel = _cached_make(ntops.kernels.column_stack.premake) + kernel(a_2d, b_2d, output) + + return output + + return torch.column_stack(tensors) diff --git a/src/ntops/torch/meshgrid.py b/src/ntops/torch/meshgrid.py new file mode 100644 index 0000000..b654078 --- /dev/null +++ b/src/ntops/torch/meshgrid.py @@ -0,0 +1,28 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def meshgrid(*tensors, indexing="ij"): + if len(tensors) > 2: + return torch.meshgrid(*tensors, indexing=indexing) + + x, y = tensors[0], tensors[1] + + if indexing == "ij": + nx, ny = x.size(0), y.size(0) + x_grid = x.view(-1, 1).expand(nx, ny) + y_grid = y.view(1, -1).expand(nx, ny) + else: + ny, nx = y.size(0), x.size(0) + x_grid = x.view(1, -1).expand(ny, nx) + y_grid = y.view(-1, 1).expand(ny, nx) + + X = torch.empty_like(x_grid) + Y = torch.empty_like(y_grid) + + kernel = _cached_make(ntops.kernels.meshgrid.premake, X.ndim) + kernel(x_grid, y_grid, X, Y) + + return X, Y diff --git a/src/ntops/torch/mode.py b/src/ntops/torch/mode.py new file mode 100644 index 0000000..f55e5e5 --- /dev/null +++ b/src/ntops/torch/mode.py @@ -0,0 +1,64 @@ +"""Kernel-based mode computation. + +Uses O(n²) per-tile frequency counting via DSL. Out-of-bounds +elements (loaded as NaN) are skipped via x != x detection. +""" +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def mode(input, dim=-1, keepdim=False): + if dim < 0: + dim += input.ndim + + if input.ndim == 1: + input_2d = input.view(1, -1) + kernel_output = torch.zeros(1, dtype=input.dtype, device=input.device) + kernel = _cached_make(ntops.kernels.mode.premake, input_2d.ndim) + kernel(input_2d, kernel_output) + values = kernel_output.squeeze(0) + + with torch.no_grad(): + mask = input == values + idx_tensor = torch.arange(input.shape[0], device=input.device) + indices = (idx_tensor * mask).max(dim=0).values + + if keepdim: + values = values.unsqueeze(0) + + return values, indices + + # Multi-dim: move target dim to last (permute keeps non-target order) + dims = list(range(input.ndim)) + dims.remove(dim) + dims.append(dim) + input_permuted = input.permute(*dims) + + # Flatten to 2D: (groups, dim_size) + input_2d = input_permuted.reshape(-1, input_permuted.shape[-1]) + + kernel_output = torch.empty(input_2d.shape[0], dtype=input.dtype, device=input.device) + kernel = _cached_make(ntops.kernels.mode.premake, input_2d.ndim) + kernel(input_2d, kernel_output) + + # Reshape: one mode per group, preserving non-target dim order + values = kernel_output.reshape(input_permuted.shape[:-1]) + + # Indices: find LAST occurrence of mode value (matching torch.mode) + with torch.no_grad(): + idx_tensor = torch.arange(input.shape[dim], device=input.device, dtype=torch.long) + idx_shape = [1] * input.ndim + idx_shape[dim] = input.shape[dim] + idx_broadcast = idx_tensor.view(idx_shape) + + mask = input == values.unsqueeze(dim) + indices = (idx_broadcast * mask).max(dim=dim, keepdim=True).values + + if keepdim: + values = values.unsqueeze(dim) + else: + indices = indices.squeeze(dim) + + return values, indices diff --git a/src/ntops/torch/roll.py b/src/ntops/torch/roll.py new file mode 100644 index 0000000..885e2b1 --- /dev/null +++ b/src/ntops/torch/roll.py @@ -0,0 +1,27 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def roll(input, shifts, dims=None): + if input.ndim == 1: + return _roll_1d(input, shifts, 0 if dims is None else dims) + + return torch.roll(input, shifts, dims) + + +def _roll_1d(input, shift, dim): + """Roll a 1D tensor using ninetoothed kernel.""" + n = input.shape[0] + shift = shift % n + + input_2d = input.view(1, n) + output_2d = torch.empty(1, n, device=input.device, dtype=input.dtype) + size_t = torch.tensor([n], dtype=torch.float32, device=input.device) + shift_t = torch.tensor([shift], dtype=torch.float32, device=input.device) + + kernel = _cached_make(ntops.kernels.roll.premake, input_2d.ndim) + kernel(input_2d, size_t, shift_t, output_2d) + + return output_2d.view(n) diff --git a/tests/test_cartesian_prod.py b/tests/test_cartesian_prod.py new file mode 100644 index 0000000..87c5202 --- /dev/null +++ b/tests/test_cartesian_prod.py @@ -0,0 +1,58 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestCartesianProd: + def test_two_1d_tensors(self): + a = torch.tensor([1, 2], device="cuda") + b = torch.tensor([3, 4], device="cuda") + + ninetoothed_output = ntops.torch.cartesian_prod(a, b) + reference_output = torch.cartesian_prod(a, b) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (4, 2) + + def test_three_1d_tensors(self): + a = torch.tensor([1, 2], device="cuda") + b = torch.tensor([3, 4], device="cuda") + c = torch.tensor([5, 6], device="cuda") + + ninetoothed_output = ntops.torch.cartesian_prod(a, b, c) + reference_output = torch.cartesian_prod(a, b, c) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (8, 3) + + def test_single_tensor(self): + a = torch.tensor([1, 2, 3], device="cuda") + + ninetoothed_output = ntops.torch.cartesian_prod(a) + reference_output = torch.cartesian_prod(a) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == reference_output.shape + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + a = torch.tensor([1, 2], device="cuda", dtype=dtype) + b = torch.tensor([3, 4], device="cuda", dtype=dtype) + + ninetoothed_output = ntops.torch.cartesian_prod(a, b) + reference_output = torch.cartesian_prod(a, b) + + assert torch.equal(ninetoothed_output, reference_output) + + def test_larger_inputs(self): + a = torch.arange(5, device="cuda") + b = torch.arange(6, device="cuda") + + ninetoothed_output = ntops.torch.cartesian_prod(a, b) + reference_output = torch.cartesian_prod(a, b) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (30, 2) diff --git a/tests/test_column_stack.py b/tests/test_column_stack.py new file mode 100644 index 0000000..09e0614 --- /dev/null +++ b/tests/test_column_stack.py @@ -0,0 +1,57 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestColumnStack: + def test_two_1d_tensors(self): + a = torch.tensor([1, 2, 3], device="cuda") + b = torch.tensor([4, 5, 6], device="cuda") + + ninetoothed_output = ntops.torch.column_stack([a, b]) + reference_output = torch.column_stack([a, b]) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (3, 2) + + def test_three_1d_tensors(self): + a = torch.randn(5, device="cuda") + b = torch.randn(5, device="cuda") + c = torch.randn(5, device="cuda") + + ninetoothed_output = ntops.torch.column_stack([a, b, c]) + reference_output = torch.column_stack([a, b, c]) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (5, 3) + + def test_2d_tensors(self): + a = torch.randn(3, 2, device="cuda") + b = torch.randn(3, 4, device="cuda") + + ninetoothed_output = ntops.torch.column_stack([a, b]) + reference_output = torch.column_stack([a, b]) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == (3, 6) + + def test_single_tensor(self): + a = torch.randn(5, device="cuda") + + ninetoothed_output = ntops.torch.column_stack([a]) + reference_output = torch.column_stack([a]) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + a = torch.randn(4, device="cuda").to(dtype) + b = torch.randn(4, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.column_stack([a, b]) + reference_output = torch.column_stack([a, b]) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_meshgrid.py b/tests/test_meshgrid.py new file mode 100644 index 0000000..b890a73 --- /dev/null +++ b/tests/test_meshgrid.py @@ -0,0 +1,64 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestMeshgrid: + def test_two_vectors(self): + x = torch.tensor([1, 2, 3], device="cuda") + y = torch.tensor([4, 5], device="cuda") + + ninetoothed_x, ninetoothed_y = ntops.torch.meshgrid(x, y, indexing="ij") + reference_x, reference_y = torch.meshgrid(x, y, indexing="ij") + + assert torch.equal(ninetoothed_x, reference_x) + assert torch.equal(ninetoothed_y, reference_y) + assert ninetoothed_x.shape == (3, 2) + assert ninetoothed_y.shape == (3, 2) + + def test_xy_indexing(self): + x = torch.tensor([1, 2, 3], device="cuda") + y = torch.tensor([4, 5], device="cuda") + + ninetoothed_x, ninetoothed_y = ntops.torch.meshgrid(x, y, indexing="xy") + reference_x, reference_y = torch.meshgrid(x, y, indexing="xy") + + assert torch.equal(ninetoothed_x, reference_x) + assert torch.equal(ninetoothed_y, reference_y) + assert ninetoothed_x.shape == (2, 3) + + def test_three_vectors(self): + x = torch.randn(3, device="cuda") + y = torch.randn(4, device="cuda") + z = torch.randn(5, device="cuda") + + ninetoothed_grid = ntops.torch.meshgrid(x, y, z, indexing="ij") + reference_grid = torch.meshgrid(x, y, z, indexing="ij") + + for ninetoothed_out, reference_out in zip(ninetoothed_grid, reference_grid): + assert torch.equal(ninetoothed_out, reference_out) + assert ninetoothed_out.shape == reference_out.shape + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + x = torch.tensor([1, 2, 3], device="cuda", dtype=dtype) + y = torch.tensor([4, 5, 6], device="cuda", dtype=dtype) + + ninetoothed_x, ninetoothed_y = ntops.torch.meshgrid(x, y, indexing="ij") + reference_x, reference_y = torch.meshgrid(x, y, indexing="ij") + + assert torch.equal(ninetoothed_x, reference_x) + assert torch.equal(ninetoothed_y, reference_y) + + def test_default_indexing(self): + x = torch.randn(3, device="cuda") + y = torch.randn(4, device="cuda") + + ninetoothed_x, ninetoothed_y = ntops.torch.meshgrid(x, y) + reference_x, reference_y = torch.meshgrid(x, y, indexing="ij") + + assert torch.equal(ninetoothed_x, reference_x) + assert torch.equal(ninetoothed_y, reference_y) diff --git a/tests/test_mode.py b/tests/test_mode.py new file mode 100644 index 0000000..04370a6 --- /dev/null +++ b/tests/test_mode.py @@ -0,0 +1,62 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestMode: + @pytest.mark.parametrize("shape, dim, values", [ + ((10,), 0, torch.tensor([1, 5, 5, 3, 5, 2, 3, 5, 8, 9])), + ((5, 5), 0, torch.tensor([[1, 1, 1, 1, 1], + [1, 2, 1, 1, 2], + [2, 1, 2, 2, 1], + [1, 2, 1, 1, 2], + [2, 1, 2, 2, 1]])), + ((5, 5), 1, torch.tensor([[1, 2, 2, 3, 3], [2, 3, 3, 4, 5], [1, 4, 4, 4, 2], [3, 3, 5, 1, 5], [2, 2, 2, 3, 3]])), + ((2, 3, 4), 0, torch.tensor([[[1, 2, 1, 2], [1, 2, 2, 1], [1, 1, 2, 2]], + [[1, 2, 1, 2], [1, 2, 2, 1], [1, 1, 2, 2]]])), + ((2, 3, 4), 2, torch.tensor([[[1, 1, 1, 2], [2, 2, 2, 1], [3, 3, 3, 1]], + [[2, 2, 2, 1], [1, 1, 1, 3], [1, 1, 2, 1]]])), + ]) + def test_basic(self, shape, dim, values): + input = values.to(dtype=torch.float32, device="cuda") + + ninetoothed_values, ninetoothed_indices = ntops.torch.mode(input, dim) + reference_values, reference_indices = torch.mode(input, dim) + + assert torch.equal(ninetoothed_values, reference_values) + assert torch.equal(ninetoothed_indices, reference_indices) + + def test_keepdim(self): + # Each dim=1 group of 3 values has a clear mode (no ties) + input = torch.tensor([[[1, 2, 2, 1], [2, 2, 2, 2], [1, 1, 1, 2]], + [[1, 2, 1, 1], [2, 3, 2, 2], [1, 2, 2, 2]]], + device="cuda", dtype=torch.float32) + + ninetoothed_values, ninetoothed_indices = ntops.torch.mode(input, dim=1, keepdim=True) + reference_values, reference_indices = torch.mode(input, dim=1, keepdim=True) + + assert torch.equal(ninetoothed_values, reference_values) + assert ninetoothed_values.shape == reference_values.shape + assert ninetoothed_indices.shape == reference_indices.shape + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) + def test_dtype(self, dtype): + input = torch.tensor([[1, 2, 2, 3], [4, 4, 4, 5], [3, 3, 3, 1]], device="cuda").to(dtype) + + ninetoothed_values, ninetoothed_indices = ntops.torch.mode(input) + reference_values, reference_indices = torch.mode(input) + + assert torch.equal(ninetoothed_values, reference_values) + assert torch.equal(ninetoothed_indices, reference_indices) + + def test_integer_mode(self): + input = torch.tensor([[1, 2, 2, 3], [4, 4, 4, 5]], device="cuda") + + ninetoothed_values, ninetoothed_indices = ntops.torch.mode(input, dim=1) + reference_values, reference_indices = torch.mode(input, dim=1) + + assert torch.equal(ninetoothed_values, reference_values) + assert torch.equal(ninetoothed_indices, reference_indices) diff --git a/tests/test_roll.py b/tests/test_roll.py new file mode 100644 index 0000000..bd198b1 --- /dev/null +++ b/tests/test_roll.py @@ -0,0 +1,56 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +class TestRoll: + @pytest.mark.parametrize("n", [1, 5, 10, 30]) + def test_1d(self, n): + input = torch.randn(n, device="cuda") + + for shift in [0, 1, -1, n // 2, n - 1]: + ninetoothed_output = ntops.torch.roll(input, shift) + reference_output = torch.roll(input, shift) + + assert torch.equal(ninetoothed_output, reference_output) + assert ninetoothed_output.shape == input.shape + + @pytest.mark.parametrize("shape", [(3, 5), (5, 3), (2, 4, 6)]) + def test_2d_3d(self, shape): + input = torch.randn(*shape, device="cuda") + + for shift, dims in [(1, 0), (2, 1), (1, -1), (2, -2)]: + if abs(dims) < input.ndim: + ninetoothed_output = ntops.torch.roll(input, shift, dims) + reference_output = torch.roll(input, shift, dims) + + assert torch.equal(ninetoothed_output, reference_output) + + def test_multi_dim_shifts(self): + input = torch.randn(4, 5, 6, device="cuda") + + ninetoothed_output = ntops.torch.roll(input, (1, 2), (0, 1)) + reference_output = torch.roll(input, (1, 2), (0, 1)) + + assert torch.equal(ninetoothed_output, reference_output) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_dtype(self, dtype): + input = torch.randn(10, device="cuda").to(dtype) + + ninetoothed_output = ntops.torch.roll(input, 3) + reference_output = torch.roll(input, 3) + + assert torch.equal(ninetoothed_output, reference_output) + + def test_full_roll(self): + input = torch.randn(5, 5, device="cuda") + + ninetoothed_output = ntops.torch.roll(input, (5, 5), (0, 1)) + reference_output = torch.roll(input, (5, 5), (0, 1)) + + assert torch.equal(ninetoothed_output, reference_output) + assert torch.equal(ninetoothed_output, input)