Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/src/braintrust/devserver/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"x-bt-project-id",
"x-bt-stream-fmt",
"x-bt-use-cache",
"x-bt-use-gateway",
"x-stainless-os",
"x-stainless-lang",
"x-stainless-package-version",
Expand Down
6 changes: 6 additions & 0 deletions py/src/braintrust/devserver/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ParsedEvalBody(TypedDict, total=False):
scores: list[ParsedFunctionId]
experiment_name: str
project_id: str
on_complete_webhook: str
parent: str | ParsedParent
stream: bool

Expand Down Expand Up @@ -244,6 +245,11 @@ def parse_eval_body(request_data: str | bytes | dict) -> ParsedEvalBody:
raise ValidationError("project_id must be a string")
parsed["project_id"] = data["project_id"]

if "on_complete_webhook" in data:
if not isinstance(data["on_complete_webhook"], str):
raise ValidationError("on_complete_webhook must be a string")
parsed["on_complete_webhook"] = data["on_complete_webhook"]

if "parent" in data:
parent = data["parent"]
# InvokeParent can be a string or a complex object
Expand Down
93 changes: 93 additions & 0 deletions py/src/braintrust/devserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import json
import sys
import textwrap
import urllib.error
import urllib.request
from datetime import datetime, timezone
from typing import Any


Expand Down Expand Up @@ -52,6 +55,80 @@

_all_evaluators: dict[str, Evaluator[Any, Any]] = {}

WEBHOOK_ATTEMPTS = 3
WEBHOOK_BACKOFF_SECONDS = (1.0, 2.0, 4.0)
WEBHOOK_TIMEOUT_SECONDS = 10.0


def _pick_string(data: dict[str, Any], keys: list[str]) -> str | None:
for key in keys:
value = data.get(key)
if isinstance(value, str) and value:
return value
return None


def build_completion_webhook_payload(summary: dict[str, Any]) -> dict[str, Any]:
return {
"event": "experiment.completed",
"summary": summary,
"experiment": {
"projectId": _pick_string(summary, ["projectId", "project_id"]),
"projectName": _pick_string(summary, ["projectName", "project_name"]),
"projectUrl": _pick_string(summary, ["projectUrl", "project_url"]),
"experimentId": _pick_string(summary, ["experimentId", "experiment_id"]),
"experimentName": _pick_string(summary, ["experimentName", "experiment_name"]),
"experimentUrl": _pick_string(summary, ["experimentUrl", "experiment_url"]),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}


def _post_completion_webhook_request(webhook_url: str, body: dict[str, Any], timeout: float) -> None:
payload = json.dumps(body).encode("utf-8")
request = urllib.request.Request(
webhook_url,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout) as response:
status = response.getcode()
if status < 200 or status >= 300:
raise RuntimeError(f"Webhook request failed with status {status}")


async def _send_completion_webhook_request(webhook_url: str, body: dict[str, Any], timeout: float) -> None:
await asyncio.to_thread(_post_completion_webhook_request, webhook_url, body, timeout)


async def dispatch_completion_webhook(
webhook_url: str,
summary: dict[str, Any],
*,
attempts: int = WEBHOOK_ATTEMPTS,
backoff_seconds: tuple[float, ...] = WEBHOOK_BACKOFF_SECONDS,
timeout_seconds: float = WEBHOOK_TIMEOUT_SECONDS,
) -> None:
payload = build_completion_webhook_payload(summary)
last_error: Exception | None = None
for attempt in range(1, attempts + 1):
try:
await _send_completion_webhook_request(
webhook_url,
payload,
timeout_seconds,
)
return
except Exception as e:
last_error = e
if attempt < attempts:
backoff = backoff_seconds[min(attempt - 1, len(backoff_seconds) - 1)]
await asyncio.sleep(backoff)

if last_error:
raise last_error


class _ParameterOverrideHooks:
def __init__(self, hooks: EvalHooks[Any], parameters: ValidatedParameters):
Expand Down Expand Up @@ -177,6 +254,7 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:

# Check if streaming is requested
stream = eval_data.get("stream", False)
on_complete_webhook = eval_data.get("on_complete_webhook")

# Set up SSE headers for streaming
sse_queue = SSEQueue()
Expand Down Expand Up @@ -210,6 +288,20 @@ def stream_fn(event: SSEProgressEvent):
# Use create_task to schedule the async write without blocking
asyncio.create_task(sse_queue.put_event("progress", event))

async def on_complete_fn(summary: ExperimentSummary):
if not on_complete_webhook:
return
try:
await dispatch_completion_webhook(
on_complete_webhook,
format_summary(summary),
)
except Exception as e:
print(
f"Failed to deliver completion webhook to {on_complete_webhook}: {e}",
file=sys.stderr,
)

parent = eval_data.get("parent")
if parent:
parent = parse_parent(parent)
Expand All @@ -234,6 +326,7 @@ def stream_fn(event: SSEProgressEvent):
],
"stream": stream_fn,
"on_start": on_start_fn,
"on_complete": on_complete_fn,
"data": dataset,
"task": task,
"experiment_name": eval_data.get("experiment_name"),
Expand Down
145 changes: 145 additions & 0 deletions py/src/braintrust/devserver/test_completion_webhook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import json

import pytest
from braintrust.test_helpers import has_devserver_installed


def _parse_sse_events(response_text: str) -> list[dict[str, object]]:
events = []
lines = response_text.strip().split("\n")
i = 0
while i < len(lines):
if lines[i].startswith("event: "):
event_type = lines[i][7:].strip()
i += 1
if i < len(lines) and lines[i].startswith("data: "):
raw_data = lines[i][6:].strip()
try:
data = json.loads(raw_data) if raw_data else None
except json.JSONDecodeError:
data = raw_data
events.append({"event": event_type, "data": data})
i += 1
else:
events.append({"event": event_type, "data": None})
else:
i += 1
return events


def test_dispatch_completion_webhook_retries(monkeypatch):
from braintrust.devserver import server as devserver_module

attempts = []
sleep_calls = []

async def fake_send(webhook_url, body, timeout):
attempts.append((webhook_url, body, timeout))
if len(attempts) < 3:
raise RuntimeError("transient")

async def fake_sleep(seconds):
sleep_calls.append(seconds)

monkeypatch.setattr(devserver_module, "_send_completion_webhook_request", fake_send)
monkeypatch.setattr(devserver_module.asyncio, "sleep", fake_sleep)

asyncio.run(
devserver_module.dispatch_completion_webhook(
"https://example.com/webhook",
{"projectName": "my-project", "experimentName": "my-exp"},
attempts=3,
backoff_seconds=(1.0, 2.0, 4.0),
timeout_seconds=10.0,
)
)

assert len(attempts) == 3
assert sleep_calls == [1.0, 2.0]


def test_parse_eval_body_accepts_on_complete_webhook():
from braintrust.devserver.schemas import parse_eval_body

parsed = parse_eval_body(
{
"name": "my-eval",
"on_complete_webhook": "https://example.com/webhook",
}
)

assert parsed["on_complete_webhook"] == "https://example.com/webhook"


@pytest.mark.skipif(not has_devserver_installed(), reason="Devserver dependencies not installed (requires .[cli])")
def test_eval_webhook_failure_non_fatal_for_stream(monkeypatch):
from braintrust import Evaluator
from braintrust.devserver import server as devserver_module
from braintrust.devserver.server import create_app
from braintrust.logger import BraintrustState
from starlette.testclient import TestClient

evaluator = Evaluator(
project_name="test-project",
eval_name="test-eval",
data=lambda: [{"input": "x", "expected": "x"}],
task=lambda input_value, _hooks: input_value,
scores=[],
experiment_name=None,
metadata=None,
)

async def fake_cached_login(**_kwargs):
return BraintrustState()

class FakeSummary:
def as_dict(self):
return {
"project_name": "test-project",
"experiment_name": "test-eval",
"scores": {},
}

class FakeResult:
summary = FakeSummary()

dispatch_calls = []

async def fake_dispatch(webhook_url, summary, **_kwargs):
dispatch_calls.append((webhook_url, summary))
raise RuntimeError("webhook delivery failed")

async def fake_eval_async(*, on_complete, **_kwargs):
await on_complete(FakeSummary())
return FakeResult()

monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login)
monkeypatch.setattr(devserver_module, "dispatch_completion_webhook", fake_dispatch)
monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async)

response = TestClient(create_app([evaluator])).post(
"/eval",
headers={
"x-bt-auth-token": "test-api-key",
"x-bt-org-name": "test-org",
"Content-Type": "application/json",
"Accept": "text/event-stream",
},
json={
"name": "test-eval",
"stream": True,
"on_complete_webhook": "https://example.com/webhook",
"data": [{"input": "x", "expected": "x"}],
},
)

assert response.status_code == 200
events = _parse_sse_events(response.text)
event_types = [e["event"] for e in events]

assert "summary" in event_types
assert "done" in event_types
assert len(dispatch_calls) == 1
assert dispatch_calls[0][0] == "https://example.com/webhook"
assert dispatch_calls[0][1]["experimentName"] == "test-eval"
22 changes: 22 additions & 0 deletions py/src/braintrust/devserver/test_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ def test_devserver_health_check(client):
assert response.text == "Hello, world!"


def test_cors_preflight_allows_gateway_header(client):
"""Test that CORS preflight accepts x-bt-use-gateway header.

The Braintrust Playground sends this header when gateway routing is
enabled. If it is missing from the devserver's allowed-headers list
the browser blocks the actual request with a CORS error.
"""
response = client.options(
"/eval",
headers={
"origin": "https://www.braintrust.dev",
"access-control-request-method": "POST",
"access-control-request-headers": "x-bt-use-gateway",
},
)
assert response.status_code == 200
allowed = response.headers.get("access-control-allow-headers", "")
assert "x-bt-use-gateway" in allowed, (
f"x-bt-use-gateway not found in access-control-allow-headers: {allowed}"
)


@pytest.mark.vcr
def test_devserver_list_evaluators(client, api_key, org_name):
"""Test listing evaluators endpoint."""
Expand Down
Loading