Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agent/runloop/unified/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ impl CtrlCState {
}

pub(crate) fn is_cancel_requested(&self) -> bool {
self.cancel_requested.load(Ordering::Relaxed)
self.cancel_requested.load(Ordering::Acquire)
}

pub(crate) fn is_exit_requested(&self) -> bool {
self.exit_requested.load(Ordering::Relaxed)
self.exit_requested.load(Ordering::Acquire)
}

pub(crate) fn disarm_exit(&self) {
Expand Down
45 changes: 26 additions & 19 deletions src/agent/runloop/unified/turn/run_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1289,12 +1289,10 @@ pub(crate) async fn run_single_agent_loop_unified(
let _updated_snapshot = {
let mut guard = tools.write().await;
guard.retain(|tool| {
!tool
.function
tool.function
.as_ref()
.unwrap()
.name
.starts_with("mcp_")
.map(|f| !f.name.starts_with("mcp_"))
.unwrap_or(true)
});
guard.extend(new_definitions);
guard.clone()
Expand Down Expand Up @@ -1368,12 +1366,10 @@ pub(crate) async fn run_single_agent_loop_unified(
let _updated_snapshot = {
let mut guard = tools.write().await;
guard.retain(|tool| {
!tool
.function
tool.function
.as_ref()
.unwrap()
.name
.starts_with("mcp_")
.map(|f| !f.name.starts_with("mcp_"))
.unwrap_or(true)
});
guard.extend(new_definitions);
guard.clone()
Expand Down Expand Up @@ -2448,15 +2444,26 @@ pub(crate) async fn run_single_agent_loop_unified(
// This prevents the loop from breaking after tool execution
let _ = final_text.take();
for call in &tool_calls {
let name = call
.function
.as_ref()
.expect("Tool call must have function")
.name
.as_str();
let args_val = call
.parsed_arguments()
.unwrap_or_else(|_| serde_json::json!({}));
let Some(function) = call.function.as_ref() else {
tracing::warn!("Malformed tool call: missing function definition");
working_history.push(uni::Message::system(
"Skipped malformed tool call: missing function definition".to_string(),
));
continue;
};
let name = function.name.as_str();
let args_val = match call.parsed_arguments() {
Ok(args) => args,
Err(err) => {
tracing::warn!("Failed to parse args for '{}': {}", name, err);
let error_msg = format!(
"Tool '{}' received invalid arguments: {}",
name, err
);
working_history.push(uni::Message::system(error_msg));
continue;
}
};

// Normalize args for loop detection: strip pagination params and normalize paths
let normalized_args = if let Some(obj) = args_val.as_object() {
Expand Down
8 changes: 8 additions & 0 deletions vtcode-core/src/llm/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,14 @@ impl LLMProvider for AnthropicProvider {
}

async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
// Validate API key before making request
if self.api_key.trim().is_empty() {
return Err(LLMError::Authentication {
message: "Anthropic API key is not configured. Set ANTHROPIC_API_KEY environment variable.".to_string(),
metadata: None,
});
}

let anthropic_request = self.convert_to_anthropic_format(&request)?;
let url = format!("{}/messages", self.base_url);

Expand Down
26 changes: 22 additions & 4 deletions vtcode-core/src/llm/providers/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,25 @@ impl LLMProvider for GeminiProvider {
}

async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
// Validate API key before making request
if self.api_key.trim().is_empty() {
return Err(LLMError::Authentication {
message: "Gemini API key is not configured. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.".to_string(),
metadata: None,
});
}

let gemini_request = self.convert_to_gemini_request(&request)?;

let url = format!(
"{}/models/{}:generateContent?key={}",
self.base_url, request.model, self.api_key
"{}/models/{}:generateContent",
self.base_url, request.model
);

let response = self
.http_client
.post(&url)
.header("x-goog-api-key", self.api_key.as_ref())
.json(&gemini_request)
.send()
.await
Expand All @@ -257,16 +266,25 @@ impl LLMProvider for GeminiProvider {
}

async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
// Validate API key before making request
if self.api_key.trim().is_empty() {
return Err(LLMError::Authentication {
message: "Gemini API key is not configured. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.".to_string(),
metadata: None,
});
}

let gemini_request = self.convert_to_gemini_request(&request)?;

let url = format!(
"{}/models/{}:streamGenerateContent?key={}",
self.base_url, request.model, self.api_key
"{}/models/{}:streamGenerateContent",
self.base_url, request.model
);

let response = self
.http_client
.post(&url)
.header("x-goog-api-key", self.api_key.as_ref())
.json(&gemini_request)
.send()
.await
Expand Down
102 changes: 102 additions & 0 deletions vtcode-tools/src/acp_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ mod shared {

const ERR_ARGS_OBJECT: &str = "Arguments must be an object";
const ERR_CLIENT_UNINITIALIZED: &str = "ACP client not initialized";
/// Maximum allowed length for agent IDs to prevent DoS via oversized strings.
const MAX_AGENT_ID_LEN: usize = 256;
/// Maximum allowed length for action names.
const MAX_ACTION_LEN: usize = 128;
/// Maximum JSON depth for call_args to prevent stack overflow.
const MAX_JSON_DEPTH: usize = 32;
/// Maximum size for call_args payload in bytes.
const MAX_ARGS_SIZE: usize = 1024 * 1024; // 1MB

pub fn extract_args_object(args: &Value) -> anyhow::Result<&serde_json::Map<String, Value>> {
args.as_object()
Expand Down Expand Up @@ -54,6 +62,78 @@ mod shared {
}
Ok(())
}

/// Validate agent ID format: alphanumeric, hyphens, underscores only, length limit.
pub fn validate_agent_id(agent_id: &str) -> anyhow::Result<()> {
if agent_id.is_empty() {
return Err(anyhow::anyhow!("agent_id cannot be empty"));
}
if agent_id.len() > MAX_AGENT_ID_LEN {
return Err(anyhow::anyhow!(
"agent_id exceeds maximum length of {} characters",
MAX_AGENT_ID_LEN
));
}
if !agent_id
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
{
return Err(anyhow::anyhow!(
"agent_id contains invalid characters (allowed: alphanumeric, hyphen, underscore, dot)"
));
}
Ok(())
}

/// Validate action name format.
pub fn validate_action(action: &str) -> anyhow::Result<()> {
if action.is_empty() {
return Err(anyhow::anyhow!("action cannot be empty"));
}
if action.len() > MAX_ACTION_LEN {
return Err(anyhow::anyhow!(
"action exceeds maximum length of {} characters",
MAX_ACTION_LEN
));
}
if !action
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
{
return Err(anyhow::anyhow!(
"action contains invalid characters"
));
}
Ok(())
}

/// Validate call_args size and depth.
pub fn validate_call_args(args: &Value) -> anyhow::Result<()> {
let serialized = serde_json::to_string(args)
.map_err(|e| anyhow::anyhow!("Failed to serialize args: {}", e))?;
if serialized.len() > MAX_ARGS_SIZE {
return Err(anyhow::anyhow!(
"call_args exceeds maximum size of {} bytes",
MAX_ARGS_SIZE
));
}
if json_depth(args) > MAX_JSON_DEPTH {
return Err(anyhow::anyhow!(
"call_args exceeds maximum nesting depth of {}",
MAX_JSON_DEPTH
));
}
Ok(())
}

/// Calculate JSON nesting depth.
fn json_depth(value: &Value) -> usize {
match value {
Value::Array(arr) => 1 + arr.iter().map(json_depth).max().unwrap_or(0),
Value::Object(obj) => 1 + obj.values().map(json_depth).max().unwrap_or(0),
_ => 0,
}
}
}

/// ACP Inter-Agent Communication Tool
Expand Down Expand Up @@ -106,6 +186,23 @@ impl Tool for AcpTool {
let obj = shared::extract_args_object(args)?;
shared::validate_field_exists(obj, "action")?;
shared::validate_field_exists(obj, "remote_agent_id")?;
// Validate formats
if let Some(action) = obj.get("action").and_then(|v| v.as_str()) {
shared::validate_action(action)?;
}
if let Some(agent_id) = obj.get("remote_agent_id").and_then(|v| v.as_str()) {
shared::validate_agent_id(agent_id)?;
}
// Validate method if provided
if let Some(method) = obj.get("method").and_then(|v| v.as_str()) {
if method != "sync" && method != "async" {
return Err(anyhow::anyhow!("Invalid method '{}': must be 'sync' or 'async'", method));
}
}
// Validate call_args if provided
if let Some(call_args) = obj.get("args") {
shared::validate_call_args(call_args)?;
}
Ok(())
}

Expand All @@ -116,7 +213,12 @@ impl Tool for AcpTool {
let remote_agent_id = shared::get_required_field(obj, "remote_agent_id", None)?;
let method = obj.get("method").and_then(|v| v.as_str()).unwrap_or("sync");

// Validate inputs before use
shared::validate_action(action)?;
shared::validate_agent_id(remote_agent_id)?;

let call_args = obj.get("args").cloned().unwrap_or(json!({}));
shared::validate_call_args(&call_args)?;

let client = self.client.read().await;
let client = shared::check_client_initialized(&*client)?;
Expand Down
18 changes: 17 additions & 1 deletion vtcode-tools/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,23 @@ impl CachedToolExecutor {

// Execute tool (caller provides actual execution)
// This is where your tool registry would call the actual tool
let result = self.execute_tool_internal(tool_name, &*owned_args).await?;
let result = match self.execute_tool_internal(tool_name, &*owned_args).await {
Ok(r) => r,
Err(e) => {
// Invoke error handlers before propagating
if let Err(hook_err) = self.middleware.on_error(&req, &e).await {
eprintln!("[vtcode-tools] Middleware on_error hook failed: {}", hook_err);
}
// Update failed stats
{
let mut stats = self.stats.write().await;
stats.failed_calls += 1;
}
// Record failure in pattern detector
self.record_pattern(tool_name, false, start.elapsed().as_millis() as u64).await;
return Err(e);
}
};

let duration_ms = start.elapsed().as_millis() as u64;

Expand Down
19 changes: 17 additions & 2 deletions vtcode-tools/src/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,40 @@ pub struct DetectedPattern {
pub confidence: f64,
}

/// Maximum events to retain before eviction (prevents unbounded memory growth).
const MAX_EVENTS_CAPACITY: usize = 1000;

/// Pattern detector using sequence analysis.
pub struct PatternDetector {
events: Vec<ToolEvent>,
patterns: HashMap<String, DetectedPattern>,
sequence_length: usize,
max_events: usize,
}

impl PatternDetector {
/// Create new detector with sliding window size.
pub fn new(sequence_length: usize) -> Self {
Self::with_capacity(sequence_length, MAX_EVENTS_CAPACITY)
}

/// Create new detector with custom event capacity limit.
pub fn with_capacity(sequence_length: usize, max_events: usize) -> Self {
Self {
events: Vec::with_capacity(64),
events: Vec::with_capacity(64.min(max_events)),
patterns: HashMap::with_capacity(16),
sequence_length,
max_events: max_events.max(sequence_length * 2),
}
}

/// Add an event to the detector.
/// Add an event to the detector with automatic eviction.
pub fn record_event(&mut self, event: ToolEvent) {
// Evict oldest events if at capacity (sliding window)
if self.events.len() >= self.max_events {
let drain_count = self.max_events / 4; // Remove 25% of oldest
self.events.drain(0..drain_count);
}
self.events.push(event);
self.analyze();
}
Expand Down
Loading