Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/agentic-server/benches/proxy_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tokio::runtime::Runtime;

use agentic_core::config::Config;
use agentic_core::proxy::ProxyState;
use agentic_server::app::build_router;
use agentic_server::app::{ServerConfig, build_router};

fn bench_config(llm_url: &str) -> Config {
Config {
Expand Down Expand Up @@ -72,7 +72,8 @@ async fn spawn_llm() -> String {

async fn spawn_gateway(config: Config) -> String {
let state = ProxyState::new(config).unwrap();
let router = build_router(state);
let server_config = ServerConfig::from_env();
let router = build_router(state, &server_config);

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
Expand Down
46 changes: 45 additions & 1 deletion crates/agentic-server/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,57 @@
use agentic_core::proxy::ProxyState;
use axum::Router;
use axum::routing::{get, post};
use http::HeaderValue;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};

use crate::handler::{health, proxy_responses, ready};

pub fn build_router(state: ProxyState) -> Router {
/// Server-level configuration read from environment variables.
pub struct ServerConfig {
pub cors_allowed_origins: Vec<String>,
}

impl ServerConfig {
/// Read `CORS_ALLOWED_ORIGINS` (comma-separated). Unset or empty = permissive.
#[must_use]
pub fn from_env() -> Self {
let cors_allowed_origins = std::env::var("CORS_ALLOWED_ORIGINS")
.ok()
.map(|s| {
s.split(',')
.map(str::trim)
.filter(|o| !o.is_empty())
.map(str::to_owned)
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self { cors_allowed_origins }
}

fn cors_layer(&self) -> CorsLayer {
let allow_origin = if self.cors_allowed_origins.is_empty() {
AllowOrigin::any()
} else {
let origins: Vec<HeaderValue> = self
.cors_allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
AllowOrigin::list(origins)
};

CorsLayer::new()
.allow_origin(allow_origin)
.allow_methods(Any)
.allow_headers(Any)
}
}

pub fn build_router(state: ProxyState, server_config: &ServerConfig) -> Router {
Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.route("/v1/responses", post(proxy_responses))
.layer(server_config.cors_layer())
.with_state(state)
}
5 changes: 3 additions & 2 deletions crates/agentic-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use agentic_core::config::Config;
use agentic_core::error::Error;
use agentic_core::proxy::ProxyState;
use agentic_core::readiness::wait_llm_ready;
use agentic_server::app::build_router;
use agentic_server::app::{ServerConfig, build_router};
use tokio::net::TcpListener;
use tracing::info;

async fn serve_gateway(config: Config, host: &str, port: u16) -> Result<(), Error> {
let addr = format!("{host}:{port}");
let state = ProxyState::new(config)?;
let router = build_router(state);
let server_config = ServerConfig::from_env();
let router = build_router(state, &server_config);
let listener = TcpListener::bind(&addr).await?;
info!("gateway listening on {addr}");
axum::serve(listener, router).await?;
Expand Down
80 changes: 80 additions & 0 deletions crates/agentic-server/tests/cors_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use axum::Router;
use axum::response::IntoResponse;
use axum::routing::get;
use http::StatusCode;
use tokio::net::TcpListener;

use agentic_core::config::Config;
use agentic_core::proxy::ProxyState;

fn test_config(llm_url: &str) -> Config {
Config {
llm_api_base: llm_url.to_owned(),
openai_api_key: None,
llm_ready_timeout_s: 5.0,
llm_ready_interval_s: 0.1,
}
}

async fn spawn_mock_llm() -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new().route("/health", get(|| async { StatusCode::OK.into_response() }));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{addr}"), handle)
}

async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) {
let state = ProxyState::new(config).unwrap();
let server_config = agentic_server::app::ServerConfig::from_env();
let router = agentic_server::app::build_router(state, &server_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
axum::serve(listener, router).await.unwrap();
});
(format!("http://{addr}"), handle)
}

#[tokio::test]
async fn test_cors_preflight_returns_200() {
let (llm_url, _h1) = spawn_mock_llm().await;
let config = test_config(&llm_url);
let (gw_url, _h2) = spawn_gateway(config).await;

let client = reqwest::Client::new();
let resp = client
.request(reqwest::Method::OPTIONS, format!("{gw_url}/v1/responses"))
.header("Origin", "http://example.com")
.header("Access-Control-Request-Method", "POST")
.header("Access-Control-Request-Headers", "Content-Type,Authorization")
.send()
.await
.unwrap();

assert_eq!(resp.status(), 200);
assert!(resp.headers().contains_key("access-control-allow-origin"));
assert!(resp.headers().contains_key("access-control-allow-methods"));
assert!(resp.headers().contains_key("access-control-allow-headers"));
}

#[tokio::test]
async fn test_cors_headers_on_regular_request() {
let (llm_url, _h1) = spawn_mock_llm().await;
let config = test_config(&llm_url);
let (gw_url, _h2) = spawn_gateway(config).await;

let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.header("Origin", "http://example.com")
.header("Content-Type", "application/json")
.body(r#"{"model":"test","input":"hi"}"#)
.send()
.await
.unwrap();

assert!(resp.headers().contains_key("access-control-allow-origin"));
}
3 changes: 2 additions & 1 deletion crates/agentic-server/tests/health_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async fn spawn_mock_llm() -> (String, tokio::task::JoinHandle<()>) {

async fn spawn_gateway(config: Config) -> (String, tokio::task::JoinHandle<()>) {
let state = ProxyState::new(config).unwrap();
let router = agentic_server::app::build_router(state);
let server_config = agentic_server::app::ServerConfig::from_env();
let router = agentic_server::app::build_router(state, &server_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
Expand Down