Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 108 additions & 14 deletions infini_train/src/kernels/cuda/accumulate_grad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,120 @@ template <typename T>
__global__ void AdamAccumulateGradKernel(const T *grad_data, T *param_data, size_t num_elements, T *m_data, T *v_data,
float learning_rate, float beta1, float beta2, float eps,
const float bias_correction_m, const float bias_correction_v) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elements) {
m_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta1), m_data[idx],
common::cuda::Cast<T>(1 - beta1) * grad_data[idx]);
v_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta2), v_data[idx],
common::cuda::Cast<T>(1 - beta2) * grad_data[idx] * grad_data[idx]);

const float m_hat = common::cuda::Cast<float>(m_data[idx]) / bias_correction_m;
const float v_hat = common::cuda::Cast<float>(v_data[idx]) / bias_correction_v;

param_data[idx] = common::cuda::Sub(
param_data[idx], common::cuda::Cast<T>(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps)));
// size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// if (idx < num_elements) {
// m_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta1), m_data[idx],
// common::cuda::Cast<T>(1 - beta1) * grad_data[idx]);
// v_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta2), v_data[idx],
// common::cuda::Cast<T>(1 - beta2) * grad_data[idx] * grad_data[idx]);

// const float m_hat = common::cuda::Cast<float>(m_data[idx]) / bias_correction_m;
// const float v_hat = common::cuda::Cast<float>(v_data[idx]) / bias_correction_v;

// param_data[idx] = common::cuda::Sub(
// param_data[idx], common::cuda::Cast<T>(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps)));
// }

//先搞向量化内存
constexpr int VEC_SIZE = 16 / sizeof(T);
size_t vec_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE;
size_t e_start = num_elements / VEC_SIZE * VEC_SIZE;
if (vec_idx < e_start) {

//能向量化搬运就向量化搬运
T local_grad[VEC_SIZE], local_param[VEC_SIZE], local_m[VEC_SIZE], local_v[VEC_SIZE];

// 开始搬运
*reinterpret_cast<int4*>(local_grad) = *reinterpret_cast<const int4*>(grad_data + vec_idx);
*reinterpret_cast<int4*>(local_param) = *reinterpret_cast<int4*>(param_data + vec_idx);
*reinterpret_cast<int4*>(local_m) = *reinterpret_cast<int4*>(m_data + vec_idx);
*reinterpret_cast<int4*>(local_v) = *reinterpret_cast<int4*>(v_data + vec_idx);


# pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {

//将存储的 int4* 转换为float
float g = common::cuda::Cast<float>(local_grad[i]);
float p = common::cuda::Cast<float>(local_param[i]);
float m = common::cuda::Cast<float>(local_m[i]);
float v = common::cuda::Cast<float>(local_v[i]);


m = beta1 * m + (1.0f - beta1) * g;
v = beta2 * v + (1.0f - beta2) * g * g;

float m_hat = m / bias_correction_m;
float v_hat = v / bias_correction_v;

// 使用内置的快速数学函数处理 float
p -= learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps);

// 计算完毕,转回 T 存储到 local 数组
local_m[i] = common::cuda::Cast<T>(m);
local_v[i] = common::cuda::Cast<T>(v);
local_param[i] = common::cuda::Cast<T>(p);
}

// 写回原数组
*reinterpret_cast<int4*>(param_data + vec_idx) = *reinterpret_cast<int4*>(local_param);
*reinterpret_cast<int4*>(m_data + vec_idx) = *reinterpret_cast<int4*>(local_m);
*reinterpret_cast<int4*>(v_data + vec_idx) = *reinterpret_cast<int4*>(local_v);

}else if(vec_idx == e_start){

# pragma unroll
for(size_t idx = vec_idx; idx < num_elements; ++ idx){

m_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta1), m_data[idx],
common::cuda::Cast<T>(1 - beta1) * grad_data[idx]);
v_data[idx] = common::cuda::Fma(common::cuda::Cast<T>(beta2), v_data[idx],
common::cuda::Cast<T>(1 - beta2) * grad_data[idx] * grad_data[idx]);

const float m_hat = common::cuda::Cast<float>(m_data[idx]) / bias_correction_m;
const float v_hat = common::cuda::Cast<float>(v_data[idx]) / bias_correction_v;

param_data[idx] = common::cuda::Sub(
param_data[idx], common::cuda::Cast<T>(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps)));
}
}
}

void AdamAccumulateGrad(const std::shared_ptr<Tensor> &grad, const std::shared_ptr<Tensor> &param,
const std::shared_ptr<Tensor> &m, const std::shared_ptr<Tensor> &v, float learning_rate,
float beta1, float beta2, float eps, int64_t t) {
// size_t num_elements = grad->NumElements();

// const float bias_correction_m = 1.0f - std::pow(beta1, t);
// const float bias_correction_v = 1.0f - std::pow(beta2, t);

// int threads_per_block = 256;
// int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block;

// auto device = grad->GetDevice();
// const auto &cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
// infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
// ->cuda_stream();

// DispatchFunc<INFINI_ALL_FLOATING_TYPES>(
// grad->Dtype(),
// [=]<typename T>() {
// AdamAccumulateGradKernel<<<num_blocks, threads_per_block, 0, cuda_stream>>>(
// static_cast<const T *>(grad->DataPtr()), static_cast<T *>(param->DataPtr()), num_elements,
// static_cast<T *>(m->DataPtr()), static_cast<T *>(v->DataPtr()), learning_rate, beta1, beta2, eps,
// bias_correction_m, bias_correction_v);
// },
// "CUDA AdamAccumulateGrad");

size_t num_elements = grad->NumElements();



const float bias_correction_m = 1.0f - std::pow(beta1, t);
const float bias_correction_v = 1.0f - std::pow(beta2, t);

int threads_per_block = 256;
int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block;


auto device = grad->GetDevice();
const auto &cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
Expand All @@ -76,13 +165,18 @@ void AdamAccumulateGrad(const std::shared_ptr<Tensor> &grad, const std::shared_p
DispatchFunc<INFINI_ALL_FLOATING_TYPES>(
grad->Dtype(),
[=]<typename T>() {
int element_size = sizeof(T);
int VEC_SIZE = 16 / element_size;
int threads_per_block = 256;
int total_threads = (num_elements + VEC_SIZE - 1) / VEC_SIZE;
int num_blocks = (total_threads + threads_per_block - 1) / threads_per_block;
AdamAccumulateGradKernel<<<num_blocks, threads_per_block, 0, cuda_stream>>>(
static_cast<const T *>(grad->DataPtr()), static_cast<T *>(param->DataPtr()), num_elements,
static_cast<T *>(m->DataPtr()), static_cast<T *>(v->DataPtr()), learning_rate, beta1, beta2, eps,
bias_correction_m, bias_correction_v);
},
"CUDA AdamAccumulateGrad");
}
}
} // namespace infini_train::kernels::cuda

#define REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(kernel_name) \
Expand Down
60 changes: 54 additions & 6 deletions infini_train/src/kernels/cuda/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,48 @@ namespace infini_train::kernels::cuda {

template <typename Tdst, typename Tsrc>
__global__ void CastKernel(Tdst *dst, const Tsrc *src, size_t num_elements, size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
// size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
// if (idx < num_elements) {
// dst[idx] = common::cuda::Cast<Tdst>(src[idx]);
// }

if (idx < num_elements) {
dst[idx] = common::cuda::Cast<Tdst>(src[idx]);
// 统一每个线程处理 4 个元素
constexpr int VEC_SIZE = 4;
size_t idx = (size_t)(blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE + offset;

if (idx + VEC_SIZE <= num_elements) {
Tsrc s_vec[VEC_SIZE];
Tdst d_vec[VEC_SIZE];

// 根据 Tsrc 宽度决定加载指令 (如果是 2 字节读 8 字节, 如果是 4 字节读 16 字节)
if constexpr (sizeof(Tsrc) == 2) {
*reinterpret_cast<longlong1*>(s_vec) = *reinterpret_cast<const longlong1*>(src + idx);
} else if constexpr (sizeof(Tsrc) == 4) {
*reinterpret_cast<int4*>(s_vec) = *reinterpret_cast<const int4*>(src + idx);
} else {
for (int i = 0; i < VEC_SIZE; ++i) s_vec[i] = src[idx + i];
}

// 寄存器内完成类型转换
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
d_vec[i] = common::cuda::Cast<Tdst>(s_vec[i]);
}

// 根据 Tdst 宽度决定写回指令
if constexpr (sizeof(Tdst) == 2) {
*reinterpret_cast<longlong1*>(d_vec) = *reinterpret_cast<longlong1*>(d_vec);
*reinterpret_cast<longlong1*>(dst + idx) = *reinterpret_cast<longlong1*>(d_vec);
} else if constexpr (sizeof(Tdst) == 4) {
*reinterpret_cast<int4*>(dst + idx) = *reinterpret_cast<int4*>(d_vec);
} else {
for (int i = 0; i < VEC_SIZE; ++i) dst[idx + i] = d_vec[i];
}
} else {
// 处理末尾非对齐数据
for (size_t i = idx; i < num_elements && i < idx + VEC_SIZE; ++i) {
dst[i] = common::cuda::Cast<Tdst>(src[i]);
}
}
}

Expand All @@ -29,15 +67,25 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
->cuda_stream();

const size_t num_elements = input->NumElements();

// const size_t num_elements = input->NumElements();
// dim3 block_dims(256);
// dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
// const size_t step = grid_dims.x * block_dims.x;

// 这里的 VEC_SIZE 必须与 Kernel 内部保持一致
int VEC_SIZE = 4;
dim3 block_dims(256);
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
const size_t step = grid_dims.x * block_dims.x;
// 每个线程干 4 个人的活,所以线程总数除以 4
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x * VEC_SIZE));
const size_t step = grid_dims.x * block_dims.x * VEC_SIZE;

DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
{dtype, input->Dtype()},
[=]<typename Tdst, typename Tsrc>() {
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());
auto src = static_cast<const Tsrc *>(input->DataPtr());
// 网格步进循环处理超大规模 Tensor
for (size_t offset = 0; offset < num_elements; offset += step) {
CastKernel<<<grid_dims, block_dims, 0, cuda_stream>>>(dst, src, num_elements, offset);
}
Expand All @@ -53,4 +101,4 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {

REGISTER_CUDA_CAST_KERNEL(Cast)

#undef REGISTER_CUDA_CAST_KERNEL
#undef REGISTER_CUDA_CAST_KERNEL
61 changes: 52 additions & 9 deletions infini_train/src/kernels/cuda/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,50 @@
namespace infini_train::kernels::cuda {

template <typename T> __global__ void FillKernel(T *data, T value, size_t size) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
data[idx] = value;
// size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// if (idx < size) {
// data[idx] = value;
// }

// 计算一个线程处理的向量步长(16字节 / 类型大小)
constexpr int VEC_SIZE = 16 / sizeof(T);
// 重新计算向量化后的全局索引
size_t idx = (size_t)(blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE;

if (idx + VEC_SIZE <= size) {
T local[VEC_SIZE];
// 强制循环展开,在寄存器中准备好填充值
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
local[i] = value;
}
// 将寄存器数组转为 int4,单条指令完成 128-bit 写入,压榨显存带宽
*reinterpret_cast<int4*>(data + idx) = *reinterpret_cast<int4*>(local);
} else {
// 处理末尾不足 VEC_SIZE 的非对齐部分
for (size_t i = idx; i < size; ++i) {
data[i] = value;
}
}
}

// TODO(dcj): refactor Fill kernel with elementwise template
void Fill(std::shared_ptr<Tensor> tensor, void *value_ptr) {
const int num_tokens = tensor->NumElements();
const int threads_per_block = 256;
const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block;
// const int num_tokens = tensor->NumElements();
// const int threads_per_block = 256;
// const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block;
// auto device = tensor->GetDevice();
// const auto &cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
// infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
// ->cuda_stream();
// DispatchFunc<INFINI_ALL_TYPES>(
// tensor->Dtype(),
// [=]<typename T>() {
// FillKernel<T><<<num_blocks, threads_per_block, 0, cuda_stream>>>(
// static_cast<T *>(tensor->DataPtr()), *(static_cast<T *>(value_ptr)), tensor->NumElements());
// },
// "CUDA Fill");

size_t num_elements = tensor->NumElements();
auto device = tensor->GetDevice();
const auto &cuda_stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
Expand All @@ -30,16 +63,26 @@ void Fill(std::shared_ptr<Tensor> tensor, void *value_ptr) {
DispatchFunc<INFINI_ALL_TYPES>(
tensor->Dtype(),
[=]<typename T>() {
// 每一个 T 类型的大小
int element_size = sizeof(T);
// 计算向量化步长,通常 float 是 4 个,half 是 8 个
int VEC_SIZE = 16 / element_size;
int threads_per_block = 256;
// 因为每个线程处理 VEC_SIZE 个元素,所以 Block 总数要除以步长
int total_threads = (num_elements + VEC_SIZE - 1) / VEC_SIZE;
int num_blocks = (total_threads + threads_per_block - 1) / threads_per_block;

FillKernel<T><<<num_blocks, threads_per_block, 0, cuda_stream>>>(
static_cast<T *>(tensor->DataPtr()), *(static_cast<T *>(value_ptr)), tensor->NumElements());
static_cast<T *>(tensor->DataPtr()), *(static_cast<T *>(value_ptr)), num_elements);
},
"CUDA Fill");
}

} // namespace infini_train::kernels::cuda

#define REGISTER_CUDA_FILL_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name)

REGISTER_CUDA_FILL_KERNEL(Fill)

#undef REGISTER_CUDA_FILL_KERNEL
#undef REGISTER_CUDA_FILL_KERNEL
Loading
Loading