Skip to content

Commit a0910dc

Browse files
author
Sentience Dev
committed
type checking
1 parent 6156db1 commit a0910dc

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

sentience/visual_agent.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
import uuid
2323
from pathlib import Path
24-
from typing import Any, Optional
24+
from typing import TYPE_CHECKING, Any, Optional
2525

2626
from .actions import click, click_async
2727
from .agent import SentienceAgent, SentienceAgentAsync, _safe_tracer_call
@@ -33,13 +33,26 @@
3333
from .snapshot_diff import SnapshotDiff
3434
from .trace_event_builder import TraceEventBuilder
3535

36-
try:
36+
# Only import PIL types for type checking, not at runtime
37+
if TYPE_CHECKING:
3738
from PIL import Image, ImageDraw, ImageFont
39+
else:
40+
# Create a dummy type for runtime when PIL is not available
41+
Image = None
42+
ImageDraw = None
43+
ImageFont = None
44+
45+
try:
46+
from PIL import Image as PILImage, ImageDraw as PILImageDraw, ImageFont as PILImageFont
3847

3948
PIL_AVAILABLE = True
4049
except ImportError:
4150
PIL_AVAILABLE = False
42-
print("⚠️ Warning: PIL/Pillow not available. Install with: pip install Pillow")
51+
# Define dummy values so type hints don't fail
52+
PILImage = None # type: ignore
53+
PILImageDraw = None # type: ignore
54+
PILImageFont = None # type: ignore
55+
# Don't print warning here - it will be printed when the class is instantiated
4356

4457

4558
class SentienceVisualAgentAsync(SentienceAgentAsync):
@@ -84,7 +97,7 @@ def __init__(
8497
# Track previous snapshot for diff computation
8598
self._previous_snapshot: Snapshot | None = None
8699

87-
def _decode_screenshot(self, screenshot_data_url: str) -> Image.Image:
100+
def _decode_screenshot(self, screenshot_data_url: str) -> "PILImage.Image":
88101
"""
89102
Decode base64 screenshot data URL to PIL Image
90103
@@ -106,7 +119,7 @@ def _decode_screenshot(self, screenshot_data_url: str) -> Image.Image:
106119
image_bytes = base64.b64decode(base64_data)
107120

108121
# Create PIL Image from bytes
109-
return Image.open(io.BytesIO(image_bytes))
122+
return PILImage.open(io.BytesIO(image_bytes))
110123

111124
def _find_label_position(
112125
self,
@@ -206,7 +219,7 @@ def _draw_labeled_screenshot(
206219
self,
207220
snapshot: Snapshot,
208221
elements: list[Element],
209-
) -> Image.Image:
222+
) -> "PILImage.Image":
210223
"""
211224
Draw bounding boxes and labels on screenshot.
212225
@@ -222,18 +235,18 @@ def _draw_labeled_screenshot(
222235

223236
# Decode screenshot
224237
img = self._decode_screenshot(snapshot.screenshot)
225-
draw = ImageDraw.Draw(img)
238+
draw = PILImageDraw.Draw(img)
226239

227240
# Try to load a font, fallback to default if not available
228241
try:
229242
# Try to use a system font
230-
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
243+
font = PILImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
231244
except:
232245
try:
233-
font = ImageFont.truetype("arial.ttf", 16)
246+
font = PILImageFont.truetype("arial.ttf", 16)
234247
except:
235248
# Use default font if system fonts not available
236-
font = ImageFont.load_default()
249+
font = PILImageFont.load_default()
237250

238251
image_width, image_height = img.size
239252
existing_labels: list[dict[str, Any]] = []
@@ -342,7 +355,7 @@ def _draw_labeled_screenshot(
342355
return img
343356

344357
def _encode_image_to_base64(
345-
self, image: Image.Image, format: str = "PNG", max_size_mb: float = 20.0
358+
self, image: "PILImage.Image", format: str = "PNG", max_size_mb: float = 20.0
346359
) -> str:
347360
"""
348361
Encode PIL Image to base64 data URL with size optimization.
@@ -367,7 +380,7 @@ def _encode_image_to_base64(
367380
# Convert RGBA to RGB for JPEG
368381
if image.mode in ("RGBA", "LA", "P"):
369382
# Create white background
370-
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
383+
rgb_image = PILImage.new("RGB", image.size, (255, 255, 255))
371384
if image.mode == "P":
372385
image = image.convert("RGBA")
373386
rgb_image.paste(image, mask=image.split()[-1] if image.mode == "RGBA" else None)
@@ -1175,7 +1188,7 @@ def __init__(
11751188
# Track previous snapshot for diff computation
11761189
self._previous_snapshot: Snapshot | None = None
11771190

1178-
def _decode_screenshot(self, screenshot_data_url: str) -> Image.Image:
1191+
def _decode_screenshot(self, screenshot_data_url: str) -> "PILImage.Image":
11791192
"""
11801193
Decode base64 screenshot data URL to PIL Image
11811194
@@ -1197,7 +1210,7 @@ def _decode_screenshot(self, screenshot_data_url: str) -> Image.Image:
11971210
image_bytes = base64.b64decode(base64_data)
11981211

11991212
# Load image from bytes
1200-
return Image.open(io.BytesIO(image_bytes))
1213+
return PILImage.open(io.BytesIO(image_bytes))
12011214

12021215
def _find_label_position(
12031216
self,
@@ -1284,7 +1297,7 @@ def _draw_labeled_screenshot(
12841297
self,
12851298
snapshot: Snapshot,
12861299
elements: list[Element],
1287-
) -> Image.Image:
1300+
) -> "PILImage.Image":
12881301
"""
12891302
Draw labeled screenshot with bounding boxes and element IDs.
12901303
@@ -1297,18 +1310,18 @@ def _draw_labeled_screenshot(
12971310
"""
12981311
# Decode screenshot
12991312
img = self._decode_screenshot(snapshot.screenshot)
1300-
draw = ImageDraw.Draw(img)
1313+
draw = PILImageDraw.Draw(img)
13011314

13021315
# Load font (fallback to default if not available)
13031316
try:
1304-
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
1317+
font = PILImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
13051318
except OSError:
13061319
try:
1307-
font = ImageFont.truetype(
1320+
font = PILImageFont.truetype(
13081321
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16
13091322
)
13101323
except OSError:
1311-
font = ImageFont.load_default()
1324+
font = PILImageFont.load_default()
13121325

13131326
image_width, image_height = img.size
13141327
existing_labels: list[dict[str, float]] = []
@@ -1410,7 +1423,7 @@ def _draw_labeled_screenshot(
14101423

14111424
def _encode_image_to_base64(
14121425
self,
1413-
image: Image.Image,
1426+
image: "PILImage.Image",
14141427
format: str = "PNG",
14151428
max_size_mb: float = 20.0,
14161429
) -> str:

0 commit comments

Comments
 (0)