Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef ENABLE_HYGON_API
#include <cuda_fp8.h>
#endif

// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
Expand Down
14 changes: 13 additions & 1 deletion src/infiniop/ops/add/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/add_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
Expand Down Expand Up @@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -91,6 +94,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -139,6 +145,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -181,6 +190,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down
3 changes: 3 additions & 0 deletions src/infiniop/ops/logsoftmax/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ __device__ void logSoftmaxKernel(
}
#if CUDART_VERSION >= 12090
max_val = BlockReduce(temp_storage).Reduce(max_val, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
max_val = BlockReduce(temp_storage).Reduce(
max_val, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
#endif
Expand Down
6 changes: 6 additions & 0 deletions src/infiniop/ops/lp_norm/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ __device__ void blockLPNormKernel(
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
float max_block = BlockReduce(temp_storage).Reduce(
local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
Expand Down Expand Up @@ -75,6 +78,9 @@ __device__ void blockLPNormStridesKernel(
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
float max_block = BlockReduce(temp_storage).Reduce(
local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/ones/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ public:
return 1;
} else if constexpr (std::is_same_v<T, uint64_t>) { // 10
return 1;
#ifndef ENABLE_HYGON_API
} else if constexpr (std::is_same_v<T, cuda_fp8_e4m3>) { // 11
return cuda_fp8_e4m3(1.0f);
#endif
} else if constexpr (std::is_same_v<T, half>) { // 12
return __float2half(1.0f);
} else if constexpr (std::is_same_v<T, float>) { // 13
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/ones/nvidia/ones_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, cuda::OnesOp, uint32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_U64: // 10
return _device_info->calculate<256, cuda::OnesOp, uint64_t>(_info, workspace, output, inputs, stream);
#ifndef ENABLE_HYGON_API
case INFINI_DTYPE_F8: // 11
return _device_info->calculate<256, cuda::OnesOp, cuda_fp8_e4m3>(_info, workspace, output, inputs, stream);
#endif
case INFINI_DTYPE_F16: // 12
return _device_info->calculate<256, cuda::OnesOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: // 13
Expand Down
16 changes: 16 additions & 0 deletions src/infiniop/ops/topkrouter/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ struct CustomLess {
}
};

// Warp-level sum reduction for Hygon platform
template <int warp_threads>
__inline__ __device__ float WarpSum(float val) {
for (int mask = warp_threads / 2; mask > 0; mask /= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}

template <typename T, int BLOCK_THREADS = 256>
__global__ void topkrouter_kernel(float *values_topk, // 输出数据, 形状[N, topk]
int *indices_topk, // 输出索引, 形状[N, topk]
Expand Down Expand Up @@ -137,12 +146,19 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数
value = sigmoid_func(data_input[index]);
}
{
#ifdef ENABLE_HYGON_API
float warp_sum = WarpSum<warp_threads>(value);
if (0 == tid) {
share_sum = warp_sum + 1e-9f;
}
#else
typedef cub::WarpReduce<float, warp_threads> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
float warp_sum = WarpReduce(temp_storage).Sum(value);
if (0 == tid) {
share_sum = warp_sum + 1e-9f;
}
#endif
}
__syncwarp();

Expand Down
19 changes: 19 additions & 0 deletions src/infiniop/ops/topksoftmax/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ inline __device__ float exp_func(T x) {
return __expf(data);
}

// Warp-level sum reduction for Hygon platform
template <int warp_threads>
__inline__ __device__ float WarpSum(float val) {
for (int mask = warp_threads / 2; mask > 0; mask /= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}

template <typename T, int BLOCK_SIZE = 128>
__global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形状[N, topk]
int *indices_topk, // 输出索引, 形状[N, topk]
Expand Down Expand Up @@ -57,6 +66,9 @@ __global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形
__shared__ typename BlockReduce::TempStorage temp_storage_max;
#if CUDART_VERSION >= 12090
T value_max = BlockReduce(temp_storage_max).Reduce(thread_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
T value_max = BlockReduce(temp_storage_max).Reduce(
thread_max, [](const T &a, const T &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
T value_max = BlockReduce(temp_storage_max).Reduce(thread_max, cub::Max());
#endif
Expand Down Expand Up @@ -117,12 +129,19 @@ __global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形
// 第五步: topk的和 //
// ------------------------------------------------ //
{
#ifdef ENABLE_HYGON_API
float warp_sum = WarpSum<32>(value);
if (0 == tid) {
shared_sum = warp_sum + 1e-9f;
}
#else
typedef cub::WarpReduce<float, 32> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
float warp_sum = WarpReduce(temp_storage).Sum(value);
if (0 == tid) {
shared_sum = warp_sum + 1e-9f;
}
#endif
}
__syncwarp();

Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/zeros/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ public:
return 0;
} else if constexpr (std::is_same_v<T, uint64_t>) { // 10
return 0;
#ifndef ENABLE_HYGON_API
} else if constexpr (std::is_same_v<T, cuda_fp8_e4m3>) { // 11
return cuda_fp8_e4m3(0.0f);
#endif
} else if constexpr (std::is_same_v<T, half>) { // 12
return __float2half(0.0f);
} else if constexpr (std::is_same_v<T, float>) { // 13
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/zeros/nvidia/zeros_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, cuda::ZerosOp, uint32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_U64: // 10
return _device_info->calculate<256, cuda::ZerosOp, uint64_t>(_info, workspace, output, inputs, stream);
#ifndef ENABLE_HYGON_API
case INFINI_DTYPE_F8: // 11
return _device_info->calculate<256, cuda::ZerosOp, cuda_fp8_e4m3>(_info, workspace, output, inputs, stream);
#endif
case INFINI_DTYPE_F16: // 12
return _device_info->calculate<256, cuda::ZerosOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: // 13
Expand Down
23 changes: 11 additions & 12 deletions xmake/hygon.lua
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,13 @@ target("infiniop-hygon")
add_cxflags("-fPIC")

-- 添加海光DCU特定的编译标志
add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936")
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
add_cuflags("-arch=" .. hygon_arch)
print("编译海光DCU架构: " .. hygon_arch)

-- 复用NVIDIA的CUDA实现,通过HIP兼容层
-- 只编译海光DCU支持的7个算子:rope, gemm, causal_softmax, random_sample, rearrange, rms_norm, swiglu
add_files("../src/infiniop/devices/nvidia/*.cu")
add_files("../src/infiniop/ops/rope/nvidia/*.cu")
add_files("../src/infiniop/ops/gemm/nvidia/*.cu")
add_files("../src/infiniop/ops/causal_softmax/nvidia/*.cu")
add_files("../src/infiniop/ops/random_sample/nvidia/*.cu")
add_files("../src/infiniop/ops/rearrange/nvidia/*.cu")
add_files("../src/infiniop/ops/rms_norm/nvidia/*.cu")
add_files("../src/infiniop/ops/swiglu/nvidia/*.cu")
add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")

if has_config("ninetoothed") then
add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}})
Expand Down Expand Up @@ -107,7 +102,9 @@ target("infinirt-hygon")
add_cxflags("-fPIC")

-- 添加海光DCU特定的编译标志
add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936")
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
add_cuflags("-arch=" .. hygon_arch)

add_files("../src/infinirt/cuda/*.cu")
target_end()
Expand Down Expand Up @@ -140,7 +137,9 @@ target("infiniccl-hygon")
add_cxflags("-fPIC")

-- 添加海光DCU特定的编译标志
add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936")
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
add_cuflags("-arch=" .. hygon_arch)

-- 使用NCCL (NVIDIA Collective Communications Library)
add_links("nccl")
Expand Down