Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions sentience/backends/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cache.invalidate() # Force refresh on next get()
"""

import asyncio
import time
from typing import TYPE_CHECKING, Any

Expand All @@ -37,6 +38,57 @@
from .protocol import BrowserBackend


def _is_execution_context_destroyed_error(e: Exception) -> bool:
"""
Playwright (and other browser backends) can throw while a navigation is in-flight.

Common symptoms:
- "Execution context was destroyed, most likely because of a navigation"
- "Cannot find context with specified id"
"""
msg = str(e).lower()
return (
"execution context was destroyed" in msg
or "most likely because of a navigation" in msg
or "cannot find context with specified id" in msg
)


async def _eval_with_navigation_retry(
backend: "BrowserBackend",
expression: str,
*,
retries: int = 10,
settle_state: str = "interactive",
settle_timeout_ms: int = 10000,
) -> Any:
"""
Evaluate JS, retrying once/ twice if the page is mid-navigation.

This makes snapshots resilient to cases like:
- press Enter (navigation) → snapshot immediately → context destroyed
"""
last_err: Exception | None = None
for attempt in range(retries + 1):
try:
return await backend.eval(expression)
except Exception as e:
last_err = e
if not _is_execution_context_destroyed_error(e) or attempt >= retries:
raise
# Navigation is in-flight; wait for new document context then retry.
try:
await backend.wait_ready_state(state=settle_state, timeout_ms=settle_timeout_ms) # type: ignore[arg-type]
except Exception:
# If readyState polling also fails mid-nav, still retry after a short backoff.
pass
# Exponential-ish backoff (caps quickly), tuned for real navigations.
await asyncio.sleep(min(0.25 * (attempt + 1), 1.5))

# Unreachable in practice, but keeps type-checkers happy.
raise last_err if last_err else RuntimeError("eval failed")


class CachedSnapshot:
"""
Snapshot cache with staleness detection.
Expand Down Expand Up @@ -289,13 +341,14 @@ async def _snapshot_via_extension(
ext_options = _build_extension_options(options)

# Call extension's snapshot function
result = await backend.eval(
result = await _eval_with_navigation_retry(
backend,
f"""
(() => {{
const options = {_json_serialize(ext_options)};
return window.sentience.snapshot(options);
}})()
"""
""",
)

if result is None:
Expand All @@ -310,14 +363,15 @@ async def _snapshot_via_extension(
if options.show_overlay:
raw_elements = result.get("raw_elements", [])
if raw_elements:
await backend.eval(
await _eval_with_navigation_retry(
backend,
f"""
(() => {{
if (window.sentience && window.sentience.showOverlay) {{
window.sentience.showOverlay({_json_serialize(raw_elements)}, null);
}}
}})()
"""
""",
)

# Build and return Snapshot
Expand All @@ -341,13 +395,14 @@ async def _snapshot_via_api(
raw_options["screenshot"] = options.screenshot

# Call extension to get raw elements
raw_result = await backend.eval(
raw_result = await _eval_with_navigation_retry(
backend,
f"""
(() => {{
const options = {_json_serialize(raw_options)};
return window.sentience.snapshot(options);
}})()
"""
""",
)

if raw_result is None:
Expand All @@ -372,14 +427,15 @@ async def _snapshot_via_api(
if options.show_overlay:
elements = api_result.get("elements", [])
if elements:
await backend.eval(
await _eval_with_navigation_retry(
backend,
f"""
(() => {{
if (window.sentience && window.sentience.showOverlay) {{
window.sentience.showOverlay({_json_serialize(elements)}, null);
}}
}})()
"""
""",
)

return Snapshot(**snapshot_data)
Expand Down
44 changes: 34 additions & 10 deletions sentience/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from .llm_provider_utils import get_api_key_from_env, handle_provider_error, require_package
from .llm_response_builder import LLMResponseBuilder
Expand Down Expand Up @@ -777,21 +778,44 @@ def __init__(
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

device = (device or "auto").strip().lower()

# Determine torch dtype
if torch_dtype == "auto":
dtype = torch.float16 if device != "cpu" else torch.float32
dtype = torch.float16 if device not in {"cpu"} else torch.float32
else:
dtype = getattr(torch, torch_dtype)

# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
torch_dtype=dtype if quantization_config is None else None,
device_map=device,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
# device_map is a Transformers concept (not a literal "cpu/mps/cuda" device string).
# - "auto" enables Accelerate device mapping.
# - Otherwise, we load normally and then move the model to the requested device.
device_map: str | None = "auto" if device == "auto" else None

def _load(*, device_map_override: str | None) -> Any:
return AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
torch_dtype=dtype if quantization_config is None else None,
device_map=device_map_override,
trust_remote_code=True,
low_cpu_mem_usage=True,
)

try:
self.model = _load(device_map_override=device_map)
except KeyError as e:
# Some envs / accelerate versions can crash on auto mapping (e.g. KeyError: 'cpu').
# Keep demo ergonomics: default stays "auto", but we gracefully fall back.
if device == "auto" and ("cpu" in str(e).lower()):
device = "cpu"
dtype = torch.float32
self.model = _load(device_map_override=None)
else:
raise

# If we didn't use device_map, move model explicitly (only safe for non-quantized loads).
if device_map is None and quantization_config is None and device in {"cpu", "cuda", "mps"}:
self.model = self.model.to(device)
self.model.eval()

def generate(
Expand Down
Loading
Loading