|
7 | 7 | import platform |
8 | 8 | import struct |
9 | 9 | import threading |
10 | | -from typing import Tuple, Dict, Optional, List |
| 10 | +from typing import Tuple, Dict, Optional |
11 | 11 |
|
12 | 12 | from mssql_python.logging import logger |
13 | 13 | from mssql_python.constants import AuthType, ConstantsDDBC |
14 | | -from mssql_python.connection_string_parser import _ConnectionStringParser |
15 | 14 |
|
16 | 15 | # Module-level credential instance cache. |
17 | 16 | # Reusing credential objects allows the Azure Identity SDK's built-in |
|
23 | 22 | _credential_cache: Dict[object, object] = {} |
24 | 23 | _credential_cache_lock = threading.Lock() |
25 | 24 |
|
| 25 | +# Canonical keys to strip when handing an Entra-token connection to ODBC. |
| 26 | +_SENSITIVE_KEYS = frozenset({"UID", "PWD", "Trusted_Connection", "Authentication"}) |
| 27 | + |
| 28 | +# Map Authentication connection-string values to internal short names. |
| 29 | +_AUTH_TYPE_MAP: Dict[str, str] = { |
| 30 | + AuthType.INTERACTIVE.value: "interactive", |
| 31 | + AuthType.DEVICE_CODE.value: "devicecode", |
| 32 | + AuthType.DEFAULT.value: "default", |
| 33 | + AuthType.MSI.value: "msi", |
| 34 | +} |
| 35 | + |
26 | 36 |
|
27 | 37 | def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): |
28 | 38 | """Build a hashable cache key from auth_type and optional credential kwargs. |
@@ -154,112 +164,36 @@ def _acquire_token( |
154 | 164 | raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e |
155 | 165 |
|
156 | 166 |
|
157 | | -def _extract_msi_client_id(connection_string: str) -> Optional[str]: |
158 | | - """Pull UID out of a connection string for user-assigned MSI. |
159 | | -
|
160 | | - For ActiveDirectoryMSI, UID (when present) carries the user-assigned |
161 | | - identity's ``client_id``. Returns None for system-assigned MSI. |
162 | | -
|
163 | | - Uses the canonical ``_ConnectionStringParser`` so braced ODBC values |
164 | | - are handled correctly: a ``UID={hello=world}`` resolves to the value |
165 | | - ``hello=world`` (no surrounding braces, no false split on the inner |
166 | | - ``=``), and a semicolon inside a legitimate braced value (e.g. |
167 | | - ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. |
| 167 | +def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: |
168 | 168 | """ |
169 | | - # Connection.__init__ already parsed the same string through |
170 | | - # _ConnectionStringParser via _construct_connection_string, so by the |
171 | | - # time we get here the input is guaranteed parseable. No defensive |
172 | | - # try/except: a parse failure now means a real bug upstream and should |
173 | | - # propagate, not silently degrade user-assigned MSI to system-assigned. |
174 | | - parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) |
175 | | - uid = (parsed.get("uid") or "").strip() |
176 | | - return uid or None |
| 169 | + Extract authentication type from parsed connection parameters. |
177 | 170 |
|
178 | | - |
179 | | -def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: |
180 | | - """ |
181 | | - Process connection parameters and extract authentication type. |
| 171 | + Returns the internal auth type string needed for token acquisition, |
| 172 | + or None when the driver should handle authentication natively |
| 173 | + (e.g. Windows Interactive). |
182 | 174 |
|
183 | 175 | Args: |
184 | | - parameters: List of connection string parameters |
| 176 | + parsed_params: Dictionary of normalized connection parameters |
185 | 177 |
|
186 | 178 | Returns: |
187 | | - Tuple[list, Optional[str]]: Modified parameters and authentication type |
188 | | -
|
189 | | - Raises: |
190 | | - ValueError: If an invalid authentication type is provided |
| 179 | + Optional[str]: Authentication type string or None |
191 | 180 | """ |
192 | | - logger.debug("process_auth_parameters: Processing %d connection parameters", len(parameters)) |
193 | | - modified_parameters = [] |
194 | | - auth_type = None |
195 | | - |
196 | | - for param in parameters: |
197 | | - param = param.strip() |
198 | | - if not param: |
199 | | - continue |
200 | | - |
201 | | - if "=" not in param: |
202 | | - modified_parameters.append(param) |
203 | | - continue |
204 | | - |
205 | | - key, value = param.split("=", 1) |
206 | | - key_lower = key.lower() |
207 | | - value_lower = value.lower() |
208 | | - |
209 | | - if key_lower == "authentication": |
210 | | - # Check for supported authentication types and set auth_type accordingly |
211 | | - if value_lower == AuthType.INTERACTIVE.value: |
212 | | - auth_type = "interactive" |
213 | | - logger.debug("process_auth_parameters: Interactive authentication detected") |
214 | | - # Interactive authentication (browser-based); only append parameter for non-Windows |
215 | | - if platform.system().lower() == "windows": |
216 | | - logger.debug( |
217 | | - "process_auth_parameters: Windows platform - using native AADInteractive" |
218 | | - ) |
219 | | - auth_type = None # Let Windows handle AADInteractive natively |
220 | | - |
221 | | - elif value_lower == AuthType.DEVICE_CODE.value: |
222 | | - # Device code authentication (for devices without browser) |
223 | | - logger.debug("process_auth_parameters: Device code authentication detected") |
224 | | - auth_type = "devicecode" |
225 | | - elif value_lower == AuthType.DEFAULT.value: |
226 | | - # Default authentication (uses DefaultAzureCredential) |
227 | | - logger.debug("process_auth_parameters: Default Azure authentication detected") |
228 | | - auth_type = "default" |
229 | | - elif value_lower == AuthType.MSI.value: |
230 | | - # Managed identity authentication (system- or user-assigned) |
231 | | - logger.debug("process_auth_parameters: Managed identity authentication detected") |
232 | | - auth_type = "msi" |
233 | | - modified_parameters.append(param) |
234 | | - |
235 | | - logger.debug( |
236 | | - "process_auth_parameters: Processing complete - auth_type=%s, param_count=%d", |
237 | | - auth_type, |
238 | | - len(modified_parameters), |
239 | | - ) |
240 | | - return modified_parameters, auth_type |
241 | | - |
242 | | - |
243 | | -def remove_sensitive_params(parameters: List[str]) -> List[str]: |
244 | | - """Remove sensitive parameters from connection string""" |
245 | | - logger.debug( |
246 | | - "remove_sensitive_params: Removing sensitive parameters - input_count=%d", len(parameters) |
247 | | - ) |
248 | | - exclude_keys = [ |
249 | | - "uid=", |
250 | | - "pwd=", |
251 | | - "trusted_connection=", |
252 | | - "authentication=", |
253 | | - ] |
254 | | - result = [ |
255 | | - param |
256 | | - for param in parameters |
257 | | - if not any(param.lower().startswith(exclude) for exclude in exclude_keys) |
258 | | - ] |
259 | | - logger.debug( |
260 | | - "remove_sensitive_params: Sensitive parameters removed - output_count=%d", len(result) |
261 | | - ) |
262 | | - return result |
| 181 | + auth_type = extract_auth_type(parsed_params) |
| 182 | + if not auth_type: |
| 183 | + return None |
| 184 | + |
| 185 | + # On Windows, Interactive auth is handled natively by the ODBC driver. |
| 186 | + if auth_type == "interactive" and platform.system().lower() == "windows": |
| 187 | + logger.debug("process_auth_parameters: Windows platform - using native AADInteractive") |
| 188 | + return None |
| 189 | + |
| 190 | + logger.debug("process_auth_parameters: auth_type=%s", auth_type) |
| 191 | + return auth_type |
| 192 | + |
| 193 | + |
| 194 | +def remove_sensitive_params(parsed_params: Dict[str, str]) -> Dict[str, str]: |
| 195 | + """Return a copy of *parsed_params* without credentials / auth keys.""" |
| 196 | + return {k: v for k, v in parsed_params.items() if k not in _SENSITIVE_KEYS} |
263 | 197 |
|
264 | 198 |
|
265 | 199 | def get_auth_token( |
@@ -287,135 +221,13 @@ def get_auth_token( |
287 | 221 | return None |
288 | 222 |
|
289 | 223 |
|
290 | | -def extract_auth_type(connection_string: str) -> Optional[str]: |
291 | | - """Extract Entra ID auth type from a connection string. |
292 | | -
|
293 | | - Used as a fallback when process_connection_string does not propagate |
294 | | - auth_type (e.g. Windows Interactive where DDBC handles auth natively). |
295 | | - Bulkcopy still needs the auth type to acquire a token via Azure Identity. |
296 | | - """ |
297 | | - auth_map = { |
298 | | - AuthType.INTERACTIVE.value: "interactive", |
299 | | - AuthType.DEVICE_CODE.value: "devicecode", |
300 | | - AuthType.DEFAULT.value: "default", |
301 | | - AuthType.MSI.value: "msi", |
302 | | - } |
303 | | - for part in connection_string.split(";"): |
304 | | - key, _, value = part.strip().partition("=") |
305 | | - if key.strip().lower() == "authentication": |
306 | | - return auth_map.get(value.strip().lower()) |
307 | | - return None |
308 | | - |
309 | | - |
310 | | -def process_connection_string( |
311 | | - connection_string: str, |
312 | | -) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: |
313 | | - """ |
314 | | - Process connection string and handle authentication. |
315 | | -
|
316 | | - NOTE: Returns a 4-tuple. Callers must unpack all four elements. |
317 | | - Destructuring with three names raises ``ValueError: too many values |
318 | | - to unpack``. The fourth element (``credential_kwargs``) is needed by |
319 | | - Connection.__init__ to persist credential constructor args (e.g. the |
320 | | - user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, |
321 | | - since UID is stripped from the sanitized connection string. |
322 | | -
|
323 | | - Args: |
324 | | - connection_string: The connection string to process |
| 224 | +def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: |
| 225 | + """Map the Authentication connection-string value to an internal type name. |
325 | 226 |
|
326 | | - Returns: |
327 | | - Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: |
328 | | - Processed connection string, attrs_before dict if needed, auth_type |
329 | | - string for bulk copy token acquisition, and credential constructor |
330 | | - kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on |
331 | | - the Connection so bulkcopy can re-use them when acquiring a fresh |
332 | | - token after sanitization has stripped UID from the connection |
333 | | - string. |
334 | | -
|
335 | | - Raises: |
336 | | - ValueError: If the connection string is invalid or empty |
| 227 | + Returns ``"interactive"``, ``"devicecode"``, ``"default"``, ``"msi"``, |
| 228 | + or *None* for unrecognised / absent values. This is a pure mapping with |
| 229 | + no platform checks — use :func:`process_auth_parameters` when you need |
| 230 | + the Windows-Interactive suppression logic. |
337 | 231 | """ |
338 | | - logger.debug( |
339 | | - "process_connection_string: Starting - conn_str_length=%d", |
340 | | - len(connection_string) if isinstance(connection_string, str) else 0, |
341 | | - ) |
342 | | - # Check type first |
343 | | - if not isinstance(connection_string, str): |
344 | | - logger.error( |
345 | | - "process_connection_string: Invalid type - expected str, got %s", |
346 | | - type(connection_string).__name__, |
347 | | - ) |
348 | | - raise ValueError("Connection string must be a string") |
349 | | - |
350 | | - # Then check if empty |
351 | | - if not connection_string: |
352 | | - logger.error("process_connection_string: Connection string is empty") |
353 | | - raise ValueError("Connection string cannot be empty") |
354 | | - |
355 | | - parameters = connection_string.split(";") |
356 | | - logger.debug( |
357 | | - "process_connection_string: Split connection string - parameter_count=%d", len(parameters) |
358 | | - ) |
359 | | - |
360 | | - # Validate that there's at least one valid parameter |
361 | | - if not any("=" in param for param in parameters): |
362 | | - logger.error( |
363 | | - "process_connection_string: Invalid connection string format - no key=value pairs found" |
364 | | - ) |
365 | | - raise ValueError("Invalid connection string format") |
366 | | - |
367 | | - modified_parameters, auth_type = process_auth_parameters(parameters) |
368 | | - |
369 | | - # Capture credential kwargs (e.g. user-assigned MSI client_id) before |
370 | | - # remove_sensitive_params strips UID from the parameter list. Pass the |
371 | | - # original connection_string (not modified_parameters) so the helper can |
372 | | - # use the canonical _ConnectionStringParser — handles braced values like |
373 | | - # UID={hello=world} correctly. |
374 | | - credential_kwargs: Dict[str, str] = {} |
375 | | - if auth_type == "msi": |
376 | | - client_id = _extract_msi_client_id(connection_string) |
377 | | - if client_id: |
378 | | - credential_kwargs["client_id"] = client_id |
379 | | - logger.debug( |
380 | | - "process_connection_string: ActiveDirectoryMSI with UID — " |
381 | | - "user-assigned managed identity selected (client_id length=%d)", |
382 | | - len(client_id), |
383 | | - ) |
384 | | - else: |
385 | | - logger.debug( |
386 | | - "process_connection_string: ActiveDirectoryMSI without UID — " |
387 | | - "system-assigned managed identity selected" |
388 | | - ) |
389 | | - |
390 | | - if auth_type: |
391 | | - logger.info( |
392 | | - "process_connection_string: Authentication type detected - auth_type=%s", auth_type |
393 | | - ) |
394 | | - modified_parameters = remove_sensitive_params(modified_parameters) |
395 | | - token_struct = get_auth_token(auth_type, credential_kwargs or None) |
396 | | - if token_struct: |
397 | | - logger.info( |
398 | | - "process_connection_string: Token authentication configured successfully - auth_type=%s", |
399 | | - auth_type, |
400 | | - ) |
401 | | - return ( |
402 | | - ";".join(modified_parameters) + ";", |
403 | | - {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, |
404 | | - auth_type, |
405 | | - credential_kwargs or None, |
406 | | - ) |
407 | | - else: |
408 | | - logger.warning( |
409 | | - "process_connection_string: Token acquisition failed, proceeding without token" |
410 | | - ) |
411 | | - |
412 | | - logger.debug( |
413 | | - "process_connection_string: Connection string processing complete - has_auth=%s", |
414 | | - bool(auth_type), |
415 | | - ) |
416 | | - return ( |
417 | | - ";".join(modified_parameters) + ";", |
418 | | - None, |
419 | | - auth_type, |
420 | | - credential_kwargs or None, |
421 | | - ) |
| 232 | + auth_value = parsed_params.get("Authentication", "").strip().lower() |
| 233 | + return _AUTH_TYPE_MAP.get(auth_value) |
0 commit comments