Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
bitwise_left_shift,
index_select,
fold,
mish,
log2,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"bitwise_left_shift",
"index_select",
"fold",
"mish",
"log2",
]
33 changes: 33 additions & 0 deletions src/ntops/kernels/bitwise_left_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools

from ninetoothed import Tensor
import ninetoothed.language as ntl

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
if input.dtype == ntl.int32:
mask = (other > 31) | (other < 0)
elif input.dtype == ntl.int64:
mask = (other > 63) | (other < 0)
elif input.dtype == ntl.uint8:
mask = (other > 7) | (other < 0)
else:
mask = ntl.zeros_like(other)

shift = ntl.where(mask, ntl.zeros_like(other), other)
input = ntl.where(mask, ntl.zeros_like(input), input)
output = input << shift


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
42 changes: 42 additions & 0 deletions src/ntops/kernels/fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools

from ninetoothed import Tensor
import ninetoothed.language as ntl

def arrangement(*tensors, L_pow2, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w, padding_h, padding_w, block_size=None):
# input: (N, C * k_w * k_h, H_out * W_out)
# output: (N, C, H_in, W_in)
input, output, L_val = tensors

# 排布 output, 使其与 input 对齐
output = output.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), (1, 1, dilation_h, dilation_w))
# => output: (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w)
output = output.ravel() # => output: (N, C, H_out, W_out, 1, 1, k_h, k_w)
output = output.permute((0, 1, 4, 5, 6, 7, 2, 3))
# => output: (N, C, 1, 1, k_h, k_w, H_out, W_out)
output = output.flatten(start_dim=0, end_dim=6).flatten(start_dim=1)
# => output: (N * C * k_h * k_w, H_out * W_out)
output = output.tile((block_size, L_pow2)).squeeze(1)
# => output: (... // block_size, ), dtype=(block_size, L_pow2)

input = input.flatten(end_dim=2) # => input: (N * C * k_h * k_w, H_out * W_out)
input = input.tile((block_size, L_pow2)).squeeze(1)
# => input: (... // block_size), dtype=(block_size, L_pow2)

return input, output, L_val

def application(input, output, L):
# input: (block_size, L_pow2)
# output: (block_size, L_pow2)
ntl.atomic_add(output.data_ptr() + output.offsets(), input)

def premake(L_pow2, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w, padding_h, padding_w, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, L_pow2=L_pow2, kernel_size_h=kernel_size_h, kernel_size_w=kernel_size_w, stride_h=stride_h, stride_w=stride_w, dilation_h=dilation_h, dilation_w=dilation_w, padding_h=padding_h, padding_w=padding_w, block_size=block_size)

tensors = (
Tensor(3, dtype=dtype, other=0, shape_options={'constexpr': True}),
Tensor(4, dtype=dtype, other=0, shape_options={'constexpr': True}),
Tensor(0, dtype=int, constexpr=True), # L
)

return arrangement_, application, tensors
83 changes: 83 additions & 0 deletions src/ntops/kernels/index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

import functools

from ninetoothed import Tensor
import ninetoothed.language as ntl

def arrangement(input, output, index, T, S, T_pow2, S_pow2, dim, block_size=None):
non_target_dim = tuple(i for i in range(input.ndim) if i != dim)
input = input.permute(non_target_dim + (dim,))
input = input.flatten(end_dim=-1) # shape: (..., T)

output = output.permute(non_target_dim + (dim,))
output = output.flatten(end_dim=-1) # shape: (..., S)

# input: (..., T)
# output: (..., S)
# index: (S,)
input_tiled = input.tile((block_size, T_pow2)).squeeze(1) # shape: (..., ), dtype=(block_size, T_pow2)
output_tiled = output.tile((block_size, S_pow2)).squeeze(1) # shape: (..., ), dtype=(block_size, S_pow2)

index_expand = index.unsqueeze(0).expand((input_tiled.shape[0], -1)) # shape: (..., S)
index_expand = index_expand.tile((1, S_pow2)).squeeze(1) # shape: (..., ), dtype=(1, S_pow2)

return input_tiled, output_tiled, index_expand, T, S

# def application(input, output, index):
# # input: (block_size, T)
# # output: (block_size, S)
# # index: (1, S)
# # 使用 gather 实现 index_select
# # Triton 3.0.0 不支持 gather 操作,因此在摩尔线程中无法使用
# # 这里仅作为参考
# index_expand = ntl.broadcast_to(index, (input.shape[0], index.shape[1]))
# # index_expand: (block_size, S)
# output = ntl.gather(input, index, axis=1)

def application(input, output, index, T, S):
# input: (block_size, T_pow2)
# output: (block_size, S_pow2)
# index: (1, S_pow2)

# 使用 T_pow2 满足 arange 的 2 次幂要求
col_indices = ntl.arange(0, input.shape[1]) # shape: (T_pow2,)

# 添加维度并广播到 (block_size, S, T_pow2)
col_indices = ntl.expand_dims(col_indices, 0) # shape: (1, T_pow2)
col_indices = ntl.expand_dims(col_indices, 0) # shape: (1, 1, T_pow2)
col_indices = ntl.broadcast_to(col_indices, (input.shape[0], output.shape[1], input.shape[1]))

# 扩展 input 到 (block_size, S, T_pow2)
input_expanded = ntl.expand_dims(input, 1) # shape: (block_size, 1, T_pow2)
input_expanded = ntl.broadcast_to(input_expanded, (input.shape[0], output.shape[1], input.shape[1]))

# 扩展 index 到 (block_size, S, T_pow2)
index_expanded = ntl.expand_dims(index, 2) # shape: (block_size, S, 1)
index_expanded = ntl.broadcast_to(index_expanded, (input.shape[0], output.shape[1], input.shape[1]))

# 仅在有效列范围内匹配,超出原始 T 的部分屏蔽
col_valid = col_indices < input.shape[1]
match_mask = (col_indices == index_expanded)
mask = ntl.where(col_valid, match_mask, False)

# 使用 where 选择对应的值
selected = ntl.where(mask, input_expanded, 0.0) # shape: (block_size, S, T_pow2)

# 对最后一个维度求和得到结果
result = ntl.sum(selected, axis=2) # shape: (block_size, S)

# 写回输出
output = result

def premake(ndim, dim, T_pow2, S_pow2, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, T_pow2=T_pow2, S_pow2=S_pow2, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, other=0, shape_options={'constexpr': True}),
Tensor(ndim, dtype=dtype, other=0, shape_options={'constexpr': True}),
Tensor(1, dtype=int, shape_options={'constexpr': True}),
Tensor(0, dtype=int, constexpr=True), # T
Tensor(0, dtype=int, constexpr=True), # S
)

return arrangement_, application, tensors
20 changes: 20 additions & 0 deletions src/ntops/kernels/log2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
dtype = input.dtype
log2_dtype = dtype if dtype != ntl.float16 else ntl.float32
output = ntl.cast(ntl.log2(ntl.cast(input, log2_dtype)), dtype)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
37 changes: 37 additions & 0 deletions src/ntops/kernels/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def _softplus(x):
return ntl.log(ntl.exp(-ntl.abs(x)) + 1.0) + ntl.maximum(x, 0.0)


def _tanh(x):
return (ntl.exp(2 * x) - 1) / (ntl.exp(2 * x) + 1)


def application(input, output):
dtype = input.dtype
if dtype == ntl.float16:
mish_dtype = ntl.float32
elif dtype == ntl.bfloat16:
mish_dtype = ntl.float32
else:
mish_dtype = dtype

input_f32 = ntl.cast(input, mish_dtype)
output_softplus_f32 = _softplus(input_f32)
output_f32 = _tanh(output_softplus_f32)
output = ntl.cast(output_f32 * input_f32, dtype)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh
from ntops.torch.bitwise_left_shift import bitwise_left_shift
from ntops.torch.index_select import index_select
from ntops.torch.fold import fold
from ntops.torch.mish import mish
from ntops.torch.log2 import log2

__all__ = [
"abs",
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"bitwise_left_shift",
"index_select",
"fold",
"mish",
"log2",
]
33 changes: 33 additions & 0 deletions src/ntops/torch/bitwise_left_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def bitwise_left_shift(input, other, *, out=None):
# Check if we need to handle non-contiguous inplace operation
is_inplace_input = out is not None and out.data_ptr() == input.data_ptr()

if out is None:
out = torch.empty_like(input)

# 处理非连续张量的原地操作特殊情况:
# 当 out 和 input 是同一个张量(原地操作)且 input 具有非标准 strides(非连续)时,
# ninetoothed 框架中的 element_wise.arrangement 函数使用 flatten() 会丢失内存布局信息,
# 导致 GPU kernel 无法正确将结果写回到具有特殊 strides 的原始张量中。
# 解决方案是先将输入转换为连续张量进行计算,然后使用 copy_() 将结果复制回原始张量,
# copy_() 方法会正确处理目标张量的 strides,确保数据被写入到正确的内存位置。
if is_inplace_input and not input.is_contiguous():
input_contig = input.contiguous()
other_contig = other.contiguous() if not other.is_contiguous() else other
out_contig = torch.empty_like(input_contig)

kernel = _cached_make(ntops.kernels.bitwise_left_shift.premake, input.ndim)
kernel(input_contig, other_contig, out_contig)

out.copy_(out_contig)
else:
kernel = _cached_make(ntops.kernels.bitwise_left_shift.premake, input.ndim)
kernel(input, other, out)

return out
88 changes: 88 additions & 0 deletions src/ntops/torch/fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch

import ntops
from ntops.torch.utils import _cached_make

def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(dilation, int):
dilation = (dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(output_size, int):
output_size = (output_size, output_size)

# 记录原始输入维度
input_was_2d = input.ndim == 2
if input_was_2d:
input = input.view((1, input.shape[0], input.shape[1]))

N, Ckk, L = input.shape
H_out, W_out = output_size
K_h, K_w = kernel_size
D_h, D_w = dilation
P_h, P_w = padding
S_h, S_w = stride

# 验证和计算 L
C = Ckk // (K_h * K_w)
if C * K_h * K_w != Ckk:
raise ValueError(f"Input channel dimension {Ckk} is not divisible by kernel size product {K_h * K_w}")

L_h = (H_out + 2 * P_h - (D_h * (K_h - 1) + 1)) // S_h + 1
L_w = (W_out + 2 * P_w - (D_w * (K_w - 1) + 1)) // S_w + 1
if L != L_h * L_w:
raise ValueError(f"Input L {L} != computed L_h*L_w {L_h * L_w}")

# 创建带 padding 的输出张量
out_padded_h = H_out + 2 * P_h
out_padded_w = W_out + 2 * P_w
out = torch.empty(
(N, C, out_padded_h, out_padded_w),
dtype=input.dtype,
device=input.device
)
torch.nn.init.zeros_(out)

# 创建并调用 kernel
block_size = 128
L_pow2 = 1 << (L - 1).bit_length()
kernel = _cached_make(
ntops.kernels.fold.premake,
L_pow2,
kernel_size[0],
kernel_size[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
padding[0],
padding[1],
dtype=input.dtype,
block_size=block_size
)
kernel(input, out, L)

# 移除 padding
result = out
if P_h > 0 or P_w > 0:
# 目前不支持直接切片,只能用 narrow 实现
result = torch.narrow(result, 2, P_h, H_out)
result = torch.narrow(result, 3, P_w, W_out)

# 由于 ninetoothed 框架下难以实现原地 padding 的操作,因此这里创建新张量
# 创建新张量接收结果,确保内存连续
output = torch.empty(
(N, C, H_out, W_out),
dtype=input.dtype,
device=input.device)
torch.nn.init.zeros_(output)
torch.add(output, result, out=output)

if input_was_2d:
output = output.view((output.shape[1], output.shape[2], output.shape[3]))

return output
Loading