Skip to content

Commit 5848b40

Browse files
committed
issue/838 - Cambricon Batched RoPE
1 parent 12cde8e commit 5848b40

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

src/infiniop/ops/rope/bang/rope_bang.mlu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
4040
const Tdata *sin_table,
4141
const Tdata *cos_table,
4242
cnrtQueue_t queue) {
43-
auto dimx = uint32_t(info.seqlen);
44-
auto dimy = uint32_t(info.nhead);
43+
auto batch_size = uint32_t(info.batch);
44+
auto seqlen = uint32_t(info.seqlen);
45+
auto nhead = uint32_t(info.nhead);
4546
auto table_dim = uint32_t(info.table_dim);
4647

4748
cnrtDim3_t k_dim;
@@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
5354
k_dim.z = 1;
5455
k_type = CNRT_FUNC_TYPE_UNION1;
5556

56-
// Launch kernel
57+
// Launch kernel with batch dimension
5758
ropeKernel<<<k_dim, k_type, queue>>>(
5859
y, x, pos_ids, sin_table, cos_table,
59-
dimx, dimy, table_dim,
60-
info.y_stride_seqlen, info.y_stride_nhead,
61-
info.x_stride_seqlen, info.x_stride_nhead,
60+
batch_size, seqlen, nhead, table_dim,
61+
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
62+
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead,
6263
info.algo);
6364

6465
cnrtQueueSync(queue);

src/infiniop/ops/rope/bang/rope_bang_kernel.mlu

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
6262
const Tindex *pos_ids,
6363
const Tdata *sin_table,
6464
const Tdata *cos_table,
65+
uint32_t batch_size,
6566
uint32_t seqlen,
6667
uint32_t nhead,
6768
uint32_t table_dim,
69+
ptrdiff_t y_stride_batch,
6870
ptrdiff_t y_stride_seqlen,
6971
ptrdiff_t y_stride_nhead,
72+
ptrdiff_t x_stride_batch,
7073
ptrdiff_t x_stride_seqlen,
7174
ptrdiff_t x_stride_nhead,
7275
infiniopRoPEAlgo_t algo) {
@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
106109
}
107110

108111
// Task distribution
109-
const int batch_volume = seqlen * nhead;
112+
const int batch_volume = batch_size * seqlen * nhead;
110113
const int remaining_tasks = batch_volume % taskDim;
111114
const int base_tasks_per_core = batch_volume / taskDim;
112115
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
136139

137140
// Main processing loop
138141
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
139-
int seq_idx = i / nhead;
142+
// Calculate 3D indices from flattened task index
143+
int batch_idx = i / (seqlen * nhead);
144+
int seq_idx = (i % (seqlen * nhead)) / nhead;
140145
int head_idx = i % nhead;
141146

142-
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
143-
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
147+
// Calculate offsets with batch dimension
148+
// Note: For GPT-NeoX, the stride calculations might be different
149+
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
150+
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
151+
152+
// Get position index for this sequence
153+
// Position IDs are shared across batches or per batch depending on input
154+
Tindex pos_idx;
155+
if (use_pos_ids_buffer) {
156+
// Position IDs loaded in NRAM
157+
pos_idx = srcP[seq_idx];
158+
} else {
159+
// Position IDs in global memory
160+
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
161+
if (batch_size > 1) {
162+
// Assume position IDs have shape [batch_size, seqlen]
163+
int pos_flat_idx = batch_idx * seqlen + seq_idx;
164+
pos_idx = pos_ids[pos_flat_idx];
165+
} else {
166+
// Single batch case: position IDs shape is [seqlen]
167+
pos_idx = pos_ids[seq_idx];
168+
}
169+
}
144170

145-
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
146171
int rot_offset = pos_idx * table_dim;
147172

148173
int processed = 0;

0 commit comments

Comments
 (0)