@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
6262 const Tindex *pos_ids,
6363 const Tdata *sin_table,
6464 const Tdata *cos_table,
65+ uint32_t batch_size,
6566 uint32_t seqlen,
6667 uint32_t nhead,
6768 uint32_t table_dim,
69+ ptrdiff_t y_stride_batch,
6870 ptrdiff_t y_stride_seqlen,
6971 ptrdiff_t y_stride_nhead,
72+ ptrdiff_t x_stride_batch,
7073 ptrdiff_t x_stride_seqlen,
7174 ptrdiff_t x_stride_nhead,
7275 infiniopRoPEAlgo_t algo) {
@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
106109 }
107110
108111 // Task distribution
109- const int batch_volume = seqlen * nhead;
112+ const int batch_volume = batch_size * seqlen * nhead;
110113 const int remaining_tasks = batch_volume % taskDim;
111114 const int base_tasks_per_core = batch_volume / taskDim;
112115 const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
136139
137140 // Main processing loop
138141 for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
139- int seq_idx = i / nhead;
142+ // Calculate 3D indices from flattened task index
143+ int batch_idx = i / (seqlen * nhead);
144+ int seq_idx = (i % (seqlen * nhead)) / nhead;
140145 int head_idx = i % nhead;
141146
142- int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
143- int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
147+ // Calculate offsets with batch dimension
148+ // Note: For GPT-NeoX, the stride calculations might be different
149+ int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
150+ int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
151+
152+ // Get position index for this sequence
153+ // Position IDs are shared across batches or per batch depending on input
154+ Tindex pos_idx;
155+ if (use_pos_ids_buffer) {
156+ // Position IDs loaded in NRAM
157+ pos_idx = srcP[seq_idx];
158+ } else {
159+ // Position IDs in global memory
160+ // Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
161+ if (batch_size > 1) {
162+ // Assume position IDs have shape [batch_size, seqlen]
163+ int pos_flat_idx = batch_idx * seqlen + seq_idx;
164+ pos_idx = pos_ids[pos_flat_idx];
165+ } else {
166+ // Single batch case: position IDs shape is [seqlen]
167+ pos_idx = pos_ids[seq_idx];
168+ }
169+ }
144170
145- Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
146171 int rot_offset = pos_idx * table_dim;
147172
148173 int processed = 0;
0 commit comments