Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <chrono>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <cstdlib>
#include <algorithm>
#include "paddle/extension.h"
#include "../ngram_match_common.cuh"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

// Get threshold from environment variable
static int get_hybrid_threshold() {
static int threshold = -1;
if (threshold < 0) {
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
threshold = env_var ? std::stoi(env_var) : 1024;
}
return threshold;
}

int sum_mixed(const int *value, int num) {
// CPU implementation for fallback
static int sum_mixed(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}

void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
static void find_candidate_pred_tokens_mixed_cpu(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
Expand All @@ -46,15 +55,11 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
int threshold = 1024;
// dynamic in future
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
int max_ngram_size,
int min_ngram_size,
const int max_draft_tokens,
int threshold) {

int unprocessed_batch_size = 0;
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
if (seq_lens_decoder[batch_idx] > 0) {
Expand Down Expand Up @@ -98,7 +103,6 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);

// Iterate through sliding windows of size ngram_size
// bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) {
// Check if the current window matches the ngram
bool match_local = true;
Expand All @@ -118,7 +122,6 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,

seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
match_global = true;
break;
}
Expand All @@ -143,7 +146,6 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,

if (start_idx >= end_idx)
continue;
// printf("match in Output with Ngram_size %d. %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, end_idx);

seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
Expand All @@ -168,6 +170,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens) {
// Check if we're on CPU or GPU
bool is_cpu = input_ids.place() == paddle::PlaceType::kCPU;

auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
Expand All @@ -178,9 +182,14 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];

const int64_t max_batch_size = seq_lens_this_time.shape()[0];
const int max_batch_size = static_cast<int>(seq_lens_this_time.shape()[0]);
const int threshold = get_hybrid_threshold();

find_candidate_pred_tokens_mixed(input_ids.data<int64_t>(),

if (is_cpu) {
// CPU fallback implementation
find_candidate_pred_tokens_mixed_cpu(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
Expand All @@ -195,7 +204,50 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
max_draft_tokens,
threshold);
} else {
// GPU implementation
cudaStream_t stream = input_ids.stream();

// Allocate temporary buffer for unprocessed counts
auto unprocessed_counts = paddle::empty({max_batch_size}, paddle::DataType::INT32, input_ids.place());
int* unprocessed_counts_ptr = unprocessed_counts.data<int>();

// Calculate unprocessed counts for each batch
int threads_per_block = std::min(256, max_batch_size);
int num_blocks = (max_batch_size + threads_per_block - 1) / threads_per_block;

ngram_match_gpu::launch_calc_unprocessed_counts_mixed_kernel(
seq_lens_decoder.data<int>(),
unprocessed_counts_ptr,
max_batch_size,
threads_per_block,
num_blocks,
stream);

// Launch main kernel - one block per batch sample
constexpr int kBlockSize = 256;
ngram_match_gpu::hybrid_mtp_ngram_kernel<kBlockSize><<<max_batch_size, kBlockSize, 0, stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
seq_lens_decoder.data<int>(),
max_dec_len.data<int64_t>(),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens,
threshold,
unprocessed_counts_ptr);
}
}

PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
Expand Down
214 changes: 0 additions & 214 deletions custom_ops/gpu_ops/speculate_decoding/ngram_match.cc

This file was deleted.

Loading