Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,22 @@ jobs:
Release,
Debug,
]
sanitizer: [
none,
address,
thread,
undefined,
]
exclude:
# Sanitizers not supported on Clang targeting MSVC (llvm-arm64)
- setup: { build: 'llvm-arm64' }
sanitizer: address
- setup: { build: 'llvm-arm64' }
sanitizer: thread
- setup: { build: 'llvm-arm64' }
sanitizer: undefined
runs-on: ${{ matrix.setup.os }}
name: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}
name: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}-sanitizer-${{ matrix.sanitizer }}
timeout-minutes: 30

steps:
Expand All @@ -58,7 +72,7 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.11
with:
key: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}
key: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}-sanitizer-${{ matrix.sanitizer }}

- name: Set up CMake
uses: lukka/get-cmake@latest
Expand All @@ -75,7 +89,7 @@ jobs:
- name: Configure CMake
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: cmake -B ${{github.workspace}}/build ${{ matrix.setup.defines }} -DCMAKE_BUILD_TYPE=${{ matrix.type }}
run: cmake -B ${{github.workspace}}/build ${{ matrix.setup.defines }} -DCMAKE_BUILD_TYPE=${{ matrix.type }} -DMINJA_SANITIZER=${{ matrix.sanitizer }}

- name: Build
run: cmake --build ${{github.workspace}}/build --config ${{ matrix.type }} --parallel
Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ option(MINJA_EXAMPLE_ENABLED "minja: Build with example"
option(MINJA_FUZZTEST_ENABLED "minja: fuzztests enabled" MINJA_FUZZTEST_ENABLED_DEFAULT)
option(MINJA_FUZZTEST_FUZZING_MODE "minja: run fuzztests (if enabled) in fuzzing mode" OFF)
option(MINJA_USE_VENV "minja: use Python venv for build" MINJA_USE_VENV_DEFAULT)
set(MINJA_SANITIZERS thread address undefined none)
set(MINJA_SANITIZER none CACHE STRING "minja: sanitizer to use")
set_property(CACHE MINJA_SANITIZER PROPERTY STRINGS ${MINJA_SANITIZERS})

if (NOT MSVC AND NOT CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "MSVC" AND NOT MINJA_SANITIZER STREQUAL "none")
message(STATUS "Using -fsanitize=${MINJA_SANITIZER}")
add_compile_options("-fsanitize=${MINJA_SANITIZER}")
link_libraries ("-fsanitize=${MINJA_SANITIZER}")
endif()

set(CMAKE_CXX_STANDARD 17)

Expand Down
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# minja.hpp - A minimalistic C++ Jinja templating engine for LLM chat templates

_**This is not an official Google product**_
_**Used to be at https://github.com/google/minja, but I've left Google and I'll only maintain my fork from now on**_

Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016), [Jan](https://jan.ai/) (through [cortex.cpp](https://github.com/menloresearch/cortex.cpp/pull/1814)), [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433) and [Docker Model Runner](https://github.com/docker/model-runner)).

Expand Down Expand Up @@ -212,6 +212,26 @@ Main limitations (non-exhaustive list):
./scripts/fuzzing_tests.sh
```

- Sanitizer tests:

```bash
for sanitizer in ADDRESS THREAD UNDEFINED ; do
docker run --rm \
-v "$PWD":/src:ro \
-v "$PWD/build-sanitizer-${sanitizer}":/src/build \
-w /src \
"$(echo "
FROM ghcr.io/astral-sh/uv:debian-slim
RUN apt-get update && apt-get install -y build-essential libcurl4-openssl-dev cmake clang-tidy
" | docker build . -q -f - )" \
bash -c "
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DMINJA_SANITIZER=${sanitizer} && \
cmake --build build -j --config Debug && \
ctest --test-dir build -j -C Debug --output-on-failure
"
done
```

- If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new):

- Is the template using any unsupported filter / test / method / global function, and which one(s)?
Expand Down
11 changes: 9 additions & 2 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,25 @@ class chat_template {
};
};
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
const auto contains_arg_needle = [&](const std::string & out_str) {
return contains(out_str, "<parameter=argument_needle>")
|| contains(out_str, "\"argument_needle\":")
|| contains(out_str, "'argument_needle':")
|| contains(out_str, ">argument_needle<")
|| contains(out_str, "<parameter name=\"argument_needle\">");
};

// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_str_arguments = contains_arg_needle(out);
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_obj_arguments = contains_arg_needle(out);

caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
Expand Down
63 changes: 37 additions & 26 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) {
}

/* Values that behave roughly like in Python. */
class Value : public std::enable_shared_from_this<Value> {
class Value {
public:
using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
Expand Down Expand Up @@ -158,12 +158,14 @@ class Value : public std::enable_shared_from_this<Value> {
Value(const json & v) {
if (v.is_object()) {
auto object = std::make_shared<ObjectType>();
object->reserve(v.size());
for (auto it = v.begin(); it != v.end(); ++it) {
(*object)[it.key()] = it.value();
object->emplace_back(it.key(), Value(it.value()));
}
object_ = std::move(object);
} else if (v.is_array()) {
auto array = std::make_shared<ArrayType>();
array->reserve(v.size());
for (const auto& item : v) {
array->push_back(Value(item));
}
Expand Down Expand Up @@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
return out.str();
}

class Context : public std::enable_shared_from_this<Context> {
class Context {
protected:
Value values_;
std::shared_ptr<Context> parent_;
Expand Down Expand Up @@ -850,12 +852,12 @@ struct LoopControlTemplateToken : public TemplateToken {

struct CallTemplateToken : public TemplateToken {
std::shared_ptr<Expression> expr;
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
};

struct EndCallTemplateToken : public TemplateToken {
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
: TemplateToken(Type::EndCall, loc, pre, post) {}
};

Expand Down Expand Up @@ -1060,11 +1062,18 @@ class MacroNode : public TemplateNode {
}
}
}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([this, macro_context](const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto execution_context = Context::make(Value::object(), macro_context);

// Use init-capture to avoid dangling 'this' pointer and circular references
auto callable = Value::callable([weak_context = std::weak_ptr<Context>(context),
name = name, params = params, body = body,
named_param_positions = named_param_positions]
(const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto context_locked = weak_context.lock();
if (!context_locked) throw std::runtime_error("Macro context no longer valid");
auto execution_context = Context::make(Value::object(), context_locked);

if (call_context->contains("caller")) {
execution_context->set("caller", call_context->get("caller"));
Expand All @@ -1075,7 +1084,7 @@ class MacroNode : public TemplateNode {
auto & arg = args.args[i];
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
param_set[i] = true;
auto & param_name = params[i].first;
const auto & param_name = params[i].first;
execution_context->set(param_name, arg);
}
for (auto & [arg_name, value] : args.kwargs) {
Expand All @@ -1094,7 +1103,7 @@ class MacroNode : public TemplateNode {
}
return body->render(execution_context);
});
macro_context->set(name->get_name(), callable);
context->set(name->get_name(), callable);
}
};

Expand Down Expand Up @@ -1264,7 +1273,7 @@ class SubscriptExpr : public Expression {
}
return result;

} else if (target_value.is_array()) {
} else if (target_value.is_array()) {
auto result = Value::array();
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result.push_back(target_value.at(i));
Expand Down Expand Up @@ -1313,7 +1322,7 @@ static bool in(const Value & value, const Value & container) {
return (((container.is_array() || container.is_object()) && container.contains(value)) ||
(value.is_string() && container.is_string() &&
container.to_str().find(value.to_str()) != std::string::npos));
};
}

class BinaryOpExpr : public Expression {
public:
Expand Down Expand Up @@ -1640,13 +1649,17 @@ class CallNode : public TemplateNode {
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("CallNode.expr is null");
if (!body) throw std::runtime_error("CallNode.body is null");

auto caller = Value::callable([this, context](const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
return Value(body->render(context));

// Use init-capture to avoid dangling 'this' pointer and circular references
auto caller = Value::callable([weak_context = std::weak_ptr<Context>(context), body=body]
(const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
auto context_locked = weak_context.lock();
if (!context_locked) throw std::runtime_error("Caller context no longer valid");
return Value(body->render(context_locked));
});

context->set("caller", caller);

auto call_expr = dynamic_cast<CallExpr*>(expr.get());
if (!call_expr) {
throw std::runtime_error("Invalid call block syntax - expected function call");
Expand All @@ -1657,7 +1670,7 @@ class CallNode : public TemplateNode {
throw std::runtime_error("Call target must be callable: " + function.dump());
}
ArgumentsValue args = call_expr->args.evaluate(context);

Value result = function.call(context, args);
out << result.to_str();
}
Expand Down Expand Up @@ -2192,7 +2205,7 @@ class Parser {

auto value = parseValue();

while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
while (it != end && consumeSpaces() && peekSymbols({ "[", ".", "(" })) {
if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index;
auto slice_loc = get_location();
Expand All @@ -2215,7 +2228,7 @@ class Parser {
}
}
}

if ((has_first_colon || has_second_colon)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
Expand All @@ -2237,15 +2250,13 @@ class Parser {
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
}
} else if (peekSymbols({ "(" })) {
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(get_location(), std::move(value), std::move(callParams));
}
consumeSpaces();
}

if (peekSymbols({ "(" })) {
auto location = get_location();
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
}
return value;
}

Expand Down Expand Up @@ -2725,7 +2736,7 @@ inline std::shared_ptr<Context> Context::builtins() {
globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
throw std::runtime_error(args.at("message").get<std::string>());
}));
globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
globals.set("tojson", simple_function("tojson", { "value", "indent", "ensure_ascii" }, [](const std::shared_ptr<Context> &, Value & args) {
return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* to_json= */ true));
}));
globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
Expand Down
18 changes: 15 additions & 3 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
return now.strftime(format)

def tojson(value, indent=None, ensure_ascii=False, sort_keys=False):
return json.dumps(value, indent=indent, ensure_ascii=ensure_ascii, sort_keys=sort_keys)

def join_cmake_path(parent, child):
'''
Expand Down Expand Up @@ -119,8 +121,11 @@ def __init__(self, template, env=None, filters=None, global_functions=None):
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
extensions=[jinja2.ext.loopcontrols],
)
# https://jinja.palletsprojects.com/en/stable/api/#policies
env.policies["json.dumps_function"] = tojson
env.filters['tojson'] = tojson
if filters:
for name, func in filters.items():
env.filters[name] = func
Expand Down Expand Up @@ -187,17 +192,24 @@ def make_tool_call(tool_name, arguments):
}

dummy_args_obj = {"argument_needle": "print('Hello, World!')"}
contains_arg_needle = lambda out_str: (
"<parameter=argument_needle>" in out_str
or '"argument_needle":' in out_str
or "'argument_needle':" in out_str
or ">argument_needle<" in out_str
or "<parameter name=\"argument_needle\">" in out_str
)

out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]),
])
tool_call_renders_str_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out
tool_call_renders_str_arguments = contains_arg_needle(out)
out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]),
])
tool_call_renders_obj_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out
tool_call_renders_obj_arguments = contains_arg_needle(out)

caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments
caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments
Expand Down
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ set(MODEL_IDS
Qwen/Qwen3-235B-A22B-Thinking-2507
Qwen/Qwen3-Coder-30B-A3B-Instruct
Qwen/QwQ-32B
zai-org/GLM-4.6

# Broken, TODO:
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8
Expand All @@ -334,6 +335,7 @@ set(MODEL_IDS
# HuggingFaceTB/SmolVLM-256M-Instruct
# HuggingFaceTB/SmolVLM-500M-Instruct
# HuggingFaceTB/SmolVLM-Instruct
# unsloth/MiniMax-M2 # https://github.com/ochafik/minja/pull/7#issuecomment-3478459580
# meta-llama/Llama-3.2-11B-Vision-Instruct
# unsloth/DeepSeek-R1
)
Expand Down
Loading