Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/context_profiler/analyzers/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from context_profiler.analyzers.base import AnalyzerResult, BaseAnalyzer
from context_profiler.models import APIRequest, BlockType, Role
from context_profiler.pricing import estimate_cost


class TokenCounterAnalyzer(BaseAnalyzer):
Expand Down Expand Up @@ -60,12 +61,19 @@ def analyze(self, request: APIRequest) -> AnalyzerResult:
tool_use_tokens = by_content_type.get("tool_use", 0)
tool_result_tokens = by_content_type.get("tool_result", 0)

cost = estimate_cost(
input_tokens=total_tokens,
output_tokens=0,
model=request.model,
)

summary = {
"total_input_tokens": total_tokens,
"message_tokens": total_tokens - tool_def_tokens,
"tool_definition_tokens": tool_def_tokens,
"system_prompt_tokens": request.system_prompt_tokens,
"source_format": request.source_format,
"model": request.model,
"by_role": dict(by_role),
"by_content_type": dict(by_content_type),
"tool_use_tokens": tool_use_tokens,
Expand All @@ -76,6 +84,9 @@ def analyze(self, request: APIRequest) -> AnalyzerResult:
"tool_definitions": tool_defs_detail,
}

if cost is not None:
summary["cost"] = cost

warnings = []
if tool_def_tokens > total_tokens * 0.3:
warnings.append(
Expand Down
59 changes: 59 additions & 0 deletions src/context_profiler/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from context_profiler.context_diff import analyze_context_diff
from context_profiler.formats import describe_format
from context_profiler.models import Session
from context_profiler.pricing import estimate_cost
from context_profiler.profiler import ProfileResult
from context_profiler.session_insights import analyze_session_insights

Expand Down Expand Up @@ -158,6 +159,37 @@ def diagnose_result(result: ProfileResult, session: Session | None = None) -> di
"recommendation": "Move stable repeated arguments into references or shorter identifiers.",
})

# Context overflow risk from budget forecast
forecast = session_insights.get("budget_forecast")
if forecast and forecast.get("estimated_overflow_turn") is not None:
current_turn_count = len(session.requests) if session else 0
overflow_turn = forecast["estimated_overflow_turn"]
utilization = forecast["current_utilization"]
# Trigger if overflow is within 2x the current turn count
if current_turn_count > 0 and overflow_turn <= current_turn_count * 2:
if utilization > 0.8:
severity = "critical"
elif utilization > 0.5:
severity = "warning"
else:
severity = "info"
issues.append({
"code": "CONTEXT_OVERFLOW_RISK",
"severity": severity,
"message": "Context is projected to overflow the model window at the current growth rate.",
"evidence": {
"growth_rate_per_turn": forecast["growth_rate_per_turn"],
"current_utilization": forecast["current_utilization"],
"estimated_overflow_turn": forecast["estimated_overflow_turn"],
"context_window_tokens": forecast["context_window_tokens"],
"model": forecast["model"],
},
"recommendation": "Consider compacting earlier turns, summarizing tool results, or removing stale context before the window fills.",
})

# Cost estimation
cost_info = _compute_cost(result, session)

return {
"schema_version": "0.1",
"source": result.source,
Expand All @@ -168,12 +200,39 @@ def diagnose_result(result: ProfileResult, session: Session | None = None) -> di
"warnings": result.all_warnings,
},
"issues": issues,
"cost": cost_info,
"diff_summary": diff["diff_summary"],
"diff_hints": diff["diff_hints"] + session_insights["hints"],
"session_insights": session_insights,
}


def _compute_cost(result: ProfileResult, session: Session | None = None) -> dict[str, Any] | None:
"""Compute cost estimation for the profiled request or session."""
token = result.analyzer_results.get("token_counter")
if not token:
return None

summary = token.summary
model = summary.get("model", "unknown")

if session and session.requests:
# Session mode: sum input tokens across all requests
total_input = sum(req.total_input_tokens for req in session.requests)
cost = estimate_cost(input_tokens=total_input, model=model)
if cost:
cost["mode"] = "session"
cost["num_requests"] = len(session.requests)
return cost
else:
# Snapshot mode: single request
total_input = summary.get("total_input_tokens", 0)
cost = estimate_cost(input_tokens=total_input, model=model)
if cost:
cost["mode"] = "snapshot"
return cost


def _analysis_scope(result: ProfileResult, session: Session | None = None) -> dict[str, Any]:
source_format = _source_format(result) or (
session.metadata.get("source_format") if session is not None else None
Expand Down
109 changes: 109 additions & 0 deletions src/context_profiler/pricing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Model pricing table for cost estimation.

Maps model name patterns to input/output pricing per 1M tokens (USD).
Prices are approximate and should be updated as providers change rates.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any


@dataclass
class ModelPricing:
"""Pricing for a single model tier."""

input_per_1m: float # USD per 1M input tokens
output_per_1m: float # USD per 1M output tokens
display_name: str


# Patterns are matched in order; first match wins.
# Use lowercase substrings for matching against model identifiers.
PRICING_TABLE: list[tuple[list[str], ModelPricing]] = [
# Claude models
(
["claude-opus-4", "claude-4-opus"],
ModelPricing(input_per_1m=15.0, output_per_1m=75.0, display_name="Claude Opus 4"),
),
(
["claude-sonnet-4", "claude-4-sonnet"],
ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude Sonnet 4"),
),
(
["claude-3-5-sonnet", "claude-3.5-sonnet"],
ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude 3.5 Sonnet"),
),
(
["claude-3-5-haiku", "claude-3.5-haiku"],
ModelPricing(input_per_1m=0.80, output_per_1m=4.0, display_name="Claude 3.5 Haiku"),
),
(
["claude-3-opus"],
ModelPricing(input_per_1m=15.0, output_per_1m=75.0, display_name="Claude 3 Opus"),
),
(
["claude-3-sonnet"],
ModelPricing(input_per_1m=3.0, output_per_1m=15.0, display_name="Claude 3 Sonnet"),
),
(
["claude-3-haiku"],
ModelPricing(input_per_1m=0.25, output_per_1m=1.25, display_name="Claude 3 Haiku"),
),
# GPT models
(
["gpt-4o-mini"],
ModelPricing(input_per_1m=0.15, output_per_1m=0.60, display_name="GPT-4o mini"),
),
(
["gpt-4o"],
ModelPricing(input_per_1m=2.50, output_per_1m=10.0, display_name="GPT-4o"),
),
(
["gpt-4-turbo"],
ModelPricing(input_per_1m=10.0, output_per_1m=30.0, display_name="GPT-4 Turbo"),
),
]


def lookup_pricing(model: str) -> ModelPricing | None:
"""Find pricing for a model by matching name patterns.

Returns None if no match is found.
"""
if not model or model == "unknown":
return None

model_lower = model.lower()
for patterns, pricing in PRICING_TABLE:
for pattern in patterns:
if pattern in model_lower:
return pricing
return None


def estimate_cost(
input_tokens: int,
output_tokens: int = 0,
model: str = "unknown",
) -> dict[str, Any] | None:
"""Estimate cost for a request given token counts and model.

Returns a dict with cost breakdown, or None if model is unknown.
"""
pricing = lookup_pricing(model)
if pricing is None:
return None

input_cost = (input_tokens / 1_000_000) * pricing.input_per_1m
output_cost = (output_tokens / 1_000_000) * pricing.output_per_1m

return {
"estimated_input_cost_usd": round(input_cost, 6),
"estimated_output_cost_usd": round(output_cost, 6),
"estimated_total_cost_usd": round(input_cost + output_cost, 6),
"estimated_model": pricing.display_name,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}
9 changes: 9 additions & 0 deletions src/context_profiler/reporters/cli_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def _render_token_summary(console: Console, summary: dict) -> None:
tool_table.add_row(f" {tool_name}", _format_tokens(tokens), _pct(tokens, total))
console.print(tool_table)

cost = summary.get("cost")
if cost:
console.print()
console.print("[bold] Estimated Cost[/bold]")
model_name = cost.get("estimated_model", "unknown")
input_cost = cost.get("estimated_input_cost_usd", 0)
console.print(f" Model: {model_name}")
console.print(f" Input cost: ${input_cost:.4f}")


def _render_timeline(console: Console, timeline: list[dict]) -> None:
console.print()
Expand Down
68 changes: 68 additions & 0 deletions src/context_profiler/session_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@
_COMPRESSION_DROP_MIN_TOKENS = 5_000
_COMPRESSION_DROP_RATIO = 0.15

# Context window sizes by model family (tokens)
# Ordered longest-prefix-first for correct matching
_CONTEXT_WINDOW_SIZES: list[tuple[str, int]] = [
("gpt-4o-mini", 128_000),
("gpt-4-turbo", 128_000),
("gpt-4o", 128_000),
("claude", 200_000),
]
_DEFAULT_CONTEXT_WINDOW = 128_000
_OVERFLOW_RISK_HORIZON_MULTIPLIER = 2 # warn if overflow within 2x current turn count


def analyze_session_insights(session: Session | None) -> dict[str, Any]:
"""Summarize session-level carryover, budget, and artifact lifecycle signals."""
Expand All @@ -41,6 +52,7 @@ def analyze_session_insights(session: Session | None) -> dict[str, Any]:
artifacts = _artifact_lifecycles(blocks_by_request)
artifact_duplications = _artifact_duplications(session)
propagation = _propagation_graph(blocks_by_request)
forecast = budget_forecast(session)
hints = _build_hints(carryover, budget_events, artifacts, artifact_duplications)

return {
Expand All @@ -49,6 +61,7 @@ def analyze_session_insights(session: Session | None) -> dict[str, Any]:
"artifact_lifecycles": artifacts,
"artifact_duplications": artifact_duplications,
"propagation": propagation,
"budget_forecast": forecast,
"hints": hints,
}

Expand Down Expand Up @@ -525,3 +538,58 @@ def _find_first_key(value: Any, keys: tuple[str, ...]) -> Any | None:
if found is not None:
return found
return None


# ---------------------------------------------------------------------------
# Budget Forecast
# ---------------------------------------------------------------------------


def _resolve_context_window(model: str) -> tuple[str, int]:
"""Match a model string to a known context window size.

Returns (matched_model_family, context_window_tokens).
"""
lowered = model.lower()
for prefix, size in _CONTEXT_WINDOW_SIZES:
if prefix in lowered:
return prefix, size
return "default", _DEFAULT_CONTEXT_WINDOW


def budget_forecast(session: Session) -> dict[str, Any] | None:
"""Predict when a session will hit the context window limit.

Returns None for sessions with fewer than 2 requests (not enough data).
"""
if session is None or len(session.requests) < 2:
return None

token_counts = [req.total_input_tokens for req in session.requests]
num_turns = len(token_counts)

# Determine model from the last request (most representative)
model_raw = session.requests[-1].model or "unknown"
matched_model, context_window = _resolve_context_window(model_raw)

# Calculate turn-over-turn deltas for growth rate
deltas = [token_counts[i] - token_counts[i - 1] for i in range(1, num_turns)]
growth_rate = sum(deltas) / len(deltas) if deltas else 0.0

current_tokens = token_counts[-1]
current_utilization = current_tokens / context_window if context_window else 0.0

# Predict overflow turn (linear extrapolation from current position)
estimated_overflow_turn: int | None = None
if growth_rate > 0:
remaining = context_window - current_tokens
turns_until_overflow = remaining / growth_rate
estimated_overflow_turn = num_turns + int(turns_until_overflow)

return {
"growth_rate_per_turn": round(growth_rate, 1),
"current_utilization": round(current_utilization, 4),
"estimated_overflow_turn": estimated_overflow_turn,
"context_window_tokens": context_window,
"model": matched_model,
}
Loading
Loading