Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions src/bot/features/image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,79 @@ async def process_image(
)

def _detect_image_type(self, image_bytes: bytes) -> str:
"""Detect type of image"""
# Simple heuristic based on image characteristics
# In practice, could use ML model for better detection

# For now, return generic type
return "screenshot"
"""Detect type of image using format and dimension heuristics."""
fmt = self._detect_format(image_bytes)
width, height = self._get_dimensions(image_bytes, fmt)

if width == 0 or height == 0:
return "generic"

aspect = width / height

# Very wide images are likely diagrams or flowcharts
if aspect > 2.5:
return "diagram"

# Very tall images are likely mobile screenshots or scrolling captures
if aspect < 0.4:
return "screenshot"

# Common desktop/mobile screenshot aspect ratios (16:9, 16:10, 9:16, etc.)
if 1.2 < aspect < 2.0 and width >= 800:
return "screenshot"

# Phone-portrait screenshots
if 0.4 <= aspect <= 0.65 and height >= 1000:
return "screenshot"

# Square-ish images with moderate resolution are often UI mockups
if 0.8 <= aspect <= 1.25 and width >= 400:
return "ui_mockup"

# Small images are likely icons or thumbnails
if width < 256 and height < 256:
return "generic"

return "generic"

@staticmethod
def _get_dimensions(image_bytes: bytes, fmt: str) -> tuple:
"""Extract width and height from image bytes without PIL."""
try:
if fmt == "png" and len(image_bytes) >= 24:
# PNG: width at offset 16 (4 bytes BE), height at offset 20 (4 bytes BE)
w = int.from_bytes(image_bytes[16:20], "big")
h = int.from_bytes(image_bytes[20:24], "big")
return w, h
elif fmt == "jpeg" and len(image_bytes) > 2:
# JPEG: scan for SOF0/SOF2 markers (0xFF 0xC0 / 0xFF 0xC2)
i = 2
while i < len(image_bytes) - 9:
if image_bytes[i] != 0xFF:
i += 1
continue
marker = image_bytes[i + 1]
if marker in (0xC0, 0xC2):
h = int.from_bytes(image_bytes[i + 5 : i + 7], "big")
w = int.from_bytes(image_bytes[i + 7 : i + 9], "big")
return w, h
# Skip to next marker
length = int.from_bytes(image_bytes[i + 2 : i + 4], "big")
i += 2 + length
elif fmt == "gif" and len(image_bytes) >= 10:
# GIF: width at offset 6 (2 bytes LE), height at offset 8 (2 bytes LE)
w = int.from_bytes(image_bytes[6:8], "little")
h = int.from_bytes(image_bytes[8:10], "little")
return w, h
elif fmt == "webp" and len(image_bytes) >= 30:
# WebP VP8: dimensions at offset 26-30
if image_bytes[12:16] == b"VP8 " and len(image_bytes) >= 30:
w = int.from_bytes(image_bytes[26:28], "little") & 0x3FFF
h = int.from_bytes(image_bytes[28:30], "little") & 0x3FFF
return w, h
except Exception:
pass
return 0, 0

def _detect_format(self, image_bytes: bytes) -> str:
"""Detect image format from magic bytes"""
Expand Down
6 changes: 5 additions & 1 deletion src/bot/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ async def agentic_text(
# Rate limit check
rate_limiter = context.bot_data.get("rate_limiter")
if rate_limiter:
allowed, limit_message = await rate_limiter.check_rate_limit(user_id, 0.001)
allowed, limit_message = await rate_limiter.check_rate_limit(user_id, 0.0)
if not allowed:
await update.message.reply_text(f"⏱️ {limit_message}")
return
Expand Down Expand Up @@ -536,6 +536,10 @@ async def agentic_text(

context.user_data["claude_session_id"] = claude_response.session_id

# Track actual cost post-execution
if rate_limiter and claude_response.cost and claude_response.cost > 0:
await rate_limiter.check_rate_limit(user_id, claude_response.cost, 0)

# Track directory changes
from .handlers.message import _update_working_directory_from_claude_response

Expand Down
10 changes: 7 additions & 3 deletions src/claude/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def run_command(
user_id: int,
session_id: Optional[str] = None,
on_stream: Optional[Callable[[StreamUpdate], None]] = None,
force_new: bool = False,
) -> ClaudeResponse:
"""Run Claude Code command with full integration."""
logger.info(
Expand All @@ -63,11 +64,13 @@ async def run_command(
working_directory=str(working_directory),
session_id=session_id,
prompt_length=len(prompt),
force_new=force_new,
)

# If no session_id provided, try to find an existing session for this
# user+directory combination (auto-resume)
if not session_id:
# user+directory combination (auto-resume).
# Skip auto-resume when force_new is set (e.g. after /new command).
if not session_id and not force_new:
existing_session = await self._find_resumable_session(
user_id, working_directory
)
Expand Down Expand Up @@ -120,7 +123,7 @@ async def stream_handler(update: StreamUpdate):
)

# For critical tools, we should fail fast
if tool_name in ["Task", "Read", "Write", "Edit"]:
if tool_name in ["Task", "Read", "Write", "Edit", "Bash"]:
# Create comprehensive error message
admin_instructions = self._get_admin_instructions(
list(blocked_tools)
Expand Down Expand Up @@ -296,6 +299,7 @@ async def _execute_with_fallback(
or "JSON decode error" in error_str
or "TaskGroup" in error_str
or "ExceptionGroup" in error_str
or "Unknown message type" in error_str
):
self._sdk_failed_count += 1
logger.warning(
Expand Down
158 changes: 152 additions & 6 deletions src/claude/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
- Track tool calls
- Security validation
- Usage analytics
- Bash directory boundary enforcement
"""

import shlex
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

import structlog

Expand All @@ -17,18 +19,133 @@

logger = structlog.get_logger()

# Commands that modify the filesystem and should have paths checked
_FS_MODIFYING_COMMANDS: Set[str] = {
"mkdir",
"touch",
"cp",
"mv",
"rm",
"rmdir",
"ln",
"install",
"tee",
}

# Commands that are read-only or don't take filesystem paths
_READ_ONLY_COMMANDS: Set[str] = {
"cat",
"ls",
"head",
"tail",
"less",
"more",
"which",
"whoami",
"pwd",
"echo",
"printf",
"env",
"printenv",
"date",
"wc",
"sort",
"uniq",
"diff",
"file",
"stat",
"du",
"df",
"tree",
"realpath",
"dirname",
"basename",
}

# Actions / expressions that make ``find`` a filesystem-modifying command
_FIND_MUTATING_ACTIONS: Set[str] = {"-delete", "-exec", "-execdir", "-ok", "-okdir"}


def check_bash_directory_boundary(
command: str,
working_directory: Path,
approved_directory: Path,
) -> Tuple[bool, Optional[str]]:
"""Check if a bash command's absolute paths stay within the approved directory.

Returns (True, None) if the command is safe, or (False, error_message) if it
attempts to write outside the approved directory boundary.
"""
try:
tokens = shlex.split(command)
except ValueError:
# If we can't parse the command, let it through —
# the sandbox will catch it at the OS level
return True, None

if not tokens:
return True, None

base_command = Path(tokens[0]).name

# Read-only commands are always allowed
if base_command in _READ_ONLY_COMMANDS:
return True, None

# Handle ``find`` specially: only dangerous when it contains mutating actions
if base_command == "find":
has_mutating_action = any(t in _FIND_MUTATING_ACTIONS for t in tokens[1:])
if not has_mutating_action:
return True, None
# Fall through to path checking below
elif base_command not in _FS_MODIFYING_COMMANDS:
# Only check filesystem-modifying commands
return True, None

# Check each argument for paths outside the boundary
resolved_approved = approved_directory.resolve()

for token in tokens[1:]:
# Skip flags
if token.startswith("-"):
continue

# Resolve both absolute and relative paths against the working
# directory so that traversal sequences like ``../../evil`` are
# caught instead of being silently allowed.
if token.startswith("/"):
resolved = Path(token).resolve()
else:
resolved = (working_directory / token).resolve()

try:
resolved.relative_to(resolved_approved)
except ValueError:
return False, (
f"Directory boundary violation: '{base_command}' targets "
f"'{token}' which is outside approved directory "
f"'{resolved_approved}'"
)

return True, None


class ToolMonitor:
"""Monitor and validate Claude's tool usage."""

def __init__(
self, config: Settings, security_validator: Optional[SecurityValidator] = None
self,
config: Settings,
security_validator: Optional[SecurityValidator] = None,
agentic_mode: bool = False,
):
"""Initialize tool monitor."""
self.config = config
self.security_validator = security_validator
self.agentic_mode = agentic_mode
self.tool_usage: Dict[str, int] = defaultdict(int)
self.security_violations: List[Dict[str, Any]] = []
self.disable_tool_validation = getattr(config, "disable_tool_validation", False)

async def validate_tool_call(
self,
Expand All @@ -45,9 +162,19 @@ async def validate_tool_call(
user_id=user_id,
)

# When disabled, skip only allowlist/disallowlist name checks.
# Keep path and command safety validation active.
if self.disable_tool_validation:
logger.debug(
"Tool name validation disabled; skipping allow/disallow checks",
tool_name=tool_name,
user_id=user_id,
)

# Check if tool is allowed
if (
hasattr(self.config, "claude_allowed_tools")
not self.disable_tool_validation
and hasattr(self.config, "claude_allowed_tools")
and self.config.claude_allowed_tools
):
if tool_name not in self.config.claude_allowed_tools:
Expand All @@ -63,7 +190,8 @@ async def validate_tool_call(

# Check if tool is explicitly disallowed
if (
hasattr(self.config, "claude_disallowed_tools")
not self.disable_tool_validation
and hasattr(self.config, "claude_disallowed_tools")
and self.config.claude_disallowed_tools
):
if tool_name in self.config.claude_disallowed_tools:
Expand Down Expand Up @@ -109,8 +237,9 @@ async def validate_tool_call(
logger.warning("Invalid file path in tool call", **violation)
return False, error

# Validate shell commands
if tool_name in ["bash", "shell", "Bash"]:
# Validate shell commands (skip in agentic mode — Claude Code runs
# inside its own sandbox, and these patterns block normal gh/git usage)
if tool_name in ["bash", "shell", "Bash"] and not self.agentic_mode:
command = tool_input.get("command", "")

# Check for dangerous commands
Expand Down Expand Up @@ -145,6 +274,23 @@ async def validate_tool_call(
logger.warning("Dangerous command detected", **violation)
return False, f"Dangerous command pattern detected: {pattern}"

# Check directory boundary for filesystem-modifying commands
valid, error = check_bash_directory_boundary(
command, working_directory, self.config.approved_directory
)
if not valid:
violation = {
"type": "directory_boundary_violation",
"tool_name": tool_name,
"command": command,
"user_id": user_id,
"working_directory": str(working_directory),
"error": error,
}
self.security_violations.append(violation)
logger.warning("Directory boundary violation", **violation)
return False, error

# Track usage
self.tool_usage[tool_name] += 1

Expand Down
9 changes: 5 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from src.exceptions import ConfigurationError
from src.notifications.service import NotificationService
from src.scheduler.scheduler import JobScheduler
from src.security.audit import AuditLogger, InMemoryAuditStorage
from src.security.audit import AuditLogger, InMemoryAuditStorage, SQLiteAuditStorage
from src.security.auth import (
AuthenticationManager,
InMemoryTokenStorage,
SQLiteTokenStorage,
TokenAuthProvider,
WhitelistAuthProvider,
)
Expand Down Expand Up @@ -113,7 +114,7 @@ async def create_application(config: Settings) -> Dict[str, Any]:

# Add token provider if enabled
if config.enable_token_auth:
token_storage = InMemoryTokenStorage() # TODO: Use database storage
token_storage = SQLiteTokenStorage(storage.db_manager)
providers.append(TokenAuthProvider(config.auth_token_secret, token_storage))

# Fall back to allowing all users in development mode
Expand All @@ -130,8 +131,8 @@ async def create_application(config: Settings) -> Dict[str, Any]:
security_validator = SecurityValidator(config.approved_directory)
rate_limiter = RateLimiter(config)

# Create audit storage and logger
audit_storage = InMemoryAuditStorage() # TODO: Use database storage in production
# Create audit storage and logger (SQLite-backed for persistence across restarts)
audit_storage = SQLiteAuditStorage(storage.db_manager)
audit_logger = AuditLogger(audit_storage)

# Create Claude integration components with persistent storage
Expand Down
Loading