Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,12 @@ class GuidedDecodingParams
kSTRUCTURAL_TAG = 4,
};

explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);
explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt, std::optional<std::int32_t> guidanceStartTokenId = std::nullopt);

bool operator==(GuidedDecodingParams const& other) const;
[[nodiscard]] GuideType getGuideType() const;
[[nodiscard]] std::optional<std::string> getGuide() const;
[[nodiscard]] std::optional<std::int32_t> getGuidanceStartTokenId() const;

private:
friend class Serialization;
Expand All @@ -523,6 +524,7 @@ class GuidedDecodingParams
/// @brief The detailed guide string. It could be a json schema, a regular expression or a EBNF grammar depending on
/// mGuideType.
std::optional<std::string> mGuide;
std::optional<std::int32_t> mGuidanceStartTokenId;
};

using RetentionPriority = SizeType32;
Expand Down
10 changes: 8 additions & 2 deletions cpp/tensorrt_llm/executor/guidedDecodingParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@
namespace tensorrt_llm::executor
{

GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide)
GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide, std::optional<std::int32_t> guidanceStartTokenId)
: mGuideType{guideType}
, mGuide{std::move(guide)}
, mGuidanceStartTokenId{guidanceStartTokenId}
{
TLLM_CHECK_WITH_INFO(mGuideType == GuideType::kJSON || mGuide.has_value(),
"The guide string must be provided unless using GuideType::kJSON.");
}

bool GuidedDecodingParams::operator==(GuidedDecodingParams const& other) const
{
return mGuideType == other.mGuideType && mGuide == other.mGuide;
return mGuideType == other.mGuideType && mGuide == other.mGuide && mGuidanceStartTokenId == other.mGuidanceStartTokenId;
}

GuidedDecodingParams::GuideType GuidedDecodingParams::getGuideType() const
Expand All @@ -45,4 +46,9 @@ std::optional<std::string> GuidedDecodingParams::getGuide() const
return mGuide;
}

std::optional<std::int32_t> GuidedDecodingParams::getGuidanceStartTokenId() const
{
return mGuidanceStartTokenId;
}

} // namespace tensorrt_llm::executor
5 changes: 4 additions & 1 deletion cpp/tensorrt_llm/executor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1563,20 +1563,23 @@ GuidedDecodingParams Serialization::deserializeGuidedDecodingParams(std::istream
{
auto guideType = su::deserializeWithGetterType<decltype(&GuidedDecodingParams::getGuideType)>(is);
auto guide = su::deserializeWithGetterType<decltype(&GuidedDecodingParams::getGuide)>(is);
return GuidedDecodingParams(guideType, guide);
auto guidanceStartTokenId = su::deserializeWithGetterType<decltype(&GuidedDecodingParams::getGuidanceStartTokenId)>(is);
return GuidedDecodingParams(guideType, guide, guidanceStartTokenId);
}

void Serialization::serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os)
{
su::serialize(guidedDecodingParams.getGuideType(), os);
su::serialize(guidedDecodingParams.getGuide(), os);
su::serialize(guidedDecodingParams.getGuidanceStartTokenId(), os);
}

size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingParams)
{
size_t totalSize = 0;
totalSize += su::serializedSize(guidedDecodingParams.getGuideType());
totalSize += su::serializedSize(guidedDecodingParams.getGuide());
totalSize += su::serializedSize(guidedDecodingParams.getGuidanceStartTokenId());
return totalSize;
}

Expand Down
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/nanobind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,23 +542,24 @@ void initRequestBindings(nb::module_& m)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);

auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); };
= [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); };

auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid GuidedDecodingParams state!");
}
new (&self) tle::GuidedDecodingParams(
nb::cast<tle::GuidedDecodingParams::GuideType>(state[0]), nb::cast<std::optional<std::string>>(state[1]));
nb::cast<tle::GuidedDecodingParams::GuideType>(state[0]), nb::cast<std::optional<std::string>>(state[1]), nb::cast<std::optional<int32_t>>(state[2]));
};

pyGuidedDecodingParams
.def(nb::init<tle::GuidedDecodingParams::GuideType, std::optional<std::string>>(), nb::arg("guide_type"),
nb::arg("guide") = nb::none())
.def(nb::init<tle::GuidedDecodingParams::GuideType, std::optional<std::string>, std::optional<std::int32_t>>(), nb::arg("guide_type"),
nb::arg("guide") = nb::none(), nb::arg("guidance_start_token_id") = nb::none())
.def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType)
.def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide)
.def_prop_ro("guidance_start_token_id", &tle::GuidedDecodingParams::getGuidanceStartTokenId)
.def("__getstate__", guidedDecodingParamsGetstate)
.def("__setstate__", guidedDecodingParamsSetstate);

Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,11 @@ void initRequestBindings(pybind11::module_& m)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);

auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); };

auto guidedDecodingParamsSetstate = [](py::tuple state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid GuidedDecodingParams state!");
}
Expand All @@ -513,6 +513,7 @@ void initRequestBindings(pybind11::module_& m)
py::arg("guide") = py::none())
.def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType)
.def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide)
.def_property_readonly("guidance_start_token_Id", &tle::GuidedDecodingParams::getGuidanceStartTokenId)
.def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate));

auto requestGetstate = [](tle::Request const& self)
Expand Down
60 changes: 60 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
def is_terminated(self) -> bool:
pass

@abstractmethod
def guidance_started(self) -> bool:
pass


class GrammarMatcherFactory(ABC):

Expand Down Expand Up @@ -56,7 +60,60 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,

def is_terminated(self) -> bool:
return self._matcher.is_terminated()

def guidance_started(self) -> bool:
return True


class GrammarMatcherWrapper(GrammarMatcher):
def __init__(self, matcher: GrammarMatcher, guidance_start_token_id: int):
super().__init__()
self._matcher = matcher
self._guidance_start_token_id = guidance_start_token_id
self._guidance_started = False
self._steps_after_guidance_start = 0

def accept_token(self, token_id: int) -> bool:
if not self._guidance_started:
if token_id == self._guidance_start_token_id:
self._guidance_started = True
self._steps_after_guidance_start = 0
return True
else:
return True
self._steps_after_guidance_start += 1
return self._matcher.accept_token(token_id)

def rollback(self, num_tokens: int) -> None:
if not self._guidance_started:
return
# cannot rollback more than _steps_after_guidance_start
num_tokens_to_rollback = min(num_tokens, self._steps_after_guidance_start)
if num_tokens > self._steps_after_guidance_start:
self._guidance_started = False
self._matcher.rollback(num_tokens_to_rollback)

def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
self._matcher.fill_next_token_bitmask(next_token_bitmask, index)

def is_terminated(self) -> bool:
return self._matcher.is_terminated()

def guidance_started(self) -> bool:
return self._guidance_started

class GrammarMatcherFactoryWrapper(GrammarMatcherFactory):
def __init__(self, factory: GrammarMatcherFactory):
super().__init__()
self._factory = factory

def create(self,
guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher:
matcher = self._factory.create(guided_decoding_params)
if guided_decoding_params.guidance_start_token_id:
return GrammarMatcherWrapper(matcher, guided_decoding_params.guidance_start_token_id)
return matcher

class XGrammarMatcherFactory(GrammarMatcherFactory):

Expand Down Expand Up @@ -167,6 +224,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
def is_terminated(self) -> bool:
return self._is_terminated

def guidance_started(self) -> bool:
return True

def _check_err(self) -> None:
if self._matcher.is_error():
raise ValueError(
Expand Down
17 changes: 10 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...bindings.internal.batch_manager import LlmRequestType
from ...logger import logger
from ..hostfunc import hostfunc
from .grammar_matcher import (GrammarMatcher, LLGuidanceMatcherFactory,
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactoryWrapper, LLGuidanceMatcherFactory,
XGrammarMatcherFactory)
from .llm_request import LlmRequest
from .scheduler import ScheduledRequests
Expand Down Expand Up @@ -161,6 +161,7 @@ def __init__(self,
raise ValueError(
f"Invalid guided decoding backend: {self.guided_decoding_backend}"
)
self.grammar_matcher_factory = GrammarMatcherFactoryWrapper(self.grammar_matcher_factory)
logger.info(
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
)
Expand Down Expand Up @@ -249,9 +250,10 @@ def _build(self, requests: GuidedRequests) -> None:

self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
self.token_mask_host[offset] = 1
self.num_guided_tokens[slot] += 1
if matcher.guidance_started():
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
self.token_mask_host[offset] = 1
self.num_guided_tokens[slot] += 1
# Process draft tokens
for i, tid in enumerate(req.draft_tokens, 1):
accepted = matcher.accept_token(tid)
Expand All @@ -260,10 +262,11 @@ def _build(self, requests: GuidedRequests) -> None:
self.num_advanced_tokens[slot] += 1
if matcher.is_terminated():
break
matcher.fill_next_token_bitmask(self.bitmask_host,
if matcher.guidance_started():
matcher.fill_next_token_bitmask(self.bitmask_host,
offset + i)
self.token_mask_host[offset + i] = 1
self.num_guided_tokens[slot] += 1
self.token_mask_host[offset + i] = 1
self.num_guided_tokens[slot] += 1

if req.is_draft:
assert len(req.draft_tokens) == 0
Expand Down
15 changes: 11 additions & 4 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ class GuidedDecodingParams:
grammar: Optional[str] = None
json_object: bool = False
structural_tag: Optional[str] = None
guidance_start_token_id: Optional[int] = None

def _validate(self):
exclude_fields = set(["guidance_start_token_id"])
num_guides = 0
for _field in fields(self):
if _field.name in exclude_fields:
continue
num_guides += bool(getattr(self, _field.name))
if num_guides > 1:
raise ValueError(f"Only one guide can be used for a request, but got {num_guides}.")
Expand Down Expand Up @@ -459,28 +463,31 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
return None

if self.guided_decoding.json_object:
return tllme.GuidedDecodingParams(tllme.GuidedDecodingParams.GuideType.JSON)
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON, None, self.guided_decoding.guidance_start_token_id,
)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json_schema, BaseModel):
json_schema = json_schema.model_json_schema()
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema, self.guided_decoding.guidance_start_token_id
)
elif self.guided_decoding.regex is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex
tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex, self.guided_decoding.guidance_start_token_id
)
elif self.guided_decoding.grammar is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar, self.guided_decoding.guidance_start_token_id
)
elif self.guided_decoding.structural_tag is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
self.guided_decoding.structural_tag,
self.guided_decoding.guidance_start_token_id,
)
else:
return None
6 changes: 3 additions & 3 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class ResponseFormat(OpenAIBaseModel):
schema: Optional[dict] = None
structures: Optional[List[StructuralTag]] = None
triggers: Optional[List[str]] = None

guidance_start_token_id: Optional[int] = None

class DisaggregatedParams(OpenAIBaseModel):
request_type: str
Expand Down Expand Up @@ -189,9 +189,9 @@ def _response_format_to_guided_decoding_params(
raise ValueError(
"The 'schema' field is required when response_format.type is 'json'."
)
return GuidedDecodingParams(json=response_format.schema)
return GuidedDecodingParams(json=response_format.schema, guidance_start_token_id=response_format.guidance_start_token_id)
elif response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
return GuidedDecodingParams(json_object=True, guidance_start_token_id=response_format.guidance_start_token_id)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
Expand Down
Loading