diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..097ba69 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -14,6 +14,9 @@ dropout, eq, exp, + frac, + fractional_max_pool2d, + fractional_max_pool3d, ge, gelu, gt, @@ -25,6 +28,7 @@ max_pool2d, mm, mul, + multilabel_margin_loss, ne, neg, pow, @@ -32,6 +36,7 @@ rms_norm, rotary_position_embedding, rsqrt, + scatter_add, scaled_dot_product_attention, sigmoid, silu, @@ -57,6 +62,9 @@ "dropout", "eq", "exp", + "frac", + "fractional_max_pool2d", + "fractional_max_pool3d", "ge", "gelu", "gt", @@ -68,6 +76,7 @@ "max_pool2d", "mm", "mul", + "multilabel_margin_loss", "ne", "neg", "pow", @@ -75,6 +84,7 @@ "rms_norm", "rotary_position_embedding", "rsqrt", + "scatter_add", "scaled_dot_product_attention", "sigmoid", "silu", diff --git a/src/ntops/kernels/frac.py b/src/ntops/kernels/frac.py new file mode 100644 index 0000000..5e3e78e --- /dev/null +++ b/src/ntops/kernels/frac.py @@ -0,0 +1,24 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + fp32_input = (input * 2.0) / 2.0 + floor_val = ntl.math.floor(fp32_input) + ceil_val = ntl.math.ceil(fp32_input) + + # 获取截断后的整数部分 + trunc_val = ntl.where(fp32_input >= 0, floor_val, ceil_val) + output = input - trunc_val # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/fractional_max_pool2d.py b/src/ntops/kernels/fractional_max_pool2d.py new file mode 100644 index 0000000..95d448f --- /dev/null +++ b/src/ntops/kernels/fractional_max_pool2d.py @@ -0,0 +1,326 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def _arrange_output(output, block_size): + arranged = output.tile((1, 1, 1, 1)) + arranged = arranged.ravel() + arranged = arranged.flatten(end_dim=4).flatten(start_dim=1) + arranged = arranged.tile((block_size, -1)) + arranged.dtype = arranged.dtype.squeeze(1) + + return arranged + + +def _arrange_sequence(sequence, input, block_size): + arranged = sequence.tile((1, 1, 1, 1)) + arranged = arranged.ravel() + arranged = arranged.flatten(end_dim=4).flatten(start_dim=1) + arranged = arranged.expand((-1, input.shape[-2] * input.shape[-1])) + arranged = arranged.tile((block_size, -1)) + + return arranged + + +def arrangement( + input, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + block_size=1, +): + if block_size is None: + block_size = 1 + + input_arranged = input.tile((1, 1, -1, -1)) + input_arranged = input_arranged.expand( + (-1, -1, output.shape[-2], output.shape[-1]) + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + sequence_h_start_arranged = _arrange_sequence( + sequence_h_start, + input, + block_size, + ) + sequence_h_end_arranged = _arrange_sequence( + sequence_h_end, + input, + block_size, + ) + sequence_w_start_arranged = _arrange_sequence( + sequence_w_start, + input, + block_size, + ) + sequence_w_end_arranged = _arrange_sequence( + sequence_w_end, + input, + block_size, + ) + + output_arranged = _arrange_output(output, block_size) + + return ( + input_arranged, + sequence_h_start_arranged, + sequence_h_end_arranged, + sequence_w_start_arranged, + sequence_w_end_arranged, + output_arranged, + ) + + +def application( + input, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, +): + h_offsets = input.offsets(2) + w_offsets = input.offsets(3) + + mask = ( + (h_offsets >= sequence_h_start) + & (h_offsets < sequence_h_end) + & (w_offsets >= sequence_w_start) + & (w_offsets < sequence_w_end) + ) + + neg_large = input * 0 - 65504.0 + masked_input = ntl.where(mask, input, neg_large) + + output = ntl.max(masked_input, axis=-1) # noqa: F841 + + +def arrangement_deterministic( + input, + input_h, + input_w, + output_h, + output_w, + kernel_h, + kernel_w, + output, + block_size=1, +): + if block_size is None: + block_size = 1 + + input_arranged = input.tile((1, 1, -1, -1)) + input_arranged = input_arranged.expand( + (-1, -1, output.shape[-2], output.shape[-1]) + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = _arrange_output(output, block_size) + + return ( + input_arranged, + input_h, + input_w, + output_h, + output_w, + kernel_h, + kernel_w, + output_arranged, + ) + + +def application_deterministic( + input, + input_h, + input_w, + output_h, + output_w, + kernel_h, + kernel_w, + output, +): + h_offsets = input.offsets(2) + w_offsets = input.offsets(3) + + oh_offsets = output.offsets(2) + ow_offsets = output.offsets(3) + + h_start = ( + oh_offsets * (input_h - kernel_h) + ) // (output_h - 1) + h_start = ntl.where( + oh_offsets == output_h - 1, + input_h - kernel_h, + h_start, + ) + + w_start = ( + ow_offsets * (input_w - kernel_w) + ) // (output_w - 1) + w_start = ntl.where( + ow_offsets == output_w - 1, + input_w - kernel_w, + w_start, + ) + + h_end = h_start + kernel_h + w_end = w_start + kernel_w + + mask = ( + (h_offsets >= h_start) + & (h_offsets < h_end) + & (w_offsets >= w_start) + & (w_offsets < w_end) + ) + + neg_large = input * 0 - 65504.0 + masked_input = ntl.where(mask, input, neg_large) + + output = ntl.max(masked_input, axis=-1) # noqa: F841 + + +def _make_sequence_tensor( + output_h_upper_bound, + output_w_upper_bound, +): + return Tensor( + 4, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + +def premake( + dtype=None, + block_size=1, + input_h_upper_bound=128, + input_w_upper_bound=128, + output_h_upper_bound=128, + output_w_upper_bound=128, +): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + input = Tensor( + 4, + dtype=dtype, + other=-65504.0, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": input_h_upper_bound}, + {"constexpr": True, "upper_bound": input_w_upper_bound}, + ), + ) + + sequence_h_start = _make_sequence_tensor( + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_h_end = _make_sequence_tensor( + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_w_start = _make_sequence_tensor( + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_w_end = _make_sequence_tensor( + output_h_upper_bound, + output_w_upper_bound, + ) + + output = Tensor( + 4, + dtype=dtype, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + tensors = ( + input, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + ) + + return arrangement_, application, tensors + + +def premake_deterministic( + dtype=None, + block_size=1, + input_h_upper_bound=128, + input_w_upper_bound=128, + output_h_upper_bound=128, + output_w_upper_bound=128, + kernel_h=1, + kernel_w=1, +): + arrangement_ = functools.partial( + arrangement_deterministic, + block_size=block_size, + ) + + input = Tensor( + 4, + dtype=dtype, + other=-65504.0, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": input_h_upper_bound}, + {"constexpr": True, "upper_bound": input_w_upper_bound}, + ), + ) + + input_h_tensor = Tensor(0) + input_w_tensor = Tensor(0) + output_h_tensor = Tensor(0) + output_w_tensor = Tensor(0) + kernel_h_tensor = Tensor(0) + kernel_w_tensor = Tensor(0) + + output = Tensor( + 4, + dtype=dtype, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + tensors = ( + input, + input_h_tensor, + input_w_tensor, + output_h_tensor, + output_w_tensor, + kernel_h_tensor, + kernel_w_tensor, + output, + ) + + return arrangement_, application_deterministic, tensors \ No newline at end of file diff --git a/src/ntops/kernels/fractional_max_pool3d.py b/src/ntops/kernels/fractional_max_pool3d.py new file mode 100644 index 0000000..61d163d --- /dev/null +++ b/src/ntops/kernels/fractional_max_pool3d.py @@ -0,0 +1,403 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def _arrange_output(output, block_size): + arranged = output.tile((1, 1, 1, 1, 1)) + arranged = arranged.ravel() + arranged = arranged.flatten(end_dim=5).flatten(start_dim=1) + arranged = arranged.tile((block_size, -1)) + arranged.dtype = arranged.dtype.squeeze(1) + + return arranged + + +def _arrange_sequence(sequence, input, block_size): + arranged = sequence.tile((1, 1, 1, 1, 1)) + arranged = arranged.ravel() + arranged = arranged.flatten(end_dim=5).flatten(start_dim=1) + arranged = arranged.expand( + (-1, input.shape[-3] * input.shape[-2] * input.shape[-1]) + ) + arranged = arranged.tile((block_size, -1)) + + return arranged + + +def arrangement( + input, + sequence_d_start, + sequence_d_end, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + block_size=1, +): + if block_size is None: + block_size = 1 + + input_arranged = input.tile((1, 1, -1, -1, -1)) + input_arranged = input_arranged.expand( + (-1, -1, output.shape[-3], output.shape[-2], output.shape[-1]) + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + sequence_d_start_arranged = _arrange_sequence( + sequence_d_start, + input, + block_size, + ) + sequence_d_end_arranged = _arrange_sequence( + sequence_d_end, + input, + block_size, + ) + sequence_h_start_arranged = _arrange_sequence( + sequence_h_start, + input, + block_size, + ) + sequence_h_end_arranged = _arrange_sequence( + sequence_h_end, + input, + block_size, + ) + sequence_w_start_arranged = _arrange_sequence( + sequence_w_start, + input, + block_size, + ) + sequence_w_end_arranged = _arrange_sequence( + sequence_w_end, + input, + block_size, + ) + + output_arranged = _arrange_output(output, block_size) + + return ( + input_arranged, + sequence_d_start_arranged, + sequence_d_end_arranged, + sequence_h_start_arranged, + sequence_h_end_arranged, + sequence_w_start_arranged, + sequence_w_end_arranged, + output_arranged, + ) + + +def application( + input, + sequence_d_start, + sequence_d_end, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, +): + d_offsets = input.offsets(2) + h_offsets = input.offsets(3) + w_offsets = input.offsets(4) + + mask = ( + (d_offsets >= sequence_d_start) + & (d_offsets < sequence_d_end) + & (h_offsets >= sequence_h_start) + & (h_offsets < sequence_h_end) + & (w_offsets >= sequence_w_start) + & (w_offsets < sequence_w_end) + ) + + neg_large = input * 0 - 65504.0 + masked_input = ntl.where(mask, input, neg_large) + + output = ntl.max(masked_input, axis=-1) # noqa: F841 + + +def arrangement_deterministic( + input, + input_d, + input_h, + input_w, + output_d, + output_h, + output_w, + kernel_d, + kernel_h, + kernel_w, + output, + block_size=1, +): + if block_size is None: + block_size = 1 + + input_arranged = input.tile((1, 1, -1, -1, -1)) + input_arranged = input_arranged.expand( + (-1, -1, output.shape[-3], output.shape[-2], output.shape[-1]) + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = _arrange_output(output, block_size) + + return ( + input_arranged, + input_d, + input_h, + input_w, + output_d, + output_h, + output_w, + kernel_d, + kernel_h, + kernel_w, + output_arranged, + ) + + +def application_deterministic( + input, + input_d, + input_h, + input_w, + output_d, + output_h, + output_w, + kernel_d, + kernel_h, + kernel_w, + output, +): + d_offsets = input.offsets(2) + h_offsets = input.offsets(3) + w_offsets = input.offsets(4) + + od_offsets = output.offsets(2) + oh_offsets = output.offsets(3) + ow_offsets = output.offsets(4) + + d_start = ( + od_offsets * (input_d - kernel_d) + ) // (output_d - 1) + d_start = ntl.where( + od_offsets == output_d - 1, + input_d - kernel_d, + d_start, + ) + + h_start = ( + oh_offsets * (input_h - kernel_h) + ) // (output_h - 1) + h_start = ntl.where( + oh_offsets == output_h - 1, + input_h - kernel_h, + h_start, + ) + + w_start = ( + ow_offsets * (input_w - kernel_w) + ) // (output_w - 1) + w_start = ntl.where( + ow_offsets == output_w - 1, + input_w - kernel_w, + w_start, + ) + + d_end = d_start + kernel_d + h_end = h_start + kernel_h + w_end = w_start + kernel_w + + mask = ( + (d_offsets >= d_start) + & (d_offsets < d_end) + & (h_offsets >= h_start) + & (h_offsets < h_end) + & (w_offsets >= w_start) + & (w_offsets < w_end) + ) + + neg_large = input * 0 - 65504.0 + masked_input = ntl.where(mask, input, neg_large) + + output = ntl.max(masked_input, axis=-1) # noqa: F841 + + +def _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, +): + return Tensor( + 5, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_d_upper_bound}, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + +def premake( + dtype=None, + block_size=1, + input_d_upper_bound=128, + input_h_upper_bound=128, + input_w_upper_bound=128, + output_d_upper_bound=128, + output_h_upper_bound=128, + output_w_upper_bound=128, +): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + input = Tensor( + 5, + dtype=dtype, + other=-65504.0, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": input_d_upper_bound}, + {"constexpr": True, "upper_bound": input_h_upper_bound}, + {"constexpr": True, "upper_bound": input_w_upper_bound}, + ), + ) + + sequence_d_start = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_d_end = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_h_start = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_h_end = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_w_start = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + sequence_w_end = _make_sequence_tensor( + output_d_upper_bound, + output_h_upper_bound, + output_w_upper_bound, + ) + + output = Tensor( + 5, + dtype=dtype, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_d_upper_bound}, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + tensors = ( + input, + sequence_d_start, + sequence_d_end, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + ) + + return arrangement_, application, tensors + + +def premake_deterministic( + dtype=None, + block_size=1, + input_d_upper_bound=128, + input_h_upper_bound=128, + input_w_upper_bound=128, + output_d_upper_bound=128, + output_h_upper_bound=128, + output_w_upper_bound=128, + kernel_d=1, + kernel_h=1, + kernel_w=1, +): + arrangement_ = functools.partial( + arrangement_deterministic, + block_size=block_size, + ) + + input = Tensor( + 5, + dtype=dtype, + other=-65504.0, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": input_d_upper_bound}, + {"constexpr": True, "upper_bound": input_h_upper_bound}, + {"constexpr": True, "upper_bound": input_w_upper_bound}, + ), + ) + + input_d_tensor = Tensor(0) + input_h_tensor = Tensor(0) + input_w_tensor = Tensor(0) + output_d_tensor = Tensor(0) + output_h_tensor = Tensor(0) + output_w_tensor = Tensor(0) + kernel_d_tensor = Tensor(0) + kernel_h_tensor = Tensor(0) + kernel_w_tensor = Tensor(0) + + output = Tensor( + 5, + dtype=dtype, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": output_d_upper_bound}, + {"constexpr": True, "upper_bound": output_h_upper_bound}, + {"constexpr": True, "upper_bound": output_w_upper_bound}, + ), + ) + + tensors = ( + input, + input_d_tensor, + input_h_tensor, + input_w_tensor, + output_d_tensor, + output_h_tensor, + output_w_tensor, + kernel_d_tensor, + kernel_h_tensor, + kernel_w_tensor, + output, + ) + + return arrangement_, application_deterministic, tensors \ No newline at end of file diff --git a/src/ntops/kernels/multilabel_margin_loss.py b/src/ntops/kernels/multilabel_margin_loss.py new file mode 100644 index 0000000..f469a01 --- /dev/null +++ b/src/ntops/kernels/multilabel_margin_loss.py @@ -0,0 +1,211 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +REDUCTION_MEAN = 1 +REDUCTION_SUM = 2 + + +def arrangement(input, target, output, block_size=None): + input_arranged, target_arranged = reduction_arrangement( + input, + target, + dim=-1, + block_size=1, + ) + + output_arranged = output.tile((1,)) + + return input_arranged, target_arranged, output_arranged + + +def application(input, target, output): + dtype = output.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + + class_size = input.shape[0] + + zero = ntl.cast(input[0] * 0, acc_dtype) + one = zero + ntl.cast(1, acc_dtype) + + zero_t = target[0] * 0 + + true = zero_t == zero_t + false = zero_t != zero_t + + loss = zero + alive_j = true + + for j in range(input.shape[0]): + y_j = target[j] + + y_nonneg = y_j >= zero_t + j_active = alive_j & y_nonneg + + x_y = zero + + # x_y = input[y_j] + for c in range(input.shape[0]): + c_t = zero_t + c + x_c = ntl.cast(input[c], acc_dtype) + x_y = ntl.where(y_j == c_t, x_c, x_y) + + for i in range(input.shape[0]): + i_t = zero_t + i + x_i = ntl.cast(input[i], acc_dtype) + + alive_q = true + is_positive_i = false + + for q in range(input.shape[0]): + t_q = target[q] + + q_nonneg = t_q >= zero_t + q_active = alive_q & q_nonneg + + is_positive_i = is_positive_i | (q_active & (t_q == i_t)) + alive_q = alive_q & q_nonneg + + margin = one - x_y + x_i + + term = ntl.where(margin > zero, margin, zero) + term = ntl.where(j_active, term, zero) + term = ntl.where(is_positive_i, zero, term) + + loss += term + + alive_j = alive_j & y_nonneg + + inv_c = ntl.cast(1.0 / class_size, acc_dtype) + loss = loss * inv_c + + output = ntl.cast(ntl.sum(loss), dtype) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + assert ndim == 2, "`multilabel_margin_loss` kernel only supports 2D input [N, C]." + + arrangement_ = functools.partial( + arrangement, + block_size=1, + ) + + input = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + target = Tensor( + ndim, + dtype=ninetoothed.int64, + shape_options={"constexpr": True}, + ) + + output = Tensor( + ndim - 1, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + tensors = ( + input, + target, + output, + ) + + return arrangement_, application, tensors + + +def reduce_arrangement(input, output, inv_numel=None, block_size=None): + input_arranged = reduction_arrangement( + input, + dim=tuple(range(input.ndim)), + block_size=block_size, + )[0] + + output_arranged = output.tile((1,)) + + if inv_numel is None: + return input_arranged, output_arranged + + return input_arranged, output_arranged, inv_numel + + +def application_reduce_sum(input, output): + dtype = output.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + + acc = ntl.cast(0, acc_dtype) + + for i in range(input.shape[0]): + acc += ntl.sum(ntl.cast(input[i], acc_dtype)) + + output = ntl.cast(acc, dtype) # noqa: F841 + + +def application_reduce_mean(input, output, inv_numel): + dtype = output.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + + acc = ntl.cast(0, acc_dtype) + + for i in range(input.shape[0]): + acc += ntl.sum(ntl.cast(input[i], acc_dtype)) + + output = ntl.cast(acc * inv_numel, dtype) # noqa: F841 + + +def premake_reduce( + ndim, + reduction=REDUCTION_MEAN, + dtype=None, + block_size=None, +): + assert reduction in ( + REDUCTION_MEAN, + REDUCTION_SUM, + ), "`reduction` must be REDUCTION_MEAN or REDUCTION_SUM." + + arrangement_ = functools.partial( + reduce_arrangement, + block_size=block_size, + ) + + input = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + output = Tensor( + 1, + dtype=dtype, + shape_options=( + {"constexpr": True, "upper_bound": 1}, + ), + ) + output.shape = (1,) + + if reduction == REDUCTION_SUM: + tensors = ( + input, + output, + ) + + return arrangement_, application_reduce_sum, tensors + + inv_numel = Tensor(0, dtype=ninetoothed.float64) + + tensors = ( + input, + output, + inv_numel, + ) + + return arrangement_, application_reduce_mean, tensors \ No newline at end of file diff --git a/src/ntops/kernels/scatter_add.py b/src/ntops/kernels/scatter_add.py new file mode 100644 index 0000000..508f99f --- /dev/null +++ b/src/ntops/kernels/scatter_add.py @@ -0,0 +1,82 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, index, src, output): + dtype = output.dtype.dtype + acc_dtype = ntl.float32 if dtype == ntl.float16 else dtype + index_dtype = index.dtype.dtype + + # block_size=1 时,input[0] 是 [1] 向量 + # 所以 zero 也必须初始化成 [1] 向量,不能用 scalar 0 + zero = ntl.cast(input[0] * 0, acc_dtype) + + # input.shape[0] 是 scatter_add 的 dim_size + for out_i in range(input.shape[0]): + out_i_t = ntl.cast(out_i, index_dtype) + + # scatter_add 语义: + # output 先等于 input,再把 src 按 index 累加进去 + acc = ntl.cast(input[out_i], acc_dtype) + + for src_i in range(input.shape[0]): + idx = index[src_i] + val = ntl.cast(src[src_i], acc_dtype) + + add_val = ntl.where( + idx == out_i_t, + val, + zero, + ) + + acc += add_val + + output[out_i] = ntl.cast(acc, dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + # scatter_add 沿 dim 做一维 scatter + # 这里强制 block_size=1,避免 index / src 的 lane 类型不一致 + arrangement_ = functools.partial( + arrangement, + dim=dim, + block_size=1, + ) + + input = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + index = Tensor( + ndim, + dtype=ninetoothed.int64, + shape_options={"constexpr": True}, + ) + + src = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + output = Tensor( + ndim, + dtype=dtype, + shape_options={"constexpr": True}, + ) + + tensors = ( + input, + index, + src, + output, + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..ca6aeba 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -13,6 +13,10 @@ from ntops.torch.dropout import dropout from ntops.torch.eq import eq from ntops.torch.exp import exp +from ntops.torch.frac import frac +from ntops.torch.fractional_max_pool2d import fractional_max_pool2d +from ntops.torch.fractional_max_pool3d import fractional_max_pool3d + from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt @@ -24,7 +28,8 @@ from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm -from ntops.torch.mul import mul +from ntops.torch.mul import mul +from ntops.torch.multilabel_margin_loss import multilabel_margin_loss from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow @@ -36,6 +41,7 @@ from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin +from ntops.torch.scatter_add import scatter_add from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh @@ -56,6 +62,9 @@ "dropout", "eq", "exp", + "frac", + "fractional_max_pool2d", + "fractional_max_pool3d", "ge", "gelu", "gt", @@ -68,6 +77,7 @@ "max_pool2d", "mm", "mul", + "multilabel_margin_loss", "ne", "neg", "pow", @@ -76,6 +86,7 @@ "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", + "scatter_add", "sigmoid", "silu", "sin", diff --git a/src/ntops/torch/frac.py b/src/ntops/torch/frac.py new file mode 100644 index 0000000..1c6faaf --- /dev/null +++ b/src/ntops/torch/frac.py @@ -0,0 +1,13 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def frac(input, *, out=None): + if out is None: + out = torch.empty_like(input) + kernel = _cached_make(ntops.kernels.frac.premake, input.ndim) + kernel(input, out) + + return out \ No newline at end of file diff --git a/src/ntops/torch/fractional_max_pool2d.py b/src/ntops/torch/fractional_max_pool2d.py new file mode 100644 index 0000000..346908d --- /dev/null +++ b/src/ntops/torch/fractional_max_pool2d.py @@ -0,0 +1,289 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +_CAST_EXCEPTIONS = (TypeError, RuntimeError, AttributeError, NotImplementedError) + + +def _pair(value): + if isinstance(value, int): + return (value, value) + + return value + + +def _shape_tuple(x): + return tuple(int(i) for i in x.shape) + + +def _calculate_fractional_output_size( + input_size, + output_size=None, + output_ratio=None, +): + assert output_size is not None or output_ratio is not None, ( + "Either `output_size` or `output_ratio` must be specified." + ) + assert output_size is None or output_ratio is None, ( + "`output_size` and `output_ratio` cannot both be specified." + ) + + if output_size is not None: + return output_size + + return int(input_size * output_ratio) + + +def _empty_like_input(input, shape): + empty_fn = getattr(torch, "empty", None) + + if callable(empty_fn): + try: + return empty_fn( + shape, + dtype=input.dtype, + device=input.device, + ) + except _CAST_EXCEPTIONS: + return _empty_like_input_by_zeros(input, shape) + + return _empty_like_input_by_zeros(input, shape) + + +def _empty_like_input_by_zeros(input, shape): + zeros_fn = getattr(torch, "zeros", None) + + if callable(zeros_fn): + try: + return zeros_fn( + shape, + dtype=input.dtype, + device=input.device, + ) + except _CAST_EXCEPTIONS as exc: + raise RuntimeError( + "Cannot create output tensor for fractional_max_pool2d." + ) from exc + + raise RuntimeError( + "Cannot create output tensor for fractional_max_pool2d." + ) + + +def _make_fractional_sequence( + input_size, + output_size, + kernel_size, + random_samples, +): + if output_size == 1: + return torch.full( + random_samples.shape + (1,), + input_size - kernel_size, + dtype=torch.int64, + device=random_samples.device, + ) + + alpha = float(input_size - kernel_size) / float(output_size - 1) + + index = torch.arange( + output_size, + dtype=torch.float32, + device=random_samples.device, + ) + + sample = random_samples.float() + + sequence = torch.floor( + (index + sample[..., None]) * alpha + ) - torch.floor( + sample[..., None] * alpha + ) + + sequence = sequence.long() + sequence[..., -1] = input_size - kernel_size + + return sequence + + +def _fractional_max_pool2d_with_random_samples( + input, + kernel_size, + h_out, + w_out, + _random_samples, +): + n, c, h, w = _shape_tuple(input) + + random_samples_shape = _shape_tuple(_random_samples) + expected_random_samples_shape = (n, c, 2) + + assert random_samples_shape == expected_random_samples_shape, ( + "`_random_samples` must have shape `(N, C, 2)`, " + f"got {random_samples_shape}." + ) + + sequence_w_start = _make_fractional_sequence( + w, + w_out, + kernel_size[1], + _random_samples[..., 0], + ) + sequence_h_start = _make_fractional_sequence( + h, + h_out, + kernel_size[0], + _random_samples[..., 1], + ) + + sequence_h_end = sequence_h_start + kernel_size[0] + sequence_w_end = sequence_w_start + kernel_size[1] + + sequence_h_start = sequence_h_start[..., None].expand( + -1, + -1, + -1, + w_out, + ).contiguous() + sequence_h_end = sequence_h_end[..., None].expand( + -1, + -1, + -1, + w_out, + ).contiguous() + + sequence_w_start = sequence_w_start[:, :, None, :].expand( + -1, + -1, + h_out, + -1, + ).contiguous() + sequence_w_end = sequence_w_end[:, :, None, :].expand( + -1, + -1, + h_out, + -1, + ).contiguous() + + output = _empty_like_input( + input, + (n, c, h_out, w_out), + ) + + kernel = _cached_make( + ntops.kernels.fractional_max_pool2d.premake, + block_size=1, + input_h_upper_bound=h, + input_w_upper_bound=w, + output_h_upper_bound=h_out, + output_w_upper_bound=w_out, + ) + + kernel( + input, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + ) + + return output + + +def _fractional_max_pool2d_deterministic( + input, + kernel_size, + h_out, + w_out, +): + n, c, h, w = _shape_tuple(input) + + output = _empty_like_input( + input, + (n, c, h_out, w_out), + ) + + kernel = _cached_make( + ntops.kernels.fractional_max_pool2d.premake_deterministic, + block_size=1, + input_h_upper_bound=h, + input_w_upper_bound=w, + output_h_upper_bound=h_out, + output_w_upper_bound=w_out, + kernel_h=kernel_size[0], + kernel_w=kernel_size[1], + ) + + kernel( + input, + h, + w, + h_out, + w_out, + kernel_size[0], + kernel_size[1], + output, + ) + + return output + + +def fractional_max_pool2d( + input, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, +): + assert input.ndim == 4, "`fractional_max_pool2d` only supports 4D input for now." + assert not return_indices, "`return_indices` is not supported yet." + + kernel_size = _pair(kernel_size) + + if output_size is not None: + output_size = _pair(output_size) + + if output_ratio is not None: + output_ratio = _pair(output_ratio) + + n, c, h, w = _shape_tuple(input) + + h_out = _calculate_fractional_output_size( + h, + output_size=None if output_size is None else output_size[0], + output_ratio=None if output_ratio is None else output_ratio[0], + ) + w_out = _calculate_fractional_output_size( + w, + output_size=None if output_size is None else output_size[1], + output_ratio=None if output_ratio is None else output_ratio[1], + ) + + assert h_out > 0 and w_out > 0, "`output_size` must be positive." + + assert h_out + kernel_size[0] - 1 <= h, ( + "`output_size[0] + kernel_size[0] - 1` must be no greater than input height." + ) + assert w_out + kernel_size[1] - 1 <= w, ( + "`output_size[1] + kernel_size[1] - 1` must be no greater than input width." + ) + + if _random_samples is not None: + return _fractional_max_pool2d_with_random_samples( + input, + kernel_size, + h_out, + w_out, + _random_samples, + ) + + return _fractional_max_pool2d_deterministic( + input, + kernel_size, + h_out, + w_out, + ) \ No newline at end of file diff --git a/src/ntops/torch/fractional_max_pool3d.py b/src/ntops/torch/fractional_max_pool3d.py new file mode 100644 index 0000000..5b24df5 --- /dev/null +++ b/src/ntops/torch/fractional_max_pool3d.py @@ -0,0 +1,333 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +_CAST_EXCEPTIONS = (TypeError, RuntimeError, AttributeError, NotImplementedError) + + +def _triple(value): + if isinstance(value, int): + return (value, value, value) + + return value + + +def _calculate_fractional_output_size( + input_size, + output_size=None, + output_ratio=None, +): + assert output_size is not None or output_ratio is not None, ( + "Either `output_size` or `output_ratio` must be specified." + ) + assert output_size is None or output_ratio is None, ( + "`output_size` and `output_ratio` cannot both be specified." + ) + + if output_size is not None: + return output_size + + return int(input_size * output_ratio) + + +def _empty_like_input(input, shape): + empty_fn = getattr(torch, "empty", None) + + if callable(empty_fn): + try: + return empty_fn( + shape, + dtype=input.dtype, + device=input.device, + ) + except _CAST_EXCEPTIONS: + return _empty_like_input_by_zeros(input, shape) + + return _empty_like_input_by_zeros(input, shape) + + +def _empty_like_input_by_zeros(input, shape): + zeros_fn = getattr(torch, "zeros", None) + + if callable(zeros_fn): + try: + return zeros_fn( + shape, + dtype=input.dtype, + device=input.device, + ) + except _CAST_EXCEPTIONS as exc: + raise RuntimeError( + "Cannot create output tensor for fractional_max_pool3d." + ) from exc + + raise RuntimeError( + "Cannot create output tensor for fractional_max_pool3d." + ) + + +def _make_fractional_sequence( + input_size, + output_size, + kernel_size, + random_samples, +): + if output_size == 1: + return torch.full( + random_samples.shape + (1,), + input_size - kernel_size, + dtype=torch.int64, + device=random_samples.device, + ) + + alpha = float(input_size - kernel_size) / float(output_size - 1) + + index = torch.arange( + output_size, + dtype=torch.float32, + device=random_samples.device, + ) + + sample = random_samples.float() + + sequence = torch.floor( + (index + sample[..., None]) * alpha + ) - torch.floor( + sample[..., None] * alpha + ) + + sequence = sequence.long() + sequence[..., -1] = input_size - kernel_size + + return sequence + + +def _fractional_max_pool3d_with_random_samples( + input, + kernel_size, + d_out, + h_out, + w_out, + _random_samples, +): + n, c, d, h, w = input.shape + + assert _random_samples.shape == (n, c, 3), ( + "`_random_samples` must have shape `(N, C, 3)`." + ) + + # PyTorch fractional_max_pool3d: + # samples[..., 0] -> D + # samples[..., 1] -> H + # samples[..., 2] -> W + sequence_d_start = _make_fractional_sequence( + d, + d_out, + kernel_size[0], + _random_samples[..., 0], + ) + sequence_h_start = _make_fractional_sequence( + h, + h_out, + kernel_size[1], + _random_samples[..., 1], + ) + sequence_w_start = _make_fractional_sequence( + w, + w_out, + kernel_size[2], + _random_samples[..., 2], + ) + + sequence_d_end = sequence_d_start + kernel_size[0] + sequence_h_end = sequence_h_start + kernel_size[1] + sequence_w_end = sequence_w_start + kernel_size[2] + + sequence_d_start = sequence_d_start[:, :, :, None, None].expand( + -1, + -1, + -1, + h_out, + w_out, + ).contiguous() + sequence_d_end = sequence_d_end[:, :, :, None, None].expand( + -1, + -1, + -1, + h_out, + w_out, + ).contiguous() + + sequence_h_start = sequence_h_start[:, :, None, :, None].expand( + -1, + -1, + d_out, + -1, + w_out, + ).contiguous() + sequence_h_end = sequence_h_end[:, :, None, :, None].expand( + -1, + -1, + d_out, + -1, + w_out, + ).contiguous() + + sequence_w_start = sequence_w_start[:, :, None, None, :].expand( + -1, + -1, + d_out, + h_out, + -1, + ).contiguous() + sequence_w_end = sequence_w_end[:, :, None, None, :].expand( + -1, + -1, + d_out, + h_out, + -1, + ).contiguous() + + output = _empty_like_input( + input, + (n, c, d_out, h_out, w_out), + ) + + kernel = _cached_make( + ntops.kernels.fractional_max_pool3d.premake, + block_size=1, + input_d_upper_bound=d, + input_h_upper_bound=h, + input_w_upper_bound=w, + output_d_upper_bound=d_out, + output_h_upper_bound=h_out, + output_w_upper_bound=w_out, + ) + + kernel( + input, + sequence_d_start, + sequence_d_end, + sequence_h_start, + sequence_h_end, + sequence_w_start, + sequence_w_end, + output, + ) + + return output + + +def _fractional_max_pool3d_deterministic( + input, + kernel_size, + d_out, + h_out, + w_out, +): + n, c, d, h, w = input.shape + + output = _empty_like_input( + input, + (n, c, d_out, h_out, w_out), + ) + + kernel = _cached_make( + ntops.kernels.fractional_max_pool3d.premake_deterministic, + block_size=1, + input_d_upper_bound=d, + input_h_upper_bound=h, + input_w_upper_bound=w, + output_d_upper_bound=d_out, + output_h_upper_bound=h_out, + output_w_upper_bound=w_out, + kernel_d=kernel_size[0], + kernel_h=kernel_size[1], + kernel_w=kernel_size[2], + ) + + kernel( + input, + d, + h, + w, + d_out, + h_out, + w_out, + kernel_size[0], + kernel_size[1], + kernel_size[2], + output, + ) + + return output + + +def fractional_max_pool3d( + input, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, +): + assert input.ndim == 5, "`fractional_max_pool3d` only supports 5D input for now." + assert not return_indices, "`return_indices` is not supported yet." + + kernel_size = _triple(kernel_size) + + if output_size is not None: + output_size = _triple(output_size) + + if output_ratio is not None: + output_ratio = _triple(output_ratio) + + n, c, d, h, w = input.shape + + d_out = _calculate_fractional_output_size( + d, + output_size=None if output_size is None else output_size[0], + output_ratio=None if output_ratio is None else output_ratio[0], + ) + h_out = _calculate_fractional_output_size( + h, + output_size=None if output_size is None else output_size[1], + output_ratio=None if output_ratio is None else output_ratio[1], + ) + w_out = _calculate_fractional_output_size( + w, + output_size=None if output_size is None else output_size[2], + output_ratio=None if output_ratio is None else output_ratio[2], + ) + + assert d_out > 0 and h_out > 0 and w_out > 0, "`output_size` must be positive." + + assert d_out + kernel_size[0] - 1 <= d, ( + "`output_size[0] + kernel_size[0] - 1` must be no greater than input depth." + ) + assert h_out + kernel_size[1] - 1 <= h, ( + "`output_size[1] + kernel_size[1] - 1` must be no greater than input height." + ) + assert w_out + kernel_size[2] - 1 <= w, ( + "`output_size[2] + kernel_size[2] - 1` must be no greater than input width." + ) + + if _random_samples is not None: + return _fractional_max_pool3d_with_random_samples( + input, + kernel_size, + d_out, + h_out, + w_out, + _random_samples, + ) + + return _fractional_max_pool3d_deterministic( + input, + kernel_size, + d_out, + h_out, + w_out, + ) \ No newline at end of file diff --git a/src/ntops/torch/multilabel_margin_loss.py b/src/ntops/torch/multilabel_margin_loss.py new file mode 100644 index 0000000..421b523 --- /dev/null +++ b/src/ntops/torch/multilabel_margin_loss.py @@ -0,0 +1,147 @@ +import torch +import torch.nn.functional as F + +import ntops +from ntops.torch.utils import _cached_make + + +_REDUCTION_NONE = 0 +_REDUCTION_MEAN = 1 +_REDUCTION_SUM = 2 + +_MAX_NTOPS_CLASS_SIZE = 32 + + +def _get_reduction_enum(reduction): + if reduction == "none": + return _REDUCTION_NONE + if reduction == "mean": + return _REDUCTION_MEAN + if reduction == "sum": + return _REDUCTION_SUM + raise ValueError("`reduction` must be one of 'none', 'mean', or 'sum'.") + + +def _as_scalar(output): + if hasattr(output, "reshape"): + return output.reshape(()) + if hasattr(output, "view"): + return output.view(()) + return output + + +def _is_real_torch_tensor(x): + return type(x).__module__.startswith("torch") + + +def _torch_fallback(input, target, reduction): + class_size = input.shape[-1] + + input_2d = input.reshape(-1, class_size) + target_2d = target.reshape(-1, class_size) + + output = F.multilabel_margin_loss( + input_2d, + target_2d, + reduction=reduction, + ) + + if reduction == "none": + return output.reshape(input.shape[:-1]) + + return output + + +def multilabel_margin_loss( + input, + target, + size_average=None, + reduce=None, + reduction="mean", +): + if size_average is not None or reduce is not None: + if reduce is False: + reduction = "none" + elif size_average is False: + reduction = "sum" + else: + reduction = "mean" + + reduction_enum = _get_reduction_enum(reduction) + + assert input.ndim >= 1, "`input` must have at least 1 dimension." + assert target.ndim == input.ndim, "`input` and `target` must have the same ndim." + assert target.shape == input.shape, "`input` and `target` must have the same shape." + + class_size = int(input.shape[-1]) + + # 普通 torch.Tensor 可以支持高维 reshape fallback + # InfiniCore tensor 没有 reshape,所以不能走这个分支 + if _is_real_torch_tensor(input) and input.ndim != 2: + return _torch_fallback( + input, + target, + reduction, + ) + + assert input.ndim == 2, ( + "`multilabel_margin_loss` ntops path for InfiniCore only supports 2D [N, C]." + ) + + if class_size > _MAX_NTOPS_CLASS_SIZE: + if _is_real_torch_tensor(input): + return _torch_fallback( + input, + target, + reduction, + ) + + raise NotImplementedError( + "class_size is too large for ntops multilabel_margin_loss on InfiniCore tensor." + ) + + per_sample = torch.empty( + tuple(input.shape[:-1]), + dtype=input.dtype, + device=input.device, + ) + + kernel = _cached_make( + ntops.kernels.multilabel_margin_loss.premake, + input.ndim, + ) + + kernel( + input, + target, + per_sample, + ) + + if reduction_enum == _REDUCTION_NONE: + return per_sample + + output = torch.empty( + (1,), + dtype=input.dtype, + device=input.device, + ) + + reduce_kernel = _cached_make( + ntops.kernels.multilabel_margin_loss.premake_reduce, + per_sample.ndim, + reduction_enum, + ) + + if reduction_enum == _REDUCTION_MEAN: + reduce_kernel( + per_sample, + output, + float(1.0 / per_sample.numel()), + ) + else: + reduce_kernel( + per_sample, + output, + ) + + return _as_scalar(output) \ No newline at end of file diff --git a/src/ntops/torch/scatter_add.py b/src/ntops/torch/scatter_add.py new file mode 100644 index 0000000..1f6c855 --- /dev/null +++ b/src/ntops/torch/scatter_add.py @@ -0,0 +1,77 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +_MAX_NTOPS_DIM_SIZE = 64 + + +def _normalize_dim(dim, ndim): + if dim < 0: + dim += ndim + + assert 0 <= dim < ndim, "`dim` out of range." + + return dim + + +def scatter_add(input, dim, index, src, *, out=None): + assert input.ndim == index.ndim == src.ndim, ( + "`input`, `index`, and `src` must have the same ndim." + ) + assert index.shape == src.shape, "`index` and `src` must have the same shape." + assert index.dtype == torch.int64, "`index` must be torch.int64." + assert src.dtype == input.dtype, "`src` and `input` must have the same dtype." + + dim = _normalize_dim(dim, input.ndim) + + if input.shape != index.shape: + result = torch.scatter_add( + input, + dim, + index, + src, + ) + + if out is not None: + out.copy_(result) + return out + + return result + + dim_size = int(input.shape[dim]) + + if dim_size > _MAX_NTOPS_DIM_SIZE: + result = torch.scatter_add( + input, + dim, + index, + src, + ) + + if out is not None: + out.copy_(result) + return out + + return result + + if out is None: + output = torch.empty_like(input) + else: + output = out + + kernel = _cached_make( + ntops.kernels.scatter_add.premake, + input.ndim, + dim, + ) + + kernel( + input, + index, + src, + output, + ) + + return output \ No newline at end of file diff --git a/tests/test_frac.py b/tests/test_frac.py new file mode 100644 index 0000000..362c54e --- /dev/null +++ b/tests/test_frac.py @@ -0,0 +1,15 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_frac(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) * 10 + ninetoothed_output = ntops.torch.frac(input) + reference_output = torch.frac(input) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_fractional_max_pool2d.py b/tests/test_fractional_max_pool2d.py new file mode 100644 index 0000000..95e3f48 --- /dev/null +++ b/tests/test_fractional_max_pool2d.py @@ -0,0 +1,199 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +_TEST_CASES_DATA = [ + ((2, 3, 15, 15), None, (3, 3), (5, 5), False), + ((1, 4, 16, 14), (896, 224, 14, 1), (4, 3), (4, 5), False), + ((2, 2, 17, 19), None, (5, 5), (7, 6), False), + ((3, 6, 9, 11), None, (2, 2), (4, 5), False), + ((1, 8, 20, 20), (3200, 400, 20, 1), (3, 3), (6, 6), False), + ((2, 5, 12, 10), None, (4, 3), (3, 3), False), +] + + +_TOLERANCE_MAP = { + torch.float16: {"atol": 1e-3, "rtol": 1e-2}, + torch.float32: {"atol": 1e-5, "rtol": 1e-4}, +} + + +@pytest.fixture +def device(): + assert torch.cuda.is_available() + return torch.empty(0).cuda().device + + +def _get_tolerance(dtype): + return _TOLERANCE_MAP[dtype] + + +def _make_input(shape, strides, dtype, device): + if strides is None: + return torch.randn( + shape, + dtype=dtype, + device=device, + ) + + storage_size = 1 + for size, stride in zip(shape, strides): + storage_size += (size - 1) * stride + + base = torch.randn( + (storage_size,), + dtype=dtype, + device=device, + ) + + return torch.as_strided( + base, + size=shape, + stride=strides, + ) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize( + "in_shape, in_strides, kernel_size, output_size, return_indices", + _TEST_CASES_DATA, +) +def test_fractional_max_pool2d_cases( + in_shape, + in_strides, + kernel_size, + output_size, + return_indices, + dtype, + device, +): + assert not return_indices, "return_indices is not supported by ntops yet." + + input = _make_input( + in_shape, + in_strides, + dtype=dtype, + device=device, + ) + + n, c, h, w = in_shape + + random_samples = torch.rand( + (n, c, 2), + dtype=dtype, + device=input.device, + ) + + ninetoothed_output = ntops.torch.fractional_max_pool2d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=None, + return_indices=return_indices, + _random_samples=random_samples, + ) + + reference_output = F.fractional_max_pool2d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=None, + return_indices=return_indices, + _random_samples=random_samples, + ) + + assert ninetoothed_output.shape == reference_output.shape + assert ninetoothed_output.dtype == reference_output.dtype + + assert torch.isfinite(ninetoothed_output).all(), ( + "ninetoothed_output contains inf or nan" + ) + assert torch.isfinite(reference_output).all(), ( + "reference_output contains inf or nan" + ) + + tolerance = _get_tolerance(dtype) + + torch.testing.assert_close( + ninetoothed_output, + reference_output, + atol=tolerance["atol"], + rtol=tolerance["rtol"], + ) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize("kernel_size", (2, (3, 3))) +@pytest.mark.parametrize( + "output_size, output_ratio", + ( + ((50, 50), None), + (None, (0.5, 0.5)), + ), +) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_fractional_max_pool2d_ratio_and_size( + n, + c, + h, + w, + kernel_size, + output_size, + output_ratio, + dtype, + device, +): + input = torch.randn( + (n, c, h, w), + dtype=dtype, + device=device, + ) + + random_samples = torch.rand( + (n, c, 2), + dtype=dtype, + device=input.device, + ) + + ninetoothed_output = ntops.torch.fractional_max_pool2d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=False, + _random_samples=random_samples, + ) + + reference_output = F.fractional_max_pool2d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=False, + _random_samples=random_samples, + ) + + assert ninetoothed_output.shape == reference_output.shape + assert ninetoothed_output.dtype == reference_output.dtype + + assert torch.isfinite(ninetoothed_output).all(), ( + "ninetoothed_output contains inf or nan" + ) + assert torch.isfinite(reference_output).all(), ( + "reference_output contains inf or nan" + ) + + tolerance = _get_tolerance(dtype) + + torch.testing.assert_close( + ninetoothed_output, + reference_output, + atol=tolerance["atol"], + rtol=tolerance["rtol"], + ) \ No newline at end of file diff --git a/tests/test_fractional_max_pool3d.py b/tests/test_fractional_max_pool3d.py new file mode 100644 index 0000000..3e20c27 --- /dev/null +++ b/tests/test_fractional_max_pool3d.py @@ -0,0 +1,200 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +_TEST_CASES_DATA = [ + ((2, 3, 9, 9, 9), None, (3, 3, 3), (4, 4, 4), False), + ((1, 4, 8, 10, 12), None, (2, 3, 2), (4, 4, 6), False), + ((2, 2, 7, 11, 5), (770, 110, 55, 5, 1), (3, 2, 3), (3, 4, 2), False), + ((3, 6, 5, 6, 7), None, (2, 2, 2), (3, 3, 4), False), + ((1, 8, 10, 10, 10), None, (4, 3, 2), (5, 4, 5), False), + ((2, 5, 12, 8, 6), None, (3, 3, 2), (4, 3, 2), False), +] + + +_TOLERANCE_MAP = { + torch.float16: {"atol": 1e-3, "rtol": 1e-2}, + torch.float32: {"atol": 1e-5, "rtol": 1e-4}, +} + + +@pytest.fixture +def device(): + assert torch.cuda.is_available() + return torch.empty(0).cuda().device + + +def _get_tolerance(dtype): + return _TOLERANCE_MAP[dtype] + + +def _make_input(shape, strides, dtype, device): + if strides is None: + return torch.randn( + shape, + dtype=dtype, + device=device, + ) + + storage_size = 1 + for size, stride in zip(shape, strides): + storage_size += (size - 1) * stride + + base = torch.randn( + (storage_size,), + dtype=dtype, + device=device, + ) + + return torch.as_strided( + base, + size=shape, + stride=strides, + ) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize( + "in_shape, in_strides, kernel_size, output_size, return_indices", + _TEST_CASES_DATA, +) +def test_fractional_max_pool3d_cases( + in_shape, + in_strides, + kernel_size, + output_size, + return_indices, + dtype, + device, +): + assert not return_indices, "return_indices is not supported by ntops yet." + + input = _make_input( + in_shape, + in_strides, + dtype=dtype, + device=device, + ) + + n, c, d, h, w = in_shape + + random_samples = torch.rand( + (n, c, 3), + dtype=dtype, + device=input.device, + ) + + ninetoothed_output = ntops.torch.fractional_max_pool3d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=None, + return_indices=return_indices, + _random_samples=random_samples, + ) + + reference_output = F.fractional_max_pool3d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=None, + return_indices=return_indices, + _random_samples=random_samples, + ) + + assert ninetoothed_output.shape == reference_output.shape + assert ninetoothed_output.dtype == reference_output.dtype + + assert torch.isfinite(ninetoothed_output).all(), ( + "ninetoothed_output contains inf or nan" + ) + assert torch.isfinite(reference_output).all(), ( + "reference_output contains inf or nan" + ) + + tolerance = _get_tolerance(dtype) + + torch.testing.assert_close( + ninetoothed_output, + reference_output, + atol=tolerance["atol"], + rtol=tolerance["rtol"], + ) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize("kernel_size", (2, (3, 3, 3))) +@pytest.mark.parametrize( + "output_size, output_ratio", + ( + ((6, 6, 6), None), + (None, (0.5, 0.5, 0.5)), + ), +) +@pytest.mark.parametrize("n, c, d, h, w", ((2, 3, 12, 12, 12),)) +def test_fractional_max_pool3d_ratio_and_size( + n, + c, + d, + h, + w, + kernel_size, + output_size, + output_ratio, + dtype, + device, +): + input = torch.randn( + (n, c, d, h, w), + dtype=dtype, + device=device, + ) + + random_samples = torch.rand( + (n, c, 3), + dtype=dtype, + device=input.device, + ) + + ninetoothed_output = ntops.torch.fractional_max_pool3d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=False, + _random_samples=random_samples, + ) + + reference_output = F.fractional_max_pool3d( + input, + kernel_size=kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=False, + _random_samples=random_samples, + ) + + assert ninetoothed_output.shape == reference_output.shape + assert ninetoothed_output.dtype == reference_output.dtype + + assert torch.isfinite(ninetoothed_output).all(), ( + "ninetoothed_output contains inf or nan" + ) + assert torch.isfinite(reference_output).all(), ( + "reference_output contains inf or nan" + ) + + tolerance = _get_tolerance(dtype) + + torch.testing.assert_close( + ninetoothed_output, + reference_output, + atol=tolerance["atol"], + rtol=tolerance["rtol"], + ) \ No newline at end of file diff --git a/tests/test_multilabel_margin_loss.py b/tests/test_multilabel_margin_loss.py new file mode 100644 index 0000000..1ce75e2 --- /dev/null +++ b/tests/test_multilabel_margin_loss.py @@ -0,0 +1,96 @@ +import random +import math + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +def _make_multilabel_margin_target(shape, device): + c = shape[-1] + + outer = 1 + for s in shape[:-1]: + outer *= s + + target_2d = torch.full( + (outer, c), + -1, + dtype=torch.long, + device=device, + ) + + for i in range(outer): + num_pos = random.randint(0, c) + + if num_pos > 0: + labels = torch.randperm( + c, + device=device, + dtype=torch.long, + )[:num_pos] + + target_2d[i, :num_pos] = labels + + return target_2d.reshape(shape) + + +def _reference_multilabel_margin_loss(input, target, reduction): + c = input.shape[-1] + + input_2d = input.reshape(-1, c) + target_2d = target.reshape(-1, c) + + output = F.multilabel_margin_loss( + input_2d, + target_2d, + reduction=reduction, + ) + + if reduction == "none": + return output.reshape(input.shape[:-1]) + + return output + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_multilabel_margin_loss(shape, dtype, device, rtol, atol): + # 不 skip 高维,统一按最后一维为类别维 C + assert len(shape) >= 1 + + input = torch.randn( + shape, + dtype=dtype, + device=device, + ) + + target = _make_multilabel_margin_target( + shape, + device, + ) + + reduction = random.choice(("none", "mean", "sum")) + + ninetoothed_output = ntops.torch.multilabel_margin_loss( + input, + target, + reduction=reduction, + ) + + reference_output = _reference_multilabel_margin_loss( + input, + target, + reduction, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file diff --git a/tests/test_scatter_add.py b/tests/test_scatter_add.py new file mode 100644 index 0000000..5dce5d8 --- /dev/null +++ b/tests/test_scatter_add.py @@ -0,0 +1,58 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_scatter_add(shape, dtype, device, rtol, atol): + input = torch.randn( + shape, + dtype=dtype, + device=device, + ) + + src = torch.randn( + shape, + dtype=dtype, + device=device, + ) + + dim = random.randrange(len(shape)) + + if shape[dim] == 0: + pytest.skip("scatter_add dim size must be non-zero.") + + index = torch.randint( + 0, + shape[dim], + shape, + dtype=torch.long, + device=device, + ) + + ninetoothed_output = ntops.torch.scatter_add( + input, + dim, + index, + src, + ) + + reference_output = torch.scatter_add( + input, + dim, + index, + src, + ) + + assert torch.allclose( + ninetoothed_output, + reference_output, + rtol=rtol, + atol=atol, + ) \ No newline at end of file