Skip to content

Commit dd14413

Browse files
authored
Preserve native LLM tool context (anomalyco#27116)
1 parent b9e7cbf commit dd14413

18 files changed

Lines changed: 244 additions & 75 deletions

packages/llm/example/tutorial.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ const streamText = LLM.stream(request).pipe(
7878
Stream.tap((event) =>
7979
Effect.sync(() => {
8080
if (event.type === "text-delta") process.stdout.write(`\ntext: ${event.text}`)
81-
if (event.type === "request-finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
81+
if (event.type === "finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
8282
}),
8383
),
8484
Stream.runDrain,
@@ -185,7 +185,7 @@ const FakeProtocol = Protocol.make<FakeBody, string, string, void>({
185185
event: Schema.String,
186186
initial: () => undefined,
187187
step: (_, frame) => Effect.succeed([undefined, [{ type: "text-delta", id: "text-0", text: frame }]] as const),
188-
onHalt: () => [{ type: "request-finish", reason: "stop" }],
188+
onHalt: () => [{ type: "finish", reason: "stop" }],
189189
},
190190
})
191191

packages/llm/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export type {
1717
ExecutableTools,
1818
Tool as ToolShape,
1919
ToolExecute,
20+
ToolExecuteContext,
2021
Tools,
2122
ToolSchema,
2223
} from "./tool"

packages/llm/src/protocols/openai-responses.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ type StepResult = readonly [ParserState, ReadonlyArray<LLMEvent>]
380380
const NO_EVENTS: StepResult["1"] = []
381381

382382
// `response.completed` / `response.incomplete` are clean finishes that emit a
383-
// `request-finish` event; `response.failed` is a hard failure that emits a
383+
// `finish` event; `response.failed` is a hard failure that emits a
384384
// `provider-error`. All three end the stream — kept in one set so `step` and
385385
// the protocol's `terminal` predicate stay in sync.
386386
const TERMINAL_TYPES = new Set(["response.completed", "response.incomplete", "response.failed"])

packages/llm/src/protocols/utils/lifecycle.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ export const finish = (
8080
usage: input.usage,
8181
providerMetadata: input.providerMetadata,
8282
}),
83-
LLMEvent.requestFinish(input),
83+
LLMEvent.finish(input),
8484
)
8585
return { ...stepped, stepStarted: false }
8686
}

packages/llm/src/schema/events.ts

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Schema } from "effect"
2-
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, ResponseID, RouteID, ToolCallID } from "./ids"
2+
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, RouteID, ToolCallID } from "./ids"
33
import { ModelRef } from "./options"
44
import { ToolResultValue } from "./messages"
55

@@ -66,14 +66,13 @@ export class Usage extends Schema.Class<Usage>("LLM.Usage")({
6666
get visibleOutputTokens() {
6767
return Math.max(0, (this.outputTokens ?? 0) - (this.reasoningTokens ?? 0))
6868
}
69+
70+
static from(input: UsageInput) {
71+
return input instanceof Usage ? input : new Usage(input)
72+
}
6973
}
7074

71-
export const RequestStart = Schema.Struct({
72-
type: Schema.tag("request-start"),
73-
id: ResponseID,
74-
model: ModelRef,
75-
}).annotate({ identifier: "LLM.Event.RequestStart" })
76-
export type RequestStart = Schema.Schema.Type<typeof RequestStart>
75+
export type UsageInput = Usage | ConstructorParameters<typeof Usage>[0]
7776

7877
export const StepStart = Schema.Struct({
7978
type: Schema.tag("step-start"),
@@ -185,13 +184,13 @@ export const StepFinish = Schema.Struct({
185184
}).annotate({ identifier: "LLM.Event.StepFinish" })
186185
export type StepFinish = Schema.Schema.Type<typeof StepFinish>
187186

188-
export const RequestFinish = Schema.Struct({
189-
type: Schema.tag("request-finish"),
187+
export const Finish = Schema.Struct({
188+
type: Schema.tag("finish"),
190189
reason: FinishReason,
191190
usage: Schema.optional(Usage),
192191
providerMetadata: Schema.optional(ProviderMetadata),
193-
}).annotate({ identifier: "LLM.Event.RequestFinish" })
194-
export type RequestFinish = Schema.Schema.Type<typeof RequestFinish>
192+
}).annotate({ identifier: "LLM.Event.Finish" })
193+
export type Finish = Schema.Schema.Type<typeof Finish>
195194

196195
export const ProviderErrorEvent = Schema.Struct({
197196
type: Schema.tag("provider-error"),
@@ -202,7 +201,6 @@ export const ProviderErrorEvent = Schema.Struct({
202201
export type ProviderErrorEvent = Schema.Schema.Type<typeof ProviderErrorEvent>
203202

204203
const llmEventTagged = Schema.Union([
205-
RequestStart,
206204
StepStart,
207205
TextStart,
208206
TextDelta,
@@ -217,13 +215,15 @@ const llmEventTagged = Schema.Union([
217215
ToolResult,
218216
ToolError,
219217
StepFinish,
220-
RequestFinish,
218+
Finish,
221219
ProviderErrorEvent,
222220
]).pipe(Schema.toTaggedUnion("type"))
223221

224222
type WithID<Event extends { readonly id: unknown }, ID> = Omit<Event, "type" | "id"> & { readonly id: ID | string }
223+
type WithUsage<Event extends { readonly usage?: Usage }> = Omit<Event, "type" | "usage"> & {
224+
readonly usage?: UsageInput
225+
}
225226

226-
const responseID = (value: ResponseID | string) => ResponseID.make(value)
227227
const contentBlockID = (value: ContentBlockID | string) => ContentBlockID.make(value)
228228
const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)
229229

@@ -233,7 +233,6 @@ const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)
233233
* `events.filter(LLMEvent.guards["tool-call"])`.
234234
*/
235235
export const LLMEvent = Object.assign(llmEventTagged, {
236-
requestStart: (input: WithID<RequestStart, ResponseID>) => RequestStart.make({ ...input, id: responseID(input.id) }),
237236
stepStart: StepStart.make,
238237
textStart: (input: WithID<TextStart, ContentBlockID>) => TextStart.make({ ...input, id: contentBlockID(input.id) }),
239238
textDelta: (input: WithID<TextDelta, ContentBlockID>) => TextDelta.make({ ...input, id: contentBlockID(input.id) }),
@@ -252,11 +251,18 @@ export const LLMEvent = Object.assign(llmEventTagged, {
252251
toolCall: (input: WithID<ToolCall, ToolCallID>) => ToolCall.make({ ...input, id: toolCallID(input.id) }),
253252
toolResult: (input: WithID<ToolResult, ToolCallID>) => ToolResult.make({ ...input, id: toolCallID(input.id) }),
254253
toolError: (input: WithID<ToolError, ToolCallID>) => ToolError.make({ ...input, id: toolCallID(input.id) }),
255-
stepFinish: StepFinish.make,
256-
requestFinish: RequestFinish.make,
254+
stepFinish: (input: WithUsage<StepFinish>) =>
255+
StepFinish.make({
256+
...input,
257+
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
258+
}),
259+
finish: (input: WithUsage<Finish>) =>
260+
Finish.make({
261+
...input,
262+
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
263+
}),
257264
providerError: ProviderErrorEvent.make,
258265
is: {
259-
requestStart: llmEventTagged.guards["request-start"],
260266
stepStart: llmEventTagged.guards["step-start"],
261267
textStart: llmEventTagged.guards["text-start"],
262268
textDelta: llmEventTagged.guards["text-delta"],
@@ -271,7 +277,7 @@ export const LLMEvent = Object.assign(llmEventTagged, {
271277
toolResult: llmEventTagged.guards["tool-result"],
272278
toolError: llmEventTagged.guards["tool-error"],
273279
stepFinish: llmEventTagged.guards["step-finish"],
274-
requestFinish: llmEventTagged.guards["request-finish"],
280+
finish: llmEventTagged.guards.finish,
275281
providerError: llmEventTagged.guards["provider-error"],
276282
},
277283
})

packages/llm/src/tool-runtime.ts

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
ToolFailure,
1313
ToolResultPart,
1414
type ToolResultValue,
15+
Usage,
1516
} from "./schema"
1617
import { type AnyTool, type ExecutableTools, type Tools, toDefinitions } from "./tool"
1718

@@ -72,19 +73,42 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
7273
tools: [...options.request.tools.filter((tool) => !runtimeToolNames.has(tool.name)), ...runtimeTools],
7374
})
7475

75-
const loop = (request: LLMRequest, step: number): Stream.Stream<LLMEvent, LLMError> =>
76+
const loop = (
77+
request: LLMRequest,
78+
step: number,
79+
usage: Usage | undefined,
80+
providerMetadata: ProviderMetadata | undefined,
81+
): Stream.Stream<LLMEvent, LLMError> =>
7682
Stream.unwrap(
7783
Effect.gen(function* () {
78-
const state: StepState = { assistantContent: [], toolCalls: [], finishReason: undefined }
84+
const state: StepState = {
85+
assistantContent: [],
86+
toolCalls: [],
87+
finishReason: undefined,
88+
usage: undefined,
89+
providerMetadata: undefined,
90+
}
7991

8092
const modelStream = options
8193
.stream(request)
94+
.pipe(Stream.map((event) => indexStep(event, step)))
8295
.pipe(Stream.tap((event) => Effect.sync(() => accumulate(state, event))))
96+
.pipe(Stream.filter((event) => event.type !== "finish"))
8397

8498
const continuation = Stream.unwrap(
8599
Effect.gen(function* () {
86-
if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return Stream.empty
87-
if (options.toolExecution === "none") return Stream.empty
100+
const totalUsage = addUsage(usage, state.usage)
101+
const totalProviderMetadata = mergeProviderMetadata(providerMetadata, state.providerMetadata)
102+
const finishStream = Stream.fromIterable([
103+
LLMEvent.finish({
104+
reason: state.finishReason ?? "unknown",
105+
usage: totalUsage,
106+
providerMetadata: totalProviderMetadata,
107+
}),
108+
])
109+
110+
if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return finishStream
111+
if (options.toolExecution === "none") return finishStream
88112

89113
const dispatched = yield* Effect.forEach(
90114
state.toolCalls,
@@ -93,24 +117,36 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
93117
)
94118
const resultStream = Stream.fromIterable(dispatched.flatMap(([call, result]) => emitEvents(call, result)))
95119

96-
if (!options.stopWhen) return resultStream
97-
if (options.stopWhen({ step, request })) return resultStream
120+
if (!options.stopWhen) return resultStream.pipe(Stream.concat(finishStream))
121+
if (options.stopWhen({ step, request })) return resultStream.pipe(Stream.concat(finishStream))
98122

99-
return resultStream.pipe(Stream.concat(loop(followUpRequest(request, state, dispatched), step + 1)))
123+
return resultStream.pipe(
124+
Stream.concat(
125+
loop(followUpRequest(request, state, dispatched), step + 1, totalUsage, totalProviderMetadata),
126+
),
127+
)
100128
}),
101129
)
102130

103131
return modelStream.pipe(Stream.concat(continuation))
104132
}),
105133
)
106134

107-
return loop(initialRequest, 0)
135+
return loop(initialRequest, 0, undefined, undefined)
136+
}
137+
138+
const indexStep = (event: LLMEvent, index: number): LLMEvent => {
139+
if (event.type === "step-start") return LLMEvent.stepStart({ index })
140+
if (event.type === "step-finish") return LLMEvent.stepFinish({ ...event, index })
141+
return event
108142
}
109143

110144
interface StepState {
111145
assistantContent: ContentPart[]
112146
toolCalls: ToolCallPart[]
113147
finishReason: FinishReason | undefined
148+
usage: Usage | undefined
149+
providerMetadata: ProviderMetadata | undefined
114150
}
115151

116152
const accumulate = (state: StepState, event: LLMEvent) => {
@@ -154,11 +190,45 @@ const accumulate = (state: StepState, event: LLMEvent) => {
154190
)
155191
return
156192
}
157-
if (event.type === "step-finish" || event.type === "request-finish") {
193+
if (event.type === "step-finish") {
158194
state.finishReason = event.reason === "stop" && state.toolCalls.length > 0 ? "tool-calls" : event.reason
195+
state.usage = addUsage(state.usage, event.usage)
196+
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
197+
return
198+
}
199+
if (event.type === "finish") {
200+
state.finishReason ??= event.reason
201+
state.usage ??= event.usage
202+
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
159203
}
160204
}
161205

206+
const addUsage = (left: Usage | undefined, right: Usage | undefined) => {
207+
if (!left) return right
208+
if (!right) return left
209+
type UsageKey =
210+
| "inputTokens"
211+
| "outputTokens"
212+
| "nonCachedInputTokens"
213+
| "cacheReadInputTokens"
214+
| "cacheWriteInputTokens"
215+
| "reasoningTokens"
216+
| "totalTokens"
217+
const sum = (key: UsageKey) =>
218+
left[key] === undefined && right[key] === undefined ? undefined : Number(left[key] ?? 0) + Number(right[key] ?? 0)
219+
220+
return new Usage({
221+
inputTokens: sum("inputTokens"),
222+
outputTokens: sum("outputTokens"),
223+
nonCachedInputTokens: sum("nonCachedInputTokens"),
224+
cacheReadInputTokens: sum("cacheReadInputTokens"),
225+
cacheWriteInputTokens: sum("cacheWriteInputTokens"),
226+
reasoningTokens: sum("reasoningTokens"),
227+
totalTokens: sum("totalTokens"),
228+
providerMetadata: mergeProviderMetadata(left.providerMetadata, right.providerMetadata),
229+
})
230+
}
231+
162232
const sameProviderMetadata = (left: ProviderMetadata | undefined, right: ProviderMetadata | undefined) =>
163233
left === right || JSON.stringify(left) === JSON.stringify(right)
164234

@@ -200,17 +270,17 @@ const dispatch = (tools: Tools, call: ToolCallPart): Effect.Effect<ToolResultVal
200270
if (!tool.execute)
201271
return Effect.succeed({ type: "error" as const, value: `Tool has no execute handler: ${call.name}` })
202272

203-
return decodeAndExecute(tool, call.input).pipe(
273+
return decodeAndExecute(tool, call).pipe(
204274
Effect.catchTag("LLM.ToolFailure", (failure) =>
205275
Effect.succeed({ type: "error" as const, value: failure.message } satisfies ToolResultValue),
206276
),
207277
)
208278
}
209279

210-
const decodeAndExecute = (tool: AnyTool, input: unknown): Effect.Effect<ToolResultValue, ToolFailure> =>
211-
tool._decode(input).pipe(
280+
const decodeAndExecute = (tool: AnyTool, call: ToolCallPart): Effect.Effect<ToolResultValue, ToolFailure> =>
281+
tool._decode(call.input).pipe(
212282
Effect.mapError((error) => new ToolFailure({ message: `Invalid tool input: ${error.message}` })),
213-
Effect.flatMap((decoded) => tool.execute!(decoded)),
283+
Effect.flatMap((decoded) => tool.execute!(decoded, { id: call.id, name: call.name })),
214284
Effect.flatMap((value) =>
215285
tool._encode(value).pipe(
216286
Effect.mapError(

packages/llm/src/tool.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Effect, JsonSchema, Schema } from "effect"
2-
import type { ToolDefinition as ToolDefinitionClass } from "./schema"
2+
import type { ToolCallPart, ToolDefinition as ToolDefinitionClass } from "./schema"
33
import { ToolDefinition, ToolFailure } from "./schema"
44

55
/**
@@ -8,9 +8,14 @@ import { ToolDefinition, ToolFailure } from "./schema"
88
* beyond pure data conversion belongs in the handler closure.
99
*/
1010
export type ToolSchema<T> = Schema.Codec<T, any, never, never>
11+
export interface ToolExecuteContext {
12+
readonly id: ToolCallPart["id"]
13+
readonly name: ToolCallPart["name"]
14+
}
1115

1216
export type ToolExecute<Parameters extends ToolSchema<any>, Success extends ToolSchema<any>> = (
1317
params: Schema.Schema.Type<Parameters>,
18+
context?: ToolExecuteContext,
1419
) => Effect.Effect<Schema.Schema.Type<Success>, ToolFailure>
1520

1621
/**
@@ -61,7 +66,7 @@ type TypedToolConfig = {
6166
type DynamicToolConfig = {
6267
readonly description: string
6368
readonly jsonSchema: JsonSchema.JsonSchema
64-
readonly execute?: (params: unknown) => Effect.Effect<unknown, ToolFailure>
69+
readonly execute?: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
6570
}
6671

6772
/**
@@ -110,7 +115,7 @@ export function make<Parameters extends ToolSchema<any>, Success extends ToolSch
110115
export function make(config: {
111116
readonly description: string
112117
readonly jsonSchema: JsonSchema.JsonSchema
113-
readonly execute: (params: unknown) => Effect.Effect<unknown, ToolFailure>
118+
readonly execute: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
114119
}): AnyExecutableTool
115120
export function make(config: {
116121
readonly description: string

packages/llm/test/adapter.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ const request = LLM.request({
5151

5252
const raiseEvent = (event: FakeEvent): import("../src/schema").LLMEvent =>
5353
event.type === "finish"
54-
? { type: "request-finish", reason: event.reason }
54+
? { type: "finish", reason: event.reason }
5555
: { type: "text-delta", id: "text-0", text: event.text }
5656

5757
const fakeProtocol = Protocol.make<FakeBody, FakeEvent, FakeEvent, void>({
@@ -112,8 +112,8 @@ describe("llm route", () => {
112112
const events = Array.from(yield* llm.stream(request).pipe(Stream.runCollect))
113113
const response = yield* llm.generate(request)
114114

115-
expect(events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
116-
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
115+
expect(events.map((event) => event.type)).toEqual(["text-delta", "finish"])
116+
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "finish"])
117117
}),
118118
)
119119

packages/llm/test/llm.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ describe("llm constructors", () => {
127127
LLMResponse.text({
128128
events: [
129129
{ type: "text-delta", id: "text-0", text: "hi" },
130-
{ type: "request-finish", reason: "stop" },
130+
{ type: "finish", reason: "stop" },
131131
],
132132
}),
133133
).toBe("hi")

0 commit comments

Comments
 (0)