Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/tools/systemMessageTools/detectToolCallStart.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import { SystemMessageToolsFramework } from "./types";

interface DetectToolCallStartOptions {
allowAlternateStarts?: boolean;
}

export function detectToolCallStart(
buffer: string,
toolCallFramework: SystemMessageToolsFramework,
options: DetectToolCallStartOptions = {},
) {
const allowAlternateStarts = options.allowAlternateStarts ?? true;
const starts = toolCallFramework.acceptedToolCallStarts;
let modifiedBuffer = buffer;
let isInToolCall = false;
let isInPartialStart = false;
const lowerCaseBuffer = buffer.toLowerCase();
for (let i = 0; i < starts.length; i++) {
if (i !== 0 && !allowAlternateStarts) {
continue;
}

const [start, replacement] = starts[i];
if (lowerCaseBuffer.startsWith(start)) {
// for non-standard cases like no ```tool codeblock, etc, replace before adding to buffer, case insensitive
Expand Down
9 changes: 8 additions & 1 deletion core/tools/systemMessageTools/interceptSystemToolCalls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export async function* interceptSystemToolCalls(
): AsyncGenerator<ChatMessage[], PromptLog | undefined> {
let buffer = "";
let parseState: ToolCallParseState | undefined;
let sawAssistantNonWhitespaceText = false;

while (true) {
const result = await messageGenerator.next();
Expand Down Expand Up @@ -71,7 +72,10 @@ export async function* interceptSystemToolCalls(
buffer += chunk;
if (!parseState) {
const { isInPartialStart, isInToolCall, modifiedBuffer } =
detectToolCallStart(buffer, systemToolFramework);
detectToolCallStart(buffer, systemToolFramework, {
// Only allow loose "TOOL_NAME:" starts at the beginning of assistant output.
allowAlternateStarts: !sawAssistantNonWhitespaceText,
});

if (isInPartialStart) {
continue;
Expand Down Expand Up @@ -109,6 +113,9 @@ export async function* interceptSystemToolCalls(
content: [{ type: "text", text: buffer }],
},
];
if (/\S/.test(buffer)) {
sawAssistantNonWhitespaceText = true;
}
}
buffer = "";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,15 @@ describe("detectToolCallStart", () => {
expect(result.isInPartialStart).toBe(false);
expect(result.modifiedBuffer).toBe(buffer);
});

it("skips non-standard starts when alternate starts are disabled", () => {
const buffer = "TOOL_NAME: example_tool";
const result = detectToolCallStart(buffer, framework, {
allowAlternateStarts: false,
});

expect(result.isInToolCall).toBe(false);
expect(result.isInPartialStart).toBe(false);
expect(result.modifiedBuffer).toBe(buffer);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,8 @@ describe("interceptSystemToolCalls", () => {
).toBe("}");
});

it("processes tool_name without codeblock format", async () => {
it("processes tool_name without codeblock format at assistant output start", async () => {
const messages: ChatMessage[][] = [
[{ role: "assistant", content: "I'll help you with that.\n" }],
[{ role: "assistant", content: "TOOL_NAME: test_tool\n" }],
[{ role: "assistant", content: "BEGIN_ARG: arg1\n" }],
[{ role: "assistant", content: "value1\n" }],
Expand All @@ -194,30 +193,7 @@ describe("interceptSystemToolCalls", () => {
framework,
);

// First chunk should be normal text
let result = await generator.next();
expect(result.value).toEqual([
{
role: "assistant",
content: [{ type: "text", text: "I'll help you with that." }],
},
]);

result = await generator.next();
expect(result.value).toEqual([
{
role: "assistant",
content: [
{
type: "text",
text: "\n",
},
],
},
]);

// The system should detect the tool_name format and convert it
result = await generator.next();
expect(
(result.value as AssistantChatMessage[])[0].toolCalls?.[0].function?.name,
).toBe("test_tool");
Expand All @@ -242,6 +218,43 @@ describe("interceptSystemToolCalls", () => {
).toBe("}");
});

it("does not intercept quoted tool syntax in explanatory text", async () => {
const messages: ChatMessage[][] = [
[{ role: "assistant", content: "Here is the syntax:\n" }],
[{ role: "assistant", content: "TOOL_NAME: read_file\n" }],
[{ role: "assistant", content: "BEGIN_ARG: filepath\n" }],
[{ role: "assistant", content: "path/to/the_file.txt\n" }],
[{ role: "assistant", content: "END_ARG\n" }],
];

const generator = interceptSystemToolCalls(
createAsyncGenerator(messages),
abortController,
framework,
);

const outputChunks: string[] = [];
while (true) {
const result = await generator.next();
if (result.done || !result.value) {
break;
}

const chunkText = (
(result.value as AssistantChatMessage[])[0].content as {
type: "text";
text: string;
}[]
)[0].text;
outputChunks.push(chunkText);
expect((result.value as AssistantChatMessage[])[0].toolCalls).toBeFalsy();
}

expect(outputChunks.join("")).toBe(
"Here is the syntax:\nTOOL_NAME: read_file\nBEGIN_ARG: filepath\npath/to/the_file.txt\nEND_ARG\n",
);
});

it("ignores content after a tool call", async () => {
const messages: ChatMessage[][] = [
[{ role: "assistant", content: "```tool\n" }],
Expand Down
Loading