@@ -3,7 +3,7 @@ use std::sync::Arc;
33
44use crate :: clickhouse:: { ClickHouseConnectionInfo , TableName } ;
55use crate :: embeddings:: { EmbeddingRequest , EmbeddingResponse } ;
6- use crate :: error:: { Error , ErrorDetails } ;
6+ use crate :: error:: { warn_discarded_cache_write , Error , ErrorDetails } ;
77use crate :: inference:: types:: file:: serialize_with_file_data;
88use crate :: inference:: types:: {
99 ContentBlockChunk , ContentBlockOutput , FinishReason , ModelInferenceRequest ,
@@ -13,7 +13,7 @@ use crate::model::StreamResponse;
1313use crate :: serde_util:: deserialize_json_string;
1414use blake3:: Hash ;
1515use clap:: ValueEnum ;
16- use serde:: de:: DeserializeOwned ;
16+ use serde:: de:: { DeserializeOwned , IgnoredAny } ;
1717use serde:: { Deserialize , Serialize } ;
1818use std:: fmt:: Debug ;
1919
@@ -210,11 +210,39 @@ pub struct CacheData<T: CacheOutput> {
210210/// to/from ClickHouse
211211/// We use a marker trait rather than an enum so that the expected type can be enforced by the caller
212212/// (e.g. `infer_stream` will never try to deserialize a `NonStreamingCacheData`)
213- pub trait CacheOutput { }
213+ pub trait CacheOutput {
214+ /// If this return `false`, then we'll log a warning and skip writing this entry to the cache
215+ fn should_write_to_cache ( & self ) -> bool ;
216+ }
214217
215- impl CacheOutput for StreamingCacheData { }
216- impl CacheOutput for NonStreamingCacheData { }
217- impl CacheOutput for EmbeddingCacheData { }
218+ impl CacheOutput for StreamingCacheData {
219+ fn should_write_to_cache ( & self ) -> bool {
220+ true
221+ }
222+ }
223+ impl CacheOutput for NonStreamingCacheData {
224+ fn should_write_to_cache ( & self ) -> bool {
225+ for block in & self . blocks {
226+ if let ContentBlockOutput :: ToolCall ( tool_call) = block {
227+ // We skip writing to the cache if the tool call arguments are not valid JSON
228+ // We're assuming that it's almost never useful to have an invalid tool call cached
229+ // (in particular, tensorzero is not being used with a provider/model that only ever
230+ // emits invalid json for its tool call arguments).
231+ // The invalid tool call will still be returned to the user, but we won't create a
232+ // cache entry, even if the user turned on caching.
233+ if serde_json:: from_str :: < IgnoredAny > ( & tool_call. arguments ) . is_err ( ) {
234+ return false ;
235+ }
236+ }
237+ }
238+ true
239+ }
240+ }
241+ impl CacheOutput for EmbeddingCacheData {
242+ fn should_write_to_cache ( & self ) -> bool {
243+ true
244+ }
245+ }
218246
219247#[ derive( Debug , Deserialize , Serialize ) ]
220248#[ serde( transparent) ]
@@ -237,6 +265,24 @@ pub struct StreamingCacheData {
237265 pub chunks : Vec < CachedProviderInferenceResponseChunk > ,
238266}
239267
268+ fn spawn_maybe_cache_write < T : Serialize + CacheOutput + Send + Sync + ' static > (
269+ row : FullCacheRow < T > ,
270+ clickhouse_client : ClickHouseConnectionInfo ,
271+ ) {
272+ tokio:: spawn ( async move {
273+ if row. data . output . should_write_to_cache ( ) {
274+ if let Err ( e) = clickhouse_client
275+ . write ( & [ row] , TableName :: ModelInferenceCache )
276+ . await
277+ {
278+ tracing:: warn!( "Failed to write to cache: {e}" ) ;
279+ }
280+ } else {
281+ warn_discarded_cache_write ( & row. data . raw_response ) ;
282+ }
283+ } ) ;
284+ }
285+
240286// This doesn't block
241287pub fn start_cache_write < T : Serialize + CacheOutput + Send + Sync + ' static > (
242288 clickhouse_client : & ClickHouseConnectionInfo ,
@@ -255,28 +301,21 @@ pub fn start_cache_write<T: Serialize + CacheOutput + Send + Sync + 'static>(
255301 let output_tokens = usage. output_tokens ;
256302 let clickhouse_client = clickhouse_client. clone ( ) ;
257303 let finish_reason = finish_reason. cloned ( ) ;
258- tokio:: spawn ( async move {
259- if let Err ( e) = clickhouse_client
260- . write (
261- & [ FullCacheRow {
262- short_cache_key,
263- long_cache_key,
264- data : CacheData {
265- output,
266- raw_request,
267- raw_response,
268- input_tokens,
269- output_tokens,
270- finish_reason,
271- } ,
272- } ] ,
273- TableName :: ModelInferenceCache ,
274- )
275- . await
276- {
277- tracing:: warn!( "Failed to write to cache: {e}" ) ;
278- }
279- } ) ;
304+ spawn_maybe_cache_write (
305+ FullCacheRow {
306+ short_cache_key,
307+ long_cache_key,
308+ data : CacheData {
309+ output,
310+ raw_request,
311+ raw_response,
312+ input_tokens,
313+ output_tokens,
314+ finish_reason,
315+ } ,
316+ } ,
317+ clickhouse_client,
318+ ) ;
280319 Ok ( ( ) )
281320}
282321
@@ -322,25 +361,21 @@ pub fn start_cache_write_streaming(
322361 } ;
323362 let raw_request = raw_request. to_string ( ) ;
324363 let clickhouse_client = clickhouse_client. clone ( ) ;
325- tokio:: spawn ( async move {
326- clickhouse_client
327- . write (
328- & [ FullCacheRow {
329- short_cache_key,
330- long_cache_key,
331- data : CacheData {
332- output,
333- raw_request,
334- raw_response : String :: new ( ) ,
335- input_tokens,
336- output_tokens,
337- finish_reason,
338- } ,
339- } ] ,
340- TableName :: ModelInferenceCache ,
341- )
342- . await
343- } ) ;
364+ spawn_maybe_cache_write (
365+ FullCacheRow {
366+ short_cache_key,
367+ long_cache_key,
368+ data : CacheData {
369+ output,
370+ raw_request,
371+ raw_response : String :: new ( ) ,
372+ input_tokens,
373+ output_tokens,
374+ finish_reason,
375+ } ,
376+ } ,
377+ clickhouse_client,
378+ ) ;
344379 Ok ( ( ) )
345380}
346381
0 commit comments