Skip to content

Commit 46ae4ba

Browse files
authored
FEAT: Add ActiveDirectoryMSI support for bulk copy (#573)
1 parent a276b15 commit 46ae4ba

5 files changed

Lines changed: 357 additions & 23 deletions

File tree

mssql_python/auth.py

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,33 @@
1111

1212
from mssql_python.logging import logger
1313
from mssql_python.constants import AuthType, ConstantsDDBC
14+
from mssql_python.connection_string_parser import _ConnectionStringParser
1415

1516
# Module-level credential instance cache.
1617
# Reusing credential objects allows the Azure Identity SDK's built-in
1718
# in-memory token cache to work, avoiding redundant token acquisitions.
1819
# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md
19-
_credential_cache: Dict[str, object] = {}
20+
#
21+
# Cache is keyed on (auth_type, sorted credential_kwargs), which is
22+
# bounded by the distinct credentials a single process ever uses.
23+
_credential_cache: Dict[object, object] = {}
2024
_credential_cache_lock = threading.Lock()
2125

2226

27+
def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]):
28+
"""Build a hashable cache key from auth_type and optional credential kwargs.
29+
30+
Returns the plain auth_type string when no kwargs are provided so that
31+
callers caching by string (the original behavior) keep working. When
32+
kwargs are present (e.g. user-assigned MSI client_id), the key is a
33+
tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map
34+
to different cached credentials.
35+
"""
36+
if not credential_kwargs:
37+
return auth_type
38+
return (auth_type, tuple(sorted(credential_kwargs.items())))
39+
40+
2341
class AADAuth:
2442
"""Handles Azure Active Directory authentication"""
2543

@@ -37,24 +55,26 @@ def get_token_struct(token: str) -> bytes:
3755
return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
3856

3957
@staticmethod
40-
def get_token(auth_type: str) -> bytes:
58+
def get_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> bytes:
4159
"""Get DDBC token struct for the specified authentication type."""
42-
token_struct, _ = AADAuth._acquire_token(auth_type)
60+
token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs)
4361
return token_struct
4462

4563
@staticmethod
46-
def get_raw_token(auth_type: str) -> str:
64+
def get_raw_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> str:
4765
"""Acquire a raw JWT for the mssql-py-core connection (bulk copy).
4866
4967
Uses the cached credential instance so the Azure Identity SDK's
5068
built-in token cache can serve a valid token without a round-trip
5169
when the previous token has not yet expired.
5270
"""
53-
_, raw_token = AADAuth._acquire_token(auth_type)
71+
_, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs)
5472
return raw_token
5573

5674
@staticmethod
57-
def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
75+
def _acquire_token(
76+
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
77+
) -> Tuple[bytes, str]:
5878
"""Internal: acquire token and return (ddbc_struct, raw_jwt)."""
5979
# Import Azure libraries inside method to support test mocking
6080
# pylint: disable=import-outside-toplevel
@@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
6383
DefaultAzureCredential,
6484
DeviceCodeCredential,
6585
InteractiveBrowserCredential,
86+
ManagedIdentityCredential,
6687
)
6788
from azure.core.exceptions import ClientAuthenticationError
6889
except ImportError as e:
@@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
7697
"default": DefaultAzureCredential,
7798
"devicecode": DeviceCodeCredential,
7899
"interactive": InteractiveBrowserCredential,
100+
"msi": ManagedIdentityCredential,
79101
}
80102

81103
credential_class = credential_map.get(auth_type)
@@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
89111
credential_class.__name__,
90112
)
91113

114+
kwargs = credential_kwargs or {}
115+
cache_key = _credential_cache_key(auth_type, kwargs)
92116
try:
93117
with _credential_cache_lock:
94-
if auth_type not in _credential_cache:
118+
if cache_key not in _credential_cache:
95119
logger.debug(
96120
"get_token: Creating new credential instance for auth_type=%s",
97121
auth_type,
98122
)
99-
_credential_cache[auth_type] = credential_class()
123+
_credential_cache[cache_key] = credential_class(**kwargs)
100124
else:
101125
logger.debug(
102126
"get_token: Reusing cached credential instance for auth_type=%s",
103127
auth_type,
104128
)
105-
credential = _credential_cache[auth_type]
129+
credential = _credential_cache[cache_key]
106130
raw_token = credential.get_token("https://database.windows.net/.default").token
107131
logger.info(
108132
"get_token: Azure AD token acquired successfully - token_length=%d chars",
@@ -130,6 +154,28 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
130154
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e
131155

132156

157+
def _extract_msi_client_id(connection_string: str) -> Optional[str]:
158+
"""Pull UID out of a connection string for user-assigned MSI.
159+
160+
For ActiveDirectoryMSI, UID (when present) carries the user-assigned
161+
identity's ``client_id``. Returns None for system-assigned MSI.
162+
163+
Uses the canonical ``_ConnectionStringParser`` so braced ODBC values
164+
are handled correctly: a ``UID={hello=world}`` resolves to the value
165+
``hello=world`` (no surrounding braces, no false split on the inner
166+
``=``), and a semicolon inside a legitimate braced value (e.g.
167+
``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``.
168+
"""
169+
# Connection.__init__ already parsed the same string through
170+
# _ConnectionStringParser via _construct_connection_string, so by the
171+
# time we get here the input is guaranteed parseable. No defensive
172+
# try/except: a parse failure now means a real bug upstream and should
173+
# propagate, not silently degrade user-assigned MSI to system-assigned.
174+
parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string)
175+
uid = (parsed.get("uid") or "").strip()
176+
return uid or None
177+
178+
133179
def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]:
134180
"""
135181
Process connection parameters and extract authentication type.
@@ -180,6 +226,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[
180226
# Default authentication (uses DefaultAzureCredential)
181227
logger.debug("process_auth_parameters: Default Azure authentication detected")
182228
auth_type = "default"
229+
elif value_lower == AuthType.MSI.value:
230+
# Managed identity authentication (system- or user-assigned)
231+
logger.debug("process_auth_parameters: Managed identity authentication detected")
232+
auth_type = "msi"
183233
modified_parameters.append(param)
184234

185235
logger.debug(
@@ -212,7 +262,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:
212262
return result
213263

214264

215-
def get_auth_token(auth_type: str) -> Optional[bytes]:
265+
def get_auth_token(
266+
auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None
267+
) -> Optional[bytes]:
216268
"""Get DDBC authentication token struct based on auth type."""
217269
logger.debug("get_auth_token: Starting - auth_type=%s", auth_type)
218270
if not auth_type:
@@ -225,7 +277,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
225277
return None # Let Windows handle AADInteractive natively
226278

227279
try:
228-
token = AADAuth.get_token(auth_type)
280+
token = AADAuth.get_token(auth_type, credential_kwargs)
229281
logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type)
230282
return token
231283
except (ValueError, RuntimeError) as e:
@@ -246,6 +298,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
246298
AuthType.INTERACTIVE.value: "interactive",
247299
AuthType.DEVICE_CODE.value: "devicecode",
248300
AuthType.DEFAULT.value: "default",
301+
AuthType.MSI.value: "msi",
249302
}
250303
for part in connection_string.split(";"):
251304
key, _, value = part.strip().partition("=")
@@ -256,16 +309,28 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
256309

257310
def process_connection_string(
258311
connection_string: str,
259-
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]:
312+
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]:
260313
"""
261314
Process connection string and handle authentication.
262315
316+
NOTE: Returns a 4-tuple. Callers must unpack all four elements.
317+
Destructuring with three names raises ``ValueError: too many values
318+
to unpack``. The fourth element (``credential_kwargs``) is needed by
319+
Connection.__init__ to persist credential constructor args (e.g. the
320+
user-assigned MSI ``client_id``) for the bulkcopy fresh-token path,
321+
since UID is stripped from the sanitized connection string.
322+
263323
Args:
264324
connection_string: The connection string to process
265325
266326
Returns:
267-
Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
268-
attrs_before dict if needed, and auth_type string for bulk copy token acquisition
327+
Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]:
328+
Processed connection string, attrs_before dict if needed, auth_type
329+
string for bulk copy token acquisition, and credential constructor
330+
kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on
331+
the Connection so bulkcopy can re-use them when acquiring a fresh
332+
token after sanitization has stripped UID from the connection
333+
string.
269334
270335
Raises:
271336
ValueError: If the connection string is invalid or empty
@@ -301,12 +366,33 @@ def process_connection_string(
301366

302367
modified_parameters, auth_type = process_auth_parameters(parameters)
303368

369+
# Capture credential kwargs (e.g. user-assigned MSI client_id) before
370+
# remove_sensitive_params strips UID from the parameter list. Pass the
371+
# original connection_string (not modified_parameters) so the helper can
372+
# use the canonical _ConnectionStringParser — handles braced values like
373+
# UID={hello=world} correctly.
374+
credential_kwargs: Dict[str, str] = {}
375+
if auth_type == "msi":
376+
client_id = _extract_msi_client_id(connection_string)
377+
if client_id:
378+
credential_kwargs["client_id"] = client_id
379+
logger.debug(
380+
"process_connection_string: ActiveDirectoryMSI with UID — "
381+
"user-assigned managed identity selected (client_id length=%d)",
382+
len(client_id),
383+
)
384+
else:
385+
logger.debug(
386+
"process_connection_string: ActiveDirectoryMSI without UID — "
387+
"system-assigned managed identity selected"
388+
)
389+
304390
if auth_type:
305391
logger.info(
306392
"process_connection_string: Authentication type detected - auth_type=%s", auth_type
307393
)
308394
modified_parameters = remove_sensitive_params(modified_parameters)
309-
token_struct = get_auth_token(auth_type)
395+
token_struct = get_auth_token(auth_type, credential_kwargs or None)
310396
if token_struct:
311397
logger.info(
312398
"process_connection_string: Token authentication configured successfully - auth_type=%s",
@@ -316,6 +402,7 @@ def process_connection_string(
316402
";".join(modified_parameters) + ";",
317403
{ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct},
318404
auth_type,
405+
credential_kwargs or None,
319406
)
320407
else:
321408
logger.warning(
@@ -326,4 +413,9 @@ def process_connection_string(
326413
"process_connection_string: Connection string processing complete - has_auth=%s",
327414
bool(auth_type),
328415
)
329-
return ";".join(modified_parameters) + ";", None, auth_type
416+
return (
417+
";".join(modified_parameters) + ";",
418+
None,
419+
auth_type,
420+
credential_kwargs or None,
421+
)

mssql_python/connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ def __init__(
321321
# We intentionally do NOT cache the token — a fresh one is acquired
322322
# each time bulkcopy() is called to avoid expired-token errors.
323323
self._auth_type = None
324+
# Credential constructor kwargs (e.g. user-assigned MSI client_id)
325+
# captured at __init__ time before remove_sensitive_params strips UID
326+
# from self.connection_str. bulkcopy() re-uses these when acquiring a
327+
# fresh token; re-parsing self.connection_str at that point would miss
328+
# them because UID is already gone.
329+
self._credential_kwargs: Optional[Dict[str, str]] = None
324330

325331
# Check if the connection string contains authentication parameters
326332
# This is important for processing the connection string correctly.
@@ -335,6 +341,7 @@ def __init__(
335341
# On Windows Interactive, process_connection_string returns None
336342
# (DDBC handles auth natively), so fall back to the connection string.
337343
self._auth_type = connection_result[2] or extract_auth_type(self.connection_str)
344+
self._credential_kwargs = connection_result[3]
338345

339346
self._closed = False
340347
self._timeout = timeout

mssql_python/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ class AuthType(Enum):
337337
INTERACTIVE = "activedirectoryinteractive"
338338
DEVICE_CODE = "activedirectorydevicecode"
339339
DEFAULT = "activedirectorydefault"
340+
MSI = "activedirectorymsi"
340341

341342

342343
class SQLTypes:

mssql_python/cursor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2933,11 +2933,17 @@ def bulkcopy(
29332933

29342934
# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
29352935
if self.connection._auth_type:
2936-
# Fresh token acquisition for mssql-py-core connection
2936+
# Fresh token acquisition for mssql-py-core connection. credential
2937+
# kwargs (e.g. user-assigned MSI client_id) were captured by
2938+
# Connection.__init__ before remove_sensitive_params stripped UID
2939+
# from connection_str — re-parsing here would miss them.
29372940
from mssql_python.auth import AADAuth
29382941

29392942
try:
2940-
raw_token = AADAuth.get_raw_token(self.connection._auth_type)
2943+
raw_token = AADAuth.get_raw_token(
2944+
self.connection._auth_type,
2945+
self.connection._credential_kwargs,
2946+
)
29412947
except (RuntimeError, ValueError) as e:
29422948
raise RuntimeError(
29432949
f"Bulk copy failed: unable to acquire Azure AD token "

0 commit comments

Comments
 (0)