Skip to content

Commit dfac273

Browse files
style: apply black 26.5.1 reformat
Mechanical commit, no semantic changes. black 26.x changed its line-wrapping rules for several common patterns (trailing-comma handling in long function calls, kwarg layout in multi-line calls). Applied across 13 src files. Run locally with `poetry run black src` against black 26.5.1. Isolated from the dep-bump commit so future bisect/blame can skip past the reformat. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 7764dd5 commit dfac273

13 files changed

Lines changed: 69 additions & 57 deletions

File tree

src/databricks/sql/auth/auth.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
102102
# TODO : unify all the auth mechanisms with the Python SDK
103103

104104
auth_type = kwargs.get("auth_type")
105-
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
105+
client_id, redirect_port_range = get_client_id_and_redirect_port(
106106
auth_type == AuthType.AZURE_OAUTH.value
107107
)
108108

@@ -124,9 +124,11 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
124124
azure_client_secret=kwargs.get("azure_client_secret"),
125125
azure_tenant_id=kwargs.get("azure_tenant_id"),
126126
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
127-
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
128-
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
129-
else redirect_port_range,
127+
oauth_redirect_port_range=(
128+
[kwargs["oauth_redirect_port"]]
129+
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
130+
else redirect_port_range
131+
),
130132
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
131133
credentials_provider=kwargs.get("credentials_provider"),
132134
identity_federation_client_id=kwargs.get("identity_federation_client_id"),

src/databricks/sql/auth/authenticators.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@ class CredentialsProvider(abc.ABC):
3333
for authenticating requests to Databricks REST APIs"""
3434

3535
@abc.abstractmethod
36-
def auth_type(self) -> str:
37-
...
36+
def auth_type(self) -> str: ...
3837

3938
@abc.abstractmethod
40-
def __call__(self, *args, **kwargs) -> HeaderFactory:
41-
...
39+
def __call__(self, *args, **kwargs) -> HeaderFactory: ...
4240

4341

4442
# Private API: this is an evolving interface and it will change in the future.
@@ -109,7 +107,7 @@ def _initial_get_token(self):
109107
if self._access_token and self._refresh_token:
110108
self._update_token_if_expired()
111109
else:
112-
(access_token, refresh_token) = self.oauth_manager.get_tokens(
110+
access_token, refresh_token = self.oauth_manager.get_tokens(
113111
hostname=self._hostname, scope=self._scopes_as_str
114112
)
115113
self._access_token = access_token
@@ -231,9 +229,9 @@ def header_factory() -> Dict[str, str]:
231229
}
232230

233231
if self.azure_workspace_resource_id:
234-
headers[
235-
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
236-
] = self.azure_workspace_resource_id
232+
headers[self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER] = (
233+
self.azure_workspace_resource_id
234+
)
237235

238236
return headers
239237

src/databricks/sql/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge):
134134
try:
135135
with HTTPServer(("", port), handler) as httpd:
136136
redirect_url = OAuthManager.__get_redirect_url(port)
137-
(auth_req_uri, _, _) = client.prepare_authorization_request(
137+
auth_req_uri, _, _ = client.prepare_authorization_request(
138138
authorization_url=auth_url,
139139
redirect_url=redirect_url,
140140
scope=scope,
@@ -269,7 +269,7 @@ def get_tokens(self, hostname: str, scope=None):
269269
auth_url = self.idp_endpoint.get_authorization_url(hostname)
270270

271271
state = OAuthManager.__token_urlsafe(16)
272-
(verifier, challenge) = OAuthManager.__get_challenge()
272+
verifier, challenge = OAuthManager.__get_challenge()
273273
client = oauthlib.oauth2.WebApplicationClient(self.client_id)
274274

275275
try:

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ def open(self):
122122
pool_class = HTTPSConnectionPool
123123
_pool_kwargs.update(
124124
{
125-
"cert_reqs": ssl.CERT_REQUIRED
126-
if self._ssl_options.tls_verify
127-
else ssl.CERT_NONE,
125+
"cert_reqs": (
126+
ssl.CERT_REQUIRED
127+
if self._ssl_options.tls_verify
128+
else ssl.CERT_NONE
129+
),
128130
"ca_certs": self._ssl_options.tls_trusted_ca_file,
129131
"cert_file": self._ssl_options.tls_client_cert_file,
130132
"key_file": self._ssl_options.tls_client_cert_key_file,

src/databricks/sql/auth/token_federation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def add_headers(self, request_headers: Dict[str, str]):
111111
"""Add authentication headers to the request."""
112112

113113
if self._cached_token and not self._cached_token.is_expired():
114-
request_headers[
115-
"Authorization"
116-
] = f"{self._cached_token.token_type} {self._cached_token.access_token}"
114+
request_headers["Authorization"] = (
115+
f"{self._cached_token.token_type} {self._cached_token.access_token}"
116+
)
117117
return
118118

119119
# Get the external headers first to check if we need token federation

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,8 @@ def _filter_json_result_set(
227227
if not case_sensitive:
228228
allowed_values = [v.upper() for v in allowed_values]
229229
# Helper lambda to get column value based on case sensitivity
230-
get_column_value = (
231-
lambda row: row[column_index].upper()
232-
if not case_sensitive
233-
else row[column_index]
230+
get_column_value = lambda row: (
231+
row[column_index].upper() if not case_sensitive else row[column_index]
234232
)
235233

236234
# Filter rows based on allowed values

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,11 @@ def _open(self):
161161
pool_class = HTTPSConnectionPool
162162
pool_kwargs.update(
163163
{
164-
"cert_reqs": ssl.CERT_REQUIRED
165-
if self.ssl_options.tls_verify
166-
else ssl.CERT_NONE,
164+
"cert_reqs": (
165+
ssl.CERT_REQUIRED
166+
if self.ssl_options.tls_verify
167+
else ssl.CERT_NONE
168+
),
167169
"ca_certs": self.ssl_options.tls_trusted_ca_file,
168170
"cert_file": self.ssl_options.tls_client_cert_file,
169171
"key_file": self.ssl_options.tls_client_cert_key_file,

src/databricks/sql/backend/thrift_backend.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from databricks.sql.result_set import ThriftResultSet
1313
from databricks.sql.telemetry.models.event import StatementType
1414

15-
1615
if TYPE_CHECKING:
1716
from databricks.sql.client import Cursor
1817
from databricks.sql.result_set import ResultSet
@@ -680,7 +679,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
680679
num_rows,
681680
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
682681
elif t_row_set.arrowBatches is not None:
683-
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
682+
(
683+
arrow_table,
684+
num_rows,
685+
) = convert_arrow_based_set_to_arrow_table(
684686
t_row_set.arrowBatches, lz4_compressed, schema_bytes
685687
)
686688
else:
@@ -1046,11 +1048,13 @@ def execute_command(
10461048
statement=operation,
10471049
runAsync=True,
10481050
# For async operation we don't want the direct results
1049-
getDirectResults=None
1050-
if async_op
1051-
else ttypes.TSparkGetDirectResults(
1052-
maxRows=max_rows,
1053-
maxBytes=max_bytes,
1051+
getDirectResults=(
1052+
None
1053+
if async_op
1054+
else ttypes.TSparkGetDirectResults(
1055+
maxRows=max_rows,
1056+
maxBytes=max_bytes,
1057+
)
10541058
),
10551059
canReadArrowResult=True if pyarrow else False,
10561060
canDecompressLZ4Result=lz4_compression,

src/databricks/sql/client.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ def read(self) -> Optional[OAuthToken]:
340340
http_path=http_path,
341341
port=kwargs.get("_port", 443),
342342
client_context=client_context,
343-
user_agent=self.session.useragent_header
344-
if hasattr(self, "session")
345-
else None,
343+
user_agent=(
344+
self.session.useragent_header if hasattr(self, "session") else None
345+
),
346346
enable_telemetry=enable_telemetry,
347347
)
348348
raise e
@@ -390,9 +390,11 @@ def read(self) -> Optional[OAuthToken]:
390390

391391
driver_connection_params = DriverConnectionParameters(
392392
http_path=http_path,
393-
mode=DatabricksClientType.SEA
394-
if self.session.use_sea
395-
else DatabricksClientType.THRIFT,
393+
mode=(
394+
DatabricksClientType.SEA
395+
if self.session.use_sea
396+
else DatabricksClientType.THRIFT
397+
),
396398
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
397399
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
398400
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),

src/databricks/sql/common/unified_http_client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,13 @@ def _setup_pool_managers(self):
148148
"num_pools": self.config.pool_connections,
149149
"maxsize": self.config.pool_maxsize,
150150
"retries": self._retry_policy,
151-
"timeout": urllib3.Timeout(
152-
connect=self.config.socket_timeout, read=self.config.socket_timeout
153-
)
154-
if self.config.socket_timeout
155-
else None,
151+
"timeout": (
152+
urllib3.Timeout(
153+
connect=self.config.socket_timeout, read=self.config.socket_timeout
154+
)
155+
if self.config.socket_timeout
156+
else None
157+
),
156158
"ssl_context": ssl_context,
157159
}
158160

0 commit comments

Comments
 (0)