1111
1212from mssql_python .logging import logger
1313from mssql_python .constants import AuthType , ConstantsDDBC
14+ from mssql_python .connection_string_parser import _ConnectionStringParser
1415
1516# Module-level credential instance cache.
1617# Reusing credential objects allows the Azure Identity SDK's built-in
1718# in-memory token cache to work, avoiding redundant token acquisitions.
1819# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md
19- _credential_cache : Dict [str , object ] = {}
20+ #
21+ # Cache is keyed on (auth_type, sorted credential_kwargs), which is
22+ # bounded by the distinct credentials a single process ever uses.
23+ _credential_cache : Dict [object , object ] = {}
2024_credential_cache_lock = threading .Lock ()
2125
2226
27+ def _credential_cache_key (auth_type : str , credential_kwargs : Optional [Dict [str , str ]]):
28+ """Build a hashable cache key from auth_type and optional credential kwargs.
29+
30+ Returns the plain auth_type string when no kwargs are provided so that
31+ callers caching by string (the original behavior) keep working. When
32+ kwargs are present (e.g. user-assigned MSI client_id), the key is a
33+ tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map
34+ to different cached credentials.
35+ """
36+ if not credential_kwargs :
37+ return auth_type
38+ return (auth_type , tuple (sorted (credential_kwargs .items ())))
39+
40+
2341class AADAuth :
2442 """Handles Azure Active Directory authentication"""
2543
@@ -37,24 +55,26 @@ def get_token_struct(token: str) -> bytes:
3755 return struct .pack (f"<I{ len (token_bytes )} s" , len (token_bytes ), token_bytes )
3856
3957 @staticmethod
40- def get_token (auth_type : str ) -> bytes :
58+ def get_token (auth_type : str , credential_kwargs : Optional [ Dict [ str , str ]] = None ) -> bytes :
4159 """Get DDBC token struct for the specified authentication type."""
42- token_struct , _ = AADAuth ._acquire_token (auth_type )
60+ token_struct , _ = AADAuth ._acquire_token (auth_type , credential_kwargs )
4361 return token_struct
4462
4563 @staticmethod
46- def get_raw_token (auth_type : str ) -> str :
64+ def get_raw_token (auth_type : str , credential_kwargs : Optional [ Dict [ str , str ]] = None ) -> str :
4765 """Acquire a raw JWT for the mssql-py-core connection (bulk copy).
4866
4967 Uses the cached credential instance so the Azure Identity SDK's
5068 built-in token cache can serve a valid token without a round-trip
5169 when the previous token has not yet expired.
5270 """
53- _ , raw_token = AADAuth ._acquire_token (auth_type )
71+ _ , raw_token = AADAuth ._acquire_token (auth_type , credential_kwargs )
5472 return raw_token
5573
5674 @staticmethod
57- def _acquire_token (auth_type : str ) -> Tuple [bytes , str ]:
75+ def _acquire_token (
76+ auth_type : str , credential_kwargs : Optional [Dict [str , str ]] = None
77+ ) -> Tuple [bytes , str ]:
5878 """Internal: acquire token and return (ddbc_struct, raw_jwt)."""
5979 # Import Azure libraries inside method to support test mocking
6080 # pylint: disable=import-outside-toplevel
@@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
6383 DefaultAzureCredential ,
6484 DeviceCodeCredential ,
6585 InteractiveBrowserCredential ,
86+ ManagedIdentityCredential ,
6687 )
6788 from azure .core .exceptions import ClientAuthenticationError
6889 except ImportError as e :
@@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
7697 "default" : DefaultAzureCredential ,
7798 "devicecode" : DeviceCodeCredential ,
7899 "interactive" : InteractiveBrowserCredential ,
100+ "msi" : ManagedIdentityCredential ,
79101 }
80102
81103 credential_class = credential_map .get (auth_type )
@@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
89111 credential_class .__name__ ,
90112 )
91113
114+ kwargs = credential_kwargs or {}
115+ cache_key = _credential_cache_key (auth_type , kwargs )
92116 try :
93117 with _credential_cache_lock :
94- if auth_type not in _credential_cache :
118+ if cache_key not in _credential_cache :
95119 logger .debug (
96120 "get_token: Creating new credential instance for auth_type=%s" ,
97121 auth_type ,
98122 )
99- _credential_cache [auth_type ] = credential_class ()
123+ _credential_cache [cache_key ] = credential_class (** kwargs )
100124 else :
101125 logger .debug (
102126 "get_token: Reusing cached credential instance for auth_type=%s" ,
103127 auth_type ,
104128 )
105- credential = _credential_cache [auth_type ]
129+ credential = _credential_cache [cache_key ]
106130 raw_token = credential .get_token ("https://database.windows.net/.default" ).token
107131 logger .info (
108132 "get_token: Azure AD token acquired successfully - token_length=%d chars" ,
@@ -130,6 +154,28 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
130154 raise RuntimeError (f"Failed to create { credential_class .__name__ } : { e } " ) from e
131155
132156
157+ def _extract_msi_client_id (connection_string : str ) -> Optional [str ]:
158+ """Pull UID out of a connection string for user-assigned MSI.
159+
160+ For ActiveDirectoryMSI, UID (when present) carries the user-assigned
161+ identity's ``client_id``. Returns None for system-assigned MSI.
162+
163+ Uses the canonical ``_ConnectionStringParser`` so braced ODBC values
164+ are handled correctly: a ``UID={hello=world}`` resolves to the value
165+ ``hello=world`` (no surrounding braces, no false split on the inner
166+ ``=``), and a semicolon inside a legitimate braced value (e.g.
167+ ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``.
168+ """
169+ # Connection.__init__ already parsed the same string through
170+ # _ConnectionStringParser via _construct_connection_string, so by the
171+ # time we get here the input is guaranteed parseable. No defensive
172+ # try/except: a parse failure now means a real bug upstream and should
173+ # propagate, not silently degrade user-assigned MSI to system-assigned.
174+ parsed = _ConnectionStringParser (validate_keywords = False )._parse (connection_string )
175+ uid = (parsed .get ("uid" ) or "" ).strip ()
176+ return uid or None
177+
178+
133179def process_auth_parameters (parameters : List [str ]) -> Tuple [List [str ], Optional [str ]]:
134180 """
135181 Process connection parameters and extract authentication type.
@@ -180,6 +226,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[
180226 # Default authentication (uses DefaultAzureCredential)
181227 logger .debug ("process_auth_parameters: Default Azure authentication detected" )
182228 auth_type = "default"
229+ elif value_lower == AuthType .MSI .value :
230+ # Managed identity authentication (system- or user-assigned)
231+ logger .debug ("process_auth_parameters: Managed identity authentication detected" )
232+ auth_type = "msi"
183233 modified_parameters .append (param )
184234
185235 logger .debug (
@@ -212,7 +262,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:
212262 return result
213263
214264
215- def get_auth_token (auth_type : str ) -> Optional [bytes ]:
265+ def get_auth_token (
266+ auth_type : str , credential_kwargs : Optional [Dict [str , str ]] = None
267+ ) -> Optional [bytes ]:
216268 """Get DDBC authentication token struct based on auth type."""
217269 logger .debug ("get_auth_token: Starting - auth_type=%s" , auth_type )
218270 if not auth_type :
@@ -225,7 +277,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
225277 return None # Let Windows handle AADInteractive natively
226278
227279 try :
228- token = AADAuth .get_token (auth_type )
280+ token = AADAuth .get_token (auth_type , credential_kwargs )
229281 logger .info ("get_auth_token: Token acquired successfully - auth_type=%s" , auth_type )
230282 return token
231283 except (ValueError , RuntimeError ) as e :
@@ -246,6 +298,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
246298 AuthType .INTERACTIVE .value : "interactive" ,
247299 AuthType .DEVICE_CODE .value : "devicecode" ,
248300 AuthType .DEFAULT .value : "default" ,
301+ AuthType .MSI .value : "msi" ,
249302 }
250303 for part in connection_string .split (";" ):
251304 key , _ , value = part .strip ().partition ("=" )
@@ -256,16 +309,28 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
256309
257310def process_connection_string (
258311 connection_string : str ,
259- ) -> Tuple [str , Optional [Dict [int , bytes ]], Optional [str ]]:
312+ ) -> Tuple [str , Optional [Dict [int , bytes ]], Optional [str ], Optional [ Dict [ str , str ]] ]:
260313 """
261314 Process connection string and handle authentication.
262315
316+ NOTE: Returns a 4-tuple. Callers must unpack all four elements.
317+ Destructuring with three names raises ``ValueError: too many values
318+ to unpack``. The fourth element (``credential_kwargs``) is needed by
319+ Connection.__init__ to persist credential constructor args (e.g. the
320+ user-assigned MSI ``client_id``) for the bulkcopy fresh-token path,
321+ since UID is stripped from the sanitized connection string.
322+
263323 Args:
264324 connection_string: The connection string to process
265325
266326 Returns:
267- Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
268- attrs_before dict if needed, and auth_type string for bulk copy token acquisition
327+ Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]:
328+ Processed connection string, attrs_before dict if needed, auth_type
329+ string for bulk copy token acquisition, and credential constructor
330+ kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on
331+ the Connection so bulkcopy can re-use them when acquiring a fresh
332+ token after sanitization has stripped UID from the connection
333+ string.
269334
270335 Raises:
271336 ValueError: If the connection string is invalid or empty
@@ -301,12 +366,33 @@ def process_connection_string(
301366
302367 modified_parameters , auth_type = process_auth_parameters (parameters )
303368
369+ # Capture credential kwargs (e.g. user-assigned MSI client_id) before
370+ # remove_sensitive_params strips UID from the parameter list. Pass the
371+ # original connection_string (not modified_parameters) so the helper can
372+ # use the canonical _ConnectionStringParser — handles braced values like
373+ # UID={hello=world} correctly.
374+ credential_kwargs : Dict [str , str ] = {}
375+ if auth_type == "msi" :
376+ client_id = _extract_msi_client_id (connection_string )
377+ if client_id :
378+ credential_kwargs ["client_id" ] = client_id
379+ logger .debug (
380+ "process_connection_string: ActiveDirectoryMSI with UID — "
381+ "user-assigned managed identity selected (client_id length=%d)" ,
382+ len (client_id ),
383+ )
384+ else :
385+ logger .debug (
386+ "process_connection_string: ActiveDirectoryMSI without UID — "
387+ "system-assigned managed identity selected"
388+ )
389+
304390 if auth_type :
305391 logger .info (
306392 "process_connection_string: Authentication type detected - auth_type=%s" , auth_type
307393 )
308394 modified_parameters = remove_sensitive_params (modified_parameters )
309- token_struct = get_auth_token (auth_type )
395+ token_struct = get_auth_token (auth_type , credential_kwargs or None )
310396 if token_struct :
311397 logger .info (
312398 "process_connection_string: Token authentication configured successfully - auth_type=%s" ,
@@ -316,6 +402,7 @@ def process_connection_string(
316402 ";" .join (modified_parameters ) + ";" ,
317403 {ConstantsDDBC .SQL_COPT_SS_ACCESS_TOKEN .value : token_struct },
318404 auth_type ,
405+ credential_kwargs or None ,
319406 )
320407 else :
321408 logger .warning (
@@ -326,4 +413,9 @@ def process_connection_string(
326413 "process_connection_string: Connection string processing complete - has_auth=%s" ,
327414 bool (auth_type ),
328415 )
329- return ";" .join (modified_parameters ) + ";" , None , auth_type
416+ return (
417+ ";" .join (modified_parameters ) + ";" ,
418+ None ,
419+ auth_type ,
420+ credential_kwargs or None ,
421+ )
0 commit comments