Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 260 additions & 5 deletions pufferlib/extensions/cuda/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
#include <cstdio>
#include <cstdint>


#define WARP_SIZE 32
#define FULL_MASK 0xffffffff
#define RMS_THREADS 128

#define SEQ_SIZE 256
#define BLOCK_SIZE 256
inline int grid_size(int N) {
Expand Down Expand Up @@ -45,6 +50,134 @@ void dispatch_and_launch(const at::Tensor& example_tensor, Args... args) {
}
*/

// Max hidden dim supported by optimized kernel (each thread stores H/RMS_THREADS values)

template<typename T>
__global__ void rmsnorm_forward_kernel_optimized(
T* __restrict__ out,
float* __restrict__ inv_norm_buf,
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B
) {
__shared__ float SH_SUMS[RMS_THREADS / WARP_SIZE];
__shared__ float SH_INV_RMS;
float X_VALUES[WARP_SIZE]; // max H we support right now is WARP_SIZE * RMS_THREADS
int tid = threadIdx.x;
int lane = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
int base = blockIdx.x * H;

float sum_sq = 0.0f;
int curxv = 0;
for (int h = tid; h < H; h += blockDim.x) {
float x_val = float(x[base + h]);
X_VALUES[curxv++] = x_val;
sum_sq += x_val * x_val;
}

for (int s = WARP_SIZE / 2; s >= 1; s /= 2) {
sum_sq += __shfl_down_sync(FULL_MASK, sum_sq, s);
}

if (lane == 0) {
SH_SUMS[warp_id] = sum_sq;
}
__syncthreads();

if (tid == 0) {
float4* ptr = (float4*)&SH_SUMS[0];
float4 sumValues = ptr[0];
float hsum = sumValues.x + sumValues.y + sumValues.z + sumValues.w;

float inv_rms = rsqrtf(hsum/H + eps);
inv_norm_buf[blockIdx.x] = inv_rms;
SH_INV_RMS = inv_rms;
}
__syncthreads();

curxv = 0;
float inv_rms = SH_INV_RMS;
for (int h = tid; h < H; h += blockDim.x) {
out[base + h] = T(weight[h] * X_VALUES[curxv++] * inv_rms);
}
}


template<typename T>
__global__ void rmsnorm_backward_kernel_optimized(
T* __restrict__ grad_x,
T* __restrict__ grad_weight,
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B
) {
__shared__ float SH_SUMS[RMS_THREADS / WARP_SIZE];
__shared__ float SH_WGX;

float X_VALUES[WARP_SIZE];
float G_VALUES[WARP_SIZE];

int tid = threadIdx.x;
int lane = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
int base = blockIdx.x * H;

float inv_rms = inv_norm_buf[blockIdx.x];
float inv_rms_3 = inv_rms * inv_rms * inv_rms;

float wg_x = 0.0f;
int curxv = 0;
for (int h = tid; h < H; h += blockDim.x) {
float x = float(x_buf[base + h]);
float g = float(grad_out[base + h]);
float w = float(weight[h]);
X_VALUES[curxv] = x;
G_VALUES[curxv] = g;
curxv++;
wg_x += w * g * x;
}

for (int s = WARP_SIZE / 2; s >= 1; s /= 2) {
wg_x += __shfl_down_sync(FULL_MASK, wg_x, s);
}

if (lane == 0) {
SH_SUMS[warp_id] = wg_x;
}
__syncthreads();

if (tid == 0) {
float4* ptr = (float4*)&SH_SUMS[0];
float4 sumValues = ptr[0];
SH_WGX = sumValues.x + sumValues.y + sumValues.z + sumValues.w;
}
__syncthreads();

float wgx_total = SH_WGX;
float gradx_end = wgx_total * inv_rms_3 / float(H);
curxv = 0;
for (int h = tid; h < H; h += blockDim.x) {
float x = X_VALUES[curxv];
float g = G_VALUES[curxv];
float w = float(weight[h]);
curxv++;

int idx = base + h;
grad_x[idx] = T(w * g * inv_rms - x * gradx_end);
grad_weight[idx] = T(g * x * inv_rms);
}
}


template<typename T>
__global__ void rmsnorm_forward_kernel(
T* __restrict__ out,
Expand Down Expand Up @@ -95,20 +228,24 @@ __global__ void rmsnorm_backward_kernel(
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= T_total*H*B) return;
int base = idx % H;
int h_idx = idx % H;
int norm_idx = idx / H;
int vec_offset = norm_idx * H; // Start offset of this vector in memory
// previously used `base = idx % H` and then `base + h` in the wg_x loop
// was wrong because base is the h-index not the vector offset.

float inv_rms = inv_norm_buf[norm_idx];
float inv_rms_3 = inv_rms * inv_rms * inv_rms;

grad_x[idx] = weight[base] * grad_out[idx] * inv_rms;
grad_weight[idx] = grad_out[idx] * inv_rms;
grad_x[idx] = weight[h_idx] * grad_out[idx] * inv_rms;
// was previously missing x_buf[idx] term in grad_weight
grad_weight[idx] = grad_out[idx] * x_buf[idx] * inv_rms;

float wg_x = 0.0f;
for (int h=0; h<H; h++) {
float x = x_buf[base + h];
float x = x_buf[vec_offset + h];
float w = weight[h];
float g = grad_out[base + h];
float g = grad_out[vec_offset + h];
wg_x += w*g*x;
}
float x = x_buf[idx];
Expand Down Expand Up @@ -158,6 +295,74 @@ __global__ void rmsnorm_backward_kernel(
}
*/

template<typename T>
void launch_rmsnorm_forward_optimized(
T* __restrict__ out,
float* __restrict__ inv_norm_buf,
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B,
cudaStream_t stream
) {
int blocks = B * T_total;

rmsnorm_forward_kernel_optimized<T><<<blocks, RMS_THREADS, 0, stream>>>(
out,
inv_norm_buf,
x,
weight,
eps,
T_total,
H,
B
);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in forward: %s\n", cudaGetErrorString(err));
}
}


template<typename T>
void launch_rmsnorm_backward_optimized(
T* __restrict__ grad_x,
T* __restrict__ grad_weight,
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B,
cudaStream_t stream
) {
int blocks = B * T_total;

rmsnorm_backward_kernel_optimized<T><<<blocks, RMS_THREADS, 0, stream>>>(
grad_x,
grad_weight,
grad_out,
inv_norm_buf,
x_buf,
weight,
eps,
T_total,
H,
B
);

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in backward: %s\n", cudaGetErrorString(err));
}
}


template<typename T>
void launch_rmsnorm_forward(
T* __restrict__ out,
Expand Down Expand Up @@ -1649,3 +1854,53 @@ void launch_sample_logits(
fprintf(stderr, "sample_logits kernel error: %s\n", cudaGetErrorString(err));
}
}

extern "C" {

void sync_device() {
cudaDeviceSynchronize();
}

const char* get_last_error() {
return cudaGetErrorString(cudaGetLastError());
}

// Original forward
void launch_rmsnorm_forward_original_f32(
float* out, float* inv_norm_buf,
const float* x, const float* weight,
double eps, int T_total, int H, int B
) {
launch_rmsnorm_forward<float>(out, inv_norm_buf, x, weight, eps, T_total, H, B, nullptr);
}

// Optimized forward
void launch_rmsnorm_forward_optimized_f32(
float* out, float* inv_norm_buf,
const float* x, const float* weight,
double eps, int T_total, int H, int B
) {
launch_rmsnorm_forward_optimized<float>(out, inv_norm_buf, x, weight, eps, T_total, H, B, nullptr);
}

// Original backward
void launch_rmsnorm_backward_original_f32(
float* grad_x, float* grad_weight,
const float* grad_out, const float* inv_norm_buf,
const float* x_buf, const float* weight,
double eps, int T_total, int H, int B
) {
launch_rmsnorm_backward<float>(grad_x, grad_weight, grad_out, inv_norm_buf, x_buf, weight, eps, T_total, H, B, nullptr);
}

// Optimized backward
void launch_rmsnorm_backward_optimized_f32(
float* grad_x, float* grad_weight,
const float* grad_out, const float* inv_norm_buf,
const float* x_buf, const float* weight,
double eps, int T_total, int H, int B
) {
launch_rmsnorm_backward_optimized<float>(grad_x, grad_weight, grad_out, inv_norm_buf, x_buf, weight, eps, T_total, H, B, nullptr);
}

} // extern "C"
Loading