Skip to content

Commit 4cfa073

Browse files
committed
fix(app-server): rate limit direct clients by IP
1 parent 7954d02 commit 4cfa073

2 files changed

Lines changed: 42 additions & 5 deletions

File tree

src/cortex-app-server/src/lib.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ where
9999
};
100100

101101
let listener = TcpListener::bind(addr).await?;
102-
axum::serve(listener, app)
103-
.with_graceful_shutdown(shutdown)
104-
.await?;
102+
axum::serve(
103+
listener,
104+
app.into_make_service_with_connect_info::<SocketAddr>(),
105+
)
106+
.with_graceful_shutdown(shutdown)
107+
.await?;
105108

106109
// Graceful shutdown: close all active sessions first
107110
// This ensures WebSocket clients receive proper close frames

src/cortex-app-server/src/middleware.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
//! HTTP middleware components.
22
3+
use std::net::SocketAddr;
34
use std::sync::Arc;
45
use std::time::{Duration, Instant};
56

67
use axum::{
7-
extract::{Request, State},
8+
extract::{ConnectInfo, Request, State},
89
http::{HeaderValue, Method, StatusCode, header},
910
middleware::Next,
1011
response::{IntoResponse, Response},
@@ -178,7 +179,11 @@ fn get_rate_limit_key(request: &Request, state: &AppState) -> String {
178179
}
179180
}
180181

181-
// Default to unknown when not behind proxy or headers not present
182+
if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
183+
return format!("ip:{}", addr.ip());
184+
}
185+
186+
// Default to unknown only when connection metadata is unavailable.
182187
"ip:unknown".to_string()
183188
}
184189

@@ -458,7 +463,10 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res
458463

459464
#[cfg(test)]
460465
mod tests {
466+
use axum::body::Body;
467+
461468
use super::*;
469+
use crate::config::ServerConfig;
462470

463471
#[test]
464472
fn test_request_id() {
@@ -477,4 +485,30 @@ mod tests {
477485
.contains(&"Authorization".to_string())
478486
);
479487
}
488+
489+
#[tokio::test]
490+
async fn rate_limit_key_uses_connect_info_for_direct_clients() {
491+
let config = ServerConfig::default();
492+
let state = AppState::new(config).await.unwrap();
493+
let mut request = Request::builder().body(Body::empty()).unwrap();
494+
let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap();
495+
request.extensions_mut().insert(ConnectInfo(addr));
496+
497+
assert_eq!(get_rate_limit_key(&request, &state), "ip:203.0.113.10");
498+
}
499+
500+
#[tokio::test]
501+
async fn rate_limit_key_prefers_proxy_headers_when_trusted() {
502+
let mut config = ServerConfig::default();
503+
config.rate_limit.trust_proxy = true;
504+
let state = AppState::new(config).await.unwrap();
505+
let mut request = Request::builder()
506+
.header("X-Forwarded-For", "198.51.100.7, 203.0.113.8")
507+
.body(Body::empty())
508+
.unwrap();
509+
let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap();
510+
request.extensions_mut().insert(ConnectInfo(addr));
511+
512+
assert_eq!(get_rate_limit_key(&request, &state), "ip:198.51.100.7");
513+
}
480514
}

0 commit comments

Comments
 (0)