Skip to content

Commit bd2babf

Browse files
committed
fix(app-server): attach API middleware stack
1 parent 7954d02 commit bd2babf

7 files changed

Lines changed: 169 additions & 8 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cortex-app-server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ gethostname = "0.5"
7373

7474
[dev-dependencies]
7575
tokio-test = { workspace = true }
76+
tower = { version = "0.5", default-features = false, features = ["util"] }

src/cortex-app-server/src/api/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub use types::{
3838
/// Create the API routes.
3939
pub fn routes() -> Router<Arc<AppState>> {
4040
Router::new()
41+
.without_v07_checks()
4142
// Health and metrics
4243
.route("/health", get(health::health_check))
4344
.route("/metrics", get(health::get_metrics))

src/cortex-app-server/src/lib.rs

Lines changed: 163 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ pub mod websocket;
3636
use std::net::SocketAddr;
3737
use std::sync::Arc;
3838

39-
use axum::Router;
39+
use axum::{Router, middleware as axum_middleware};
4040
use tokio::net::TcpListener;
41-
use tower_http::cors::CorsLayer;
4241
use tower_http::trace::TraceLayer;
4342
use tracing::{info, warn};
4443

@@ -131,15 +130,171 @@ pub fn create_router(state: AppState) -> Router {
131130
/// This variant is useful when you need to keep a reference to the state
132131
/// for cleanup purposes (e.g., during graceful shutdown).
133132
pub fn create_router_with_state(state: Arc<AppState>) -> Router {
134-
let api_routes = api::routes()
135-
.merge(websocket::routes())
136-
.merge(streaming::routes())
137-
.merge(share::routes())
138-
.merge(admin::routes());
133+
let cors_layer = middleware::cors_layer(&state.config.cors_origins);
134+
135+
let api_routes = add_api_middleware(
136+
api::routes()
137+
.merge(websocket::routes())
138+
.merge(streaming::routes())
139+
.merge(share::routes())
140+
.merge(admin::routes()),
141+
Arc::clone(&state),
142+
);
139143

140144
Router::new()
145+
.without_v07_checks()
141146
.nest("/api/v1", api_routes)
142147
.layer(TraceLayer::new_for_http())
143-
.layer(CorsLayer::permissive())
148+
.layer(cors_layer)
144149
.with_state(state)
145150
}
151+
152+
fn add_api_middleware(
153+
router: Router<Arc<AppState>>,
154+
state: Arc<AppState>,
155+
) -> Router<Arc<AppState>> {
156+
router
157+
.layer(axum_middleware::from_fn_with_state(
158+
Arc::clone(&state),
159+
middleware::rate_limit_middleware,
160+
))
161+
.layer(axum_middleware::from_fn(
162+
middleware::content_type_middleware,
163+
))
164+
.layer(axum_middleware::from_fn_with_state(
165+
state,
166+
middleware::timeout_middleware,
167+
))
168+
.layer(axum_middleware::from_fn(
169+
middleware::security_headers_middleware,
170+
))
171+
}
172+
173+
#[cfg(test)]
174+
mod tests {
175+
use super::*;
176+
use axum::{
177+
body::Body,
178+
http::{Request, StatusCode, header},
179+
routing::get,
180+
};
181+
use tower::ServiceExt;
182+
183+
async fn test_app(config: ServerConfig) -> Router {
184+
let state = AppState::new(config).await.unwrap();
185+
create_router(state)
186+
}
187+
188+
async fn slow_test_handler() -> &'static str {
189+
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
190+
"done"
191+
}
192+
193+
#[tokio::test]
194+
async fn create_router_applies_security_headers_middleware() {
195+
let app = test_app(ServerConfig::default()).await;
196+
197+
let response = app
198+
.oneshot(
199+
Request::builder()
200+
.uri("/api/v1/models")
201+
.body(Body::empty())
202+
.unwrap(),
203+
)
204+
.await
205+
.unwrap();
206+
207+
assert_eq!(response.status(), StatusCode::OK);
208+
assert_eq!(
209+
response.headers().get("X-Content-Type-Options").unwrap(),
210+
"nosniff"
211+
);
212+
assert_eq!(response.headers().get("X-Frame-Options").unwrap(), "DENY");
213+
}
214+
215+
#[tokio::test]
216+
async fn create_router_applies_rate_limit_middleware() {
217+
let mut config = ServerConfig::default();
218+
config.rate_limit.requests_per_minute = 1;
219+
config.rate_limit.burst_size = 1;
220+
config.rate_limit.exempt_paths.clear();
221+
222+
let app = test_app(config).await;
223+
224+
let first = app
225+
.clone()
226+
.oneshot(
227+
Request::builder()
228+
.uri("/api/v1/models")
229+
.body(Body::empty())
230+
.unwrap(),
231+
)
232+
.await
233+
.unwrap();
234+
assert_eq!(first.status(), StatusCode::OK);
235+
236+
let second = app
237+
.oneshot(
238+
Request::builder()
239+
.uri("/api/v1/models")
240+
.body(Body::empty())
241+
.unwrap(),
242+
)
243+
.await
244+
.unwrap();
245+
assert_eq!(second.status(), StatusCode::TOO_MANY_REQUESTS);
246+
assert_eq!(second.headers().get(header::RETRY_AFTER).unwrap(), "60");
247+
assert_eq!(
248+
second.headers().get("X-Content-Type-Options").unwrap(),
249+
"nosniff"
250+
);
251+
}
252+
253+
#[tokio::test]
254+
async fn create_router_applies_content_type_middleware() {
255+
let app = test_app(ServerConfig::default()).await;
256+
257+
let response = app
258+
.oneshot(
259+
Request::builder()
260+
.method("POST")
261+
.uri("/api/v1/models")
262+
.header(header::CONTENT_TYPE, "text/plain")
263+
.body(Body::from("not json"))
264+
.unwrap(),
265+
)
266+
.await
267+
.unwrap();
268+
269+
assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
270+
assert_eq!(
271+
response.headers().get("X-Content-Type-Options").unwrap(),
272+
"nosniff"
273+
);
274+
}
275+
276+
#[tokio::test]
277+
async fn create_router_applies_timeout_middleware() {
278+
let mut config = ServerConfig::default();
279+
config.request_timeout = 0;
280+
config.rate_limit.enabled = false;
281+
282+
let state = Arc::new(AppState::new(config).await.unwrap());
283+
let app = add_api_middleware(
284+
Router::new().route("/slow", get(slow_test_handler)),
285+
Arc::clone(&state),
286+
)
287+
.with_state(state);
288+
289+
let response = app
290+
.oneshot(Request::builder().uri("/slow").body(Body::empty()).unwrap())
291+
.await
292+
.unwrap();
293+
294+
assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT);
295+
assert_eq!(
296+
response.headers().get("X-Content-Type-Options").unwrap(),
297+
"nosniff"
298+
);
299+
}
300+
}

src/cortex-app-server/src/share.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::storage::StoredMessage;
2222
/// Create share routes.
2323
pub fn routes() -> Router<Arc<AppState>> {
2424
Router::new()
25+
.without_v07_checks()
2526
.route("/share", post(create_share))
2627
.route("/share/:token", get(get_shared_session))
2728
.route("/share/:token", delete(revoke_share))

src/cortex-app-server/src/streaming.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::state::AppState;
3030
/// Create streaming API routes.
3131
pub fn routes() -> Router<Arc<AppState>> {
3232
Router::new()
33+
.without_v07_checks()
3334
// CLI Session management
3435
.route("/cli/sessions", post(create_cli_session))
3536
.route("/cli/sessions", get(list_cli_sessions))

src/cortex-app-server/src/websocket.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::state::AppState;
2828
/// Create WebSocket routes.
2929
pub fn routes() -> Router<Arc<AppState>> {
3030
Router::new()
31+
.without_v07_checks()
3132
.route("/ws", get(websocket_handler))
3233
.route("/ws/sessions/:id", get(session_websocket_handler))
3334
}

0 commit comments

Comments
 (0)