Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 92 additions & 4 deletions src/tool/ambient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::safety::{self, PermissionRequest, PermissionResult, SafetySystem, Urg
use anyhow::Result;
use async_trait::async_trait;
use chrono::Utc;
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use serde_json::{Map, Value, json};
use std::collections::HashSet;
use std::sync::{Arc, Mutex, OnceLock};
Expand Down Expand Up @@ -119,10 +119,98 @@ impl EndAmbientCycleTool {
}
}

// ---------------------------------------------------------------------------
// Custom deserializers: accept either a JSON number or a numeric string for
// u32 fields. Claude tool calls occasionally serialize numeric arguments as
// strings (e.g. {"compactions": "0"} instead of {"compactions": 0}), which
// caused every ambient cycle to fail with `invalid type: string "0", expected
// u32`. See issue #133 / upstream PR #173.
// ---------------------------------------------------------------------------

fn deserialize_string_or_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct StringOrU32;

impl<'de> Visitor<'de> for StringOrU32 {
type Value = u32;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("u32 or string representing u32")
}

fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
u32::try_from(v).map_err(E::custom)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
v.parse().map_err(E::custom)
}
}

deserializer.deserialize_any(StringOrU32)
}

fn deserialize_string_or_option_u32<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct StringOrOptionU32;

impl<'de> Visitor<'de> for StringOrOptionU32 {
type Value = Option<u32>;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("optional u32 or string representing u32")
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserialize_string_or_u32(deserializer).map(Some)
}

fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
u32::try_from(v).map_err(E::custom).map(Some)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
v.parse().map_err(E::custom).map(Some)
}
}

deserializer.deserialize_option(StringOrOptionU32)
}

#[derive(Deserialize)]
struct EndCycleInput {
summary: String,
#[serde(deserialize_with = "deserialize_string_or_u32")]
memories_modified: u32,
#[serde(deserialize_with = "deserialize_string_or_u32")]
compactions: u32,
#[serde(default)]
proactive_work: Option<String>,
Expand All @@ -132,7 +220,7 @@ struct EndCycleInput {

#[derive(Deserialize)]
struct NextScheduleInput {
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_string_or_option_u32")]
wake_in_minutes: Option<u32>,
#[serde(default)]
context: Option<String>,
Expand Down Expand Up @@ -280,7 +368,7 @@ impl ScheduleAmbientTool {

#[derive(Deserialize)]
struct ScheduleInput {
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_string_or_option_u32")]
wake_in_minutes: Option<u32>,
#[serde(default)]
wake_at: Option<String>,
Expand Down Expand Up @@ -722,7 +810,7 @@ struct ScheduleToolInput {
schedule_id: Option<String>,
#[serde(default)]
task: Option<String>,
#[serde(default)]
#[serde(default, deserialize_with = "deserialize_string_or_option_u32")]
wake_in_minutes: Option<u32>,
#[serde(default)]
wake_at: Option<String>,
Expand Down
76 changes: 76 additions & 0 deletions src/tool/ambient/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,79 @@ async fn test_schedule_tool_requires_time() {
.expect_err("should require wake_in_minutes or wake_at");
assert!(err.to_string().contains("wake_in_minutes"));
}

// ---------------------------------------------------------------------------
// Regression tests for issue #133 / upstream PR #173:
// Claude tool calls sometimes serialize numeric parameters as strings
// (e.g. {"compactions": "0"}). Both number and string forms must deserialize.
// ---------------------------------------------------------------------------

#[test]
fn test_end_cycle_input_accepts_stringified_u32_fields() {
let input = json!({
"summary": "Compaction skipped",
"memories_modified": "3",
"compactions": "0",
"next_schedule": {
"wake_in_minutes": "20",
"context": "Recheck stale facts",
"priority": "normal"
}
});

let parsed: EndCycleInput = serde_json::from_value(input).unwrap();
assert_eq!(parsed.memories_modified, 3);
assert_eq!(parsed.compactions, 0);
let ns = parsed.next_schedule.unwrap();
assert_eq!(ns.wake_in_minutes, Some(20));
}

#[test]
fn test_end_cycle_input_still_accepts_native_u32_fields() {
// Make sure the new deserializer didn't break the existing JSON-number form.
let input = json!({
"summary": "All good",
"memories_modified": 7,
"compactions": 2,
"next_schedule": {
"wake_in_minutes": 15
}
});
let parsed: EndCycleInput = serde_json::from_value(input).unwrap();
assert_eq!(parsed.memories_modified, 7);
assert_eq!(parsed.compactions, 2);
assert_eq!(parsed.next_schedule.unwrap().wake_in_minutes, Some(15));
}

#[test]
fn test_schedule_input_accepts_stringified_wake_in_minutes() {
let input = json!({
"wake_in_minutes": "45",
"context": "Verify CI"
});
let parsed: ScheduleInput = serde_json::from_value(input).unwrap();
assert_eq!(parsed.wake_in_minutes, Some(45));
assert_eq!(parsed.context, "Verify CI");
}

#[test]
fn test_schedule_tool_input_accepts_stringified_wake_in_minutes() {
let input = json!({
"task": "Check on tests",
"wake_in_minutes": "60"
});
let parsed: ScheduleToolInput = serde_json::from_value(input).unwrap();
assert_eq!(parsed.wake_in_minutes, Some(60));
}

#[test]
fn test_end_cycle_input_rejects_non_numeric_string() {
// Defensive: a non-numeric string must still be rejected, not silently
// treated as 0.
let input = json!({
"summary": "Bad input",
"memories_modified": "not-a-number",
"compactions": 0
});
assert!(serde_json::from_value::<EndCycleInput>(input).is_err());
}