11//! HTTP middleware components.
22
3+ use std:: net:: SocketAddr ;
34use std:: sync:: Arc ;
45use std:: time:: { Duration , Instant } ;
56
67use axum:: {
7- extract:: { Request , State } ,
8+ extract:: { ConnectInfo , Request , State } ,
89 http:: { HeaderValue , Method , StatusCode , header} ,
910 middleware:: Next ,
1011 response:: { IntoResponse , Response } ,
@@ -178,7 +179,11 @@ fn get_rate_limit_key(request: &Request, state: &AppState) -> String {
178179 }
179180 }
180181
181- // Default to unknown when not behind proxy or headers not present
182+ if let Some ( ConnectInfo ( addr) ) = request. extensions ( ) . get :: < ConnectInfo < SocketAddr > > ( ) {
183+ return format ! ( "ip:{}" , addr. ip( ) ) ;
184+ }
185+
186+ // Default to unknown only when connection metadata is unavailable.
182187 "ip:unknown" . to_string ( )
183188}
184189
@@ -458,7 +463,10 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res
458463
459464#[ cfg( test) ]
460465mod tests {
466+ use axum:: body:: Body ;
467+
461468 use super :: * ;
469+ use crate :: config:: ServerConfig ;
462470
463471 #[ test]
464472 fn test_request_id ( ) {
@@ -477,4 +485,30 @@ mod tests {
477485 . contains( & "Authorization" . to_string( ) )
478486 ) ;
479487 }
488+
489+ #[ tokio:: test]
490+ async fn rate_limit_key_uses_connect_info_for_direct_clients ( ) {
491+ let config = ServerConfig :: default ( ) ;
492+ let state = AppState :: new ( config) . await . unwrap ( ) ;
493+ let mut request = Request :: builder ( ) . body ( Body :: empty ( ) ) . unwrap ( ) ;
494+ let addr: SocketAddr = "203.0.113.10:4242" . parse ( ) . unwrap ( ) ;
495+ request. extensions_mut ( ) . insert ( ConnectInfo ( addr) ) ;
496+
497+ assert_eq ! ( get_rate_limit_key( & request, & state) , "ip:203.0.113.10" ) ;
498+ }
499+
500+ #[ tokio:: test]
501+ async fn rate_limit_key_prefers_proxy_headers_when_trusted ( ) {
502+ let mut config = ServerConfig :: default ( ) ;
503+ config. rate_limit . trust_proxy = true ;
504+ let state = AppState :: new ( config) . await . unwrap ( ) ;
505+ let mut request = Request :: builder ( )
506+ . header ( "X-Forwarded-For" , "198.51.100.7, 203.0.113.8" )
507+ . body ( Body :: empty ( ) )
508+ . unwrap ( ) ;
509+ let addr: SocketAddr = "203.0.113.10:4242" . parse ( ) . unwrap ( ) ;
510+ request. extensions_mut ( ) . insert ( ConnectInfo ( addr) ) ;
511+
512+ assert_eq ! ( get_rate_limit_key( & request, & state) , "ip:198.51.100.7" ) ;
513+ }
480514}
0 commit comments