Skip to content

Commit 5b691bc

Browse files
committed
fix(run): include max tokens in dry-run estimates
1 parent 7954d02 commit 5b691bc

3 files changed

Lines changed: 120 additions & 44 deletions

File tree

src/cortex-cli/src/agent_cmd/tests.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
#[cfg(test)]
44
mod tests {
55
use crate::agent_cmd::cli::{CopyArgs, ExportArgs};
6-
use crate::agent_cmd::loader::{
7-
load_builtin_agents, parse_frontmatter, read_file_with_encoding,
8-
};
6+
use crate::agent_cmd::loader::{load_builtin_agents, parse_frontmatter};
97
use crate::agent_cmd::types::AgentMode;
8+
use crate::utils::file::read_file_with_encoding;
109

1110
#[test]
1211
fn test_read_file_with_utf8() {

src/cortex-cli/src/run_cmd/cli.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ pub struct RunCli {
161161
/// Maximum tokens for response (used for token validation).
162162
/// If specified, cortex will validate that prompt + max_tokens
163163
/// does not exceed the model's context limit before making the API call.
164-
#[arg(long = "max-tokens")]
164+
#[arg(long = "max-tokens", value_parser = clap::value_parser!(u32).range(1..))]
165165
pub max_tokens: Option<u32>,
166166

167167
/// Custom system prompt to use instead of the default.

src/cortex-cli/src/run_cmd/execution.rs

Lines changed: 117 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ use super::output::{copy_to_clipboard, send_notification};
1717
use super::session::{SessionMode, resolve_session_id};
1818
use super::system::check_file_descriptor_limits;
1919

20+
#[derive(Debug, PartialEq, Eq)]
21+
struct DryRunTokenEstimate {
22+
user_prompt_tokens: u32,
23+
attachment_tokens: u32,
24+
system_prompt_tokens: u32,
25+
tool_tokens: u32,
26+
tool_count: u32,
27+
total_input_tokens: u32,
28+
max_response_tokens: Option<u32>,
29+
total_with_max_response: Option<u32>,
30+
}
31+
2032
impl RunCli {
2133
/// Run the command.
2234
pub async fn run(self) -> Result<()> {
@@ -784,55 +796,28 @@ impl RunCli {
784796

785797
/// Run in dry-run mode - show token estimates without executing.
786798
async fn run_dry_run(&self, message: &str, attachments: &[FileAttachment]) -> Result<()> {
787-
use cortex_engine::tokenizer::TokenCounter;
788-
789799
let config = cortex_engine::Config::default();
790800
let model = self
791801
.model
792802
.as_ref()
793803
.map(|m| resolve_model_alias(m).to_string())
794804
.unwrap_or_else(|| config.model.clone());
795805

796-
let mut counter = TokenCounter::for_model(&model);
797-
798-
// Count user prompt tokens
799-
let user_prompt_tokens = counter.count(message);
800-
801-
// Count attachment tokens
802-
let mut attachment_tokens = 0u32;
803-
for attachment in attachments {
804-
let content =
805-
std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new());
806-
attachment_tokens += counter.count(&content);
807-
// Add overhead for file markers
808-
attachment_tokens += 20; // Approximate overhead for "--- File: ... ---" markers
809-
}
810-
811-
// Estimate system prompt tokens (typical system prompt is ~500-2000 tokens)
812-
// This is an approximation as the actual system prompt varies
813-
let system_prompt_tokens = 1500u32;
814-
815-
// Estimate tool definition tokens
816-
// Each tool definition is approximately 100-200 tokens on average
817-
// Common tools: Execute, Read, Write, Edit, LS, Grep, Glob, etc.
818-
let tool_count = 15; // Approximate number of default tools
819-
let tool_tokens = tool_count * 150; // ~150 tokens per tool definition
820-
821-
// Calculate totals
822-
let total_input_tokens =
823-
user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens;
806+
let estimate = self.estimate_dry_run_tokens(&model, message, attachments);
824807

825808
// Output based on format
826809
if matches!(self.format, OutputFormat::Json | OutputFormat::Jsonl) {
827810
let output = serde_json::json!({
828811
"dry_run": true,
829812
"model": model,
830813
"token_estimates": {
831-
"user_prompt": user_prompt_tokens,
832-
"attachments": attachment_tokens,
833-
"system_prompt": system_prompt_tokens,
834-
"tool_definitions": tool_tokens,
835-
"total_input": total_input_tokens,
814+
"user_prompt": estimate.user_prompt_tokens,
815+
"attachments": estimate.attachment_tokens,
816+
"system_prompt": estimate.system_prompt_tokens,
817+
"tool_definitions": estimate.tool_tokens,
818+
"total_input": estimate.total_input_tokens,
819+
"max_response": estimate.max_response_tokens,
820+
"total_with_max_response": estimate.total_with_max_response,
836821
},
837822
"message_preview": if message.len() > 100 {
838823
format!("{}...", &message[..100])
@@ -849,24 +834,36 @@ impl RunCli {
849834
println!("Model: {}", model);
850835
println!();
851836
println!("Token Breakdown:");
852-
println!(" User prompt: {:>8} tokens", user_prompt_tokens);
837+
println!(
838+
" User prompt: {:>8} tokens",
839+
estimate.user_prompt_tokens
840+
);
853841
if !attachments.is_empty() {
854842
println!(
855843
" Attachments: {:>8} tokens ({} files)",
856-
attachment_tokens,
844+
estimate.attachment_tokens,
857845
attachments.len()
858846
);
859847
}
860848
println!(
861849
" System prompt: {:>8} tokens (estimated)",
862-
system_prompt_tokens
850+
estimate.system_prompt_tokens
863851
);
864852
println!(
865853
" Tool definitions: {:>8} tokens (estimated, {} tools)",
866-
tool_tokens, tool_count
854+
estimate.tool_tokens, estimate.tool_count
867855
);
868856
println!(" {}", "-".repeat(30));
869-
println!(" Total input: {:>8} tokens", total_input_tokens);
857+
println!(
858+
" Total input: {:>8} tokens",
859+
estimate.total_input_tokens
860+
);
861+
if let Some(max_tokens) = estimate.max_response_tokens {
862+
println!(" Max response: {:>8} tokens", max_tokens);
863+
if let Some(total_with_max_response) = estimate.total_with_max_response {
864+
println!(" Input + response: {:>8} tokens", total_with_max_response);
865+
}
866+
}
870867
println!();
871868
println!("Note: System prompt and tool definition token counts are estimates.");
872869
println!("Actual counts may vary based on agent configuration.");
@@ -884,4 +881,84 @@ impl RunCli {
884881

885882
Ok(())
886883
}
884+
885+
fn estimate_dry_run_tokens(
886+
&self,
887+
model: &str,
888+
message: &str,
889+
attachments: &[FileAttachment],
890+
) -> DryRunTokenEstimate {
891+
use cortex_engine::tokenizer::TokenCounter;
892+
893+
let mut counter = TokenCounter::for_model(&model);
894+
895+
// 统计用户提示词 token。
896+
let user_prompt_tokens = counter.count(message);
897+
898+
// 统计附件 token。
899+
let mut attachment_tokens = 0u32;
900+
for attachment in attachments {
901+
let content =
902+
std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new());
903+
attachment_tokens += counter.count(&content);
904+
// 加上文件标记的近似开销。
905+
attachment_tokens += 20;
906+
}
907+
908+
// 系统提示词会随配置变化,这里使用常见范围内的近似值。
909+
let system_prompt_tokens = 1500u32;
910+
911+
// 工具定义按默认工具数量和单个工具的平均 token 估算。
912+
let tool_count = 15;
913+
let tool_tokens = tool_count * 150;
914+
915+
// 计算输入总量,并在传入 max_tokens 时纳入响应上限。
916+
let total_input_tokens =
917+
user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens;
918+
let total_with_max_response = self
919+
.max_tokens
920+
.map(|max| total_input_tokens.saturating_add(max));
921+
922+
DryRunTokenEstimate {
923+
user_prompt_tokens,
924+
attachment_tokens,
925+
system_prompt_tokens,
926+
tool_tokens,
927+
tool_count,
928+
total_input_tokens,
929+
max_response_tokens: self.max_tokens,
930+
total_with_max_response,
931+
}
932+
}
933+
}
934+
935+
#[cfg(test)]
936+
mod tests {
937+
use super::*;
938+
use clap::Parser;
939+
940+
#[test]
941+
fn dry_run_estimate_includes_max_response_tokens() {
942+
let cli = RunCli::try_parse_from(["run", "--dry-run", "--max-tokens", "4096", "Long task"])
943+
.expect("max tokens above zero should parse");
944+
945+
let estimate = cli.estimate_dry_run_tokens("gpt-4o", "Long task", &[]);
946+
947+
assert_eq!(estimate.max_response_tokens, Some(4096));
948+
assert_eq!(
949+
estimate.total_with_max_response,
950+
Some(estimate.total_input_tokens + 4096)
951+
);
952+
}
953+
954+
#[test]
955+
fn run_rejects_zero_max_tokens() {
956+
let error = RunCli::try_parse_from(["run", "--max-tokens", "0", "Long task"])
957+
.expect_err("zero max tokens should be rejected");
958+
959+
assert!(
960+
error.to_string().contains("invalid value"),
961+
"unexpected error: {error}"
962+
);
963+
}
887964
}

0 commit comments

Comments
 (0)