Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 5 additions & 60 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from ..core.base import AbstractMelleaTool
from ..formatters import ChatFormatter, TemplateFormatter, granite as granite_formatters
from ..formatters.granite.base.util import _GuidanceLogitsProcessor
from ..helpers import message_to_openai_message, messages_to_docs, send_to_queue
from ..stdlib.components import Intrinsic, Message
from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement
Expand Down Expand Up @@ -159,65 +160,6 @@ def _cleanup_kv_cache(cache_info: HFAloraCacheInfo) -> None:
torch.cuda.empty_cache()


# modified from VLLM v0.9.2 code base
# https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py
class _GuidanceLogitsProcessor:
def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None:
self.grammar = grammar
self.vocab_size: int = ll_tokenizer.vocab_size
self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer
self.ll_matchers: list[llguidance.LLMatcher] = []
self.bitmasks: list[torch.Tensor] = []
self.new_sampling: bool = False
self.batch_size: int = -1

def __call__(
self, batch_input_ids: torch.Tensor, batch_scores: torch.Tensor
) -> torch.Tensor:
i_batch, _ = batch_input_ids.shape
s_batch, _ = batch_scores.shape
assert i_batch == s_batch

# s_batch, s_vocab = batch_scores.shape
# assert s_vocab == self.vocab_size
#
# NOTE: somehow, this does not hold. s_vocab is not same as either of
# * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=True) == self.vocab_size == ll_tokenizer.vocab_size
# * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=False)

if self.batch_size != i_batch:
self.batch_size = i_batch
self.bitmasks = [
llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined]
for _ in range(self.batch_size)
]

self.ll_matchers = [
llguidance.LLMatcher(self.ll_tokenizer, self.grammar)
for _ in range(self.batch_size)
]

for input_ids, scores, ll_matcher, bitmask in zip(
batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks
):
if self.new_sampling and len(input_ids) > 0:
ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids.tolist()[-1]
)
err = ll_matcher.get_error() # type: ignore[attr-defined]
if err:
MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace(
scores, bitmask.to(scores.device)
) # type: ignore[attr-defined]

self.new_sampling = True

return batch_scores


class LocalHFBackend(FormatterBackend, AdapterMixin):
"""The LocalHFBackend uses Huggingface's transformers library for inference, and uses a Formatter to convert `Component`s into prompts. This backend also supports Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397).

Expand Down Expand Up @@ -620,7 +562,10 @@ async def _generate_from_intrinsic(

generate_input, other_input = (
granite_formatters.base.util.chat_completion_request_to_transformers_inputs( # type: ignore
rewritten, self._tokenizer, self._model
rewritten,
self._tokenizer,
self._model,
ll_tokenizer=self._llguidance_tokenizer,
)
)

Expand Down
88 changes: 73 additions & 15 deletions mellea/formatters/granite/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pydantic

if TYPE_CHECKING:
import llguidance
from transformers import PreTrainedModel, PreTrainedTokenizerBase

# First Party
Expand Down Expand Up @@ -112,11 +113,66 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple:
return model, tokenizer


# Modified from VLLM v0.9.2 code base
# https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py
class _GuidanceLogitsProcessor:
"""A HuggingFace logits processor that enforces an llguidance grammar."""

def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None:
"""Initialize the processor with a compiled grammar and an llguidance tokenizer."""
self.grammar = grammar
self.vocab_size: int = ll_tokenizer.vocab_size
self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer
self.ll_matchers: list[llguidance.LLMatcher] = []
self.bitmasks: list = []
self.new_sampling: bool = False
self.batch_size: int = -1

def __call__(self, batch_input_ids, batch_scores): # type: ignore[no-untyped-def]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AGENTS.md §5 requires types on core functions. The previous version had (batch_input_ids: torch.Tensor, batch_scores: torch.Tensor) -> torch.Tensor, and bitmasks: list[torch.Tensor] on line 127. You can keep the runtime imports lazy and still restore the annotations by adding import torch to the existing if TYPE_CHECKING: block at the top of the file.

"""Apply the grammar's allowed-token bitmask to ``batch_scores`` in place."""
# Third Party
import llguidance
import llguidance.torch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: these run on every generation step. sys.modules caches it so cost is small, but hoisting into __init__ (or module-level lazy) would be cleaner — you can't reach __call__ without going through the constructor.


i_batch, _ = batch_input_ids.shape
s_batch, _ = batch_scores.shape
assert i_batch == s_batch

if self.batch_size != i_batch:
self.batch_size = i_batch
self.bitmasks = [
llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined]
for _ in range(self.batch_size)
]
self.ll_matchers = [
llguidance.LLMatcher(self.ll_tokenizer, self.grammar)
for _ in range(self.batch_size)
]

for input_ids, scores, ll_matcher, bitmask in zip(
batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks
):
if self.new_sampling and len(input_ids) > 0:
ll_matcher.consume_token(input_ids.tolist()[-1]) # type: ignore[attr-defined]
err = ll_matcher.get_error() # type: ignore[attr-defined]
if err:
logging.warning("Error in LLMatcher: %s", err)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version in huggingface.py used MelleaLogger.get_logger().warning(...). Suggest keeping that here — the rest of the HF backend standardizes on it (~15 callsites), and users who configure the mellea logger (level, handlers) won't see warnings routed through the root logger otherwise.


llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace( # type: ignore[attr-defined]
scores, bitmask.to(scores.device)
)

self.new_sampling = True
return batch_scores


def chat_completion_request_to_transformers_inputs(
request: dict,
tokenizer: PreTrainedTokenizerBase | None = None,
model: PreTrainedModel | None = None,
constrained_decoding_prefix: str | None = None,
ll_tokenizer: llguidance.LLTokenizer | None = None,
) -> tuple[dict, dict]:
"""Translate an OpenAI-style chat completion request.

Expand All @@ -130,14 +186,17 @@ def chat_completion_request_to_transformers_inputs(
model: HuggingFace model object. Only required if the request uses constrained
decoding.
constrained_decoding_prefix: Optional generation prefix to append to the prompt.
ll_tokenizer: Pre-built ``llguidance.LLTokenizer``. Only used when the request
uses constrained decoding; if not provided, one is constructed from
``tokenizer``. Pass an existing instance to avoid the construction cost.

Returns:
Tuple of ``(generate_input, other_input)`` where ``generate_input`` contains
kwargs to pass directly to ``generate()`` and ``other_input`` contains
additional parameters for ``generate_with_transformers``.

Raises:
ImportError: If ``torch``, ``transformers``, or ``xgrammar`` packages
ImportError: If ``torch``, ``transformers``, or ``llguidance`` packages
are not installed (the latter only when constrained decoding is used).
TypeError: If ``tokenizer.apply_chat_template()`` returns an unexpected type.
ValueError: If padding or end-of-sequence token IDs cannot be determined
Expand Down Expand Up @@ -191,7 +250,8 @@ def chat_completion_request_to_transformers_inputs(

# generate() will fail with many different creative error messages if tokens aren't
# on the right device.
input_tokens = input_tokens.to(model.device) # type: ignore[union-attr]
if model is not None:
input_tokens = input_tokens.to(model.device)
generate_input["input_tokens"] = input_tokens

# The generate() method sometimes needs to know what is the integer ID
Expand Down Expand Up @@ -234,9 +294,10 @@ def chat_completion_request_to_transformers_inputs(
):
# Constrained decoding in Hugging Face requires using a third-party library
# to create a callback function to be invoked from inside generate()
with import_optional("xgrammar"):
with import_optional("llguidance"):
# Third Party
import xgrammar as xgr # type: ignore[import-not-found]
import llguidance
import llguidance.hf
if tokenizer is None:
raise ValueError(
"Request specifies constrained decoding, but no "
Expand All @@ -245,22 +306,19 @@ def chat_completion_request_to_transformers_inputs(
if model is None:
raise ValueError(
"Request specifies constrained decoding, but no "
"tokenizer object was passed to this function."
"model object was passed to this function."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xgrammar version needed model for model.vocab_size in the max(...) reconciliation. The llguidance path doesn't appear to use model anywhere in this branch, so this check (and the model argument requirement for constrained decoding generally) can probably be dropped.


# Different parts of a Hugging Face model will have different opinions about
# the number of tokens in the tokenizer's vocabulary, because of course they do.
# Gather together all the possibilities and pick the biggest one.
vocab_size = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(tokenizer) # type: ignore[arg-type]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code did vocab_size = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) to defend against HF's vocab-size disagreement across components (resized embeddings, added special tokens, etc.). Worth confirming llguidance.hf.from_tokenizer handles the model.vocab_size > tokenizer.vocab_size case — the old comment ("of course they do") implied this had bitten someone before.


tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(
grammar = llguidance.LLMatcher.grammar_from_json_schema(
request["extra_body"]["structured_outputs"]["json"]
# NOTE: Mellea's structured output for hf sets `whitespace_flexible` to False.
# Not doing so here to match the previous granite formatter behavior.
# defaults={"whitespace_flexible": False},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either commit to the choice (delete the dead code) or pass defaults=... explicitly so the intent lives in code rather than a comment. As written, this is the kind of note that rots — a future reader will wonder whether it was forgotten.

)
logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
logits_processor = _GuidanceLogitsProcessor(grammar, ll_tokenizer)

# The "logits_processor" argument to generate() must be a list.
generate_input["logits_processor"] = [logits_processor] # type: ignore[assignment]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ hf = [
"peft>=0.18.1", # Native aLoRA support added in PEFT 0.18.0
"transformers>=4.53.2,<5",
"trl==0.19.1",
"xgrammar==0.1.33", # Necessary for granite_common intrinsics. Pinned due to Issue 990.
"huggingface-hub>=0.33.4",
]

Expand Down
Loading
Loading