Skip to content

Commit d5b4ba6

Browse files
authored
feat: consolidated streaming and timeout improvements (#90)
* feat: consolidated streaming and timeout improvements ## Summary This PR consolidates **3 feature PRs** for streaming and timeout improvements. ### Included PRs: - #25: Add per-chunk timeout for SSE streaming to prevent hangs - #29: Add early JSON validation for tool call arguments - #34: Add per-tool timeout in batch execution ### Key Changes: - Added CHUNK_TIMEOUT_SECS constant for SSE streaming (60s) - Wrapped SSE event iteration with tokio::time::timeout - Added validate_arguments() method to StreamToolCall - Added is_valid_complete() helper and complete_tool_call_validated() method - Added DEFAULT_TOOL_TIMEOUT_SECS constant (60 seconds) for batch - Added tool_timeout_secs field to BatchToolArgs for configuration - Applied individual timeout to each tool execution in execute_parallel() ### Files Modified: - src/cortex-engine/src/client/cortex.rs - src/cortex-engine/src/streaming.rs - src/cortex-engine/src/tools/handlers/batch.rs - src/cortex-engine/src/tools/unified_executor.rs Closes #25, #29, #34 * fix(batch): implement timeout_secs parameter for overall batch timeout Address Greptile review feedback: The timeout_secs parameter was documented and accepted but never used. Now it properly wraps the entire parallel execution with a batch-level timeout, separate from the per-tool timeout_secs. - Add batch-level timeout wrapper around execute_parallel - Return descriptive error message when batch times out - Add test for batch timeout behavior
1 parent 3c1a144 commit d5b4ba6

4 files changed

Lines changed: 161 additions & 16 deletions

File tree

src/cortex-engine/src/client/cortex.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
//! - Responses API (streaming SSE)
66
//! - Credit system with price verification
77
8+
use std::time::Duration;
9+
810
use async_trait::async_trait;
911
use eventsource_stream::Eventsource;
1012
use futures::StreamExt;
1113
use reqwest::Client;
1214
use serde::{Deserialize, Serialize};
1315
use tokio::sync::mpsc;
16+
use tokio::time::timeout;
1417
use tokio_stream::wrappers::ReceiverStream;
1518

1619
use super::{
@@ -22,6 +25,11 @@ use crate::error::{CortexError, Result};
2225

2326
const DEFAULT_CORTEX_URL: &str = "https://api.cortex.foundation";
2427

28+
/// Timeout in seconds for receiving individual SSE chunks during streaming.
29+
/// If no data is received within this duration, the connection is terminated
30+
/// to prevent indefinite hangs when connections stall mid-stream.
31+
const CHUNK_TIMEOUT_SECS: u64 = 60;
32+
2533
/// Pricing information for a model.
2634
#[derive(Debug, Clone, Serialize, Deserialize)]
2735
pub struct PricingInfo {
@@ -567,8 +575,26 @@ impl ModelClient for CortexClient {
567575
let mut stream = std::pin::pin!(stream);
568576
let mut accumulated_text = String::new();
569577
let mut usage = TokenUsage::default();
570-
571-
while let Some(event_result) = stream.next().await {
578+
let chunk_timeout = Duration::from_secs(CHUNK_TIMEOUT_SECS);
579+
580+
loop {
581+
// Apply per-chunk timeout to prevent indefinite hangs when connections stall
582+
let event_result = match timeout(chunk_timeout, stream.next()).await {
583+
Ok(Some(result)) => result,
584+
Ok(None) => break, // Stream ended normally
585+
Err(_) => {
586+
// Timeout elapsed - no data received within CHUNK_TIMEOUT_SECS
587+
let _ = tx
588+
.send(Err(CortexError::BackendError {
589+
message: format!(
590+
"SSE chunk timeout - no data received for {} seconds",
591+
CHUNK_TIMEOUT_SECS
592+
),
593+
}))
594+
.await;
595+
break;
596+
}
597+
};
572598
match event_result {
573599
Ok(event) => {
574600
if event.data.is_empty() || event.data == "[DONE]" {

src/cortex-engine/src/streaming.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ impl StreamContent {
143143
}
144144
}
145145

146+
/// Complete a tool call and validate its arguments.
147+
/// Returns Ok(()) if the tool call was found and arguments are valid JSON.
148+
/// Returns Err if tool call not found or arguments are invalid JSON.
149+
pub fn complete_tool_call_validated(&mut self, id: &str) -> Result<(), String> {
150+
if let Some(tc) = self.tool_calls.iter_mut().find(|tc| tc.id == id) {
151+
tc.complete = true;
152+
tc.validate_arguments()
153+
} else {
154+
Err(format!("Tool call with id '{}' not found", id))
155+
}
156+
}
157+
146158
/// Check if has content.
147159
pub fn has_content(&self) -> bool {
148160
!self.text.is_empty() || !self.tool_calls.is_empty()
@@ -171,6 +183,23 @@ impl StreamToolCall {
171183
None
172184
}
173185
}
186+
187+
/// Validate that arguments contain valid JSON.
188+
/// Returns Ok(()) if valid, Err with details if invalid.
189+
pub fn validate_arguments(&self) -> Result<(), String> {
190+
if self.arguments.trim().is_empty() {
191+
return Ok(()); // Empty is valid (no arguments)
192+
}
193+
serde_json::from_str::<serde_json::Value>(&self.arguments)
194+
.map(|_| ())
195+
.map_err(|e| format!("Invalid JSON in tool call arguments: {}", e))
196+
}
197+
198+
/// Check if arguments are complete and valid JSON.
199+
/// Returns true only if complete and valid.
200+
pub fn is_valid_complete(&self) -> bool {
201+
self.complete && self.validate_arguments().is_ok()
202+
}
174203
}
175204

176205
/// Token counts.

src/cortex-engine/src/tools/handlers/batch.rs

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ pub const MAX_BATCH_SIZE: usize = 10;
2323
/// Default timeout for batch execution in seconds.
2424
pub const DEFAULT_BATCH_TIMEOUT_SECS: u64 = 300;
2525

26+
/// Default timeout for individual tool execution in seconds.
27+
/// This prevents a single tool from consuming the entire batch timeout.
28+
pub const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 60;
29+
2630
/// Tools that cannot be called within a batch (prevent recursion).
2731
/// Note: Task is now allowed to enable parallel task execution.
2832
pub const DISALLOWED_TOOLS: &[&str] = &["Batch", "batch", "Agent", "agent"];
@@ -45,6 +49,10 @@ pub struct BatchToolArgs {
4549
/// Optional timeout in seconds for the entire batch (default: 300s).
4650
#[serde(default)]
4751
pub timeout_secs: Option<u64>,
52+
/// Optional timeout in seconds for individual tools (default: 60s).
53+
/// This prevents a single tool from consuming the entire batch timeout.
54+
#[serde(default)]
55+
pub tool_timeout_secs: Option<u64>,
4856
}
4957

5058
/// Result of a single tool call within the batch.
@@ -158,7 +166,7 @@ impl BatchToolHandler {
158166
&self,
159167
calls: Vec<BatchToolCall>,
160168
context: &ToolContext,
161-
timeout_duration: Duration,
169+
tool_timeout: Duration,
162170
) -> BatchResult {
163171
let start_time = Instant::now();
164172
let results = Arc::new(Mutex::new(Vec::with_capacity(calls.len())));
@@ -176,9 +184,9 @@ impl BatchToolHandler {
176184
async move {
177185
let call_start = Instant::now();
178186

179-
// Execute with per-call timeout (use batch timeout for each call)
187+
// Execute with per-tool timeout to prevent single tools from blocking others
180188
let execution_result = timeout(
181-
timeout_duration,
189+
tool_timeout,
182190
executor.execute_tool(&call.tool, call.arguments, &ctx),
183191
)
184192
.await;
@@ -202,19 +210,23 @@ impl BatchToolHandler {
202210
duration_ms,
203211
},
204212
Ok(Err(e)) => BatchCallResult {
205-
tool: tool_name,
213+
tool: tool_name.clone(),
206214
index,
207215
success: false,
208216
result: None,
209217
error: Some(format!("Execution error: {}", e)),
210218
duration_ms,
211219
},
212220
Err(_) => BatchCallResult {
213-
tool: tool_name,
221+
tool: tool_name.clone(),
214222
index,
215223
success: false,
216224
result: None,
217-
error: Some(format!("Timeout after {}s", timeout_duration.as_secs())),
225+
error: Some(format!(
226+
"Tool '{}' timed out after {}s",
227+
tool_name,
228+
tool_timeout.as_secs()
229+
)),
218230
duration_ms,
219231
},
220232
};
@@ -328,14 +340,30 @@ impl ToolHandler for BatchToolHandler {
328340
// Validate calls
329341
self.validate_calls(&args.calls)?;
330342

331-
// Determine timeout
332-
let timeout_secs = args.timeout_secs.unwrap_or(DEFAULT_BATCH_TIMEOUT_SECS);
333-
let timeout_duration = Duration::from_secs(timeout_secs);
334-
335-
// Execute all tools in parallel
336-
let batch_result = self
337-
.execute_parallel(args.calls, context, timeout_duration)
338-
.await;
343+
// Determine overall batch timeout (wraps around entire parallel execution)
344+
let batch_timeout_secs = args.timeout_secs.unwrap_or(DEFAULT_BATCH_TIMEOUT_SECS);
345+
let batch_timeout = Duration::from_secs(batch_timeout_secs);
346+
347+
// Determine per-tool timeout (prevents single tool from blocking others)
348+
let tool_timeout_secs = args.tool_timeout_secs.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS);
349+
let tool_timeout = Duration::from_secs(tool_timeout_secs);
350+
351+
// Execute all tools in parallel with overall batch timeout
352+
let batch_result = match timeout(
353+
batch_timeout,
354+
self.execute_parallel(args.calls, context, tool_timeout),
355+
)
356+
.await
357+
{
358+
Ok(result) => result,
359+
Err(_) => {
360+
// Batch-level timeout exceeded
361+
return Ok(ToolResult::error(format!(
362+
"Batch execution timed out after {}s. Consider using a longer timeout_secs or reducing the number of tools.",
363+
batch_timeout_secs
364+
)));
365+
}
366+
};
339367

340368
// Format output
341369
let output = self.format_result(&batch_result);
@@ -384,6 +412,12 @@ pub fn batch_tool_definition() -> ToolDefinition {
384412
"description": "Optional timeout in seconds for the entire batch execution (default: 300)",
385413
"minimum": 1,
386414
"maximum": 600
415+
},
416+
"tool_timeout_secs": {
417+
"type": "integer",
418+
"description": "Optional timeout in seconds for individual tool execution (default: 60). Prevents a single tool from consuming the entire batch timeout.",
419+
"minimum": 1,
420+
"maximum": 300
387421
}
388422
}
389423
}),
@@ -409,6 +443,7 @@ pub async fn execute_batch(
409443
})
410444
.collect(),
411445
timeout_secs: None,
446+
tool_timeout_secs: None,
412447
};
413448

414449
let arguments = serde_json::to_value(args)
@@ -652,4 +687,58 @@ mod tests {
652687
elapsed.as_millis()
653688
);
654689
}
690+
691+
#[tokio::test]
692+
async fn test_batch_timeout() {
693+
// Create an executor with a slow tool
694+
struct SlowExecutor;
695+
696+
#[async_trait]
697+
impl BatchToolExecutor for SlowExecutor {
698+
async fn execute_tool(
699+
&self,
700+
_name: &str,
701+
_arguments: Value,
702+
_context: &ToolContext,
703+
) -> Result<ToolResult> {
704+
// Sleep longer than batch timeout
705+
tokio::time::sleep(Duration::from_secs(5)).await;
706+
Ok(ToolResult::success("Done"))
707+
}
708+
709+
fn has_tool(&self, _name: &str) -> bool {
710+
true
711+
}
712+
}
713+
714+
let executor: Arc<dyn BatchToolExecutor> = Arc::new(SlowExecutor);
715+
let handler = BatchToolHandler::new(executor);
716+
let context = ToolContext::new(PathBuf::from("."));
717+
718+
// Use a very short batch timeout (1 second) to test timeout behavior
719+
let args = json!({
720+
"calls": [
721+
{"tool": "SlowTool", "arguments": {}}
722+
],
723+
"timeout_secs": 1
724+
});
725+
726+
let start = Instant::now();
727+
let result = handler.execute(args, &context).await;
728+
let elapsed = start.elapsed();
729+
730+
assert!(result.is_ok());
731+
let tool_result = result.unwrap();
732+
733+
// Should timeout quickly (around 1 second)
734+
assert!(
735+
elapsed.as_secs() < 3,
736+
"Batch should have timed out quickly, but took {}s",
737+
elapsed.as_secs()
738+
);
739+
740+
// Should have timed out
741+
assert!(!tool_result.success);
742+
assert!(tool_result.output.contains("timed out"));
743+
}
655744
}

src/cortex-engine/src/tools/unified_executor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ impl UnifiedToolExecutor {
466466
Ok(BatchToolArgs {
467467
calls,
468468
timeout_secs: arguments.get("timeout_secs").and_then(|v| v.as_u64()),
469+
tool_timeout_secs: arguments.get("tool_timeout_secs").and_then(|v| v.as_u64()),
469470
})
470471
} else {
471472
Err(CortexError::InvalidInput(

0 commit comments

Comments
 (0)