1010from ..captcha import CaptchaHandler , CaptchaOptions
1111from ..captcha_strategies import ExternalSolver , HumanHandoffSolver , VisionSolver
1212from ..llm_interaction_handler import LLMInteractionHandler
13- from ..llm_provider import LLMProvider
13+ from ..llm_provider import LLMProvider , LLMResponse
1414from ..models import Snapshot , StepHookContext
1515from ..permissions import PermissionPolicy
1616from ..runtime_agent import RuntimeAgent , RuntimeStep
@@ -84,6 +84,9 @@ class PredicateBrowserAgentConfig:
8484 # Prompt / token controls
8585 history_last_n : int = 0 # 0 disables LLM-facing step history (lowest token usage)
8686
87+ # Opt-in: track token usage from LLM provider responses (best-effort; depends on provider reporting).
88+ token_usage_enabled : bool = False
89+
8790 # Compact prompt customization
8891 # Signature: builder(task_goal, step_goal, dom_context, snapshot, history_summary) -> (system, user)
8992 compact_prompt_builder : Callable [
@@ -146,6 +149,112 @@ def apply_captcha_config_to_runtime(
146149 )
147150
148151
152+ @dataclass
153+ class TokenUsageTotals :
154+ calls : int = 0
155+ prompt_tokens : int = 0
156+ completion_tokens : int = 0
157+ total_tokens : int = 0
158+
159+ def add (self , resp : LLMResponse ) -> None :
160+ self .calls += 1
161+ pt = resp .prompt_tokens if isinstance (resp .prompt_tokens , int ) else 0
162+ ct = resp .completion_tokens if isinstance (resp .completion_tokens , int ) else 0
163+ tt = resp .total_tokens if isinstance (resp .total_tokens , int ) else (pt + ct )
164+ self .prompt_tokens += max (0 , int (pt ))
165+ self .completion_tokens += max (0 , int (ct ))
166+ self .total_tokens += max (0 , int (tt ))
167+
168+
169+ class _TokenUsageCollector :
170+ def __init__ (self ) -> None :
171+ self ._by_role : dict [str , TokenUsageTotals ] = {}
172+ self ._by_model : dict [str , TokenUsageTotals ] = {}
173+
174+ def record (self , * , role : str , resp : LLMResponse ) -> None :
175+ self ._by_role .setdefault (role , TokenUsageTotals ()).add (resp )
176+ m = str (resp .model_name or "" ).strip () or "unknown"
177+ self ._by_model .setdefault (m , TokenUsageTotals ()).add (resp )
178+
179+ def reset (self ) -> None :
180+ self ._by_role .clear ()
181+ self ._by_model .clear ()
182+
183+ def summary (self ) -> dict [str , Any ]:
184+ def _sum (items : dict [str , TokenUsageTotals ]) -> TokenUsageTotals :
185+ out = TokenUsageTotals ()
186+ for t in items .values ():
187+ out .calls += t .calls
188+ out .prompt_tokens += t .prompt_tokens
189+ out .completion_tokens += t .completion_tokens
190+ out .total_tokens += t .total_tokens
191+ return out
192+
193+ total = _sum (self ._by_role )
194+ return {
195+ "total" : {
196+ "calls" : total .calls ,
197+ "prompt_tokens" : total .prompt_tokens ,
198+ "completion_tokens" : total .completion_tokens ,
199+ "total_tokens" : total .total_tokens ,
200+ },
201+ "by_role" : {
202+ k : {
203+ "calls" : v .calls ,
204+ "prompt_tokens" : v .prompt_tokens ,
205+ "completion_tokens" : v .completion_tokens ,
206+ "total_tokens" : v .total_tokens ,
207+ }
208+ for k , v in self ._by_role .items ()
209+ },
210+ "by_model" : {
211+ k : {
212+ "calls" : v .calls ,
213+ "prompt_tokens" : v .prompt_tokens ,
214+ "completion_tokens" : v .completion_tokens ,
215+ "total_tokens" : v .total_tokens ,
216+ }
217+ for k , v in self ._by_model .items ()
218+ },
219+ }
220+
221+
222+ class _TokenAccountingProvider (LLMProvider ):
223+ def __init__ (self , * , inner : LLMProvider , collector : _TokenUsageCollector , role : str ):
224+ super ().__init__ (model = getattr (inner , "model_name" , "wrapped" ))
225+ self ._inner = inner
226+ self ._collector = collector
227+ self ._role = role
228+
229+ def generate (self , system_prompt : str , user_prompt : str , ** kwargs ) -> LLMResponse :
230+ resp = self ._inner .generate (system_prompt , user_prompt , ** kwargs )
231+ try :
232+ self ._collector .record (role = self ._role , resp = resp )
233+ except Exception :
234+ pass
235+ return resp
236+
237+ def supports_json_mode (self ) -> bool :
238+ return self ._inner .supports_json_mode ()
239+
240+ def supports_vision (self ) -> bool :
241+ return self ._inner .supports_vision ()
242+
243+ def generate_with_image (
244+ self , system_prompt : str , user_prompt : str , image_base64 : str , ** kwargs
245+ ) -> LLMResponse :
246+ resp = self ._inner .generate_with_image (system_prompt , user_prompt , image_base64 , ** kwargs )
247+ try :
248+ self ._collector .record (role = self ._role , resp = resp )
249+ except Exception :
250+ pass
251+ return resp
252+
253+ @property
254+ def model_name (self ) -> str :
255+ return self ._inner .model_name
256+
257+
149258class _RuntimeAgentWithPromptOverrides (RuntimeAgent ):
150259 def __init__ (
151260 self ,
@@ -227,9 +336,33 @@ def __init__(
227336 config : PredicateBrowserAgentConfig = PredicateBrowserAgentConfig (),
228337 ) -> None :
229338 self .runtime = runtime
230- self .executor = executor
231- self .vision_executor = vision_executor
232- self .vision_verifier = vision_verifier
339+ self ._token_usage : _TokenUsageCollector | None = (
340+ _TokenUsageCollector () if bool (config .token_usage_enabled ) else None
341+ )
342+
343+ # Optionally wrap providers for best-effort token usage accounting.
344+ if self ._token_usage is not None :
345+ self .executor = _TokenAccountingProvider (
346+ inner = executor , collector = self ._token_usage , role = "executor"
347+ )
348+ self .vision_executor = (
349+ _TokenAccountingProvider (
350+ inner = vision_executor , collector = self ._token_usage , role = "vision_executor"
351+ )
352+ if vision_executor is not None
353+ else None
354+ )
355+ self .vision_verifier = (
356+ _TokenAccountingProvider (
357+ inner = vision_verifier , collector = self ._token_usage , role = "vision_verifier"
358+ )
359+ if vision_verifier is not None
360+ else None
361+ )
362+ else :
363+ self .executor = executor
364+ self .vision_executor = vision_executor
365+ self .vision_verifier = vision_verifier
233366 self .config = config
234367
235368 # LLM-facing step history summaries (bounded)
@@ -252,6 +385,23 @@ def __init__(
252385 history_summary_provider = self ._get_history_summary ,
253386 )
254387
388+ def get_token_usage (self ) -> dict [str , Any ]:
389+ """
390+ Best-effort token usage summary.
391+
392+ Only available when `PredicateBrowserAgentConfig.token_usage_enabled=True`.
393+ """
394+ if self ._token_usage is None :
395+ return {"enabled" : False , "reason" : "token_usage_enabled is False" }
396+ out = self ._token_usage .summary ()
397+ out ["enabled" ] = True
398+ return out
399+
400+ def reset_token_usage (self ) -> None :
401+ if self ._token_usage is None :
402+ return
403+ self ._token_usage .reset ()
404+
255405 def _get_history_summary (self ) -> str :
256406 if int (self .config .history_last_n ) <= 0 :
257407 return ""
0 commit comments