Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions config/tool_use_with_ask_user.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Config with ask_user tool for handling underspecified tasks
# Based on tool_use.yaml

agent:
templates:
system_template: |-
You are a helpful assistant that can interact with a computer to solve tasks.
instance_template: |-
<uploaded_files>
{{working_dir}}
</uploaded_files>
I've uploaded a python code repository in the directory {{working_dir}}. Consider the following PR description:

<pr_description>
{{problem_statement}}
</pr_description>

Can you help me implement the necessary changes to the repository so that the requirements specified in the <pr_description> are met?
I've already taken care of all changes to any of the test files described in the <pr_description>. This means you DON'T have to modify the testing logic or any of the tests in any way!
Your task is to make the minimal changes to non-tests files in the {{working_dir}} directory to ensure the <pr_description> is satisfied.
Follow these steps to resolve the issue:
1. As a first step, it might be a good idea to find and read code relevant to the <pr_description>
2. Create a script to reproduce the error and execute it with `python <filename.py>` using the bash tool, to confirm the error
3. Edit the source code of the repo to resolve the issue
4. Rerun your reproduce script and confirm that the error is fixed!
5. Think about edgecases and make sure your fix handles them as well

IMPORTANT: Your output will be checked by an auto-grader looking for exact answers.
This task may be missing critical information.
Use the ask_user tool to ask me for any missing details.

Your thinking should be thorough and so it's fine if it's very long.
next_step_template: |-
OBSERVATION:
{{observation}}
next_step_no_output_template: |-
Your command ran successfully and did not produce any output.
tools:
execution_timeout: 450
bundles:
- path: tools/registry
- path: tools/edit_anthropic
- path: tools/submit
- path: tools/ask_user
env_variables:
USE_FILEMAP: 'true'
enable_bash_tool: true
parse_function:
type: function_calling
184 changes: 184 additions & 0 deletions sweagent/agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import copy
import json
import logging
import os
import re
import shlex
import time
from pathlib import Path, PurePosixPath
from typing import Annotated, Any, Literal

import litellm
import yaml
from jinja2 import Template
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand Down Expand Up @@ -57,6 +61,164 @@
from sweagent.utils.patch_formatter import PatchFormatter


# Global task definitions cache for ask_user interception
_task_definitions_cache: dict | None = None
_task_definitions_path: str | None = None


def _load_task_definitions(output_dir: Path | None) -> dict | None:
"""Load task definitions from output directory if available."""
global _task_definitions_cache, _task_definitions_path

if output_dir is None:
return None

task_def_file = output_dir / "task_definitions.json"
task_def_str = str(task_def_file)

# Return cached if same file
if _task_definitions_path == task_def_str and _task_definitions_cache is not None:
return _task_definitions_cache

if task_def_file.exists():
try:
with open(task_def_file, "r") as f:
_task_definitions_cache = json.load(f)
_task_definitions_path = task_def_str
return _task_definitions_cache
except Exception:
pass
return None


def _handle_ask_user_on_host(question: str, context: str, instance_id: str, output_dir: Path | None, logger) -> str:
"""Handle ask_user command on the host side using litellm.

This function intercepts ask_user calls to run the LLM call on the host,
which can reach internal API endpoints that the container cannot access.
"""
task_defs = _load_task_definitions(output_dir)

if task_defs is None or instance_id not in task_defs:
return f"Error: No task definition found for instance {instance_id}"

task_def = task_defs[instance_id]
primary_task = task_def.get("primary_task", "")
underspecified_prompt = task_def.get("underspecified_task", "")

# Extract removed values
removed_values = []
if "removed_segments" in task_def:
for seg in task_def["removed_segments"]:
if isinstance(seg, dict) and seg.get("value"):
removed_values.append(seg["value"])

removed_values_str = ", ".join(removed_values) if removed_values else "None specified"
underspec_str = underspecified_prompt or "Not provided"

system_prompt = f"""You are simulating a user who has a task in mind but didn't fully specify it.

The user originally intended to give this COMPLETE prompt:
{primary_task}

But they actually gave this UNDERSPECIFIED version:
{underspec_str}

The parts that were removed/made vague:
{removed_values_str}

An AI assistant (who only sees the underspecified version) is now asking you a clarifying question.

Your job: Compare the two prompts, find what's MISSING from the underspecified version, and provide the EXACT information from the complete prompt.

Guidelines:
- Find the EXACT values that are in the complete prompt but missing from the underspecified one
- Provide those specific values (times, names, dates, numbers, phrases, etc.)
- Be concise - just answer what's asked
- Don't reveal you're a simulation

ENVIRONMENT CONTEXT:
- The agent is working in a repository at /workspace (or the working directory specified in the prompt)
- The agent has full access to the repository files
- The agent can read, write, and execute files in the repository
- When providing file paths, use paths relative to the repository root
"""

user_prompt = f"The assistant asks: {question}"
if context:
user_prompt += f"\n\nContext: {context}"

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]

# Get API credentials from environment
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
api_base = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL")
model = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-5.2")

if not api_key:
return "Error: No API key available for user simulation"

try:
logger.info(f"ask_user intercepted on host: question='{question[:100]}...'")
# Drop unsupported params for models like GPT-5 that don't support temperature
litellm.drop_params = True
response = litellm.completion(
model=model,
messages=messages,
temperature=1, # GPT-5 only supports temperature=1
api_key=api_key,
api_base=api_base,
timeout=30,
)
result = response.choices[0].message.content
logger.info(f"ask_user response generated: '{result[:100]}...'")
return result
except Exception as e:
logger.error(f"ask_user LLM call failed: {e}")
return f"Error generating user response: {str(e)}"


def _parse_ask_user_command(command: str) -> tuple[str, str] | None:
"""Parse ask_user command to extract question and optional context.

Handles formats like:
- ask_user "question"
- ask_user "question" "context"
- ask_user 'question'
- ask_user question_without_quotes

Returns (question, context) tuple or None if not an ask_user command.
"""
command = command.strip()
if not command.startswith("ask_user"):
return None

# Remove the "ask_user" prefix
args_str = command[8:].strip()
if not args_str:
return None

try:
# Use shlex to properly parse quoted arguments
args = shlex.split(args_str)
if len(args) >= 1:
question = args[0]
context = args[1] if len(args) > 1 else ""
return (question, context)
except ValueError:
# Fallback: try simple quote extraction
match = re.match(r'["\'](.+?)["\'](?:\s+["\'](.+?)["\'])?', args_str)
if match:
return (match.group(1), match.group(2) or "")
# Last resort: treat entire string as question
return (args_str, "")

return None


class TemplateConfig(BaseModel):
"""This configuration is used to define almost all message templates that are
formatted by the agent and sent to the LM.
Expand Down Expand Up @@ -943,6 +1105,28 @@ def handle_action(self, step: StepOutput) -> StepOutput:
self._chook.on_action_started(step=step)
execution_t0 = time.perf_counter()
run_action: str = self.tools.guard_multiline_input(step.action).strip()

# Intercept ask_user commands and handle on HOST side
# This is needed because Modal containers cannot reach internal API endpoints
ask_user_args = _parse_ask_user_command(run_action)
if ask_user_args is not None:
question, context = ask_user_args
instance_id = self._problem_statement.id if self._problem_statement else "unknown"
# output_dir is parent of traj_path's parent (traj_path = output_dir/instance_id/instance_id.traj)
output_dir = self.traj_path.parent.parent if self.traj_path else None
step.observation = _handle_ask_user_on_host(
question=question,
context=context,
instance_id=instance_id,
output_dir=output_dir,
logger=self.logger,
)
step.execution_time = time.perf_counter() - execution_t0
self._total_execution_time += step.execution_time
self._chook.on_action_executed(step=step)
step.state = self.tools.get_state(env=self._env)
return self.handle_submission(step)

try:
step.observation = self._env.communicate(
input=run_action,
Expand Down
123 changes: 123 additions & 0 deletions sweagent/run/hooks/task_definition_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
RunHook for injecting complete task definitions into containers for ask_user tool.

This hook writes the full task definition (including underspecified version and removed
segments) to a file in the container so the ask_user tool can access it.
"""

import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional

from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig
from sweagent.environment.swe_env import SWEEnv
from sweagent.run.hooks.abstract import RunHook

logger = logging.getLogger(__name__)


class TaskDefinitionInjectionHook(RunHook):
"""
Inject complete task definitions into container for ask_user tool.

Writes task definition to /tmp/task_definition.json in the container, which
the ask_user tool reads to provide accurate clarifications.
"""

def __init__(
self,
task_definitions: Optional[Dict[str, Dict[str, Any]]] = None,
task_definitions_file: Optional[Path] = None,
):
"""
Initialize hook with task definitions.

Args:
task_definitions: Dict mapping instance_id to task definition, OR
task_definitions_file: Path to JSON file with task definitions

Task definition should contain:
- primary_task: Complete task description
- underspecified_task: Partial task given to agent
- removed_segments: List of removed segments
- expected_questions: Expected clarification questions
"""
super().__init__()
self.task_definitions = task_definitions
self.task_definitions_file = task_definitions_file

def _load_task_definitions(self) -> Dict[str, Dict[str, Any]]:
"""Load task definitions from file or return cached dict."""
if self.task_definitions is not None:
return self.task_definitions

if self.task_definitions_file and self.task_definitions_file.exists():
with open(self.task_definitions_file) as f:
return json.load(f)

return {}

def on_instance_start(
self,
*,
index: int,
env: SWEEnv,
problem_statement: ProblemStatement | ProblemStatementConfig,
) -> None:
"""
Inject task definition into container before agent starts.

Called after environment is ready but before agent.setup().
"""
instance_id = problem_statement.id

# Load task definitions
task_definitions = self._load_task_definitions()

# Check if we have a task definition for this instance
if instance_id not in task_definitions:
logger.debug(f"No task definition found for instance {instance_id}, skipping injection")
return

task_def = task_definitions[instance_id]

# Write task definition to container
task_def_path = "/tmp/task_definition.json"
task_def_json = json.dumps(task_def, indent=2)

try:
# Write file to container using swerex
logger.info(f"Injecting task definition for {instance_id} to {task_def_path}")

# Create a temporary file write command
command = f"cat > {task_def_path} << 'TASK_DEFINITION_EOF'\n{task_def_json}\nTASK_DEFINITION_EOF"
env.communicate(command, check="raise")

# Set environment variables for the ask_user tool
env_vars = {
"TASK_DEFINITION_PATH": task_def_path,
"HAS_TASK_DEFINITION": "true",
}

# Pass through API credentials for the user simulator LLM
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
base_url = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL")
simulator_model = os.environ.get("USER_SIMULATOR_MODEL")

if api_key:
env_vars["OPENAI_API_KEY"] = api_key
env_vars["LLM_API_KEY"] = api_key
if base_url:
env_vars["OPENAI_BASE_URL"] = base_url
env_vars["LLM_BASE_URL"] = base_url
if simulator_model:
env_vars["USER_SIMULATOR_MODEL"] = simulator_model

env.set_env_variables(env_vars)

logger.info(f"Successfully injected task definition for {instance_id}")
except Exception as e:
logger.error(f"Failed to inject task definition for {instance_id}: {e}")
# Don't raise - let the run continue even if injection fails
Loading