Skip to content

Commit b476518

Browse files
committed
Resolving merge conficts
1 parent 46ae4ba commit b476518

4 files changed

Lines changed: 179 additions & 450 deletions

File tree

mssql_python/auth.py

Lines changed: 43 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import platform
88
import struct
99
import threading
10-
from typing import Tuple, Dict, Optional, List
10+
from typing import Tuple, Dict, Optional
1111

1212
from mssql_python.logging import logger
1313
from mssql_python.constants import AuthType, ConstantsDDBC
14-
from mssql_python.connection_string_parser import _ConnectionStringParser
1514

1615
# Module-level credential instance cache.
1716
# Reusing credential objects allows the Azure Identity SDK's built-in
@@ -23,6 +22,17 @@
2322
_credential_cache: Dict[object, object] = {}
2423
_credential_cache_lock = threading.Lock()
2524

25+
# Canonical keys to strip when handing an Entra-token connection to ODBC.
26+
_SENSITIVE_KEYS = frozenset({"UID", "PWD", "Trusted_Connection", "Authentication"})
27+
28+
# Map Authentication connection-string values to internal short names.
29+
_AUTH_TYPE_MAP: Dict[str, str] = {
30+
AuthType.INTERACTIVE.value: "interactive",
31+
AuthType.DEVICE_CODE.value: "devicecode",
32+
AuthType.DEFAULT.value: "default",
33+
AuthType.MSI.value: "msi",
34+
}
35+
2636

2737
def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]):
2838
"""Build a hashable cache key from auth_type and optional credential kwargs.
@@ -154,112 +164,36 @@ def _acquire_token(
154164
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e
155165

156166

157-
def _extract_msi_client_id(connection_string: str) -> Optional[str]:
158-
"""Pull UID out of a connection string for user-assigned MSI.
159-
160-
For ActiveDirectoryMSI, UID (when present) carries the user-assigned
161-
identity's ``client_id``. Returns None for system-assigned MSI.
162-
163-
Uses the canonical ``_ConnectionStringParser`` so braced ODBC values
164-
are handled correctly: a ``UID={hello=world}`` resolves to the value
165-
``hello=world`` (no surrounding braces, no false split on the inner
166-
``=``), and a semicolon inside a legitimate braced value (e.g.
167-
``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``.
167+
def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]:
168168
"""
169-
# Connection.__init__ already parsed the same string through
170-
# _ConnectionStringParser via _construct_connection_string, so by the
171-
# time we get here the input is guaranteed parseable. No defensive
172-
# try/except: a parse failure now means a real bug upstream and should
173-
# propagate, not silently degrade user-assigned MSI to system-assigned.
174-
parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string)
175-
uid = (parsed.get("uid") or "").strip()
176-
return uid or None
169+
Extract authentication type from parsed connection parameters.
177170
178-
179-
def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]:
180-
"""
181-
Process connection parameters and extract authentication type.
171+
Returns the internal auth type string needed for token acquisition,
172+
or None when the driver should handle authentication natively
173+
(e.g. Windows Interactive).
182174
183175
Args:
184-
parameters: List of connection string parameters
176+
parsed_params: Dictionary of normalized connection parameters
185177
186178
Returns:
187-
Tuple[list, Optional[str]]: Modified parameters and authentication type
188-
189-
Raises:
190-
ValueError: If an invalid authentication type is provided
179+
Optional[str]: Authentication type string or None
191180
"""
192-
logger.debug("process_auth_parameters: Processing %d connection parameters", len(parameters))
193-
modified_parameters = []
194-
auth_type = None
195-
196-
for param in parameters:
197-
param = param.strip()
198-
if not param:
199-
continue
200-
201-
if "=" not in param:
202-
modified_parameters.append(param)
203-
continue
204-
205-
key, value = param.split("=", 1)
206-
key_lower = key.lower()
207-
value_lower = value.lower()
208-
209-
if key_lower == "authentication":
210-
# Check for supported authentication types and set auth_type accordingly
211-
if value_lower == AuthType.INTERACTIVE.value:
212-
auth_type = "interactive"
213-
logger.debug("process_auth_parameters: Interactive authentication detected")
214-
# Interactive authentication (browser-based); only append parameter for non-Windows
215-
if platform.system().lower() == "windows":
216-
logger.debug(
217-
"process_auth_parameters: Windows platform - using native AADInteractive"
218-
)
219-
auth_type = None # Let Windows handle AADInteractive natively
220-
221-
elif value_lower == AuthType.DEVICE_CODE.value:
222-
# Device code authentication (for devices without browser)
223-
logger.debug("process_auth_parameters: Device code authentication detected")
224-
auth_type = "devicecode"
225-
elif value_lower == AuthType.DEFAULT.value:
226-
# Default authentication (uses DefaultAzureCredential)
227-
logger.debug("process_auth_parameters: Default Azure authentication detected")
228-
auth_type = "default"
229-
elif value_lower == AuthType.MSI.value:
230-
# Managed identity authentication (system- or user-assigned)
231-
logger.debug("process_auth_parameters: Managed identity authentication detected")
232-
auth_type = "msi"
233-
modified_parameters.append(param)
234-
235-
logger.debug(
236-
"process_auth_parameters: Processing complete - auth_type=%s, param_count=%d",
237-
auth_type,
238-
len(modified_parameters),
239-
)
240-
return modified_parameters, auth_type
241-
242-
243-
def remove_sensitive_params(parameters: List[str]) -> List[str]:
244-
"""Remove sensitive parameters from connection string"""
245-
logger.debug(
246-
"remove_sensitive_params: Removing sensitive parameters - input_count=%d", len(parameters)
247-
)
248-
exclude_keys = [
249-
"uid=",
250-
"pwd=",
251-
"trusted_connection=",
252-
"authentication=",
253-
]
254-
result = [
255-
param
256-
for param in parameters
257-
if not any(param.lower().startswith(exclude) for exclude in exclude_keys)
258-
]
259-
logger.debug(
260-
"remove_sensitive_params: Sensitive parameters removed - output_count=%d", len(result)
261-
)
262-
return result
181+
auth_type = extract_auth_type(parsed_params)
182+
if not auth_type:
183+
return None
184+
185+
# On Windows, Interactive auth is handled natively by the ODBC driver.
186+
if auth_type == "interactive" and platform.system().lower() == "windows":
187+
logger.debug("process_auth_parameters: Windows platform - using native AADInteractive")
188+
return None
189+
190+
logger.debug("process_auth_parameters: auth_type=%s", auth_type)
191+
return auth_type
192+
193+
194+
def remove_sensitive_params(parsed_params: Dict[str, str]) -> Dict[str, str]:
195+
"""Return a copy of *parsed_params* without credentials / auth keys."""
196+
return {k: v for k, v in parsed_params.items() if k not in _SENSITIVE_KEYS}
263197

264198

265199
def get_auth_token(
@@ -287,135 +221,13 @@ def get_auth_token(
287221
return None
288222

289223

290-
def extract_auth_type(connection_string: str) -> Optional[str]:
291-
"""Extract Entra ID auth type from a connection string.
292-
293-
Used as a fallback when process_connection_string does not propagate
294-
auth_type (e.g. Windows Interactive where DDBC handles auth natively).
295-
Bulkcopy still needs the auth type to acquire a token via Azure Identity.
296-
"""
297-
auth_map = {
298-
AuthType.INTERACTIVE.value: "interactive",
299-
AuthType.DEVICE_CODE.value: "devicecode",
300-
AuthType.DEFAULT.value: "default",
301-
AuthType.MSI.value: "msi",
302-
}
303-
for part in connection_string.split(";"):
304-
key, _, value = part.strip().partition("=")
305-
if key.strip().lower() == "authentication":
306-
return auth_map.get(value.strip().lower())
307-
return None
308-
309-
310-
def process_connection_string(
311-
connection_string: str,
312-
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]:
313-
"""
314-
Process connection string and handle authentication.
315-
316-
NOTE: Returns a 4-tuple. Callers must unpack all four elements.
317-
Destructuring with three names raises ``ValueError: too many values
318-
to unpack``. The fourth element (``credential_kwargs``) is needed by
319-
Connection.__init__ to persist credential constructor args (e.g. the
320-
user-assigned MSI ``client_id``) for the bulkcopy fresh-token path,
321-
since UID is stripped from the sanitized connection string.
322-
323-
Args:
324-
connection_string: The connection string to process
224+
def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]:
225+
"""Map the Authentication connection-string value to an internal type name.
325226
326-
Returns:
327-
Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]:
328-
Processed connection string, attrs_before dict if needed, auth_type
329-
string for bulk copy token acquisition, and credential constructor
330-
kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on
331-
the Connection so bulkcopy can re-use them when acquiring a fresh
332-
token after sanitization has stripped UID from the connection
333-
string.
334-
335-
Raises:
336-
ValueError: If the connection string is invalid or empty
227+
Returns ``"interactive"``, ``"devicecode"``, ``"default"``, ``"msi"``,
228+
or *None* for unrecognised / absent values. This is a pure mapping with
229+
no platform checks — use :func:`process_auth_parameters` when you need
230+
the Windows-Interactive suppression logic.
337231
"""
338-
logger.debug(
339-
"process_connection_string: Starting - conn_str_length=%d",
340-
len(connection_string) if isinstance(connection_string, str) else 0,
341-
)
342-
# Check type first
343-
if not isinstance(connection_string, str):
344-
logger.error(
345-
"process_connection_string: Invalid type - expected str, got %s",
346-
type(connection_string).__name__,
347-
)
348-
raise ValueError("Connection string must be a string")
349-
350-
# Then check if empty
351-
if not connection_string:
352-
logger.error("process_connection_string: Connection string is empty")
353-
raise ValueError("Connection string cannot be empty")
354-
355-
parameters = connection_string.split(";")
356-
logger.debug(
357-
"process_connection_string: Split connection string - parameter_count=%d", len(parameters)
358-
)
359-
360-
# Validate that there's at least one valid parameter
361-
if not any("=" in param for param in parameters):
362-
logger.error(
363-
"process_connection_string: Invalid connection string format - no key=value pairs found"
364-
)
365-
raise ValueError("Invalid connection string format")
366-
367-
modified_parameters, auth_type = process_auth_parameters(parameters)
368-
369-
# Capture credential kwargs (e.g. user-assigned MSI client_id) before
370-
# remove_sensitive_params strips UID from the parameter list. Pass the
371-
# original connection_string (not modified_parameters) so the helper can
372-
# use the canonical _ConnectionStringParser — handles braced values like
373-
# UID={hello=world} correctly.
374-
credential_kwargs: Dict[str, str] = {}
375-
if auth_type == "msi":
376-
client_id = _extract_msi_client_id(connection_string)
377-
if client_id:
378-
credential_kwargs["client_id"] = client_id
379-
logger.debug(
380-
"process_connection_string: ActiveDirectoryMSI with UID — "
381-
"user-assigned managed identity selected (client_id length=%d)",
382-
len(client_id),
383-
)
384-
else:
385-
logger.debug(
386-
"process_connection_string: ActiveDirectoryMSI without UID — "
387-
"system-assigned managed identity selected"
388-
)
389-
390-
if auth_type:
391-
logger.info(
392-
"process_connection_string: Authentication type detected - auth_type=%s", auth_type
393-
)
394-
modified_parameters = remove_sensitive_params(modified_parameters)
395-
token_struct = get_auth_token(auth_type, credential_kwargs or None)
396-
if token_struct:
397-
logger.info(
398-
"process_connection_string: Token authentication configured successfully - auth_type=%s",
399-
auth_type,
400-
)
401-
return (
402-
";".join(modified_parameters) + ";",
403-
{ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct},
404-
auth_type,
405-
credential_kwargs or None,
406-
)
407-
else:
408-
logger.warning(
409-
"process_connection_string: Token acquisition failed, proceeding without token"
410-
)
411-
412-
logger.debug(
413-
"process_connection_string: Connection string processing complete - has_auth=%s",
414-
bool(auth_type),
415-
)
416-
return (
417-
";".join(modified_parameters) + ";",
418-
None,
419-
auth_type,
420-
credential_kwargs or None,
421-
)
232+
auth_value = parsed_params.get("Authentication", "").strip().lower()
233+
return _AUTH_TYPE_MAP.get(auth_value)

0 commit comments

Comments
 (0)