Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- New feature: Support for macOS and Linux.
- Documentation: Added API documentation in the Wiki.
- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal`
via an `entra_id_token_factory` callback registered on the mssql-py-core
connection. The callback is invoked by mssql-tds mid-handshake (FedAuth
workflow 0x02) so the tenant id can be resolved from the server-supplied
STS URL. Requires `mssql-py-core` 0.1.5+. Partial fix for #534.

### Changed
- Improved error handling in the connection module.
Expand Down
157 changes: 150 additions & 7 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This module handles authentication for the mssql_python package.
"""

import hashlib
import platform
import struct
import threading
Expand Down Expand Up @@ -154,6 +155,143 @@ def _acquire_token(
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e


def _parse_tenant_id(sts_url: str) -> Optional[str]:
"""Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL.

Expected formats:
https://login.microsoftonline.com/<tenant>/
https://login.microsoftonline.com/<tenant>/?...
https://login.microsoftonline.com/<tenant>
where <tenant> is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``)
or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are
accepted by ``azure.identity.ClientSecretCredential``.
"""
# pylint: disable=import-outside-toplevel
from urllib.parse import urlparse

try:
parsed = urlparse(sts_url)
except (ValueError, AttributeError):
return None
# Reject anything that isn't an https URL with a netloc. ``urlparse`` will
# happily put a bare string like ``"tenant-guid"`` into ``path``, which
# would then look like a valid tenant. Azure AD STS URLs are always https.
if parsed.scheme != "https" or not parsed.netloc:
return None
path = (parsed.path or "").strip("/")
if not path:
return None
first_segment = path.split("/", 1)[0]
return first_segment or None
Comment on lines +172 to +185


class ServicePrincipalAuth:
"""Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal.

The bulkcopy path through mssql-py-core uses callback-based token
acquisition (FedAuth workflow ``0x02``) because tenant_id is only known
from the STS URL that the server returns during the TDS handshake.
"""

@staticmethod
def make_token_factory(client_id: str, client_secret: str):
"""Return a callable suitable for ``entra_id_token_factory``.

Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``.
Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format).

``ClientSecretCredential`` instances are reused across calls via the
module-level ``_credential_cache``, keyed by
``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's
in-memory token cache (which is per-credential-instance) actually
works across handshake retries, reconnects, and separate bulkcopy
invocations using the same identity.
"""
if not client_id:
raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)")
if not client_secret:
raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)")

def _factory(spn: str, sts_url: str, auth_method: str) -> bytes:
# pylint: disable=import-outside-toplevel,unused-argument
try:
from azure.identity import ClientSecretCredential
from azure.core.exceptions import ClientAuthenticationError
except ImportError as e:
raise RuntimeError(
"Azure authentication libraries are not installed. "
"Please install with: pip install azure-identity azure-core"
) from e

if not spn:
raise RuntimeError(
"ServicePrincipal token factory: empty SPN from server "
"(cannot construct token scope)"
)
tenant_id = _parse_tenant_id(sts_url)
if not tenant_id:
raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}")

logger.info(
"ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s",
tenant_id,
spn,
)
try:
# Reuse the shared credential cache (introduced for MSI in PR #573)
# so SP credentials get the same per-instance token reuse semantics
# as the other AD methods.
#
# The cache key includes a hash of client_secret so a rotated
# secret produces a different cache entry. Without this, an
# external secret rotation would not invalidate the cached
# ClientSecretCredential: azure-identity's internal token cache
# would keep returning the previously-issued token (good for
# up to ~1 hour) until expiry, masking the rotation. Hashing
# avoids storing the raw secret in the dict key.
secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest()
cache_key = _credential_cache_key(
"serviceprincipal",
{
"tenant_id": tenant_id,
"client_id": client_id,
"secret_hash": secret_hash,
},
)
with _credential_cache_lock:
credential = _credential_cache.get(cache_key)
if credential is None:
credential = ClientSecretCredential(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
)
_credential_cache[cache_key] = credential
# mssql-tds passes the resource SPN; azure-identity wants a scope.
scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default"
token = credential.get_token(scope).token
logger.info(
"ServicePrincipal token factory: token acquired, length=%d chars",
len(token),
)
return token.encode("utf-16-le")
except ClientAuthenticationError as e:
# Keep the detailed provider error in debug logs only. The
# surfaced message is intentionally generic so that any
# secret-bearing provider text never reaches the user-facing
# exception chain.
logger.error(
"ServicePrincipal authentication failed: tenant=%s, error=%s",
tenant_id,
str(e),
)
raise RuntimeError(
"ServicePrincipal authentication failed; " "see debug logs for provider details"
) from None

return _factory


def _extract_msi_client_id(connection_string: str) -> Optional[str]:
"""Pull UID out of a connection string for user-assigned MSI.

Expand Down Expand Up @@ -230,6 +368,17 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[
# Managed identity authentication (system- or user-assigned)
logger.debug("process_auth_parameters: Managed identity authentication detected")
auth_type = "msi"
elif value_lower == AuthType.SERVICE_PRINCIPAL.value:
# ServicePrincipal authentication. ODBC (msodbcsql 17.3+)
# handles this natively for regular queries, so leave
# auth_type=None to let ODBC own the query path.
# Bulkcopy still needs the auth type — extract_auth_type()
# propagates it as "serviceprincipal" so the bulkcopy path
# can register an entra_id_token_factory callback (Model B,
# required because tenant_id is only known from the STS URL
# that the server returns during the FedAuth handshake).
logger.debug("process_auth_parameters: Service principal authentication detected")
auth_type = None
modified_parameters.append(param)

logger.debug(
Expand Down Expand Up @@ -299,6 +448,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]:
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
AuthType.MSI.value: "msi",
AuthType.SERVICE_PRINCIPAL.value: "serviceprincipal",
}
for part in connection_string.split(";"):
key, _, value = part.strip().partition("=")
Expand All @@ -313,13 +463,6 @@ def process_connection_string(
"""
Process connection string and handle authentication.

NOTE: Returns a 4-tuple. Callers must unpack all four elements.
Destructuring with three names raises ``ValueError: too many values
to unpack``. The fourth element (``credential_kwargs``) is needed by
Connection.__init__ to persist credential constructor args (e.g. the
user-assigned MSI ``client_id``) for the bulkcopy fresh-token path,
since UID is stripped from the sanitized connection string.

Args:
connection_string: The connection string to process

Expand Down
1 change: 1 addition & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class AuthType(Enum):
DEVICE_CODE = "activedirectorydevicecode"
DEFAULT = "activedirectorydefault"
MSI = "activedirectorymsi"
SERVICE_PRINCIPAL = "activedirectoryserviceprincipal"


class SQLTypes:
Expand Down
87 changes: 62 additions & 25 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2933,31 +2933,60 @@ def bulkcopy(

# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
if self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection. credential
# kwargs (e.g. user-assigned MSI client_id) were captured by
# Connection.__init__ before remove_sensitive_params stripped UID
# from connection_str — re-parsing here would miss them.
from mssql_python.auth import AADAuth

try:
raw_token = AADAuth.get_raw_token(
# Fresh token acquisition for mssql-py-core connection
from mssql_python.auth import AADAuth, ServicePrincipalAuth

if self.connection._auth_type == "serviceprincipal":
# Model B: callback-based. tenant_id is only known from the
# STS URL the server returns mid-handshake, so we register a
# factory that py-core invokes during FedAuth (workflow 0x02).
client_id = params.get("uid", "")
client_secret = params.get("pwd", "")
if not client_id or not client_secret:
raise RuntimeError(
"Bulk copy with Authentication=ActiveDirectoryServicePrincipal "
"requires UID (client_id) and PWD (client_secret) in the "
"connection string."
)
try:
factory = ServicePrincipalAuth.make_token_factory(client_id, client_secret)
except (RuntimeError, ValueError) as e:
raise RuntimeError(
f"Bulk copy failed: unable to build ServicePrincipal token factory: {e}"
) from e
pycore_context["entra_id_token_factory"] = factory
# Keep authentication/user_name/password in pycore_context —
# py-core's auth validator + transformer need them to resolve
# the auth method to ActiveDirectoryServicePrincipal before
# the factory is dispatched at handshake time.
Comment on lines +2951 to +2961
Comment on lines +2957 to +2961
logger.debug("Bulk copy: registered ServicePrincipal token factory")
else:
# Model A: pre-acquired token. Used for Default, DeviceCode,
# Interactive (non-Windows), MSI (system- or user-assigned),
# and any other AD method whose tenant_id is discoverable
# client-side via Azure Identity SDK. credential kwargs
# (e.g. user-assigned MSI client_id) were captured by
# Connection.__init__ before remove_sensitive_params stripped
# UID from connection_str — re-parsing here would miss them.
try:
raw_token = AADAuth.get_raw_token(
self.connection._auth_type,
self.connection._credential_kwargs,
)
except (RuntimeError, ValueError) as e:
raise RuntimeError(
f"Bulk copy failed: unable to acquire Azure AD token "
f"for auth_type '{self.connection._auth_type}': {e}"
) from e
pycore_context["access_token"] = raw_token
# Token replaces credential fields — py-core's validator rejects
# access_token combined with authentication/user_name/password.
for key in ("authentication", "user_name", "password"):
pycore_context.pop(key, None)
logger.debug(
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
self.connection._auth_type,
self.connection._credential_kwargs,
)
except (RuntimeError, ValueError) as e:
raise RuntimeError(
f"Bulk copy failed: unable to acquire Azure AD token "
f"for auth_type '{self.connection._auth_type}': {e}"
) from e
pycore_context["access_token"] = raw_token
# Token replaces credential fields — py-core's validator rejects
# access_token combined with authentication/user_name/password.
for key in ("authentication", "user_name", "password"):
pycore_context.pop(key, None)
logger.debug(
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
self.connection._auth_type,
)

pycore_connection = None
pycore_cursor = None
Expand Down Expand Up @@ -3007,9 +3036,17 @@ def bulkcopy(
raise type(e)(str(e)) from None

finally:
# Clear sensitive data to minimize memory exposure
# Clear sensitive data to minimize memory exposure. The
# entra_id_token_factory closure captures client_secret, so drop
# our dict reference to it (Rust still holds an Arc until the
# connection is dropped, but at least we don't keep an extra ref).
if pycore_context:
for key in ("password", "user_name", "access_token"):
for key in (
"password",
"user_name",
"access_token",
"entra_id_token_factory",
):
pycore_context.pop(key, None)
# Clean up bulk copy resources
for resource in (pycore_cursor, pycore_connection):
Expand Down
Loading
Loading