@@ -36,9 +36,8 @@ pub mod websocket;
3636use std:: net:: SocketAddr ;
3737use std:: sync:: Arc ;
3838
39- use axum:: Router ;
39+ use axum:: { Router , middleware as axum_middleware } ;
4040use tokio:: net:: TcpListener ;
41- use tower_http:: cors:: CorsLayer ;
4241use tower_http:: trace:: TraceLayer ;
4342use tracing:: { info, warn} ;
4443
@@ -131,15 +130,171 @@ pub fn create_router(state: AppState) -> Router {
131130/// This variant is useful when you need to keep a reference to the state
132131/// for cleanup purposes (e.g., during graceful shutdown).
133132pub fn create_router_with_state ( state : Arc < AppState > ) -> Router {
134- let api_routes = api:: routes ( )
135- . merge ( websocket:: routes ( ) )
136- . merge ( streaming:: routes ( ) )
137- . merge ( share:: routes ( ) )
138- . merge ( admin:: routes ( ) ) ;
133+ let cors_layer = middleware:: cors_layer ( & state. config . cors_origins ) ;
134+
135+ let api_routes = add_api_middleware (
136+ api:: routes ( )
137+ . merge ( websocket:: routes ( ) )
138+ . merge ( streaming:: routes ( ) )
139+ . merge ( share:: routes ( ) )
140+ . merge ( admin:: routes ( ) ) ,
141+ Arc :: clone ( & state) ,
142+ ) ;
139143
140144 Router :: new ( )
145+ . without_v07_checks ( )
141146 . nest ( "/api/v1" , api_routes)
142147 . layer ( TraceLayer :: new_for_http ( ) )
143- . layer ( CorsLayer :: permissive ( ) )
148+ . layer ( cors_layer )
144149 . with_state ( state)
145150}
151+
152+ fn add_api_middleware (
153+ router : Router < Arc < AppState > > ,
154+ state : Arc < AppState > ,
155+ ) -> Router < Arc < AppState > > {
156+ router
157+ . layer ( axum_middleware:: from_fn_with_state (
158+ Arc :: clone ( & state) ,
159+ middleware:: rate_limit_middleware,
160+ ) )
161+ . layer ( axum_middleware:: from_fn (
162+ middleware:: content_type_middleware,
163+ ) )
164+ . layer ( axum_middleware:: from_fn_with_state (
165+ state,
166+ middleware:: timeout_middleware,
167+ ) )
168+ . layer ( axum_middleware:: from_fn (
169+ middleware:: security_headers_middleware,
170+ ) )
171+ }
172+
173+ #[ cfg( test) ]
174+ mod tests {
175+ use super :: * ;
176+ use axum:: {
177+ body:: Body ,
178+ http:: { Request , StatusCode , header} ,
179+ routing:: get,
180+ } ;
181+ use tower:: ServiceExt ;
182+
183+ async fn test_app ( config : ServerConfig ) -> Router {
184+ let state = AppState :: new ( config) . await . unwrap ( ) ;
185+ create_router ( state)
186+ }
187+
188+ async fn slow_test_handler ( ) -> & ' static str {
189+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 25 ) ) . await ;
190+ "done"
191+ }
192+
193+ #[ tokio:: test]
194+ async fn create_router_applies_security_headers_middleware ( ) {
195+ let app = test_app ( ServerConfig :: default ( ) ) . await ;
196+
197+ let response = app
198+ . oneshot (
199+ Request :: builder ( )
200+ . uri ( "/api/v1/models" )
201+ . body ( Body :: empty ( ) )
202+ . unwrap ( ) ,
203+ )
204+ . await
205+ . unwrap ( ) ;
206+
207+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
208+ assert_eq ! (
209+ response. headers( ) . get( "X-Content-Type-Options" ) . unwrap( ) ,
210+ "nosniff"
211+ ) ;
212+ assert_eq ! ( response. headers( ) . get( "X-Frame-Options" ) . unwrap( ) , "DENY" ) ;
213+ }
214+
215+ #[ tokio:: test]
216+ async fn create_router_applies_rate_limit_middleware ( ) {
217+ let mut config = ServerConfig :: default ( ) ;
218+ config. rate_limit . requests_per_minute = 1 ;
219+ config. rate_limit . burst_size = 1 ;
220+ config. rate_limit . exempt_paths . clear ( ) ;
221+
222+ let app = test_app ( config) . await ;
223+
224+ let first = app
225+ . clone ( )
226+ . oneshot (
227+ Request :: builder ( )
228+ . uri ( "/api/v1/models" )
229+ . body ( Body :: empty ( ) )
230+ . unwrap ( ) ,
231+ )
232+ . await
233+ . unwrap ( ) ;
234+ assert_eq ! ( first. status( ) , StatusCode :: OK ) ;
235+
236+ let second = app
237+ . oneshot (
238+ Request :: builder ( )
239+ . uri ( "/api/v1/models" )
240+ . body ( Body :: empty ( ) )
241+ . unwrap ( ) ,
242+ )
243+ . await
244+ . unwrap ( ) ;
245+ assert_eq ! ( second. status( ) , StatusCode :: TOO_MANY_REQUESTS ) ;
246+ assert_eq ! ( second. headers( ) . get( header:: RETRY_AFTER ) . unwrap( ) , "60" ) ;
247+ assert_eq ! (
248+ second. headers( ) . get( "X-Content-Type-Options" ) . unwrap( ) ,
249+ "nosniff"
250+ ) ;
251+ }
252+
253+ #[ tokio:: test]
254+ async fn create_router_applies_content_type_middleware ( ) {
255+ let app = test_app ( ServerConfig :: default ( ) ) . await ;
256+
257+ let response = app
258+ . oneshot (
259+ Request :: builder ( )
260+ . method ( "POST" )
261+ . uri ( "/api/v1/models" )
262+ . header ( header:: CONTENT_TYPE , "text/plain" )
263+ . body ( Body :: from ( "not json" ) )
264+ . unwrap ( ) ,
265+ )
266+ . await
267+ . unwrap ( ) ;
268+
269+ assert_eq ! ( response. status( ) , StatusCode :: UNSUPPORTED_MEDIA_TYPE ) ;
270+ assert_eq ! (
271+ response. headers( ) . get( "X-Content-Type-Options" ) . unwrap( ) ,
272+ "nosniff"
273+ ) ;
274+ }
275+
276+ #[ tokio:: test]
277+ async fn create_router_applies_timeout_middleware ( ) {
278+ let mut config = ServerConfig :: default ( ) ;
279+ config. request_timeout = 0 ;
280+ config. rate_limit . enabled = false ;
281+
282+ let state = Arc :: new ( AppState :: new ( config) . await . unwrap ( ) ) ;
283+ let app = add_api_middleware (
284+ Router :: new ( ) . route ( "/slow" , get ( slow_test_handler) ) ,
285+ Arc :: clone ( & state) ,
286+ )
287+ . with_state ( state) ;
288+
289+ let response = app
290+ . oneshot ( Request :: builder ( ) . uri ( "/slow" ) . body ( Body :: empty ( ) ) . unwrap ( ) )
291+ . await
292+ . unwrap ( ) ;
293+
294+ assert_eq ! ( response. status( ) , StatusCode :: GATEWAY_TIMEOUT ) ;
295+ assert_eq ! (
296+ response. headers( ) . get( "X-Content-Type-Options" ) . unwrap( ) ,
297+ "nosniff"
298+ ) ;
299+ }
300+ }
0 commit comments