Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
bitwise_or,
bmm,
clamp,
channel_shuffle,
conv2d,
cos,
div,
Expand All @@ -17,13 +18,15 @@
ge,
gelu,
gt,
im2col,
isinf,
isnan,
layer_norm,
le,
lt,
max_pool2d,
mm,
moveaxis,
mul,
ne,
neg,
Expand All @@ -39,6 +42,8 @@
softmax,
sub,
tanh,
tensor_split,
unflatten,
)

__all__ = [
Expand All @@ -50,6 +55,7 @@
"bitwise_not",
"bitwise_or",
"bmm",
"channel_shuffle",
"clamp",
"conv2d",
"cos",
Expand All @@ -60,13 +66,15 @@
"ge",
"gelu",
"gt",
"im2col",
"isinf",
"isnan",
"layer_norm",
"le",
"lt",
"max_pool2d",
"mm",
"moveaxis",
"mul",
"ne",
"neg",
Expand All @@ -82,4 +90,6 @@
"softmax",
"sub",
"tanh",
"tensor_split",
"unflatten",
]
32 changes: 32 additions & 0 deletions src/ntops/kernels/channel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import functools

import ninetoothed
from ninetoothed import Tensor


def arrangement(input, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.flatten().tile((block_size,))
output = output.flatten().tile((block_size,))

return input, output


def application(input, output):
output = input


def premake(dtype=None, block_size=None):
arrangement_ = functools.partial(
arrangement,
block_size=block_size,
)

tensors = (
Tensor(5, dtype=dtype),
Tensor(4, dtype=dtype),
)

return arrangement_, application, tensors
140 changes: 140 additions & 0 deletions src/ntops/kernels/im2col.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import functools

from ninetoothed import Symbol, Tensor


BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32


def arrangement(
input,
output,
kernel_size_h=None,
kernel_size_w=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
ceil_mode=None,
block_size_m=None,
block_size_n=None,
):
if kernel_size_h is None:
kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=8)

if kernel_size_w is None:
kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=8)

if stride_h is None:
stride_h = Symbol("stride_h", constexpr=True)

if stride_w is None:
stride_w = Symbol("stride_w", constexpr=True)

if padding_h is None:
padding_h = Symbol("padding_h", constexpr=True)

if padding_w is None:
padding_w = Symbol("padding_w", constexpr=True)

if dilation_h is None:
dilation_h = Symbol("dilation_h", constexpr=True)

if dilation_w is None:
dilation_w = Symbol("dilation_w", constexpr=True)

if ceil_mode is None:
ceil_mode = False

if block_size_m is None:
block_size_m = BLOCK_SIZE_M

if block_size_n is None:
block_size_n = BLOCK_SIZE_N
input_arranged = input.pad(
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
)

input_arranged = input_arranged.tile(
(1, input.shape[1], kernel_size_h, kernel_size_w),
strides=(-1, -1, stride_h, stride_w),
dilation=(1, 1, dilation_h, dilation_w),
floor_mode=not ceil_mode,
)

input_arranged = input_arranged.squeeze(1)
input_arranged.dtype = input_arranged.dtype.squeeze(0)

input_arranged = input_arranged.ravel()
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)

input_arranged = input_arranged.tile((block_size_m, block_size_n))

# output: (N * OH * OW, C * KH * KW)
output_arranged = output.tile((block_size_m, block_size_n))

return input_arranged, output_arranged


def application(input, output):
output = input


def premake(
kernel_size_h=None,
kernel_size_w=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
ceil_mode=None,
dtype=None,
block_size_m=None,
block_size_n=None,
):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
padding_h=padding_h,
padding_w=padding_w,
dilation_h=dilation_h,
dilation_w=dilation_w,
ceil_mode=ceil_mode,
block_size_m=block_size_m,
block_size_n=block_size_n,
)

input = Tensor(
4,
dtype=dtype,
shape_options=(
{"upper_bound": 16}, # N
{"upper_bound": 64}, # C
{"upper_bound": 256}, # H
{"upper_bound": 256}, # W
),
)

output = Tensor(
2,
dtype=dtype,
shape_options=(
{"upper_bound": 65536}, # N * OH * OW
{"upper_bound": 1024}, # C * KH * KW
),
)

tensors = (
input,
output,
)

return arrangement_, application, tensors
37 changes: 37 additions & 0 deletions src/ntops/kernels/moveaxis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import functools

import ninetoothed
from ninetoothed import Tensor


def arrangement(input, output, permutation, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

assert input.ndim == output.ndim

input_arranged = input.permute(permutation)
input_arranged = input_arranged.flatten().tile((block_size,))

output_arranged = output.flatten().tile((block_size,))

return input_arranged, output_arranged


def application(input, output):
output = input


def premake(ndim, permutation, dtype=None, block_size=None):
arrangement_ = functools.partial(
arrangement,
permutation=permutation,
block_size=block_size,
)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
29 changes: 29 additions & 0 deletions src/ntops/kernels/tensor_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import ninetoothed
from ninetoothed import Tensor


def arrangement(input, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

assert input.ndim == output.ndim

input = input.flatten().tile((block_size,))
output = output.flatten().tile((block_size,))

return input, output


def application(input, output):
output = input


def premake(ndim):
return (
arrangement,
application,
(
Tensor(ndim),
Tensor(ndim),
),
)
29 changes: 29 additions & 0 deletions src/ntops/kernels/unflatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import functools

import ninetoothed
from ninetoothed import Tensor


def arrangement(input, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.flatten().tile((block_size,))
output = output.flatten().tile((block_size,))

return input, output


def application(input, output):
output = input


def premake(input_ndim, output_ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(input_ndim, dtype=dtype),
Tensor(output_ndim, dtype=dtype),
)

return arrangement_, application, tensors
Loading