-
Notifications
You must be signed in to change notification settings - Fork 116
feat: consolidate to llguidance from xgrammar #1077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import pydantic | ||
|
|
||
| if TYPE_CHECKING: | ||
| import llguidance | ||
| from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
|
|
||
| # First Party | ||
|
|
@@ -112,11 +113,66 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple: | |
| return model, tokenizer | ||
|
|
||
|
|
||
| # Modified from VLLM v0.9.2 code base | ||
| # https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py | ||
| class _GuidanceLogitsProcessor: | ||
| """A HuggingFace logits processor that enforces an llguidance grammar.""" | ||
|
|
||
| def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None: | ||
| """Initialize the processor with a compiled grammar and an llguidance tokenizer.""" | ||
| self.grammar = grammar | ||
| self.vocab_size: int = ll_tokenizer.vocab_size | ||
| self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer | ||
| self.ll_matchers: list[llguidance.LLMatcher] = [] | ||
| self.bitmasks: list = [] | ||
| self.new_sampling: bool = False | ||
| self.batch_size: int = -1 | ||
|
|
||
| def __call__(self, batch_input_ids, batch_scores): # type: ignore[no-untyped-def] | ||
| """Apply the grammar's allowed-token bitmask to ``batch_scores`` in place.""" | ||
| # Third Party | ||
| import llguidance | ||
| import llguidance.torch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: these run on every generation step. |
||
|
|
||
| i_batch, _ = batch_input_ids.shape | ||
| s_batch, _ = batch_scores.shape | ||
| assert i_batch == s_batch | ||
|
|
||
| if self.batch_size != i_batch: | ||
| self.batch_size = i_batch | ||
| self.bitmasks = [ | ||
| llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined] | ||
| for _ in range(self.batch_size) | ||
| ] | ||
| self.ll_matchers = [ | ||
| llguidance.LLMatcher(self.ll_tokenizer, self.grammar) | ||
| for _ in range(self.batch_size) | ||
| ] | ||
|
|
||
| for input_ids, scores, ll_matcher, bitmask in zip( | ||
| batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks | ||
| ): | ||
| if self.new_sampling and len(input_ids) > 0: | ||
| ll_matcher.consume_token(input_ids.tolist()[-1]) # type: ignore[attr-defined] | ||
| err = ll_matcher.get_error() # type: ignore[attr-defined] | ||
| if err: | ||
| logging.warning("Error in LLMatcher: %s", err) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous version in |
||
|
|
||
| llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0) | ||
| llguidance.torch.apply_token_bitmask_inplace( # type: ignore[attr-defined] | ||
| scores, bitmask.to(scores.device) | ||
| ) | ||
|
|
||
| self.new_sampling = True | ||
| return batch_scores | ||
|
|
||
|
|
||
| def chat_completion_request_to_transformers_inputs( | ||
| request: dict, | ||
| tokenizer: PreTrainedTokenizerBase | None = None, | ||
| model: PreTrainedModel | None = None, | ||
| constrained_decoding_prefix: str | None = None, | ||
| ll_tokenizer: llguidance.LLTokenizer | None = None, | ||
| ) -> tuple[dict, dict]: | ||
| """Translate an OpenAI-style chat completion request. | ||
|
|
||
|
|
@@ -130,14 +186,17 @@ def chat_completion_request_to_transformers_inputs( | |
| model: HuggingFace model object. Only required if the request uses constrained | ||
| decoding. | ||
| constrained_decoding_prefix: Optional generation prefix to append to the prompt. | ||
| ll_tokenizer: Pre-built ``llguidance.LLTokenizer``. Only used when the request | ||
| uses constrained decoding; if not provided, one is constructed from | ||
| ``tokenizer``. Pass an existing instance to avoid the construction cost. | ||
|
|
||
| Returns: | ||
| Tuple of ``(generate_input, other_input)`` where ``generate_input`` contains | ||
| kwargs to pass directly to ``generate()`` and ``other_input`` contains | ||
| additional parameters for ``generate_with_transformers``. | ||
|
|
||
| Raises: | ||
| ImportError: If ``torch``, ``transformers``, or ``xgrammar`` packages | ||
| ImportError: If ``torch``, ``transformers``, or ``llguidance`` packages | ||
| are not installed (the latter only when constrained decoding is used). | ||
| TypeError: If ``tokenizer.apply_chat_template()`` returns an unexpected type. | ||
| ValueError: If padding or end-of-sequence token IDs cannot be determined | ||
|
|
@@ -191,7 +250,8 @@ def chat_completion_request_to_transformers_inputs( | |
|
|
||
| # generate() will fail with many different creative error messages if tokens aren't | ||
| # on the right device. | ||
| input_tokens = input_tokens.to(model.device) # type: ignore[union-attr] | ||
| if model is not None: | ||
| input_tokens = input_tokens.to(model.device) | ||
| generate_input["input_tokens"] = input_tokens | ||
|
|
||
| # The generate() method sometimes needs to know what is the integer ID | ||
|
|
@@ -234,9 +294,10 @@ def chat_completion_request_to_transformers_inputs( | |
| ): | ||
| # Constrained decoding in Hugging Face requires using a third-party library | ||
| # to create a callback function to be invoked from inside generate() | ||
| with import_optional("xgrammar"): | ||
| with import_optional("llguidance"): | ||
| # Third Party | ||
| import xgrammar as xgr # type: ignore[import-not-found] | ||
| import llguidance | ||
| import llguidance.hf | ||
| if tokenizer is None: | ||
| raise ValueError( | ||
| "Request specifies constrained decoding, but no " | ||
|
|
@@ -245,22 +306,19 @@ def chat_completion_request_to_transformers_inputs( | |
| if model is None: | ||
| raise ValueError( | ||
| "Request specifies constrained decoding, but no " | ||
| "tokenizer object was passed to this function." | ||
| "model object was passed to this function." | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The xgrammar version needed |
||
|
|
||
| # Different parts of a Hugging Face model will have different opinions about | ||
| # the number of tokens in the tokenizer's vocabulary, because of course they do. | ||
| # Gather together all the possibilities and pick the biggest one. | ||
| vocab_size = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) | ||
| if ll_tokenizer is None: | ||
| ll_tokenizer = llguidance.hf.from_tokenizer(tokenizer) # type: ignore[arg-type] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous code did |
||
|
|
||
| tokenizer_info = xgr.TokenizerInfo.from_huggingface( | ||
| tokenizer, vocab_size=vocab_size | ||
| ) | ||
| grammar_compiler = xgr.GrammarCompiler(tokenizer_info) | ||
| compiled_grammar = grammar_compiler.compile_json_schema( | ||
| grammar = llguidance.LLMatcher.grammar_from_json_schema( | ||
| request["extra_body"]["structured_outputs"]["json"] | ||
| # NOTE: Mellea's structured output for hf sets `whitespace_flexible` to False. | ||
| # Not doing so here to match the previous granite formatter behavior. | ||
| # defaults={"whitespace_flexible": False}, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either commit to the choice (delete the dead code) or pass |
||
| ) | ||
| logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) | ||
| logits_processor = _GuidanceLogitsProcessor(grammar, ll_tokenizer) | ||
|
|
||
| # The "logits_processor" argument to generate() must be a list. | ||
| generate_input["logits_processor"] = [logits_processor] # type: ignore[assignment] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AGENTS.md §5 requires types on core functions. The previous version had
(batch_input_ids: torch.Tensor, batch_scores: torch.Tensor) -> torch.Tensor, andbitmasks: list[torch.Tensor]on line 127. You can keep the runtime imports lazy and still restore the annotations by addingimport torchto the existingif TYPE_CHECKING:block at the top of the file.