Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/infiniop/ops/rope/bang/rope_bang.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto batch_size = uint32_t(info.batch);
auto seqlen = uint32_t(info.seqlen);
auto nhead = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;

// Configure kernel launch parameters
// Using union type 1 with 4 cores
k_dim.x = 4;
k_dim.y = 1;
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;

// Launch kernel
// Launch kernel with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead,
batch_size, seqlen, nhead, table_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead,
info.algo);

cnrtQueueSync(queue);
Expand Down
35 changes: 30 additions & 5 deletions src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t batch_size,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) {
Expand Down Expand Up @@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
}

// Task distribution
const int batch_volume = seqlen * nhead;
const int batch_volume = batch_size * seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
Expand Down Expand Up @@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(

// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
int seq_idx = i / nhead;
// Calculate 3D indices from flattened task index
int batch_idx = i / (seqlen * nhead);
int seq_idx = (i % (seqlen * nhead)) / nhead;
int head_idx = i % nhead;

int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Calculate offsets with batch dimension
// Note: For GPT-NeoX, the stride calculations might be different
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;

// Get position index for this sequence
// Position IDs are shared across batches or per batch depending on input
Tindex pos_idx;
if (use_pos_ids_buffer) {
// Position IDs loaded in NRAM
pos_idx = srcP[seq_idx];
} else {
// Position IDs in global memory
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
if (batch_size > 1) {
// Assume position IDs have shape [batch_size, seqlen]
int pos_flat_idx = batch_idx * seqlen + seq_idx;
pos_idx = pos_ids[pos_flat_idx];
} else {
// Single batch case: position IDs shape is [seqlen]
pos_idx = pos_ids[seq_idx];
}
}

Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;

int processed = 0;
Expand Down