Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 124 additions & 5 deletions src/infiniop/ops/rearrange/nvidia/rearrange_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#define ARRAY_TYPE_SIZE size_t

// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
#define MAX_BLOCK_ARRAY_SIZE 5
#define MAX_GRID_ARRAY_SIZE 5
#define MAX_BLOCK_ARRAY_SIZE 6
#define MAX_GRID_ARRAY_SIZE 6

template <int ArrSize, typename ArrayType>
struct ArrayStruct {
Expand Down Expand Up @@ -182,35 +182,148 @@ struct Constraint {
DEFINE_REARRANGE_KERNEL(float1, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(float2, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(float4, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
DEFINE_REARRANGE_KERNEL(double4_32a, constraint_num, block_array_size, grid_array_size)

// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 为1-5和1-5的所有组合生成内核
// 为1-6和1-6的所有组合生成内核
DEFINE_KERNELS_BY_CONSTRAINT(1, 1)
DEFINE_KERNELS_BY_CONSTRAINT(1, 2)
DEFINE_KERNELS_BY_CONSTRAINT(1, 3)
DEFINE_KERNELS_BY_CONSTRAINT(1, 4)
DEFINE_KERNELS_BY_CONSTRAINT(1, 5)
DEFINE_KERNELS_BY_CONSTRAINT(1, 6)
DEFINE_KERNELS_BY_CONSTRAINT(2, 1)
DEFINE_KERNELS_BY_CONSTRAINT(2, 2)
DEFINE_KERNELS_BY_CONSTRAINT(2, 3)
DEFINE_KERNELS_BY_CONSTRAINT(2, 4)
DEFINE_KERNELS_BY_CONSTRAINT(2, 5)
DEFINE_KERNELS_BY_CONSTRAINT(2, 6)
DEFINE_KERNELS_BY_CONSTRAINT(3, 1)
DEFINE_KERNELS_BY_CONSTRAINT(3, 2)
DEFINE_KERNELS_BY_CONSTRAINT(3, 3)
DEFINE_KERNELS_BY_CONSTRAINT(3, 4)
DEFINE_KERNELS_BY_CONSTRAINT(3, 5)
DEFINE_KERNELS_BY_CONSTRAINT(3, 6)
DEFINE_KERNELS_BY_CONSTRAINT(4, 1)
DEFINE_KERNELS_BY_CONSTRAINT(4, 2)
DEFINE_KERNELS_BY_CONSTRAINT(4, 3)
DEFINE_KERNELS_BY_CONSTRAINT(4, 4)
DEFINE_KERNELS_BY_CONSTRAINT(4, 5)
DEFINE_KERNELS_BY_CONSTRAINT(4, 6)
DEFINE_KERNELS_BY_CONSTRAINT(5, 1)
DEFINE_KERNELS_BY_CONSTRAINT(5, 2)
DEFINE_KERNELS_BY_CONSTRAINT(5, 3)
DEFINE_KERNELS_BY_CONSTRAINT(5, 4)
DEFINE_KERNELS_BY_CONSTRAINT(5, 5)
DEFINE_KERNELS_BY_CONSTRAINT(5, 6)
DEFINE_KERNELS_BY_CONSTRAINT(6, 1)
DEFINE_KERNELS_BY_CONSTRAINT(6, 2)
DEFINE_KERNELS_BY_CONSTRAINT(6, 3)
DEFINE_KERNELS_BY_CONSTRAINT(6, 4)
DEFINE_KERNELS_BY_CONSTRAINT(6, 5)
DEFINE_KERNELS_BY_CONSTRAINT(6, 6)

// ==============================================================================
// 动态Kernel - 支持任意维度的fallback实现
// ==============================================================================

template <typename Tmem_type>
__global__ void rearrange_dynamic_kernel(
void *__restrict__ dst,
const void *__restrict__ src,
const size_t block_dim,
const size_t block_len_total,
const ARRAY_TYPE_SIZE *block_len,
const ARRAY_TYPE_STRIDE *src_block_stride,
const ARRAY_TYPE_STRIDE *dst_block_stride,
const size_t grid_dim,
const ARRAY_TYPE_SIZE *grid_len,
const ARRAY_TYPE_STRIDE *src_grid_stride,
const ARRAY_TYPE_STRIDE *dst_grid_stride,
const size_t constraint_num,
const Constraint<ARRAY_TYPE_SIZE> *constraints) {

size_t thread_idx = threadIdx.x;
if (thread_idx >= block_len_total) {
return;
}

// 使用共享内存存储grid级别的偏移量
__shared__ ptrdiff_t shared_src_offset;
__shared__ ptrdiff_t shared_dst_offset;
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[2]; // 最多支持2个约束

// 第0号线程计算grid偏移
if (threadIdx.x == 0) {
ptrdiff_t src_offset = 0;
ptrdiff_t dst_offset = 0;
size_t remaining = blockIdx.x;

// 初始化约束
for (size_t j = 0; j < constraint_num && j < 2; j++) {
shared_constraints_grid_idx_multiple[j] = 0;
}

// 计算grid维度的偏移
for (int i = grid_dim - 1; i >= 0; i--) {
size_t idx = remaining % grid_len[i];
remaining /= grid_len[i];
src_offset += idx * src_grid_stride[i];
dst_offset += idx * dst_grid_stride[i];

// 处理约束
for (size_t j = 0; j < constraint_num && j < 2; j++) {
if (i == constraints[j].grid_idx) {
shared_constraints_grid_idx_multiple[j] = idx * constraints[j].grid_div_block;
}
}
}

shared_src_offset = src_offset;
shared_dst_offset = dst_offset;
}

__syncthreads();

// 所有线程读取共享内存
ptrdiff_t src_offset = shared_src_offset;
ptrdiff_t dst_offset = shared_dst_offset;
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[2];
for (size_t j = 0; j < constraint_num && j < 2; j++) {
constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j];
}

// 计算block维度的偏移
size_t remaining = thread_idx;
for (int i = block_dim - 1; i >= 0; i--) {
size_t idx = remaining % block_len[i];
remaining /= block_len[i];

src_offset += idx * src_block_stride[i];
dst_offset += idx * dst_block_stride[i];

// 检查约束
for (size_t j = 0; j < constraint_num && j < 2; j++) {
if (i == constraints[j].block_idx) {
if (constraints_grid_idx_multiple[j] + idx >= constraints[j].total_len) {
return;
}
}
}
}

// 执行数据拷贝
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) =
*reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset);
}

// 为不同的数据类型生成动态kernel的模板实例
template __global__ void rearrange_dynamic_kernel<uchar1>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);
template __global__ void rearrange_dynamic_kernel<uchar2>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);
template __global__ void rearrange_dynamic_kernel<float1>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);
template __global__ void rearrange_dynamic_kernel<float2>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);
template __global__ void rearrange_dynamic_kernel<float4>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);
template __global__ void rearrange_dynamic_kernel<double4_32a>(void *, const void *, const size_t, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const ARRAY_TYPE_SIZE *, const ARRAY_TYPE_STRIDE *, const ARRAY_TYPE_STRIDE *, const size_t, const Constraint<ARRAY_TYPE_SIZE> *);

// 准备参数结构体
struct RearrangeParams {
Expand Down Expand Up @@ -258,7 +371,7 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
GET_REARRANGE_KERNEL(float4, block_array_size, grid_array_size, constraint_num); \
break; \
case 32: \
GET_REARRANGE_KERNEL(double4, block_array_size, grid_array_size, constraint_num); \
GET_REARRANGE_KERNEL(double4_32a, block_array_size, grid_array_size, constraint_num); \
break; \
default: \
return INFINI_STATUS_BAD_PARAM; \
Expand Down Expand Up @@ -294,6 +407,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
case 5: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 5); \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 6); \
break; \
}

#define GET_REARRANGE_KERNEL_BY_BLOCK_NUM \
Expand All @@ -313,6 +429,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
case 5: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(5); \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(6); \
break; \
}

GET_REARRANGE_KERNEL_BY_BLOCK_NUM
Expand Down
Loading