Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_apply_invalid_token(
Logits,
invalid_token_ids,
cu_invalid_token_num,
stride_logit_b,
):
cur_batch = tl.program_id(0)
start_index = tl.load(cu_invalid_token_num + cur_batch)
end_index = tl.load(cu_invalid_token_num + cur_batch + 1)
for i in range(start_index, end_index):
cur_invalid_token_id = tl.load(invalid_token_ids + i)
cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id
tl.store(cur_logit_ptr, float("-inf"))
return


def apply_invalid_token_ids(
Logits: torch.Tensor,
invalid_token_ids: torch.Tensor,
cu_invalid_token_num: torch.Tensor,
):
batch_size = Logits.shape[0]
grid = (batch_size,)
_fwd_kernel_apply_invalid_token[grid](
Logits=Logits,
invalid_token_ids=invalid_token_ids,
cu_invalid_token_num=cu_invalid_token_num,
stride_logit_b=Logits.stride(0),
)
return
4 changes: 4 additions & 0 deletions lightllm/server/core/objs/py_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
allowed_token_ids: Optional[List[int]] = None,
# if provided, the invalid token ids will be ignored during generation
invalid_token_ids: Optional[List[int]] = None,
# p d mode used params
group_request_id: Optional[int] = None,
# move kv to deocde node, only used in pd mode
Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(
self.guided_grammar = guided_grammar
self.guided_json = guided_json
self.allowed_token_ids = allowed_token_ids
self.invalid_token_ids = invalid_token_ids
self.group_request_id = group_request_id
self.move_kv_to_decode_node = move_kv_to_decode_node
self.suggested_dp_index = suggested_dp_index
Expand Down Expand Up @@ -269,6 +272,7 @@ def to_dict(self):
ret["guided_grammar"] = self.guided_grammar
ret["guided_json"] = self.guided_json
ret["allowed_token_ids"] = self.allowed_token_ids
ret["invalid_token_ids"] = self.invalid_token_ids
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
ret["seed"] = self.seed
return ret
Expand Down
28 changes: 28 additions & 0 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))
INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The limit of 10 invalid token IDs is quite restrictive for many use cases. Consider increasing INVALID_TOKEN_IDS_MAX_LENGTH to a more reasonable value (e.g., 64 or 128) to provide more flexibility for users who need to mask a larger set of tokens.



class StopSequence(ctypes.Structure):
Expand Down Expand Up @@ -205,6 +206,25 @@ def to_list(self):
return list(self.ids[: self.size])


class InvalidTokenIds(ctypes.Structure):
_pack_ = 4
_fields_ = [
("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH),
("size", ctypes.c_int),
]

def initialize(self, ids: List[int]):
self.size = len(ids)
assert (
self.size <= INVALID_TOKEN_IDS_MAX_LENGTH
), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}."
self.ids[: self.size] = ids[:]
return

def to_list(self):
return list(self.ids[: self.size])


class ExponentialDecayLengthPenalty(ctypes.Structure):
_pack_ = 4
_fields_ = [
Expand Down Expand Up @@ -304,6 +324,8 @@ class SamplingParams(ctypes.Structure):
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
("allowed_token_ids", AllowedTokenIds),
# if provided, the invalid token ids will be ignored during generation
("invalid_token_ids", InvalidTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
("group_request_id", ctypes.c_int64), # p d mode used params
Expand Down Expand Up @@ -394,6 +416,11 @@ def init(self, tokenizer, **kwargs):
self.allowed_token_ids = AllowedTokenIds()
self.allowed_token_ids.initialize(allowed_token_ids)

# Initialize invalid_token_ids
invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list[int](invalid_token_ids))
Comment on lines +420 to +422
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This initialization logic has several critical issues:

  1. Syntax Error: list[int](invalid_token_ids) will raise a TypeError because list[int] is a GenericAlias (type hint) and is not a callable constructor. It should be list(invalid_token_ids).
  2. Logic Error: The code populates invalid_token_ids by masking all keys present in logit_bias to -inf. However, logit_bias is typically used for both boosting (positive values) and suppressing (negative values) tokens. This implementation incorrectly invalidates tokens that the user intended to boost.
  3. Field Inconsistency: The invalid_token_ids field added to py_sampling_params.py is ignored here, as the code only looks at logit_bias keys.
Suggested change
invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list[int](invalid_token_ids))
# Initialize invalid_token_ids from the dedicated field or suppressed logit_bias
invalid_token_ids = kwargs.get("invalid_token_ids") or []
if not invalid_token_ids:
logit_bias = kwargs.get("logit_bias") or {}
invalid_token_ids = [int(k) for k, v in logit_bias.items() if v <= -100]
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list(invalid_token_ids))


if self.do_sample is False:
self.temperature = 1.0
self.top_p = 1.0
Expand Down Expand Up @@ -493,6 +520,7 @@ def to_dict(self):
"guided_grammar": self.guided_grammar.to_str(),
"guided_json": self.guided_json.to_str(),
"allowed_token_ids": self.allowed_token_ids.to_list(),
"invalid_token_ids": self.invalid_token_ids.to_list(),
"group_request_id": self.group_request_id,
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),
"skip_special_tokens": self.skip_special_tokens,
Expand Down
8 changes: 8 additions & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def __init__(
if len(self.allowed_token_ids) == 0:
self.allowed_token_ids = None

# if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample)
self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list()

# p d mode use params
if self.shm_param.move_kv_to_decode_node.exists:
self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict()
Expand All @@ -456,6 +459,11 @@ def __init__(
logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids")
self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size]

if len(self.invalid_token_ids) > 0:
if not all(e < vocab_size for e in self.invalid_token_ids):
logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids")
self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size]

# nixl decode node information
if self.shm_param.nixl_params.data_len > 0:
self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from typing import List, Tuple
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache
from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty
from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache
from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.utils.envs_utils import get_env_start_args
Expand All @@ -15,7 +16,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
b_top_ks,
b_length_penalty_param,
b_mask_eos_reqs,
invalid_token_ids,
cu_invalid_token_num,
is_all_greedy,
has_invalid_token_ids,
skip_top_k,
skip_top_p,
exist_req_use_random_seed,
Expand Down Expand Up @@ -63,6 +67,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
eos_ids=eos_ids,
sampling_params_manager=sampling_params_manager,
)

if has_invalid_token_ids:
apply_invalid_token_ids(
Logits=logits,
invalid_token_ids=invalid_token_ids,
cu_invalid_token_num=cu_invalid_token_num,
)

logits.div_(b_temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)

Expand Down Expand Up @@ -152,6 +164,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
skip_top_p = True
exist_req_use_random_seed = False

# invalid token ids
invalid_token_ids: List[int] = []
has_invalid_token_ids = False
cu_invalid_token_num = [0]
invalid_token_num_start = 0

for i, req_obj in enumerate(reqs):
sample_param = req_obj.sampling_param
shm_param = sample_param.shm_param
Expand All @@ -173,6 +191,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
if req_obj.generator is not None:
exist_req_use_random_seed = True
req_idxes.append(req_obj.req_idx)
invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids)
cu_invalid_token_num.append(invalid_token_num_start)
if len(req_obj.sampling_param.invalid_token_ids) > 0:
has_invalid_token_ids = True
invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids)

req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32)
temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32)
Expand All @@ -183,14 +206,25 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
)
mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool)

if has_invalid_token_ids:
invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list(
key="invalid_token_ids", data=invalid_token_ids, dtype=torch.int32
)
cu_invalid_token_num_cpu = g_pin_mem_manager.gen_from_list(
key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32
)

return (
req_idxes_cpu.cuda(non_blocking=True),
temperatures_cpu.cuda(non_blocking=True),
top_ps_cpu.cuda(non_blocking=True),
top_ks_cpu.cuda(non_blocking=True),
length_penalty_param_cpu.cuda(non_blocking=True),
mask_eos_reqs_cpu.cuda(non_blocking=True),
invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
is_all_greedy,
has_invalid_token_ids,
skip_top_k,
skip_top_p,
exist_req_use_random_seed,
Expand Down
129 changes: 129 additions & 0 deletions test/test_api/test_invalid_token_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Smoke test for the invalid_token_ids feature (logit_bias path).

Hits the lightllm-native /generate endpoint, which forwards `logit_bias` keys
into the SamplingParams `invalid_token_ids` field. The kernel masks those
ids to -inf, so they must never appear in the output.

Run:
python test/test_api/test_invalid_token_ids.py

Assumes the server is up on http://localhost:8000 and the model tokenizer
is Qwen3.5 (matches the launch command in the PR description).
"""


import json
import sys
from typing import Dict, List, Tuple

import requests
from transformers import AutoTokenizer


URL = "http://localhost:8000/generate"
HEADERS = {"Content-Type": "application/json"}
MODEL_DIR = "/nvme/models/Qwen3.5-35B-A3B"

# Stay under INVALID_TOKEN_IDS_MAX_LENGTH (default 10).
BLOCK_WORDS = ["the", " the", "The", " is", " a", " of", " and"]


def _post_generate(prompt: str, parameters: dict, timeout: int = 120) -> dict:
payload = {"inputs": prompt, "parameters": parameters}
resp = requests.post(URL, headers=HEADERS, data=json.dumps(payload), timeout=timeout)
if resp.status_code != 200:
raise RuntimeError(f"{resp.status_code} {resp.text}")
return resp.json()


def _generated_text(resp: dict) -> str:
text = resp["generated_text"]
return text[0] if isinstance(text, list) else text


def _token_ids_from_details(resp: dict) -> List[int]:
tokens = resp.get("tokens", [])
if tokens and isinstance(tokens[0], list):
tokens = tokens[0]
out: List[int] = []
for tok in tokens:
tid = tok.get("id")
if tid is not None:
out.append(int(tid))
return out


def _build_block_map(tokenizer) -> Tuple[Dict[int, float], Dict[int, str]]:
"""Map token id -> bias (-100 = block) and id -> source word."""
bias: Dict[int, float] = {}
source: Dict[int, str] = {}
for w in BLOCK_WORDS:
ids = tokenizer.encode(w, add_special_tokens=False)
for tid in ids:
bias.setdefault(tid, -100.0)
source.setdefault(tid, w)
return bias, source


def test_invalid_token_ids():
print("[1/3] Loading tokenizer...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)

bias_map, source_map = _build_block_map(tokenizer)
blocked_ids = sorted(bias_map.keys())
print(f" Blocking {len(blocked_ids)} token ids: {blocked_ids}")
for tid in blocked_ids:
print(f" {tid:6d} <- {source_map[tid]!r}")

prompt = "Write three short English sentences about San Francisco. " "Mention the bay, the bridge and the weather."
base_params = {
"do_sample": False,
"temperature": 1.0,
"max_new_tokens": 80,
"return_details": True,
}

print("[2/3] Baseline request (no logit_bias)...", flush=True)
base_resp = _post_generate(prompt, dict(base_params))
base_text = _generated_text(base_resp)
base_ids = _token_ids_from_details(base_resp)
print(f" text: {base_text!r}")
base_hits = [tid for tid in base_ids if tid in bias_map]
print(f" blocked-tokens that appeared in baseline: {len(base_hits)} ({base_hits[:10]})")

print("[3/3] logit_bias request...", flush=True)
bias_params = dict(base_params)
bias_params["logit_bias"] = {str(k): v for k, v in bias_map.items()}
biased_resp = _post_generate(prompt, bias_params)
biased_text = _generated_text(biased_resp)
biased_ids = _token_ids_from_details(biased_resp)
print(f" text: {biased_text!r}")
biased_hits = [(tid, source_map[tid]) for tid in biased_ids if tid in bias_map]
print(f" blocked-tokens that appeared with bias: {len(biased_hits)} ({biased_hits[:10]})")

failures = []
if biased_hits:
failures.append(f"Blocked token ids leaked into biased output: {biased_hits}")

# Sanity check: the baseline should have produced at least one of the blocked tokens.
# If it did not, the test is uninformative (but still passes the strict check above).
if not base_hits:
print(
" WARNING: baseline did not produce any of the target tokens; "
"the assertion below is trivially satisfied."
)

if biased_text == base_text:
failures.append("Biased output is identical to baseline; bias may not be applied.")

if failures:
for f in failures:
print(f"FAIL: {f}", file=sys.stderr)
sys.exit(1)

print("PASS: invalid_token_ids correctly suppressed blocked tokens.")


if __name__ == "__main__":
test_invalid_token_ids()
Loading
Loading