Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.10"

dependencies = [
"cloudpickle>=3.1.1",
Expand Down
55 changes: 9 additions & 46 deletions src/runpod_flash/cli/commands/login.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import datetime as dt
from typing import Optional

import typer
from rich.console import Console
Expand All @@ -15,20 +13,8 @@

console = Console()

POLL_INTERVAL_SECONDS = 2.0
DEFAULT_TIMEOUT_SECONDS = 600.0


def _parse_expires_at(value: Optional[str]) -> Optional[dt.datetime]:
if not value:
return None
try:
return dt.datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None


async def _login(open_browser: bool, timeout_seconds: float) -> None:
async def _login(open_browser: bool) -> None:
async with RunpodGraphQLClient(require_api_key=False) as client:
request = await client.create_flash_auth_request()
request_id = request.get("id")
Expand All @@ -45,46 +31,23 @@ async def _login(open_browser: bool, timeout_seconds: float) -> None:
if open_browser:
typer.launch(auth_url)

expires_at = _parse_expires_at(request.get("expiresAt"))
deadline = dt.datetime.now(dt.timezone.utc) + dt.timedelta(
seconds=timeout_seconds
)
if expires_at and expires_at < deadline:
deadline = expires_at

with console.status("[dim]Waiting for authorization...[/dim]"):
while True:
status_payload = await client.get_flash_auth_request_status(request_id)
status = status_payload.get("status")
api_key = status_payload.get("apiKey")

if api_key and status in {"APPROVED", "CONSUMED"}:
check_and_migrate_legacy_credentials()
path = save_api_key(api_key)
console.print(
f"[green]Logged in.[/green] Credentials saved to [dim]{path}[/dim]"
)
console.print()
return

if status in {"DENIED", "EXPIRED", "CONSUMED"}:
raise RuntimeError(f"login failed: {status.lower()}")
api_key = console.input("Paste the API key shown after authorization: ").strip()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This echoes the key to the terminal — it ends up in scrollback, screen recordings, and shoulder-surfing range. For a security-focused PR, prefer rich.prompt.Prompt.ask("Paste the API key", password=True) or getpass.getpass. The user already has the key visible in the browser; they don't need it echoed again here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A wrong/empty paste currently forces the user to rerun flash login, which means re-approving in the browser. A small retry loop (e.g. 3 attempts) would be friendlier without weakening security.


if dt.datetime.now(dt.timezone.utc) >= deadline:
raise RuntimeError("login timed out")
if not api_key:
raise RuntimeError("no api key provided")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now any non-empty string gets written to the credentials file. Worth at minimum a prefix check (e.g. rpa_) so a fat-fingered paste fails loudly. Even better: fire a cheap authenticated call (e.g. myself) to confirm the key works before persisting, so users don't discover the bad paste on the next command.


await asyncio.sleep(POLL_INTERVAL_SECONDS)
check_and_migrate_legacy_credentials()
path = save_api_key(api_key)
console.print(f"[green]Logged in.[/green] Credentials saved to [dim]{path}[/dim]")
console.print()


def login_command(
no_open: bool = typer.Option(False, "--no-open", help="do not open the browser"),
timeout: float = typer.Option(
DEFAULT_TIMEOUT_SECONDS, "--timeout", help="max wait time in seconds"
),
):
"""Authenticate and save a Runpod API key for flash."""
try:
asyncio.run(_login(open_browser=not no_open, timeout_seconds=timeout))
asyncio.run(_login(open_browser=not no_open))
except RuntimeError as exc:
print_error(console, str(exc))
raise typer.Exit(code=1)
14 changes: 0 additions & 14 deletions src/runpod_flash/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,20 +991,6 @@ async def create_flash_auth_request(self) -> Dict[str, Any]:
result = await self._execute_graphql(mutation)
return result.get("createFlashAuthRequest", {})

async def get_flash_auth_request_status(self, request_id: str) -> Dict[str, Any]:
query = """
query flashAuthRequestStatus($flashAuthRequestId: String!) {
flashAuthRequestStatus(flashAuthRequestId: $flashAuthRequestId) {
id
status
expiresAt
apiKey
}
}
"""
result = await self._execute_graphql(query, {"flashAuthRequestId": request_id})
return result.get("flashAuthRequestStatus", {})

async def close(self):
"""Close the HTTP session."""
if self.session and not self.session.closed:
Expand Down
62 changes: 62 additions & 0 deletions src/runpod_flash/core/resources/request_logs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import re
from dataclasses import dataclass
Expand Down Expand Up @@ -34,6 +35,67 @@ class QBRequestLogBatch:
ready_worker_ids: List[str] = field(default_factory=list)


@dataclass
class SSEEvent:
id: str
data: dict[str, Any]


@dataclass
class LogEvent:
source: str
line: str
ts: str


def parse_sse_event(data: str) -> Optional[SSEEvent]:
"""
Parses an SSE line into a dictionary
"""
if not data:
return None

try:
event_id_line, data_line = filter(bool, data.split("\n"))
event_id = event_id_line.split(":", 1)[1].strip()
data_json = data_line.split(":", 1)[1].strip()
data = json.loads(data_json)
return SSEEvent(id=event_id, data=data)
except Exception as e:
log.error("Failed to parse SSE event: %s", e)
return None
Comment on lines +51 to +66


def parse_log_event(data: dict[str, Any]) -> Optional[LogEvent]:
"""
Parses a log event from a dictionary
"""
try:
return LogEvent(source=data["source"], line=data["line"], ts=data["ts"])
except Exception as e:
log.error("Failed to parse log event: %s", e)
return None


async def stream_pod_logs(pod_id: str, tail: int = 0):
"""
Streams logs from pod using SSE
"""
Comment on lines +38 to +83
if tail < 0:
raise ValueError("tail must be greater than 0")

url = f"{RUNPOD_HAPI_URL}/v1/pod/{pod_id}/logs?stream=true&tail={tail}"

async with get_authenticated_httpx_client() as client:
async with client.get(url) as response:
async for line in response.aiter_lines():
event = parse_sse_event(line)
if event:
log_event = parse_log_event(event.data)
if log_event:
yield log_event
Comment on lines +89 to +96

Comment on lines +38 to +97

class QBRequestLogFetcher:
def __init__(
self,
Expand Down
74 changes: 25 additions & 49 deletions tests/unit/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,6 @@

import pytest

from runpod_flash.cli.commands.login import _parse_expires_at


class TestParseExpiresAt:
def test_iso_format(self):
result = _parse_expires_at("2026-03-01T12:00:00Z")
assert result is not None
assert result.year == 2026

def test_none_input(self):
assert _parse_expires_at(None) is None

def test_empty_string(self):
assert _parse_expires_at("") is None

def test_invalid_string(self):
assert _parse_expires_at("not-a-date") is None


class TestGraphQLClientNoKeyForLogin:
"""Login mutations must not send stored credentials."""
Expand Down Expand Up @@ -61,16 +43,13 @@ def test_require_api_key_true_loads_key(self):
assert client.api_key == "loaded-key"


def _make_mock_client(**status_return):
def _make_mock_client():
"""Build an AsyncMock that works as an async context manager."""
client = AsyncMock()
client.create_flash_auth_request.return_value = {
"id": "req-123",
"expiresAt": None,
}
client.get_flash_auth_request_status.return_value = status_return
# _login uses `async with RunpodGraphQLClient(...) as client:`,
# so __aenter__ must return the same mock instance
client.__aenter__.return_value = client
return client

Expand All @@ -85,42 +64,39 @@ def _get_login_fn():


class TestLoginFlow:
async def test_login_denied(self):
mock_client = _make_mock_client(status="DENIED", apiKey=None)
async def test_login_saves_pasted_key(self, isolate_credentials_file):
mock_client = _make_mock_client()
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
with (
patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
),
patch("runpod_flash.cli.commands.login.console") as mock_console,
):
with pytest.raises(RuntimeError, match="login failed: denied"):
await _login(open_browser=False, timeout_seconds=5)

async def test_login_approved_saves_key(self, isolate_credentials_file):
mock_client = _make_mock_client(status="APPROVED", apiKey="fresh-api-key")
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
):
await _login(open_browser=False, timeout_seconds=5)
mock_console.input.return_value = "pasted-api-key"
await _login(open_browser=False)
assert isolate_credentials_file.exists()
assert "fresh-api-key" in isolate_credentials_file.read_text()
assert "pasted-api-key" in isolate_credentials_file.read_text()

async def test_login_expired(self):
mock_client = _make_mock_client(status="EXPIRED", apiKey=None)
async def test_login_empty_key_raises(self):
mock_client = _make_mock_client()
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
with (
patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
),
patch("runpod_flash.cli.commands.login.console") as mock_console,
):
with pytest.raises(RuntimeError, match="login failed: expired"):
await _login(open_browser=False, timeout_seconds=5)
mock_console.input.return_value = " "
with pytest.raises(RuntimeError, match="no api key provided"):
await _login(open_browser=False)

async def test_no_request_id_raises(self):
mock_client = _make_mock_client(status="APPROVED", apiKey="key")
mock_client = _make_mock_client()
mock_client.create_flash_auth_request.return_value = {}
_login = _get_login_fn()

Expand All @@ -129,4 +105,4 @@ async def test_no_request_id_raises(self):
return_value=mock_client,
):
with pytest.raises(RuntimeError, match="auth request failed"):
await _login(open_browser=False, timeout_seconds=5)
await _login(open_browser=False)
Loading
Loading