Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions include/infinicore_infer/models/jiuge.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,19 @@ __C __export struct KVCache *createPagedKVCache(
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill
/// @param enable_paged_attn 是否启用 paged attention
/// @param repetition_penalty 重复惩罚系数(1.0 表示无惩罚)
/// @param previous_tokens_per_req 每个请求的唯一 token ID 数组指针(vLLM-style,用于高效重复惩罚)
/// @param previous_tokens_len_per_req 每个请求的唯一 token 数量
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
const float *repetition_penalty,
const uint32_t *const *previous_tokens_per_req,
const uint32_t *previous_tokens_len_per_req,
uint32_t *output);

__C __export void
Expand All @@ -120,6 +124,9 @@ inferBatch(struct JiugeModel *,
const int32_t *block_tables,
const int32_t *slot_mapping,
const float *temperature, const uint32_t *topk, const float *topp,
const float *repetition_penalty,
const uint32_t *const *previous_tokens_per_req,
const uint32_t *previous_tokens_len_per_req,
const uint32_t is_prefill, const bool enable_paged_attn,
uint32_t *output);

Expand Down
36 changes: 34 additions & 2 deletions scripts/infer_task.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
class InferTask:
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0):
self.id = id
self.finish_reason = None
self.tokens = tokens
self.max_tokens = max_tokens
self.temperature = temperature
self.topk = topk
self.topp = topp
self.repetition_penalty = repetition_penalty
self.end_tokens = end_tokens
self._kv_cache = None
self.pos = 0

# vLLM-style unique token tracking for efficient repetition penalty
# Track unique token IDs that have been generated (not the full sequence)
# Initialize with prompt tokens so they are also penalized
self._unique_generated_tokens = set(tokens) # Initialize with prompt tokens!
self._unique_tokens_array = sorted(self._unique_generated_tokens) # Pre-sort for efficiency
self._unique_tokens_dirty = False # Already initialized, no need to rebuild

def bind_kvcache(self, kv_cache, pos=0):
self._kv_cache = kv_cache
self.pos = pos
self.tokens = self.tokens[pos:]
# Update tokens and add any new tokens to unique set
remaining_tokens = self.tokens[pos:]
for token in remaining_tokens:
if token not in self._unique_generated_tokens:
self._unique_generated_tokens.add(token)
self._unique_tokens_dirty = True
self.tokens = remaining_tokens

def release_kvcache(self):
cache = self._kv_cache
Expand All @@ -34,6 +48,24 @@ def next(self, out_token):
self.finish_reason = "length"
else:
self.tokens = [out_token]
# Incrementally update unique token set (vLLM-style)
# Only add if it's a new token (O(1) average)
if out_token not in self._unique_generated_tokens:
self._unique_generated_tokens.add(out_token)
self._unique_tokens_dirty = True

def get_unique_previous_tokens(self):
"""
Returns a sorted list of unique token IDs that have been generated.
This is the vLLM-style "seen tokens" list for efficient repetition penalty.

Returns:
tuple: (array, length) where array is sorted list of unique token IDs
"""
if self._unique_tokens_dirty:
self._unique_tokens_array = sorted(self._unique_generated_tokens)
self._unique_tokens_dirty = False
return self._unique_tokens_array, len(self._unique_tokens_array)


class KVCache:
Expand Down
68 changes: 65 additions & 3 deletions scripts/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from infer_task import InferTask, KVCache

from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import ctypes

torch.set_default_device("cpu")

Expand Down Expand Up @@ -395,11 +396,55 @@ def __init__(self, tasks: List[InferTask]):
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
self.repetition_penalties_list = [t.repetition_penalty for t in tasks]

# Flatten token lists
flat_tokens = [tok for toks in token_lists for tok in toks]
self.ntok = len(flat_tokens)

# Collect unique tokens per request (vLLM-style for efficient repetition penalty)
# Each request has its own list of unique token IDs
self.unique_tokens_arrays = [] # List of arrays, one per request
self.unique_tokens_lens = [] # List of lengths, one per request
self.unique_tokens_flat = [] # Flattened array for C API
self.unique_tokens_offsets = [0] # Offsets into flat array

total_unique_tokens = 0
for task in tasks:
tokens_array, tokens_len = task.get_unique_previous_tokens()
self.unique_tokens_arrays.append(tokens_array)
self.unique_tokens_lens.append(tokens_len)
self.unique_tokens_flat.extend(tokens_array)
total_unique_tokens += tokens_len
self.unique_tokens_offsets.append(total_unique_tokens)

# Convert to C-compatible arrays
if total_unique_tokens > 0:
self.unique_tokens_c = (c_uint * total_unique_tokens)(*self.unique_tokens_flat)
# Create array of pointers, one per request
self.unique_tokens_ptrs = []
for req_idx in range(self.nreq):
offset = self.unique_tokens_offsets[req_idx]
length = self.unique_tokens_lens[req_idx]
if length > 0:
# Create pointer to the start of this request's tokens in the flat array
ptr = ctypes.cast(
ctypes.addressof(self.unique_tokens_c) + offset * ctypes.sizeof(c_uint),
POINTER(c_uint)
)
else:
ptr = None
self.unique_tokens_ptrs.append(ptr)
# Create array of pointers (use None for empty requests)
self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*self.unique_tokens_ptrs)
else:
self.unique_tokens_c = None
# All requests have no previous tokens
self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*[None] * self.nreq)

# Array of lengths per request
self.unique_tokens_lens_array = (c_uint * self.nreq)(*self.unique_tokens_lens)

# Convert to ctypes arrays in one pass
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
Expand All @@ -408,6 +453,7 @@ def __init__(self, tasks: List[InferTask]):
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
self.repetition_penalties = (c_float * self.nreq)(*self.repetition_penalties_list)

def input_args(self):
return (
Expand All @@ -420,6 +466,9 @@ def input_args(self):
self.temperaturas,
self.topks,
self.topps,
self.repetition_penalties,
self.unique_tokens_ptrs_array, # Array of pointers to unique tokens per request
self.unique_tokens_lens_array, # Array of lengths per request
)


Expand Down Expand Up @@ -534,7 +583,7 @@ def load_all_safetensors_from_dir(dir_path_: str):
else:
raise ValueError("Unsupported model architecture")


if "llama" == config["model_type"]:
from tokenizers import decoders as _dec
backend = getattr(self.tokenizer, "backend_tokenizer", None)
Expand Down Expand Up @@ -593,9 +642,21 @@ def drop_kv_cache(self, kv_cache):
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = JiugeBatchedTask(tasks)
args = batch_inputs.input_args()
self.jiuge_model.infer_batch(
self.model_instance,
*(batch_inputs.input_args()),
args[0], # tokens
args[1], # ntok
args[2], # req_lens
args[3], # nreq
args[4], # req_pos
args[5], # kv_caches
args[6], # temperature
args[7], # topk
args[8], # topp
args[9], # repetition_penalty
args[10], # previous_tokens_per_req
args[11], # previous_tokens_len_per_req
output,
)
return list(output)
Expand All @@ -616,6 +677,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.
topk_,
topp_,
self.eos_token_id,
1.0, # repetition_penalty default
)
infer_task.bind_kvcache(KVCache(self))

Expand Down Expand Up @@ -648,7 +710,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.

def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
tasks = [
InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id, 1.0)
for i in range(batch_size)
]
kv_caches = [KVCache(self) for _ in range(batch_size)]
Expand Down
Loading