5252
5353logger = logging .getLogger (__name__ )
5454
55+ _FORWARDED_AUTH_FLOW_HEADERS = ("User-Agent" ,)
56+
5557
5658class PKCEParameters (BaseModel ):
5759 """PKCE (Proof Key for Code Exchange) parameters."""
@@ -477,6 +479,14 @@ def _add_auth_header(self, request: httpx.Request) -> None:
477479 if self .context .current_tokens and self .context .current_tokens .access_token : # pragma: no branch
478480 request .headers ["Authorization" ] = f"Bearer { self .context .current_tokens .access_token } "
479481
482+ def _forward_request_headers (self , source_request : httpx .Request , outgoing_request : httpx .Request ) -> httpx .Request :
483+ """Forward selected caller headers to OAuth flow requests."""
484+ for header_name in _FORWARDED_AUTH_FLOW_HEADERS :
485+ header_value = source_request .headers .get (header_name )
486+ if header_value is not None and header_name not in outgoing_request .headers :
487+ outgoing_request .headers [header_name ] = header_value
488+ return outgoing_request
489+
480490 async def _handle_oauth_metadata_response (self , response : httpx .Response ) -> None :
481491 content = await response .aread ()
482492 metadata = OAuthMetadata .model_validate_json (content )
@@ -508,6 +518,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
508518 if not self .context .is_token_valid () and self .context .can_refresh_token ():
509519 # Try to refresh token
510520 refresh_request = await self ._refresh_token () # pragma: no cover
521+ refresh_request = self ._forward_request_headers (request , refresh_request ) # pragma: no cover
511522 refresh_response = yield refresh_request # pragma: no cover
512523
513524 if not await self ._handle_refresh_response (refresh_response ): # pragma: no cover
@@ -532,6 +543,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
532543
533544 for url in prm_discovery_urls : # pragma: no branch
534545 discovery_request = create_oauth_metadata_request (url )
546+ discovery_request = self ._forward_request_headers (request , discovery_request )
535547
536548 discovery_response = yield discovery_request # sending request
537549
@@ -558,6 +570,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
558570 # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
559571 for url in asm_discovery_urls : # pragma: no branch
560572 oauth_metadata_request = create_oauth_metadata_request (url )
573+ oauth_metadata_request = self ._forward_request_headers (request , oauth_metadata_request )
561574 oauth_metadata_response = yield oauth_metadata_request
562575
563576 ok , asm = await handle_auth_metadata_response (oauth_metadata_response )
@@ -596,13 +609,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
596609 self .context .client_metadata ,
597610 self .context .get_authorization_base_url (self .context .server_url ),
598611 )
612+ registration_request = self ._forward_request_headers (request , registration_request )
599613 registration_response = yield registration_request
600614 client_information = await handle_registration_response (registration_response )
601615 self .context .client_info = client_information
602616 await self .context .storage .set_client_info (client_information )
603617
604618 # Step 5: Perform authorization and complete token exchange
605- token_response = yield await self ._perform_authorization ()
619+ token_request = await self ._perform_authorization ()
620+ token_request = self ._forward_request_headers (request , token_request )
621+ token_response = yield token_request
606622 await self ._handle_token_response (token_response )
607623 except Exception : # pragma: no cover
608624 logger .exception ("OAuth flow error" )
@@ -624,7 +640,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
624640 )
625641
626642 # Step 2b: Perform (re-)authorization and token exchange
627- token_response = yield await self ._perform_authorization ()
643+ token_request = await self ._perform_authorization ()
644+ token_request = self ._forward_request_headers (request , token_request )
645+ token_response = yield token_request
628646 await self ._handle_token_response (token_response )
629647 except Exception : # pragma: no cover
630648 logger .exception ("OAuth flow error" )
0 commit comments