Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
logsumexp,
lp_pool1d,
lp_pool2d,
lp_pool3d,
max,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"logsumexp",
"lp_pool1d",
"lp_pool2d",
"lp_pool3d",
"max",
]
45 changes: 45 additions & 0 deletions src/ntops/kernels/logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def _exp(x, dtype):
exp_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype)

def _log(x, dtype):
log_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(ntl.log(ntl.cast(x, log_dtype)), dtype)

def application(input, output):
# input&output: (C // block_size, )
# input.dtype: (block_size, )
dtype = output.dtype.dtype
prev_max = ntl.cast(float("-inf"), dtype)
denominator = ntl.cast(0, dtype)

for i in range(input.shape[0]):
input_i = ntl.cast(input[i], dtype)
curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype)
input_max_diff_exp = _exp(input_i - curr_max, dtype)
prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype)
denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp)
prev_max = curr_max

output[0] = prev_max + _log(denominator, dtype)


def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(
ndim, dtype=dtype, other=float("-inf"), shape_options={"constexpr": True}
),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
79 changes: 79 additions & 0 deletions src/ntops/kernels/lp_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed.language import libdevice
from ninetoothed import Tensor
from ninetoothed import Symbol


def arrangement(input, output, norm_type, kernel_size_val, kernel_size, stride, block_size, ceil_mode):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, L_in) output: (N, C, L_out)

input_arranged = input.tile((1, 1, kernel_size), (1, 1, stride), floor_mode=not ceil_mode)
# => (N, C, L_out), dtype=(1, 1, k)
input_arranged = input_arranged.ravel()
# => (N, C, L_out, 1, 1, k)
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)
# => (N*C*L_out, k)
# k 的找到最近的 2 的倍数
nearest_pow2 = 1 << (kernel_size - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., k // nearest_pow2 = 1), dtype=(1, nearest_pow2)
input_arranged.dtype = input_arranged.dtype.squeeze(0)
# => (..., 1), dtype=(nearest_pow2, )
input_arranged = input_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, )
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

output_arranged = output.tile((1, 1, 1))
# => (N, C, L_out), dtype=(1, 1, 1)
output_arranged = output_arranged.ravel()
# => (N, C, L_out, 1, 1, 1)
output_arranged = output_arranged.flatten(end_dim=3).flatten(start_dim=1)
# => (N*C*L_out, 1)
output_arranged = output_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1)
output_arranged.dtype = output_arranged.dtype.squeeze(1)
# => (..., 1), dtype=(block_size, )

return input_arranged, output_arranged, norm_type, kernel_size_val


def _pow(x, norm, dtype):
pow_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(libdevice.pow(ntl.cast(x, pow_dtype), norm), dtype)

def application(input, output, norm_type, kernel_size):
# input: (block_size, nearest_pow2)
# output: (block_size)
dtype = input.dtype
mask = input < 1e20
cnt = ntl.sum(ntl.cast(mask, ntl.int32), axis=1)
input_masked = ntl.where(~mask, 0, input)
x_pow = _pow(input_masked, norm_type, dtype)
acc_sim = ntl.sum(x_pow, 1) / cnt * kernel_size
output = _pow(acc_sim, 1.0 / norm_type, dtype)


def premake(ndim, kernel_size, stride, ceil_mode=False, dtype=None, block_size=None):
arrangement_ = functools.partial(
arrangement,
kernel_size=kernel_size,
stride=stride,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
Tensor(ndim, dtype=dtype, other=float("inf")), # input
Tensor(ndim, dtype=dtype), # output
Tensor(0, dtype=dtype), # norm_type
Tensor(0, dtype=dtype, constexpr=True), # kernel_size
)

return arrangement_, application, tensors
163 changes: 163 additions & 0 deletions src/ntops/kernels/lp_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed.language import libdevice
from ninetoothed import Tensor
from ninetoothed import Symbol


def _pow(x, norm, dtype):
pow_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(libdevice.pow(ntl.cast(x, pow_dtype), norm), dtype)

def arrangement_ceil_mode(*tensors, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size, ceil_mode):
input, output, norm_type, kernel_size_flatted = tensors
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, H_in, W_in) output: (N, C, H_out, W_out)
# ref. example 里的 max_pool2d arrangement

input_arranged = input.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), floor_mode=not ceil_mode)
# => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w)
input_arranged = input_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, k_h, k_w)
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, k_h*k_w)
# k_h*k_w 的找到最近的 2 的倍数
nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2)
input_arranged.dtype = input_arranged.dtype.squeeze(0)
# => (..., 1), dtype=(nearest_pow2, )
input_arranged = input_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, )
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

output_arranged = output.tile((1, 1, 1, 1))
# => (N, C, H_out, W_out), dtype=(1, 1, 1, 1)
output_arranged = output_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, 1, 1)
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, 1)
output_arranged = output_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1)
output_arranged.dtype = output_arranged.dtype.squeeze(1)
# => (..., 1), dtype=(block_size, )

return input_arranged, output_arranged, norm_type, kernel_size_flatted



def application_ceil_mode(input, output, norm_type, kernel_size_flatted):
# input: (block_size, nearest_pow2) arrangement 之后最外层被用于并行计算
# output: (block_size, )
# 这里 torch 实现与文档上的不一致,文档上描述的是 sum(windows^p)^(1/p)
# 实际上 torch 的实现是 mean(windows^p) * (kernel_size_h * kernel_size_w))^(1/p)
# 这在 strides=kernel_size 时的结果是一致的,但是在 strides!=kernel_size && ceil_mode=True 时会有差异
# 主要体现在边界处理上, torch 的算法会放大边界处的值,因为边界处的窗口内有效元素个数少于 kernel_size_h * kernel_size_w
# 下面给出了两种不同的实现
# 这是补 0 的实现 (要使用这种实现,请将input的默认值修改为 0)
# dtype = input.dtype
# x_pow = _pow(input, norm_type, dtype)
# acc = ntl.sum(x_pow, axis=0)
# output = _pow(acc, 1.0 / norm_type, dtype)

# 为了通过测试,下面使用的是与 torch 实现一致的版本
dtype = input.dtype
mask = input < 1e20
cnt = ntl.sum(ntl.cast(mask, ntl.int32), axis=1)
input_masked = ntl.where(~mask, 0, input)
x_pow = _pow(input_masked, norm_type, dtype)
acc_sim = ntl.sum(x_pow, 1) / cnt * kernel_size_flatted
output = _pow(acc_sim, 1.0 / norm_type, dtype)


def premake_ceil_mode(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None):
arrangement_ = functools.partial(
arrangement_ceil_mode,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
Tensor(ndim, dtype=dtype, other=float("inf")), # input
Tensor(ndim, dtype=dtype), # output
Tensor(0, dtype=dtype), # norm_type
Tensor(0, dtype=dtype), # kernel_size_flatted
)

return arrangement_, application_ceil_mode, tensors



def arrangement(input, output, norm_type, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size, ceil_mode):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, H_in, W_in) output: (N, C, H_out, W_out)
# ref. example 里的 max_pool2d arrangement

input_arranged = input.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), floor_mode=not ceil_mode)
# => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w)
input_arranged = input_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, k_h, k_w)
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, k_h*k_w)
# k_h*k_w 的找到最近的 2 的倍数
nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2)
input_arranged.dtype = input_arranged.dtype.squeeze(0)
# => (..., 1), dtype=(nearest_pow2, )
input_arranged = input_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, )
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

output_arranged = output.tile((1, 1, 1, 1))
# => (N, C, H_out, W_out), dtype=(1, 1, 1, 1)
output_arranged = output_arranged.ravel()
# => (N, C, H_out, W_out, 1, 1, 1, 1)
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
# => (N*C*H_out*W_out, 1)
output_arranged = output_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1)
output_arranged.dtype = output_arranged.dtype.squeeze(1)
# => (..., 1), dtype=(block_size, )

return input_arranged, output_arranged, norm_type

def application(input, output, norm_type):
# input: (block_size, nearest_pow2)
# output: (block_size, )
dtype = input.dtype
x_pow = _pow(input, norm_type, dtype)
acc = ntl.sum(x_pow, axis=1)
output = _pow(acc, 1.0 / norm_type, dtype)


def premake(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
Tensor(ndim, dtype=dtype, other=0), # input
Tensor(ndim, dtype=dtype), # output
Tensor(0, dtype=dtype), # norm_type
)

return arrangement_, application, tensors
Loading