Skip to content

Commit aa2a736

Browse files
Merge branch 'main' into main
2 parents be7df99 + cae01ce commit aa2a736

File tree

8 files changed

+157
-56
lines changed

8 files changed

+157
-56
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ members = [
1515
resolver = "2"
1616

1717
[workspace.package]
18-
version = "2025.7.5"
18+
version = "2025.7.6"
1919
rust-version = "1.86.0"
2020
license = "Apache-2.0"
2121

tensorzero-core/src/cache.rs

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use crate::clickhouse::{ClickHouseConnectionInfo, TableName};
55
use crate::embeddings::{EmbeddingRequest, EmbeddingResponse};
6-
use crate::error::{Error, ErrorDetails};
6+
use crate::error::{warn_discarded_cache_write, Error, ErrorDetails};
77
use crate::inference::types::file::serialize_with_file_data;
88
use crate::inference::types::{
99
ContentBlockChunk, ContentBlockOutput, FinishReason, ModelInferenceRequest,
@@ -13,7 +13,7 @@ use crate::model::StreamResponse;
1313
use crate::serde_util::deserialize_json_string;
1414
use blake3::Hash;
1515
use clap::ValueEnum;
16-
use serde::de::DeserializeOwned;
16+
use serde::de::{DeserializeOwned, IgnoredAny};
1717
use serde::{Deserialize, Serialize};
1818
use std::fmt::Debug;
1919

@@ -210,11 +210,39 @@ pub struct CacheData<T: CacheOutput> {
210210
/// to/from ClickHouse
211211
/// We use a marker trait rather than an enum so that the expected type can be enforced by the caller
212212
/// (e.g. `infer_stream` will never try to deserialize a `NonStreamingCacheData`)
213-
pub trait CacheOutput {}
213+
pub trait CacheOutput {
214+
/// If this return `false`, then we'll log a warning and skip writing this entry to the cache
215+
fn should_write_to_cache(&self) -> bool;
216+
}
214217

215-
impl CacheOutput for StreamingCacheData {}
216-
impl CacheOutput for NonStreamingCacheData {}
217-
impl CacheOutput for EmbeddingCacheData {}
218+
impl CacheOutput for StreamingCacheData {
219+
fn should_write_to_cache(&self) -> bool {
220+
true
221+
}
222+
}
223+
impl CacheOutput for NonStreamingCacheData {
224+
fn should_write_to_cache(&self) -> bool {
225+
for block in &self.blocks {
226+
if let ContentBlockOutput::ToolCall(tool_call) = block {
227+
// We skip writing to the cache if the tool call arguments are not valid JSON
228+
// We're assuming that it's almost never useful to have an invalid tool call cached
229+
// (in particular, tensorzero is not being used with a provider/model that only ever
230+
// emits invalid json for its tool call arguments).
231+
// The invalid tool call will still be returned to the user, but we won't create a
232+
// cache entry, even if the user turned on caching.
233+
if serde_json::from_str::<IgnoredAny>(&tool_call.arguments).is_err() {
234+
return false;
235+
}
236+
}
237+
}
238+
true
239+
}
240+
}
241+
impl CacheOutput for EmbeddingCacheData {
242+
fn should_write_to_cache(&self) -> bool {
243+
true
244+
}
245+
}
218246

219247
#[derive(Debug, Deserialize, Serialize)]
220248
#[serde(transparent)]
@@ -237,6 +265,24 @@ pub struct StreamingCacheData {
237265
pub chunks: Vec<CachedProviderInferenceResponseChunk>,
238266
}
239267

268+
fn spawn_maybe_cache_write<T: Serialize + CacheOutput + Send + Sync + 'static>(
269+
row: FullCacheRow<T>,
270+
clickhouse_client: ClickHouseConnectionInfo,
271+
) {
272+
tokio::spawn(async move {
273+
if row.data.output.should_write_to_cache() {
274+
if let Err(e) = clickhouse_client
275+
.write(&[row], TableName::ModelInferenceCache)
276+
.await
277+
{
278+
tracing::warn!("Failed to write to cache: {e}");
279+
}
280+
} else {
281+
warn_discarded_cache_write(&row.data.raw_response);
282+
}
283+
});
284+
}
285+
240286
// This doesn't block
241287
pub fn start_cache_write<T: Serialize + CacheOutput + Send + Sync + 'static>(
242288
clickhouse_client: &ClickHouseConnectionInfo,
@@ -255,28 +301,21 @@ pub fn start_cache_write<T: Serialize + CacheOutput + Send + Sync + 'static>(
255301
let output_tokens = usage.output_tokens;
256302
let clickhouse_client = clickhouse_client.clone();
257303
let finish_reason = finish_reason.cloned();
258-
tokio::spawn(async move {
259-
if let Err(e) = clickhouse_client
260-
.write(
261-
&[FullCacheRow {
262-
short_cache_key,
263-
long_cache_key,
264-
data: CacheData {
265-
output,
266-
raw_request,
267-
raw_response,
268-
input_tokens,
269-
output_tokens,
270-
finish_reason,
271-
},
272-
}],
273-
TableName::ModelInferenceCache,
274-
)
275-
.await
276-
{
277-
tracing::warn!("Failed to write to cache: {e}");
278-
}
279-
});
304+
spawn_maybe_cache_write(
305+
FullCacheRow {
306+
short_cache_key,
307+
long_cache_key,
308+
data: CacheData {
309+
output,
310+
raw_request,
311+
raw_response,
312+
input_tokens,
313+
output_tokens,
314+
finish_reason,
315+
},
316+
},
317+
clickhouse_client,
318+
);
280319
Ok(())
281320
}
282321

@@ -322,25 +361,21 @@ pub fn start_cache_write_streaming(
322361
};
323362
let raw_request = raw_request.to_string();
324363
let clickhouse_client = clickhouse_client.clone();
325-
tokio::spawn(async move {
326-
clickhouse_client
327-
.write(
328-
&[FullCacheRow {
329-
short_cache_key,
330-
long_cache_key,
331-
data: CacheData {
332-
output,
333-
raw_request,
334-
raw_response: String::new(),
335-
input_tokens,
336-
output_tokens,
337-
finish_reason,
338-
},
339-
}],
340-
TableName::ModelInferenceCache,
341-
)
342-
.await
343-
});
364+
spawn_maybe_cache_write(
365+
FullCacheRow {
366+
short_cache_key,
367+
long_cache_key,
368+
data: CacheData {
369+
output,
370+
raw_request,
371+
raw_response: String::new(),
372+
input_tokens,
373+
output_tokens,
374+
finish_reason,
375+
},
376+
},
377+
clickhouse_client,
378+
);
344379
Ok(())
345380
}
346381

tensorzero-core/src/endpoints/feedback.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,11 @@ async fn get_function_name(
476476
MetricConfigLevel::Episode => "episode_id_uint",
477477
};
478478
let query = format!(
479-
"SELECT function_name FROM {table_name} FINAL WHERE {identifier_key} = toUInt128(toUUID('{target_id}'))"
479+
"SELECT function_name
480+
FROM {table_name}
481+
WHERE {identifier_key} = toUInt128(toUUID('{target_id}'))
482+
LIMIT 1
483+
SETTINGS max_threads=1"
480484
);
481485
let function_name = connection_info
482486
.run_query_synchronous_no_params(query)

tensorzero-core/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ pub fn set_unstable_error_json(unstable_error_json: bool) -> Result<(), Error> {
5252
})
5353
}
5454

55+
pub fn warn_discarded_cache_write(raw_response: &str) {
56+
if *DEBUG.get().unwrap_or(&false) {
57+
tracing::warn!("Skipping cache write due to invalid output:\nRaw response: {raw_response}");
58+
} else {
59+
tracing::warn!("Skipping cache write due to invalid output");
60+
}
61+
}
62+
5563
pub fn warn_discarded_thought_block(provider_type: &str, thought: &Thought) {
5664
if *DEBUG.get().unwrap_or(&false) {
5765
tracing::warn!("Provider type `{provider_type}` does not support input thought blocks, discarding: {thought:?}");

tensorzero-core/src/providers/dummy.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ impl InferenceProvider for DummyProvider {
333333
arguments: serde_json::to_string(&*DUMMY_TOOL_RESPONSE).unwrap(),
334334
id: "0".to_string(),
335335
})],
336+
"invalid_tool_arguments" => vec![ContentBlockOutput::ToolCall(ToolCall {
337+
name: "get_temperature".to_string(),
338+
arguments: "Not valid 'JSON'".to_string(),
339+
id: "0".to_string(),
340+
})],
336341
"reasoner" => vec![
337342
ContentBlockOutput::Thought(Thought {
338343
text: Some("hmmm".to_string()),

tensorzero-core/tests/e2e/cache.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,25 @@ use reqwest_eventsource::RequestBuilderExt;
99
use serde_json::json;
1010
use serde_json::Value;
1111
use std::time::Duration;
12+
use tensorzero::CacheParamsOptions;
13+
use tensorzero::ClientInferenceParams;
14+
use tensorzero::ClientInput;
15+
use tensorzero::ClientInputMessage;
16+
use tensorzero::ClientInputMessageContent;
1217
use tensorzero::ContentBlockChunk;
18+
use tensorzero::InferenceOutput;
1319
use tensorzero_core::cache::cache_lookup_streaming;
1420
use tensorzero_core::cache::start_cache_write_streaming;
21+
use tensorzero_core::cache::CacheEnabledMode;
1522
use tensorzero_core::cache::NonStreamingCacheData;
1623
use tensorzero_core::inference::types::ContentBlock;
1724
use tensorzero_core::inference::types::ContentBlockOutput;
1825
use tensorzero_core::inference::types::FinishReason;
1926
use tensorzero_core::inference::types::ProviderInferenceResponseChunk;
2027
use tensorzero_core::inference::types::Text;
2128
use tensorzero_core::inference::types::TextChunk;
29+
use tensorzero_core::inference::types::TextKind;
30+
use tracing_test::traced_test;
2231
use uuid::Uuid;
2332

2433
use tensorzero_core::cache::cache_lookup;
@@ -33,6 +42,7 @@ use tensorzero_core::inference::types::{
3342
};
3443

3544
use crate::common::get_gateway_endpoint;
45+
use crate::providers::common::make_embedded_gateway;
3646
use tensorzero_core::clickhouse::test_helpers::{
3747
get_clickhouse, select_chat_inference_clickhouse, select_model_inference_clickhouse,
3848
};
@@ -312,6 +322,45 @@ async fn test_cache_stream_write_and_read() {
312322
assert!(result.is_none());
313323
}
314324

325+
#[traced_test]
326+
#[tokio::test]
327+
pub async fn test_dont_cache_invalid_tool_call() {
328+
let client = make_embedded_gateway().await;
329+
let randomness = Uuid::now_v7();
330+
let params = ClientInferenceParams {
331+
model_name: Some("dummy::invalid_tool_arguments".to_string()),
332+
input: ClientInput {
333+
system: None,
334+
messages: vec![ClientInputMessage {
335+
role: Role::User,
336+
content: vec![ClientInputMessageContent::Text(TextKind::Text {
337+
text: format!("Test inference: {randomness}"),
338+
})],
339+
}],
340+
},
341+
cache_options: CacheParamsOptions {
342+
enabled: CacheEnabledMode::On,
343+
max_age_s: None,
344+
},
345+
..Default::default()
346+
};
347+
client.inference(params.clone()).await.unwrap();
348+
349+
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
350+
let clickhouse = get_clickhouse().await;
351+
assert!(logs_contain("Skipping cache write"));
352+
353+
// Run again, and check that we get a cache miss
354+
let res = client.inference(params).await.unwrap();
355+
let InferenceOutput::NonStreaming(res) = res else {
356+
panic!("Expected non-streaming inference response");
357+
};
358+
let model_inference = select_model_inference_clickhouse(&clickhouse, res.inference_id())
359+
.await
360+
.unwrap();
361+
assert_eq!(model_inference.get("cached").unwrap(), false);
362+
}
363+
315364
#[tokio::test]
316365
pub async fn test_streaming_cache_with_err() {
317366
let episode_id = Uuid::now_v7();

ui/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"name": "tensorzero-ui",
33
"private": true,
44
"type": "module",
5-
"version": "2025.7.5",
5+
"version": "2025.7.6",
66
"scripts": {
77
"build": "NODE_ENV=production react-router build",
88
"dev": "react-router dev",

0 commit comments

Comments
 (0)