Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 241 additions & 4 deletions src/codegen/cli/commands/agent/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Agent command for creating remote agent runs."""

import json
from pathlib import Path

import requests
import typer
Expand All @@ -13,6 +14,9 @@
from codegen.cli.auth.token_manager import get_current_org_name, get_current_token
from codegen.cli.rich.spinners import create_spinner
from codegen.cli.utils.org import resolve_org_id
from codegen.git.repo_operator.local_git_repo import LocalGitRepo
from codegen.git.repo_operator.repo_operator import RepoOperator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Syntax error: API_ENDPOINT is used but not imported
This will raise a NameError at runtime when pull() accesses API_ENDPOINT.

Suggested change
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.cli.api.endpoints import API_ENDPOINT

from codegen.git.schemas.repo_config import RepoConfig

console = Console()

Expand Down Expand Up @@ -144,28 +148,33 @@ def agent_callback(ctx: typer.Context):
raise typer.Exit()


# For backward compatibility, also allow `codegen agent --prompt "..."` and `codegen agent --id X --json`
# For backward compatibility, also allow `codegen agent --prompt "..."`, `codegen agent --id X --json`, and `codegen agent --id X pull`
def agent(
action: str = typer.Argument(None, help="Action to perform: 'pull' to checkout PR branch"),
prompt: str | None = typer.Option(None, "--prompt", "-p", help="The prompt to send to the agent"),
agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch"),
agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch or pull"),
as_json: bool = typer.Option(False, "--json", help="Output raw JSON response"),
org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"),
model: str | None = typer.Option(None, help="Model to use for this agent run (optional)"),
repo_id: int | None = typer.Option(None, help="Repository ID to use for this agent run (optional)"),
):
"""Create a new agent run with the given prompt, or fetch an existing agent run by ID."""
"""Create a new agent run with the given prompt, fetch an existing agent run by ID, or pull PR branch."""
if prompt:
# If prompt is provided, create the agent run
create(prompt=prompt, org_id=org_id, model=model, repo_id=repo_id)
elif agent_id and action == "pull":
# If agent ID and pull action provided, pull the PR branch
pull(agent_id=agent_id, org_id=org_id)
elif agent_id:
# If agent ID is provided, fetch the agent run
get(agent_id=agent_id, as_json=as_json, org_id=org_id)
else:
# If neither prompt nor agent_id, show help
# If none of the above, show help
console.print("[red]Error:[/red] Either --prompt or --id is required")
console.print("Usage:")
console.print(" [cyan]codegen agent --prompt 'Your prompt here'[/cyan] # Create agent run")
console.print(" [cyan]codegen agent --id 123 --json[/cyan] # Fetch agent run as JSON")
console.print(" [cyan]codegen agent --id 123 pull[/cyan] # Pull PR branch")
raise typer.Exit(1)


Expand Down Expand Up @@ -232,3 +241,231 @@ def get(
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic bug: argument position mismatch for typer.Option in pull
The pull function defines agent_id as --id but positional argument agent_id is also required when calling pull() internally.

Suggested change
except Exception as e:
agent_id: int = typer.Argument(..., help="Agent run ID to pull PR branch for"),

console.print(f"[red]Unexpected error:[/red] {e}")
raise typer.Exit(1)


@agent_app.command()
def pull(
agent_id: int = typer.Option(..., "--id", help="Agent run ID to pull PR branch for"),
org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"),
):
"""Fetch and checkout the PR branch for an agent run."""
token = get_current_token()
if not token:
console.print("[red]Error:[/red] Not authenticated. Please run 'codegen login' first.")
raise typer.Exit(1)

resolved_org_id = resolve_org_id(org_id)
if resolved_org_id is None:
console.print("[red]Error:[/red] Organization ID not provided. Pass --org-id, set CODEGEN_ORG_ID, or REPOSITORY_ORG_ID.")
raise typer.Exit(1)

# Check if we're in a git repository
try:
current_repo = LocalGitRepo(Path.cwd())
if not current_repo.has_remote():
console.print("[red]Error:[/red] Current directory is not a git repository with remotes.")
raise typer.Exit(1)
except Exception:
console.print("[red]Error:[/red] Current directory is not a valid git repository.")
raise typer.Exit(1)

# Fetch agent run data
spinner = create_spinner(f"Fetching agent run {agent_id}...")
spinner.start()

try:
headers = {"Authorization": f"Bearer {token}"}
url = f"{API_ENDPOINT.rstrip('/')}/v1/organizations/{resolved_org_id}/agent/run/{agent_id}"
response = requests.get(url, headers=headers)
response.raise_for_status()
agent_data = response.json()
except requests.HTTPError as e:
org_name = get_current_org_name()
org_display = f"{org_name} ({resolved_org_id})" if org_name else f"organization {resolved_org_id}"

if e.response.status_code == 404:
console.print(f"[red]Error:[/red] Agent run {agent_id} not found in {org_display}.")
elif e.response.status_code == 403:
console.print(f"[red]Error:[/red] Access denied to agent run {agent_id} in {org_display}. Check your permissions.")
else:
console.print(f"[red]Error:[/red] HTTP {e.response.status_code}: {e}")
raise typer.Exit(1)
except requests.RequestException as e:
console.print(f"[red]Error fetching agent run:[/red] {e}")
raise typer.Exit(1)
finally:
spinner.stop()

# Check if agent run has PRs
github_prs = agent_data.get("github_pull_requests", [])
if not github_prs:
console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has no associated pull requests.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security issue: unsanitized external URL usage in GitHub API call
github_api_url is built directly from owner and repo extracted from the PR URL; malicious repo names could lead to SSRF.

Suggested change
console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has no associated pull requests.")
from urllib.parse import quote_plus
owner = quote_plus(owner)
repo = quote_plus(repo)
github_api_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}" # safe URL encoding

raise typer.Exit(1)

if len(github_prs) > 1:
console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has multiple PRs. Using the first one.")

pr = github_prs[0]
pr_url = pr.get("url")
head_branch_name = pr.get("head_branch_name")

if not pr_url:
console.print("[red]Error:[/red] PR URL not found in agent run data.")
raise typer.Exit(1)

if not head_branch_name:
# Try to extract branch name from PR URL as fallback
# GitHub PR URLs often follow patterns like:
# https://github.com/owner/repo/pull/123
# We can use GitHub API to get the branch name
console.print("[yellow]Info:[/yellow] HEAD branch name not in API response, attempting to fetch from GitHub...")
try:
# Extract owner, repo, and PR number from PR URL manually
# Expected format: https://github.com/owner/repo/pull/123
if not pr_url.startswith("https://github.com/"):
msg = f"Only GitHub URLs are supported, got: {pr_url}"
raise ValueError(msg)

# Remove the GitHub base and split the path
path_parts = pr_url.replace("https://github.com/", "").split("/")
if len(path_parts) < 4 or path_parts[2] != "pull":
msg = f"Invalid GitHub PR URL format: {pr_url}"
raise ValueError(msg)

owner = path_parts[0]
repo = path_parts[1]
pr_number = path_parts[3]

# Use GitHub API to get PR details
import requests as github_requests

github_api_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}"

github_response = github_requests.get(github_api_url)
if github_response.status_code == 200:
pr_data = github_response.json()
head_branch_name = pr_data.get("head", {}).get("ref")
if head_branch_name:
console.print(f"[green]✓ Found branch name from GitHub API:[/green] {head_branch_name}")
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic error: Fetching remote before validating branch existence maybe unnecessary; but main issue: resetting existing local branch is not implemented
checkout_remote_branch will fail if local branch exists and diverged; you already warn but don't reset.

Suggested change
else:
if head_branch_name in local_branches:
repo_operator.git_cli.git('reset', '--hard', f'origin/{head_branch_name}')
else:
repo_operator.checkout_remote_branch(head_branch_name)

console.print("[red]Error:[/red] Could not extract branch name from GitHub API response.")
raise typer.Exit(1)
else:
console.print(f"[red]Error:[/red] Failed to fetch PR details from GitHub API (status: {github_response.status_code})")
console.print("[yellow]Tip:[/yellow] The PR may be private or the GitHub API rate limit may be exceeded.")
raise typer.Exit(1)
except Exception as e:
console.print(f"[red]Error:[/red] Could not fetch branch name from GitHub: {e}")
console.print("[yellow]Tip:[/yellow] The backend may need to be updated to include branch information.")
raise typer.Exit(1)

# Parse PR URL to get repository information
try:
# Extract owner and repo from PR URL manually
# Expected format: https://github.com/owner/repo/pull/123
if not pr_url.startswith("https://github.com/"):
msg = f"Only GitHub URLs are supported, got: {pr_url}"
raise ValueError(msg)

# Remove the GitHub base and split the path
path_parts = pr_url.replace("https://github.com/", "").split("/")
if len(path_parts) < 4 or path_parts[2] != "pull":
msg = f"Invalid GitHub PR URL format: {pr_url}"
raise ValueError(msg)

owner = path_parts[0]
repo = path_parts[1]
pr_repo_full_name = f"{owner}/{repo}"
except Exception as e:
console.print(f"[red]Error:[/red] Could not parse PR URL: {pr_url} - {e}")
raise typer.Exit(1)

# Check if current repository matches PR repository
current_repo_full_name = current_repo.full_name
if not current_repo_full_name:
console.print("[red]Error:[/red] Could not determine current repository name.")
raise typer.Exit(1)

if current_repo_full_name.lower() != pr_repo_full_name.lower():
console.print("[red]Error:[/red] Repository mismatch!")
console.print(f" Current repo: [cyan]{current_repo_full_name}[/cyan]")
console.print(f" PR repo: [cyan]{pr_repo_full_name}[/cyan]")
console.print("[yellow]Tip:[/yellow] Navigate to the correct repository directory first.")
raise typer.Exit(1)

# Perform git operations with safety checks
try:
repo_config = RepoConfig.from_repo_path(str(Path.cwd()))
repo_operator = RepoOperator(repo_config)

# Safety check: warn if repository has uncommitted changes
if repo_operator.git_cli.is_dirty():
console.print("[yellow]⚠️ Warning:[/yellow] You have uncommitted changes in your repository.")
console.print("These changes may be lost when switching branches.")

# Get user confirmation
confirm = typer.confirm("Do you want to continue? Your changes may be lost.")
if not confirm:
console.print("[yellow]Operation cancelled.[/yellow]")
raise typer.Exit(0)

console.print("[blue]Proceeding with branch checkout...[/blue]")

console.print(f"[blue]Repository match confirmed:[/blue] {current_repo_full_name}")
console.print(f"[blue]Fetching and checking out branch:[/blue] {head_branch_name}")

# Fetch the branch from remote
fetch_spinner = create_spinner("Fetching latest changes from remote...")
fetch_spinner.start()
try:
fetch_result = repo_operator.fetch_remote("origin")
if fetch_result.name != "SUCCESS":
console.print(f"[yellow]Warning:[/yellow] Fetch result: {fetch_result.name}")
except Exception as e:
console.print(f"[red]Error during fetch:[/red] {e}")
raise
finally:
fetch_spinner.stop()

# Check if the branch already exists locally
local_branches = [b.name for b in repo_operator.git_cli.branches]
if head_branch_name in local_branches:
console.print(f"[yellow]Info:[/yellow] Local branch '{head_branch_name}' already exists. It will be reset to match the remote.")

# Checkout the remote branch
checkout_spinner = create_spinner(f"Checking out branch {head_branch_name}...")
checkout_spinner.start()
try:
checkout_result = repo_operator.checkout_remote_branch(head_branch_name)
if checkout_result.name == "SUCCESS":
console.print(f"[green]✓ Successfully checked out branch:[/green] {head_branch_name}")
elif checkout_result.name == "NOT_FOUND":
console.print(f"[red]Error:[/red] Branch {head_branch_name} not found on remote.")
console.print("[yellow]Tip:[/yellow] The branch may have been deleted or renamed.")
raise typer.Exit(1)
else:
console.print(f"[yellow]Warning:[/yellow] Checkout result: {checkout_result.name}")
except Exception as e:
console.print(f"[red]Error during checkout:[/red] {e}")
raise
finally:
checkout_spinner.stop()

# Display success info
console.print(
Panel(
f"[green]✓ Successfully pulled PR branch![/green]\n\n"
f"[cyan]Agent Run:[/cyan] {agent_id}\n"
f"[cyan]Repository:[/cyan] {current_repo_full_name}\n"
f"[cyan]Branch:[/cyan] {head_branch_name}\n"
f"[cyan]PR URL:[/cyan] {pr_url}",
title="🌿 [bold]Branch Checkout Complete[/bold]",
border_style="green",
box=box.ROUNDED,
padding=(1, 2),
)
)

except Exception as e:
console.print(f"[red]Error during git operations:[/red] {e}")
raise typer.Exit(1)