-
Notifications
You must be signed in to change notification settings - Fork 325
feat: support invalid_token_ids in sampling params #1305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
36 changes: 36 additions & 0 deletions
36
lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _fwd_kernel_apply_invalid_token( | ||
| Logits, | ||
| invalid_token_ids, | ||
| cu_invalid_token_num, | ||
| stride_logit_b, | ||
| ): | ||
| cur_batch = tl.program_id(0) | ||
| start_index = tl.load(cu_invalid_token_num + cur_batch) | ||
| end_index = tl.load(cu_invalid_token_num + cur_batch + 1) | ||
| for i in range(start_index, end_index): | ||
| cur_invalid_token_id = tl.load(invalid_token_ids + i) | ||
| cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id | ||
| tl.store(cur_logit_ptr, float("-inf")) | ||
| return | ||
|
|
||
|
|
||
| def apply_invalid_token_ids( | ||
| Logits: torch.Tensor, | ||
| invalid_token_ids: torch.Tensor, | ||
| cu_invalid_token_num: torch.Tensor, | ||
| ): | ||
| batch_size = Logits.shape[0] | ||
| grid = (batch_size,) | ||
| _fwd_kernel_apply_invalid_token[grid]( | ||
| Logits=Logits, | ||
| invalid_token_ids=invalid_token_ids, | ||
| cu_invalid_token_num=cu_invalid_token_num, | ||
| stride_logit_b=Logits.stride(0), | ||
| ) | ||
| return |
File renamed without changes.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |||||||||||||||||||||||
| REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) | ||||||||||||||||||||||||
| GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) | ||||||||||||||||||||||||
| JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) | ||||||||||||||||||||||||
| INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class StopSequence(ctypes.Structure): | ||||||||||||||||||||||||
|
|
@@ -205,6 +206,25 @@ def to_list(self): | |||||||||||||||||||||||
| return list(self.ids[: self.size]) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class InvalidTokenIds(ctypes.Structure): | ||||||||||||||||||||||||
| _pack_ = 4 | ||||||||||||||||||||||||
| _fields_ = [ | ||||||||||||||||||||||||
| ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), | ||||||||||||||||||||||||
| ("size", ctypes.c_int), | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def initialize(self, ids: List[int]): | ||||||||||||||||||||||||
| self.size = len(ids) | ||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||
| self.size <= INVALID_TOKEN_IDS_MAX_LENGTH | ||||||||||||||||||||||||
| ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." | ||||||||||||||||||||||||
| self.ids[: self.size] = ids[:] | ||||||||||||||||||||||||
| return | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def to_list(self): | ||||||||||||||||||||||||
| return list(self.ids[: self.size]) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class ExponentialDecayLengthPenalty(ctypes.Structure): | ||||||||||||||||||||||||
| _pack_ = 4 | ||||||||||||||||||||||||
| _fields_ = [ | ||||||||||||||||||||||||
|
|
@@ -304,6 +324,8 @@ class SamplingParams(ctypes.Structure): | |||||||||||||||||||||||
| # processor which only retains scores for the given token ids. Defaults to None. | ||||||||||||||||||||||||
| # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. | ||||||||||||||||||||||||
| ("allowed_token_ids", AllowedTokenIds), | ||||||||||||||||||||||||
| # if provided, the invalid token ids will be ignored during generation | ||||||||||||||||||||||||
| ("invalid_token_ids", InvalidTokenIds), | ||||||||||||||||||||||||
| ("stop_sequences", StopSequenceGroups), | ||||||||||||||||||||||||
| ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), | ||||||||||||||||||||||||
| ("group_request_id", ctypes.c_int64), # p d mode used params | ||||||||||||||||||||||||
|
|
@@ -394,6 +416,11 @@ def init(self, tokenizer, **kwargs): | |||||||||||||||||||||||
| self.allowed_token_ids = AllowedTokenIds() | ||||||||||||||||||||||||
| self.allowed_token_ids.initialize(allowed_token_ids) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Initialize invalid_token_ids | ||||||||||||||||||||||||
| invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) | ||||||||||||||||||||||||
| self.invalid_token_ids = InvalidTokenIds() | ||||||||||||||||||||||||
| self.invalid_token_ids.initialize(list[int](invalid_token_ids)) | ||||||||||||||||||||||||
|
Comment on lines
+420
to
+422
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This initialization logic has several critical issues:
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if self.do_sample is False: | ||||||||||||||||||||||||
| self.temperature = 1.0 | ||||||||||||||||||||||||
| self.top_p = 1.0 | ||||||||||||||||||||||||
|
|
@@ -493,6 +520,7 @@ def to_dict(self): | |||||||||||||||||||||||
| "guided_grammar": self.guided_grammar.to_str(), | ||||||||||||||||||||||||
| "guided_json": self.guided_json.to_str(), | ||||||||||||||||||||||||
| "allowed_token_ids": self.allowed_token_ids.to_list(), | ||||||||||||||||||||||||
| "invalid_token_ids": self.invalid_token_ids.to_list(), | ||||||||||||||||||||||||
| "group_request_id": self.group_request_id, | ||||||||||||||||||||||||
| "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), | ||||||||||||||||||||||||
| "skip_special_tokens": self.skip_special_tokens, | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| """ | ||
| Smoke test for the invalid_token_ids feature (logit_bias path). | ||
|
|
||
| Hits the lightllm-native /generate endpoint, which forwards `logit_bias` keys | ||
| into the SamplingParams `invalid_token_ids` field. The kernel masks those | ||
| ids to -inf, so they must never appear in the output. | ||
|
|
||
| Run: | ||
| python test/test_api/test_invalid_token_ids.py | ||
|
|
||
| Assumes the server is up on http://localhost:8000 and the model tokenizer | ||
| is Qwen3.5 (matches the launch command in the PR description). | ||
| """ | ||
|
|
||
|
|
||
| import json | ||
| import sys | ||
| from typing import Dict, List, Tuple | ||
|
|
||
| import requests | ||
| from transformers import AutoTokenizer | ||
|
|
||
|
|
||
| URL = "http://localhost:8000/generate" | ||
| HEADERS = {"Content-Type": "application/json"} | ||
| MODEL_DIR = "/nvme/models/Qwen3.5-35B-A3B" | ||
|
|
||
| # Stay under INVALID_TOKEN_IDS_MAX_LENGTH (default 10). | ||
| BLOCK_WORDS = ["the", " the", "The", " is", " a", " of", " and"] | ||
|
|
||
|
|
||
| def _post_generate(prompt: str, parameters: dict, timeout: int = 120) -> dict: | ||
| payload = {"inputs": prompt, "parameters": parameters} | ||
| resp = requests.post(URL, headers=HEADERS, data=json.dumps(payload), timeout=timeout) | ||
| if resp.status_code != 200: | ||
| raise RuntimeError(f"{resp.status_code} {resp.text}") | ||
| return resp.json() | ||
|
|
||
|
|
||
| def _generated_text(resp: dict) -> str: | ||
| text = resp["generated_text"] | ||
| return text[0] if isinstance(text, list) else text | ||
|
|
||
|
|
||
| def _token_ids_from_details(resp: dict) -> List[int]: | ||
| tokens = resp.get("tokens", []) | ||
| if tokens and isinstance(tokens[0], list): | ||
| tokens = tokens[0] | ||
| out: List[int] = [] | ||
| for tok in tokens: | ||
| tid = tok.get("id") | ||
| if tid is not None: | ||
| out.append(int(tid)) | ||
| return out | ||
|
|
||
|
|
||
| def _build_block_map(tokenizer) -> Tuple[Dict[int, float], Dict[int, str]]: | ||
| """Map token id -> bias (-100 = block) and id -> source word.""" | ||
| bias: Dict[int, float] = {} | ||
| source: Dict[int, str] = {} | ||
| for w in BLOCK_WORDS: | ||
| ids = tokenizer.encode(w, add_special_tokens=False) | ||
| for tid in ids: | ||
| bias.setdefault(tid, -100.0) | ||
| source.setdefault(tid, w) | ||
| return bias, source | ||
|
|
||
|
|
||
| def test_invalid_token_ids(): | ||
| print("[1/3] Loading tokenizer...", flush=True) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) | ||
|
|
||
| bias_map, source_map = _build_block_map(tokenizer) | ||
| blocked_ids = sorted(bias_map.keys()) | ||
| print(f" Blocking {len(blocked_ids)} token ids: {blocked_ids}") | ||
| for tid in blocked_ids: | ||
| print(f" {tid:6d} <- {source_map[tid]!r}") | ||
|
|
||
| prompt = "Write three short English sentences about San Francisco. " "Mention the bay, the bridge and the weather." | ||
| base_params = { | ||
| "do_sample": False, | ||
| "temperature": 1.0, | ||
| "max_new_tokens": 80, | ||
| "return_details": True, | ||
| } | ||
|
|
||
| print("[2/3] Baseline request (no logit_bias)...", flush=True) | ||
| base_resp = _post_generate(prompt, dict(base_params)) | ||
| base_text = _generated_text(base_resp) | ||
| base_ids = _token_ids_from_details(base_resp) | ||
| print(f" text: {base_text!r}") | ||
| base_hits = [tid for tid in base_ids if tid in bias_map] | ||
| print(f" blocked-tokens that appeared in baseline: {len(base_hits)} ({base_hits[:10]})") | ||
|
|
||
| print("[3/3] logit_bias request...", flush=True) | ||
| bias_params = dict(base_params) | ||
| bias_params["logit_bias"] = {str(k): v for k, v in bias_map.items()} | ||
| biased_resp = _post_generate(prompt, bias_params) | ||
| biased_text = _generated_text(biased_resp) | ||
| biased_ids = _token_ids_from_details(biased_resp) | ||
| print(f" text: {biased_text!r}") | ||
| biased_hits = [(tid, source_map[tid]) for tid in biased_ids if tid in bias_map] | ||
| print(f" blocked-tokens that appeared with bias: {len(biased_hits)} ({biased_hits[:10]})") | ||
|
|
||
| failures = [] | ||
| if biased_hits: | ||
| failures.append(f"Blocked token ids leaked into biased output: {biased_hits}") | ||
|
|
||
| # Sanity check: the baseline should have produced at least one of the blocked tokens. | ||
| # If it did not, the test is uninformative (but still passes the strict check above). | ||
| if not base_hits: | ||
| print( | ||
| " WARNING: baseline did not produce any of the target tokens; " | ||
| "the assertion below is trivially satisfied." | ||
| ) | ||
|
|
||
| if biased_text == base_text: | ||
| failures.append("Biased output is identical to baseline; bias may not be applied.") | ||
|
|
||
| if failures: | ||
| for f in failures: | ||
| print(f"FAIL: {f}", file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| print("PASS: invalid_token_ids correctly suppressed blocked tokens.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_invalid_token_ids() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limit of 10 invalid token IDs is quite restrictive for many use cases. Consider increasing
INVALID_TOKEN_IDS_MAX_LENGTHto a more reasonable value (e.g., 64 or 128) to provide more flexibility for users who need to mask a larger set of tokens.