@@ -12,6 +12,7 @@ import {
1212 ToolFailure ,
1313 ToolResultPart ,
1414 type ToolResultValue ,
15+ Usage ,
1516} from "./schema"
1617import { type AnyTool , type ExecutableTools , type Tools , toDefinitions } from "./tool"
1718
@@ -72,19 +73,42 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
7273 tools : [ ...options . request . tools . filter ( ( tool ) => ! runtimeToolNames . has ( tool . name ) ) , ...runtimeTools ] ,
7374 } )
7475
75- const loop = ( request : LLMRequest , step : number ) : Stream . Stream < LLMEvent , LLMError > =>
76+ const loop = (
77+ request : LLMRequest ,
78+ step : number ,
79+ usage : Usage | undefined ,
80+ providerMetadata : ProviderMetadata | undefined ,
81+ ) : Stream . Stream < LLMEvent , LLMError > =>
7682 Stream . unwrap (
7783 Effect . gen ( function * ( ) {
78- const state : StepState = { assistantContent : [ ] , toolCalls : [ ] , finishReason : undefined }
84+ const state : StepState = {
85+ assistantContent : [ ] ,
86+ toolCalls : [ ] ,
87+ finishReason : undefined ,
88+ usage : undefined ,
89+ providerMetadata : undefined ,
90+ }
7991
8092 const modelStream = options
8193 . stream ( request )
94+ . pipe ( Stream . map ( ( event ) => indexStep ( event , step ) ) )
8295 . pipe ( Stream . tap ( ( event ) => Effect . sync ( ( ) => accumulate ( state , event ) ) ) )
96+ . pipe ( Stream . filter ( ( event ) => event . type !== "finish" ) )
8397
8498 const continuation = Stream . unwrap (
8599 Effect . gen ( function * ( ) {
86- if ( state . finishReason !== "tool-calls" || state . toolCalls . length === 0 ) return Stream . empty
87- if ( options . toolExecution === "none" ) return Stream . empty
100+ const totalUsage = addUsage ( usage , state . usage )
101+ const totalProviderMetadata = mergeProviderMetadata ( providerMetadata , state . providerMetadata )
102+ const finishStream = Stream . fromIterable ( [
103+ LLMEvent . finish ( {
104+ reason : state . finishReason ?? "unknown" ,
105+ usage : totalUsage ,
106+ providerMetadata : totalProviderMetadata ,
107+ } ) ,
108+ ] )
109+
110+ if ( state . finishReason !== "tool-calls" || state . toolCalls . length === 0 ) return finishStream
111+ if ( options . toolExecution === "none" ) return finishStream
88112
89113 const dispatched = yield * Effect . forEach (
90114 state . toolCalls ,
@@ -93,24 +117,36 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
93117 )
94118 const resultStream = Stream . fromIterable ( dispatched . flatMap ( ( [ call , result ] ) => emitEvents ( call , result ) ) )
95119
96- if ( ! options . stopWhen ) return resultStream
97- if ( options . stopWhen ( { step, request } ) ) return resultStream
120+ if ( ! options . stopWhen ) return resultStream . pipe ( Stream . concat ( finishStream ) )
121+ if ( options . stopWhen ( { step, request } ) ) return resultStream . pipe ( Stream . concat ( finishStream ) )
98122
99- return resultStream . pipe ( Stream . concat ( loop ( followUpRequest ( request , state , dispatched ) , step + 1 ) ) )
123+ return resultStream . pipe (
124+ Stream . concat (
125+ loop ( followUpRequest ( request , state , dispatched ) , step + 1 , totalUsage , totalProviderMetadata ) ,
126+ ) ,
127+ )
100128 } ) ,
101129 )
102130
103131 return modelStream . pipe ( Stream . concat ( continuation ) )
104132 } ) ,
105133 )
106134
107- return loop ( initialRequest , 0 )
135+ return loop ( initialRequest , 0 , undefined , undefined )
136+ }
137+
138+ const indexStep = ( event : LLMEvent , index : number ) : LLMEvent => {
139+ if ( event . type === "step-start" ) return LLMEvent . stepStart ( { index } )
140+ if ( event . type === "step-finish" ) return LLMEvent . stepFinish ( { ...event , index } )
141+ return event
108142}
109143
110144interface StepState {
111145 assistantContent : ContentPart [ ]
112146 toolCalls : ToolCallPart [ ]
113147 finishReason : FinishReason | undefined
148+ usage : Usage | undefined
149+ providerMetadata : ProviderMetadata | undefined
114150}
115151
116152const accumulate = ( state : StepState , event : LLMEvent ) => {
@@ -154,11 +190,45 @@ const accumulate = (state: StepState, event: LLMEvent) => {
154190 )
155191 return
156192 }
157- if ( event . type === "step-finish" || event . type === "request-finish" ) {
193+ if ( event . type === "step-finish" ) {
158194 state . finishReason = event . reason === "stop" && state . toolCalls . length > 0 ? "tool-calls" : event . reason
195+ state . usage = addUsage ( state . usage , event . usage )
196+ state . providerMetadata = mergeProviderMetadata ( state . providerMetadata , event . providerMetadata )
197+ return
198+ }
199+ if ( event . type === "finish" ) {
200+ state . finishReason ??= event . reason
201+ state . usage ??= event . usage
202+ state . providerMetadata = mergeProviderMetadata ( state . providerMetadata , event . providerMetadata )
159203 }
160204}
161205
206+ const addUsage = ( left : Usage | undefined , right : Usage | undefined ) => {
207+ if ( ! left ) return right
208+ if ( ! right ) return left
209+ type UsageKey =
210+ | "inputTokens"
211+ | "outputTokens"
212+ | "nonCachedInputTokens"
213+ | "cacheReadInputTokens"
214+ | "cacheWriteInputTokens"
215+ | "reasoningTokens"
216+ | "totalTokens"
217+ const sum = ( key : UsageKey ) =>
218+ left [ key ] === undefined && right [ key ] === undefined ? undefined : Number ( left [ key ] ?? 0 ) + Number ( right [ key ] ?? 0 )
219+
220+ return new Usage ( {
221+ inputTokens : sum ( "inputTokens" ) ,
222+ outputTokens : sum ( "outputTokens" ) ,
223+ nonCachedInputTokens : sum ( "nonCachedInputTokens" ) ,
224+ cacheReadInputTokens : sum ( "cacheReadInputTokens" ) ,
225+ cacheWriteInputTokens : sum ( "cacheWriteInputTokens" ) ,
226+ reasoningTokens : sum ( "reasoningTokens" ) ,
227+ totalTokens : sum ( "totalTokens" ) ,
228+ providerMetadata : mergeProviderMetadata ( left . providerMetadata , right . providerMetadata ) ,
229+ } )
230+ }
231+
162232const sameProviderMetadata = ( left : ProviderMetadata | undefined , right : ProviderMetadata | undefined ) =>
163233 left === right || JSON . stringify ( left ) === JSON . stringify ( right )
164234
@@ -200,17 +270,17 @@ const dispatch = (tools: Tools, call: ToolCallPart): Effect.Effect<ToolResultVal
200270 if ( ! tool . execute )
201271 return Effect . succeed ( { type : "error" as const , value : `Tool has no execute handler: ${ call . name } ` } )
202272
203- return decodeAndExecute ( tool , call . input ) . pipe (
273+ return decodeAndExecute ( tool , call ) . pipe (
204274 Effect . catchTag ( "LLM.ToolFailure" , ( failure ) =>
205275 Effect . succeed ( { type : "error" as const , value : failure . message } satisfies ToolResultValue ) ,
206276 ) ,
207277 )
208278}
209279
210- const decodeAndExecute = ( tool : AnyTool , input : unknown ) : Effect . Effect < ToolResultValue , ToolFailure > =>
211- tool . _decode ( input ) . pipe (
280+ const decodeAndExecute = ( tool : AnyTool , call : ToolCallPart ) : Effect . Effect < ToolResultValue , ToolFailure > =>
281+ tool . _decode ( call . input ) . pipe (
212282 Effect . mapError ( ( error ) => new ToolFailure ( { message : `Invalid tool input: ${ error . message } ` } ) ) ,
213- Effect . flatMap ( ( decoded ) => tool . execute ! ( decoded ) ) ,
283+ Effect . flatMap ( ( decoded ) => tool . execute ! ( decoded , { id : call . id , name : call . name } ) ) ,
214284 Effect . flatMap ( ( value ) =>
215285 tool . _encode ( value ) . pipe (
216286 Effect . mapError (
0 commit comments