Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3188,6 +3188,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-budget-force-tool-call"},
{"--no-reasoning-budget-force-tool-call"},
string_format(
"if the conversation contains defined tools, force the model to output a tool call immediately "
"after the thinking block is closed, if the close is forced by the thinking budget (default: disabled)"
),
[](common_params & params, bool value) {
params.sampling.reasoning_budget_force_tool = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_FORCE_TOOL"));
add_opt(common_arg(
{"--reasoning-block-tool-call-start"},
{"--no-reasoning-block-tool-call-start"},
string_format(
"after the thinking block is closed, if the close is forced by the thinking budget (default: disabled)"
),
[](common_params & params, bool value) {
params.sampling.reasoning_block_tool_start = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_PREVENT_TOOL_CALL"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
Expand Down
16 changes: 13 additions & 3 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.tool_start_tag = "[TOOL_CALLS]";

data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
Expand Down Expand Up @@ -1185,6 +1187,7 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.supports_thinking = true;
data.thinking_start_tag = "<|channel>thought";
data.thinking_end_tag = "<channel|>";
data.tool_start_tag = "<|tool_call>";

data.preserved_tokens = {
"<|channel>",
Expand Down Expand Up @@ -1471,6 +1474,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp

data.prompt += data.generation_prompt;
}
data.tool_start_tag = SECTION_BEGIN;

auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Kimi K2 Thinking format:
Expand Down Expand Up @@ -1591,6 +1595,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat

data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
data.tool_start_tag = TOOL_CALL_START;

if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
Expand Down Expand Up @@ -1675,10 +1680,12 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ

const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
const std::string TOOL_START = "<|tool_call_start|>";
const std::string GEN_PROMPT = "<|im_start|>assistant\n";

data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
data.tool_start_tag = TOOL_START;

if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
Expand Down Expand Up @@ -1710,8 +1717,8 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
)
);

auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["}));
auto maybe_start = p.optional(p.literal("<|tool_call_start|>"));
auto content = p.content(p.until_one_of({TOOL_START, "["}));
auto maybe_start = p.optional(p.literal(TOOL_START));
return generation_prompt + reasoning + content + maybe_start + tool_calls + end;
});

Expand Down Expand Up @@ -1846,6 +1853,8 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha
const std::string INVOKE_END = "</" + DSML + "invoke>";
const std::string PARAM_START = "<" + DSML + "parameter";
const std::string PARAM_END = "</" + DSML + "parameter>";

data.tool_start_tag = FC_START;
const std::string GEN_PROMPT = "<|Assistant|>";

if (inputs.has_continuation()) {
Expand Down Expand Up @@ -2396,7 +2405,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
auto_params.thinking_end_tag = autoparser.reasoning.end;
auto_params.tool_start_tag = autoparser.tools.format.section_start.empty() ? autoparser.tools.format.per_call_start : autoparser.tools.format.section_start;
}
common_peg_arena arena;
arena.load(auto_params.parser);
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
std::string tool_start_tag; // e.g., "<tool_calls>"
};

// per-message parsing syntax
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ struct common_params_sampling {
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool reasoning_budget_force_tool; // force tool call after reasoning forcibly ends
bool reasoning_block_tool_start; // block tool call markers in reasoning block
std::string tool_call_start; // the starting marker for tool calls

bool backend_sampling = false;

Expand Down
14 changes: 13 additions & 1 deletion common/reasoning-budget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct common_reasoning_budget_ctx {
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
std::set<llama_token> blocked_tokens;

int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
Expand Down Expand Up @@ -144,6 +145,14 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;

if (ctx->state != REASONING_BUDGET_FORCING) {
// if we have blocked tokens and we're during reasoning, force all blocked tokens to -inf
if (ctx->state == REASONING_BUDGET_COUNTING && !ctx->blocked_tokens.empty()) {
for (size_t i = 0; i < cur_p->size; i++) {
if (ctx->blocked_tokens.find(cur_p->data[i].id) != ctx->blocked_tokens.end()) {
cur_p->data[i].logit = -INFINITY;
}
}
}
// passthrough — don't modify logits
return;
}
Expand Down Expand Up @@ -209,6 +218,7 @@ static struct llama_sampler * common_reasoning_budget_init_state(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::set<llama_token> & blocked_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
Expand All @@ -223,6 +233,7 @@ static struct llama_sampler * common_reasoning_budget_init_state(
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .blocked_tokens= */ blocked_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
Expand All @@ -236,9 +247,10 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::set<llama_token> & blocked_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, blocked_tokens, budget, initial_state);
}

common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
Expand Down
3 changes: 3 additions & 0 deletions common/reasoning-budget.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama.h"

#include <cstdint>
#include <set>
#include <vector>

enum common_reasoning_budget_state {
Expand All @@ -28,6 +29,7 @@ enum common_reasoning_budget_state {
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// blocked_tokens - tokens that should be disallowed during reasoning
// budget - max tokens allowed in the reasoning block
// initial_state - initial state
//
Expand All @@ -36,6 +38,7 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::set<llama_token> & blocked_tokens,
int32_t budget,
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);

Expand Down
16 changes: 13 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}

// reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression)
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) {
// reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for
// thinking-block suppression, or we are forcing blocking tool-start tokens)
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0 || params.reasoning_block_tool_start)) {
std::set<llama_token> blocked_tokens;
if (params.reasoning_block_tool_start && !params.tool_call_start.empty()) {
auto tstart = params.tool_call_start;
auto tstart_tokens = common_tokenize(vocab, tstart, true, false);
if (!tstart_tokens.empty()) {
blocked_tokens.insert(tstart_tokens[0]);
}
}
rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
blocked_tokens,
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);

for (const auto & token : prefill_tokens) {
Expand Down Expand Up @@ -433,7 +443,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
if (gsmpl->params.grammar_lazy) {
// if grammar is lazy, only apply when reasoning budget is not active
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE || state == REASONING_BUDGET_FORCING;
}
return true;
}
Expand Down
80 changes: 76 additions & 4 deletions tests/test-reasoning-budget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <set>
#include <string>
#include <vector>

Expand All @@ -26,7 +27,8 @@ static void test_reasoning_budget(
int32_t budget,
common_reasoning_budget_state initial_state,
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
size_t expected_force_end, // token index where forcing should end (after this, no more forcing)
const std::set<llama_token> & blocked_tokens = {}
) {
// Find the maximum token ID to ensure our vocab covers all tokens
llama_token max_token = 0;
Expand All @@ -43,6 +45,7 @@ static void test_reasoning_budget(
start_tokens,
end_tokens,
forced_tokens,
blocked_tokens,
budget,
initial_state
);
Expand Down Expand Up @@ -152,7 +155,7 @@ static void test_reasoning_budget_clone_mid_counting() {
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};

auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 2, REASONING_BUDGET_IDLE);
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, {}, 2, REASONING_BUDGET_IDLE);

llama_sampler_accept(sampler, 100); // COUNTING, remaining=2
llama_sampler_accept(sampler, 50); // COUNTING, remaining=1
Expand All @@ -171,7 +174,7 @@ static void test_reasoning_budget_clone_mid_forcing() {
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};

auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, {}, 0, REASONING_BUDGET_FORCING);

GGML_ASSERT(get_forced_token(sampler, 102) == 102);
llama_sampler_accept(sampler, 102); // advance to the second forced token
Expand All @@ -184,6 +187,74 @@ static void test_reasoning_budget_clone_mid_forcing() {
llama_sampler_free(sampler);
}

// Verify that tokens in `blocked_tokens` are suppressed only while we are
// counting down the reasoning budget. In IDLE / DONE the sampler must
// passthrough (blocked tokens stay finite); in COUNTING they get -INFINITY.
static void test_reasoning_budget_blocked_tokens() {
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::set<llama_token> blocked = {200, 201};
const llama_token max_token = 250;

auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, blocked, 10, REASONING_BUDGET_IDLE);

auto check = [&](bool expect_blocked, const char * label) {
std::vector<llama_token_data> cur;
cur.reserve((size_t) max_token + 1);
for (size_t i = 0; i <= (size_t) max_token; i++) {
cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(sampler, &cur_p);

for (llama_token t : {200, 201}) {
const bool is_blocked = !std::isfinite(cur[t].logit);
if (is_blocked != expect_blocked) {
fprintf(stderr, "blocked_tokens test FAILED at '%s': token %d expected blocked=%d, got blocked=%d\n",
label, (int) t, expect_blocked ? 1 : 0, is_blocked ? 1 : 0);
GGML_ASSERT(false && "blocked token logit mismatch");
}
}
// a non-blocked, non-forced token should remain finite outside FORCING
GGML_ASSERT(std::isfinite(cur[50].logit) && "non-blocked token was unexpectedly masked");
};

// IDLE: passthrough — blocked tokens must not be touched yet
check(false, "IDLE before start");

// Enter COUNTING via the start token
llama_sampler_accept(sampler, 100);
check(true, "COUNTING after start");

// Still COUNTING after a normal token inside the reasoning block
llama_sampler_accept(sampler, 50);
check(true, "COUNTING mid-block");

// Natural end → DONE, blocking must stop
llama_sampler_accept(sampler, 101);
check(false, "DONE after natural end");

llama_sampler_free(sampler);

// Sanity: with an empty blocked set, COUNTING still passes everything through.
auto * sampler_empty = common_reasoning_budget_init(nullptr, start, end, forced, {}, 10, REASONING_BUDGET_IDLE);
llama_sampler_accept(sampler_empty, 100); // COUNTING
{
std::vector<llama_token_data> cur;
cur.reserve((size_t) max_token + 1);
for (size_t i = 0; i <= (size_t) max_token; i++) {
cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(sampler_empty, &cur_p);
for (size_t i = 0; i < cur.size(); i++) {
GGML_ASSERT(std::isfinite(cur[i].logit) && "empty blocked set should not mask any token");
}
}
llama_sampler_free(sampler_empty);
}

// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
Expand Down Expand Up @@ -312,8 +383,9 @@ int main(void) {

test_reasoning_budget_clone_mid_counting();
test_reasoning_budget_clone_mid_forcing();
test_reasoning_budget_blocked_tokens();

printf("OK (8 tests passed)\n");
printf("OK (9 tests passed)\n");

printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
Expand Down
2 changes: 2 additions & 0 deletions tools/cli/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct cli_context {

task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
task.params.sampling.generation_prompt = chat_params.generation_prompt;
task.params.sampling.reasoning_block_tool_start = defaults.sampling.reasoning_block_tool_start;
task.params.sampling.tool_call_start = chat_params.tool_start_tag;

if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =
Expand Down
3 changes: 3 additions & 0 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,10 @@ json oaicompat_chat_params_parse(
llama_params["reasoning_budget_tokens"] = reasoning_budget;
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
llama_params["reasoning_budget_tool_start_tag"] = chat_params.tool_start_tag;
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
llama_params["reasoning_budget_force_tool"] = opt.reasoning_budget_force_tool;
llama_params["reasoning_block_tool_start"] = opt.reasoning_block_tool_start;
}
}

Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ struct server_chat_params {
int reasoning_budget = -1;
std::string reasoning_budget_message;
std::string media_path;
bool reasoning_budget_force_tool = false;
bool reasoning_block_tool_start = false;
bool force_pure_content = false;
};

Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,8 @@ struct server_context_impl {
/* reasoning_budget */ params_base.sampling.reasoning_budget_tokens,
/* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message,
/* media_path */ params_base.media_path,
/* reasoning_force_tool */ params_base.sampling.reasoning_budget_force_tool,
/* reasoning_block_tool */ params_base.sampling.reasoning_block_tool_start,
/* force_pure_content */ params_base.force_pure_content_parser
};
}
Expand Down
Loading
Loading