Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/infiniop/ops/random_sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __INFINIOP_RANDOM_SAMPLE_API_H__

#include "../operator_descriptor.h"
#include <stdint.h>

typedef struct InfiniopDescriptor *infiniopRandomSampleDescriptor_t;

Expand All @@ -15,6 +16,22 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc,
size_t *size);

/**
* @brief Performs random sampling with repetition penalty support.
*
* @param previous_tokens Array of UNIQUE token IDs that have appeared in the sequence.
* Should contain no duplicates for optimal performance (vLLM-style).
* Can be NULL if no tokens have been generated yet.
* When NULL or previous_tokens_len is 0, falls back to full-history
* penalty (applies penalty to all tokens) for backward compatibility.
* @param previous_tokens_len Number of unique tokens in previous_tokens array.
* Must be 0 if previous_tokens is NULL.
*
* @note For best performance, pass only unique token IDs (no duplicates).
* The implementation applies penalty only to tokens in this array.
* This follows vLLM's efficient approach: O(U) instead of O(T) where
* U = unique tokens << T = total tokens.
*/
__C __export infiniStatus_t infiniopRandomSample(
infiniopRandomSampleDescriptor_t desc,
void *workspace,
Expand All @@ -25,6 +42,9 @@ __C __export infiniStatus_t infiniopRandomSample(
float topp,
int topk,
float temperature,
float repetition_penalty,
const uint32_t *previous_tokens, // Array of unique previously generated token IDs
size_t previous_tokens_len, // Number of unique tokens (0 if NULL)
void *stream);

__C __export infiniStatus_t infiniopDestroyRandomSampleDescriptor(
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop-test/src/ops/random_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ std::shared_ptr<infiniop_test::Result> Test::run(
topp,
topk,
temperature,
1.0f, // repetition_penalty (default to 1.0 for backward compatibility)
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));

Expand All @@ -87,6 +88,7 @@ std::shared_ptr<infiniop_test::Result> Test::run(
topp,
topk,
temperature,
1.0f, // repetition_penalty (default to 1.0 for backward compatibility)
nullptr);
},
warm_ups, iterations);
Expand Down
57 changes: 53 additions & 4 deletions src/infiniop/ops/random_sample/cpu/random_sample_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "../info.h"
#include "infinicore.h"
#include <algorithm>
#include <cstdio>

namespace op::random_sample::cpu {

Expand Down Expand Up @@ -75,7 +76,8 @@ struct Algo {
infiniStatus_t random(
void *workspace, size_t workspace_size,
void *result, void const *probs, size_t n,
float random_val, float topp, int topk, float temperature,
float random_val, float topp, int topk, float temperature, float repetition_penalty,
const uint32_t *previous_tokens, size_t previous_tokens_len,
void *stream) {

struct KVPair {
Expand All @@ -88,10 +90,51 @@ struct Algo {
};

auto idx = reinterpret_cast<Tidx *>(result);

// Apply repetition penalty if needed
std::vector<typename ComputeType<Tval>::type> penalized_probs(n);
if (repetition_penalty != 1.0f) {
// Initialize with original values
for (size_t i = 0; i < n; i++) {
penalized_probs[i] = get<Tidx, Tval>(probs, i);
}

// If previous_tokens are provided, only penalize those tokens (proper repetition penalty)
// Otherwise, penalize all tokens (full-history penalty for backward compatibility)
if (previous_tokens != nullptr && previous_tokens_len > 0) {
// Proper repetition penalty: only penalize previously generated tokens
for (size_t i = 0; i < previous_tokens_len; i++) {
uint32_t token_id = previous_tokens[i];
if (token_id < n) {
auto val = penalized_probs[token_id];
if (val > 0) {
penalized_probs[token_id] = val / repetition_penalty;
} else {
penalized_probs[token_id] = val * repetition_penalty;
}
}
}
} else {
// Full-history penalty: penalize all tokens (backward compatibility)
for (size_t i = 0; i < n; i++) {
auto val = penalized_probs[i];
if (val > 0) {
penalized_probs[i] = val / repetition_penalty;
} else {
penalized_probs[i] = val * repetition_penalty;
}
}
}
}

// build & sort
std::vector<KVPair> pairs(n);
for (size_t i = 0; i < n; i++) {
pairs[i] = {static_cast<Tidx>(i), get<Tidx, Tval>(probs, i)};
if (repetition_penalty != 1.0f) {
pairs[i] = {static_cast<Tidx>(i), penalized_probs[i]};
} else {
pairs[i] = {static_cast<Tidx>(i), get<Tidx, Tval>(probs, i)};
}
}
std::sort(pairs.begin(), pairs.end());
// softmax & sum
Expand All @@ -101,7 +144,9 @@ struct Algo {
pairs[i].val = pairs[i - 1].val + std::exp((pairs[i].val - max_val) / temperature);
}
// topk & topp & limit
auto const pk = pairs[std::min(static_cast<size_t>(topk), n) - 1].val,
// Handle disabled topk (0 or -1 means consider all tokens, like vLLM)
size_t effective_topk = (topk <= 0) ? n : std::min(static_cast<size_t>(topk), n);
auto const pk = pairs[effective_topk - 1].val,
pp = pairs[n - 1].val * topp,
plimit = random_val * std::min(pk, pp);
// sample
Expand All @@ -125,12 +170,16 @@ infiniStatus_t Descriptor::calculate(
float topp,
int topk,
float temperature,
float repetition_penalty,
const uint32_t *previous_tokens,
size_t previous_tokens_len,
void *stream) const {

Calculate::calculate<Algo>(
Algo{}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
random_val, topp, topk, temperature, repetition_penalty,
previous_tokens, previous_tokens_len,
stream);

return INFINI_STATUS_SUCCESS;
Expand Down
100 changes: 87 additions & 13 deletions src/infiniop/ops/random_sample/metax/random_sample_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <hccub/device/device_radix_sort.cuh>
#include <hccub/device/device_reduce.cuh>
#include <hccub/device/device_scan.cuh>
#include <hcr/hc_runtime_api.h>
#include <vector>
#include <algorithm>
#include <cstdio>

namespace op::random_sample::metax {

Expand Down Expand Up @@ -75,6 +79,8 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random += align256(sizeof(Tval) * n);
// indices_out
size_random += align256(sizeof(Tidx) * n);
// sorted_out (needed when repetition_penalty != 1.0)
size_random += align256(sizeof(Tval) * n);
// cub device api
size_t size_radix_sort;
CHECK_METAX((radixSort<Tval, Tidx>(
Expand Down Expand Up @@ -161,6 +167,8 @@ static __global__ void randomSampleKernel(
const Tidx *__restrict__ indices_out,
size_t n,
float random, float topp, size_t topk) {
// topk should already be validated to be > 0 and <= n by the caller
// (disabled topk 0/-1 is converted to n before calling this kernel)
topk = cub::Min()(topk, n);
auto p = (Tval)(random * cub::Min()(topp * (float)sorted[n - 1], (float)sorted[topk - 1]));
for (size_t i = 0;; ++i) {
Expand Down Expand Up @@ -205,7 +213,8 @@ struct Algo {
infiniStatus_t random(
void *workspace_, size_t workspace_size,
void *result_, const void *probs, size_t n,
float random_val, float topp, int topk, float temperature,
float random_val, float topp, int topk, float temperature, float repetition_penalty,
const uint32_t *previous_tokens, size_t previous_tokens_len,
void *stream_) const {

using Tval = typename CudaTval<Tval_>::Type;
Expand All @@ -226,19 +235,81 @@ struct Algo {
auto indices_out = reinterpret_cast<Tidx *>(workspace);
workspace += align256(sizeof(Tidx) * n);

workspace_ = reinterpret_cast<void *>(workspace);
workspace_size = workspace_end - workspace;

auto block = cub::Min()((size_t)block_size, n);
auto grid = (n + block - 1) / block;
// sort
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_METAX(radixSort(
workspace_, workspace_size,
logits, sorted,
indices, indices_out,
n,
stream));

// Apply repetition penalty if needed (penalize all tokens before sorting)
if (repetition_penalty != 1.0f) {
// Allocate temporary output buffer for radixSort from workspace (before CUB workspace)
auto sorted_out = reinterpret_cast<Tval *>(workspace);
workspace += align256(sizeof(Tval) * n);

// Now set CUB workspace pointer and size
workspace_ = reinterpret_cast<void *>(workspace);
workspace_size = workspace_end - workspace;

// Copy logits to host memory
std::vector<Tval> host_logits(n);
CHECK_METAX(hcMemcpyAsync(host_logits.data(), logits, n * sizeof(Tval), hcMemcpyDeviceToHost, stream));
CHECK_METAX(hcStreamSynchronize(stream));

// Apply penalty: if previous_tokens are provided, only penalize those tokens
// Otherwise, penalize all tokens (full-history penalty for backward compatibility)
if (previous_tokens != nullptr && previous_tokens_len > 0) {
// Proper repetition penalty: only penalize previously generated tokens
for (size_t i = 0; i < previous_tokens_len; i++) {
uint32_t token_id = previous_tokens[i];
if (token_id < n) {
float val = static_cast<float>(host_logits[token_id]);
if (val > 0) {
host_logits[token_id] = static_cast<Tval>(val / repetition_penalty);
} else {
host_logits[token_id] = static_cast<Tval>(val * repetition_penalty);
}
}
}
} else {
// Full-history penalty: penalize all tokens (backward compatibility)
for (size_t i = 0; i < n; i++) {
float val = static_cast<float>(host_logits[i]);
if (val > 0) {
host_logits[i] = static_cast<Tval>(val / repetition_penalty);
} else {
host_logits[i] = static_cast<Tval>(val * repetition_penalty);
}
}
}


// Copy penalized logits to sorted buffer (will be used as input to radixSort)
CHECK_METAX(hcMemcpyAsync(sorted, host_logits.data(), n * sizeof(Tval), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcStreamSynchronize(stream));

// sort with penalized logits
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_METAX(radixSort(
workspace_, workspace_size,
sorted, sorted_out,
indices, indices_out,
n,
stream));

// Copy sorted_out back to sorted for softmax
CHECK_METAX(hcMemcpyAsync(sorted, sorted_out, n * sizeof(Tval), hcMemcpyDeviceToDevice, stream));
} else {
// Set CUB workspace pointer and size
workspace_ = reinterpret_cast<void *>(workspace);
workspace_size = workspace_end - workspace;

// sort
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_METAX(radixSort(
workspace_, workspace_size,
logits, sorted,
indices, indices_out,
n,
stream));
}
// softmax
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
Expand All @@ -248,10 +319,13 @@ struct Algo {
sorted, n,
stream));
// sample
// Handle disabled topk (0 or -1 means consider all tokens, like vLLM)
int effective_topk = (topk <= 0) ? static_cast<int>(n) : topk;
randomSampleKernel<<<1, 1, 0, stream>>>(
result,
sorted, indices_out, n,
random_val, topp, topk);
random_val, topp, effective_topk);

return INFINI_STATUS_SUCCESS;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ infiniStatus_t Descriptor::calculate(
float topp,
int topk,
float temperature,
float repetition_penalty,
const uint32_t *previous_tokens,
size_t previous_tokens_len,
void *stream) const {

if (workspace_size < _min_workspace_size) {
Expand All @@ -94,7 +97,8 @@ infiniStatus_t Descriptor::calculate(
Calculate::calculate<Algo>(
Algo{block_size}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
random_val, topp, topk, temperature, repetition_penalty,
previous_tokens, previous_tokens_len,
stream);

return INFINI_STATUS_SUCCESS;
Expand Down
6 changes: 5 additions & 1 deletion src/infiniop/ops/random_sample/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ __C infiniStatus_t infiniopRandomSample(
float topp,
int topk,
float temperature,
float repetition_penalty,
const uint32_t *previous_tokens,
size_t previous_tokens_len,
void *stream) {

#define CALCULATE(CASE, NAMESPACE) \
Expand All @@ -142,7 +145,8 @@ __C infiniStatus_t infiniopRandomSample(
->calculate(workspace, workspace_size, \
result, probs, \
random_val, \
topp, topk, temperature, \
topp, topk, temperature, repetition_penalty, \
previous_tokens, previous_tokens_len, \
stream)

switch (desc->device_type) {
Expand Down
Loading
Loading