99from typing import Tuple , Dict , Optional , List
1010
1111from mssql_python .logging import logger
12- from mssql_python .constants import AuthType
12+ from mssql_python .constants import AuthType , ConstantsDDBC
1313
1414
1515class AADAuth :
@@ -30,7 +30,25 @@ def get_token_struct(token: str) -> bytes:
3030
3131 @staticmethod
3232 def get_token (auth_type : str ) -> bytes :
33- """Get token using the specified authentication type"""
33+ """Get DDBC token struct for the specified authentication type."""
34+ token_struct , _ = AADAuth ._acquire_token (auth_type )
35+ return token_struct
36+
37+ @staticmethod
38+ def get_raw_token (auth_type : str ) -> str :
39+ """Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).
40+
41+ This deliberately does NOT cache the credential or token — each call
42+ creates a new Azure Identity credential instance and requests a token.
43+ A fresh acquisition avoids expired-token errors when bulkcopy() is
44+ called long after the original DDBC connect().
45+ """
46+ _ , raw_token = AADAuth ._acquire_token (auth_type )
47+ return raw_token
48+
49+ @staticmethod
50+ def _acquire_token (auth_type : str ) -> Tuple [bytes , str ]:
51+ """Internal: acquire token and return (ddbc_struct, raw_jwt)."""
3452 # Import Azure libraries inside method to support test mocking
3553 # pylint: disable=import-outside-toplevel
3654 try :
@@ -53,30 +71,27 @@ def get_token(auth_type: str) -> bytes:
5371 "interactive" : InteractiveBrowserCredential ,
5472 }
5573
56- credential_class = credential_map [auth_type ]
74+ credential_class = credential_map .get (auth_type )
75+ if not credential_class :
76+ raise ValueError (
77+ f"Unsupported auth_type '{ auth_type } '. " f"Supported: { ', ' .join (credential_map )} "
78+ )
5779 logger .info (
5880 "get_token: Starting Azure AD authentication - auth_type=%s, credential_class=%s" ,
5981 auth_type ,
6082 credential_class .__name__ ,
6183 )
6284
6385 try :
64- logger .debug (
65- "get_token: Creating credential instance - credential_class=%s" ,
66- credential_class .__name__ ,
67- )
6886 credential = credential_class ()
69- logger .debug (
70- "get_token: Requesting token from Azure AD - scope=https://database.windows.net/.default"
71- )
72- token = credential .get_token ("https://database.windows.net/.default" ).token
87+ raw_token = credential .get_token ("https://database.windows.net/.default" ).token
7388 logger .info (
7489 "get_token: Azure AD token acquired successfully - token_length=%d chars" ,
75- len (token ),
90+ len (raw_token ),
7691 )
77- return AADAuth .get_token_struct (token )
92+ token_struct = AADAuth .get_token_struct (raw_token )
93+ return token_struct , raw_token
7894 except ClientAuthenticationError as e :
79- # Re-raise with more specific context about Azure AD authentication failure
8095 logger .error (
8196 "get_token: Azure AD authentication failed - credential_class=%s, error=%s" ,
8297 credential_class .__name__ ,
@@ -88,7 +103,6 @@ def get_token(auth_type: str) -> bytes:
88103 f"user cancellation, network issues, or unsupported configuration."
89104 ) from e
90105 except Exception as e :
91- # Catch any other unexpected exceptions
92106 logger .error (
93107 "get_token: Unexpected error during credential creation - credential_class=%s, error=%s" ,
94108 credential_class .__name__ ,
@@ -180,7 +194,7 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:
180194
181195
182196def get_auth_token (auth_type : str ) -> Optional [bytes ]:
183- """Get authentication token based on auth type"""
197+ """Get DDBC authentication token struct based on auth type. """
184198 logger .debug ("get_auth_token: Starting - auth_type=%s" , auth_type )
185199 if not auth_type :
186200 logger .debug ("get_auth_token: No auth_type specified, returning None" )
@@ -202,17 +216,37 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
202216 return None
203217
204218
219+ def extract_auth_type (connection_string : str ) -> Optional [str ]:
220+ """Extract Entra ID auth type from a connection string.
221+
222+ Used as a fallback when process_connection_string does not propagate
223+ auth_type (e.g. Windows Interactive where DDBC handles auth natively).
224+ Bulkcopy still needs the auth type to acquire a token via Azure Identity.
225+ """
226+ auth_map = {
227+ AuthType .INTERACTIVE .value : "interactive" ,
228+ AuthType .DEVICE_CODE .value : "devicecode" ,
229+ AuthType .DEFAULT .value : "default" ,
230+ }
231+ for part in connection_string .split (";" ):
232+ key , _ , value = part .strip ().partition ("=" )
233+ if key .strip ().lower () == "authentication" :
234+ return auth_map .get (value .strip ().lower ())
235+ return None
236+
237+
205238def process_connection_string (
206239 connection_string : str ,
207- ) -> Tuple [str , Optional [Dict [int , bytes ]]]:
240+ ) -> Tuple [str , Optional [Dict [int , bytes ]], Optional [ str ] ]:
208241 """
209242 Process connection string and handle authentication.
210243
211244 Args:
212245 connection_string: The connection string to process
213246
214247 Returns:
215- Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
248+ Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
249+ attrs_before dict if needed, and auth_type string for bulk copy token acquisition
216250
217251 Raises:
218252 ValueError: If the connection string is invalid or empty
@@ -259,7 +293,11 @@ def process_connection_string(
259293 "process_connection_string: Token authentication configured successfully - auth_type=%s" ,
260294 auth_type ,
261295 )
262- return ";" .join (modified_parameters ) + ";" , {1256 : token_struct }
296+ return (
297+ ";" .join (modified_parameters ) + ";" ,
298+ {ConstantsDDBC .SQL_COPT_SS_ACCESS_TOKEN .value : token_struct },
299+ auth_type ,
300+ )
263301 else :
264302 logger .warning (
265303 "process_connection_string: Token acquisition failed, proceeding without token"
@@ -269,4 +307,4 @@ def process_connection_string(
269307 "process_connection_string: Connection string processing complete - has_auth=%s" ,
270308 bool (auth_type ),
271309 )
272- return ";" .join (modified_parameters ) + ";" , None
310+ return ";" .join (modified_parameters ) + ";" , None , auth_type
0 commit comments