Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/common/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
#elif defined(WITH_ILUVATAR)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif WITH_METAX // TODO: Use `defined`.
using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
#elif defined(WITH_METAX)
#include <mcr/mc_runtime.h>
using cuda_bfloat16 = maca_bfloat16;
using cuda_bfloat162 = maca_bfloat162;
Expand All @@ -23,10 +27,11 @@ constexpr int CUDA_BLOCK_SIZE_128 = 128;
constexpr int CUDA_BLOCK_SIZE_256 = 256;
constexpr int CUDA_BLOCK_SIZE_512 = 512;
constexpr int CUDA_BLOCK_SIZE_1024 = 1024;
constexpr int CUDA_BLOCK_SIZE_2048 = 2048;

// Query the maximum threads per block for the current CUDA device.
inline int QueryMaxThreadsPerBlock() {
#ifdef WITH_NVIDIA
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
int device = 0;
cudaGetDevice(&device);
cudaDeviceProp prop;
Expand All @@ -43,7 +48,9 @@ inline int GetOptimalBlockSize() {
int max_threads = QueryMaxThreadsPerBlock();

// Select the largest supported block size for better performance.
if (max_threads >= CUDA_BLOCK_SIZE_1024) {
if (max_threads >= CUDA_BLOCK_SIZE_2048) {
return CUDA_BLOCK_SIZE_2048;
} else if (max_threads >= CUDA_BLOCK_SIZE_1024) {
return CUDA_BLOCK_SIZE_1024;
} else if (max_threads >= CUDA_BLOCK_SIZE_512) {
return CUDA_BLOCK_SIZE_512;
Expand Down
54 changes: 54 additions & 0 deletions src/cuda/add/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef INFINI_OPS_CUDA_ADD_KERNEL_CUH_
#define INFINI_OPS_CUDA_ADD_KERNEL_CUH_

#include "common/cuda/kernel_commons.h"

namespace infini::ops {

struct AddOp {
static constexpr std::size_t num_inputs = 2;

template <typename T>
__device__ __forceinline__ T operator()(const T& input,
const T& other) const {
if constexpr (std::is_same_v<T, half2>) {
return __hadd2(input, other);
} else if constexpr (std::is_same_v<T, half> ||
std::is_same_v<T, TypeMapType<DataType::kBFloat16>>) {
return __hadd(input, other);
} else if constexpr (std::is_same_v<T, float>) {
return __fadd_rn(input, other);
} else {
return input + other;
}
}
};

template <typename T, unsigned int BLOCK_SIZE>
__global__ void AddKernel(T* out, const T* input, const T* other,
const size_t* out_shape, const size_t* input_shape,
const size_t* other_shape,
const ptrdiff_t* out_strides,
const ptrdiff_t* input_strides,
const ptrdiff_t* other_strides, size_t output_size,
size_t ndim, size_t offset, bool out_contiguous,
bool input_contiguous, bool other_contiguous) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;

if (idx < output_size) {
size_t out_idx =
out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides);
size_t input_idx =
input_contiguous ? idx
: IndexToOffset(idx, ndim, input_shape, input_strides);
size_t other_idx =
other_contiguous ? idx
: IndexToOffset(idx, ndim, other_shape, other_strides);

out[out_idx] = AddOp{}(input[input_idx], other[other_idx]);
}
}

} // namespace infini::ops

#endif
85 changes: 31 additions & 54 deletions src/cuda/add/kernel.h
Original file line number Diff line number Diff line change
@@ -1,57 +1,18 @@
#ifndef INFINI_OPS_CUDA_ADD_KERNEL_H_
#define INFINI_OPS_CUDA_ADD_KERNEL_H_

#include <utility>
#include <cstdint>

// clang-format off
#include <cuda_runtime.h>
// clang-format on

#include "base/add.h"
#include "common/cuda/kernel_commons.h"
#include "common/generic_utils.h"
#include "cuda/add/kernel.cuh"

namespace infini::ops {

typedef struct AddOp {
public:
static constexpr std::size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T& input,
const T& other) const {
if constexpr (std::is_same_v<T, half2>) {
return __hadd2(input, other);
} else if constexpr (std::is_same_v<T, half> ||
std::is_same_v<T, TypeMapType<DataType::kBFloat16>>) {
return __hadd(input, other);
} else if constexpr (std::is_same_v<T, float>) {
return __fadd_rn(input, other);
} else {
return input + other;
}
}
} AddOp;

template <typename T>
__global__ void AddKernel(
T* out, const T* input, const T* other, const Tensor::Size* out_shape,
const Tensor::Size* input_shape, const Tensor::Size* other_shape,
const Tensor::Stride* out_strides, const Tensor::Stride* input_strides,
const Tensor::Stride* other_strides, size_t output_size, size_t ndim,
size_t offset, bool out_contiguous, bool input_contiguous,
bool other_contiguous) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;

if (idx < output_size) {
Tensor::Size out_idx =
out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides);
Tensor::Size input_idx =
input_contiguous ? idx
: IndexToOffset(idx, ndim, input_shape, input_strides);
Tensor::Size other_idx =
other_contiguous ? idx
: IndexToOffset(idx, ndim, other_shape, other_strides);

out[out_idx] = AddOp{}(input[input_idx], other[other_idx]);
}
}

template <typename Backend>
class CudaAdd : public Add {
public:
Expand Down Expand Up @@ -96,24 +57,40 @@ class CudaAdd : public Add {
out_type_,
[&](auto tag) {
using T = typename decltype(tag)::type;
// TODO(lzm): currently hard-code block_size to be 256.
auto cuda_stream =
static_cast<typename Backend::stream_t>(stream_ ? stream_ : 0);
int block_size = GetOptimalBlockSize();
dim3 blockDims(
std::min(static_cast<Tensor::Size>(256), output_size_));
std::min(static_cast<Tensor::Size>(block_size), output_size_));
dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x));
size_t step = gridDims.x * blockDims.x;

T* d_out = reinterpret_cast<T*>(out.data());
const T* d_input = reinterpret_cast<const T*>(input.data());
const T* d_other = reinterpret_cast<const T*>(other.data());

for (size_t i = 0; i < output_size_; i += step) {
AddKernel<<<gridDims, blockDims, 0,
static_cast<typename Backend::stream_t>(stream_)>>>(
d_out, d_input, d_other, d_out_shape_, d_input_shape_,
d_other_shape_, d_out_strides_, d_input_strides_,
d_other_strides_, output_size_, ndim_, i, is_out_contiguous_,
is_input_contiguous_, is_other_contiguous_);
#define LAUNCH_ADD_KERNEL(BLOCK_SIZE) \
for (size_t i = 0; i < output_size_; i += step) { \
AddKernel<T, BLOCK_SIZE><<<gridDims, blockDims, 0, cuda_stream>>>( \
d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, \
d_out_strides_, d_input_strides_, d_other_strides_, output_size_, \
ndim_, i, is_out_contiguous_, is_input_contiguous_, \
is_other_contiguous_); \
}

if (block_size == CUDA_BLOCK_SIZE_2048) {
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_2048)
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_1024)
} else if (block_size == CUDA_BLOCK_SIZE_512) {
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_512)
} else if (block_size == CUDA_BLOCK_SIZE_256) {
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_256)
} else {
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_128)
}

#undef LAUNCH_ADD_KERNEL
},
"CudaAdd::operator()");
}
Expand Down
37 changes: 25 additions & 12 deletions src/cuda/causal_softmax/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,13 @@
// clang-format on

#include "base/causal_softmax.h"
#include "common/cuda/kernel_commons.h"
#include "cuda/causal_softmax/kernel.cuh"
#include "data_type.h"
#include "dispatcher.h"

namespace infini::ops {

namespace causal_softmax {

constexpr unsigned int kBlockSize = 256;

} // namespace causal_softmax

template <typename Backend>
class CudaCausalSoftmax : public CausalSoftmax {
public:
Expand All @@ -41,16 +36,34 @@ class CudaCausalSoftmax : public CausalSoftmax {
std::abort();
}

int block_size = GetOptimalBlockSize();

DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
out.dtype(),
[&](auto tag) {
using T = typename decltype(tag)::type;
CausalSoftmaxKernel<causal_softmax::kBlockSize, T, float>
<<<grid, causal_softmax::kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()),
reinterpret_cast<const T*>(input.data()), batch_size_,
seq_len_, total_seq_len_, stride_out_batch, stride_out_row,
stride_input_batch, stride_input_row);

#define LAUNCH_CAUSAL_SOFTMAX_KERNEL(BLOCK_SIZE) \
CausalSoftmaxKernel<BLOCK_SIZE, T, float> \
<<<grid, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<T*>(out.data()), \
reinterpret_cast<const T*>(input.data()), batch_size_, seq_len_, \
total_seq_len_, stride_out_batch, stride_out_row, \
stride_input_batch, stride_input_row);

if (block_size == CUDA_BLOCK_SIZE_2048) {
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048)
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024)
} else if (block_size == CUDA_BLOCK_SIZE_512) {
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512)
} else if (block_size == CUDA_BLOCK_SIZE_256) {
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256)
} else {
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128)
}

#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL
},
"CudaCausalSoftmax::operator()");
}
Expand Down
38 changes: 25 additions & 13 deletions src/cuda/rms_norm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,13 @@
// clang-format on

#include "base/rms_norm.h"
#include "common/cuda/kernel_commons.h"
#include "cuda/rms_norm/kernel.cuh"
#include "data_type.h"
#include "dispatcher.h"

namespace infini::ops {

namespace {

constexpr unsigned int kBlockSize = 256;

} // namespace

template <typename Backend>
class CudaRmsNorm : public RmsNorm {
public:
Expand All @@ -43,17 +38,34 @@ class CudaRmsNorm : public RmsNorm {
std::abort();
}

int block_size = GetOptimalBlockSize();

DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
out.dtype(),
[&](auto tag) {
using T = typename decltype(tag)::type;
RmsNormKernel<kBlockSize, float, T, T>
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()), stride_out_batch,
stride_out_nhead, reinterpret_cast<const T*>(input.data()),
stride_input_batch, stride_input_nhead,
reinterpret_cast<const T*>(weight.data()), nhead_, dim_,
eps_);

#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
RmsNormKernel<BLOCK_SIZE, float, T, T> \
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<T*>(out.data()), stride_out_batch, \
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
stride_input_batch, stride_input_nhead, \
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps_);

if (block_size == CUDA_BLOCK_SIZE_2048) {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024)
} else if (block_size == CUDA_BLOCK_SIZE_512) {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512)
} else if (block_size == CUDA_BLOCK_SIZE_256) {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256)
} else {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128)
}

#undef LAUNCH_RMS_NORM_KERNEL
},
"CudaRmsNorm::operator()");
}
Expand Down
5 changes: 3 additions & 2 deletions src/cuda/swiglu/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ class CudaSwiglu : public Swiglu {
ndim_, i, is_out_contiguous_, is_input_contiguous_, \
is_gate_contiguous_); \
}

if (block_size == CUDA_BLOCK_SIZE_1024) {
if (block_size == CUDA_BLOCK_SIZE_2048) {
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_2048)
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024)
} else if (block_size == CUDA_BLOCK_SIZE_512) {
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512)
Expand Down
6 changes: 6 additions & 0 deletions src/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class Device {
return std::string{StringFromType(type_)} + ":" + std::to_string(index_);
}

bool operator==(const Device& other) const {
return type_ == other.type_ && index_ == other.index_;
}

bool operator!=(const Device& other) const { return !(*this == other); }

private:
Type type_{Type::kCpu};

Expand Down
2 changes: 1 addition & 1 deletion src/iluvatar/gemm/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct IluvatarBackend {
// Iluvatar uses `cudaDataType` for `computeType`, so we need to use
// `CUDA_R_32F` instead of `CUBLAS_COMPUTE_32F_FAST_TF32`.
static constexpr auto BLAS_COMPUTE_32F = CUDA_R_32F;

static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUDA_R_32F;

// Iluvatar uses `CUBLAS_GEMM_DEFAULT_TENSOR_OP` instead of
Expand Down
Loading