Skip to content
142 changes: 78 additions & 64 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import thrift from 'thrift';
import Int64 from 'node-int64';
import os from 'os';

import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
Expand All @@ -15,9 +13,12 @@ import IDBSQLSession from './contracts/IDBSQLSession';
import IAuthentication from './connection/contracts/IAuthentication';
import HttpConnection from './connection/connections/HttpConnection';
import IConnectionOptions from './connection/contracts/IConnectionOptions';
import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError, serializeQueryTags } from './utils';
import { buildUserAgentString } from './utils';
import IBackend from './contracts/IBackend';
import { InternalConnectionOptions } from './contracts/InternalConnectionOptions';
import ThriftBackend from './thrift-backend/ThriftBackend';
import SeaBackend from './sea/SeaBackend';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth';
import {
Expand Down Expand Up @@ -47,19 +48,6 @@ function prependSlash(str: string): string {
return str;
}

function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
if (!catalogName && !schemaName) {
return {};
}

return {
initialNamespace: {
catalogName,
schemaName,
},
};
}

export type ThriftLibrary = Pick<typeof thrift, 'createClient'>;

/**
Expand Down Expand Up @@ -111,6 +99,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

private readonly sessions = new CloseableCollection<DBSQLSession>();

private backend?: IBackend;

// Telemetry components — `telemetryClient` is the shared per-host owner
// (process-wide via TelemetryClientProvider). The exporter, aggregator,
// circuit-breaker registry and feature-flag cache live on it. Each
Expand Down Expand Up @@ -633,34 +623,35 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

this.connectionProvider = this.createConnectionProvider(options);

const thriftConnection = await this.connectionProvider.getThriftConnection();
// M0: `useSEA` is consumed via a non-exported internal-options cast so it
// doesn't ship in the public `.d.ts`. Mirrors Python's `kwargs.get("use_sea")`
// pattern (see databricks-sql-python/src/databricks/sql/session.py).
const internalOptions = options as ConnectionOptions & InternalConnectionOptions;
const backend = internalOptions.useSEA
? new SeaBackend()
: new ThriftBackend({
context: this,
onConnectionEvent: (event, payload) => this.forwardConnectionEvent(event, payload),
});

thriftConnection.on('error', (error: Error) => {
// Error.stack already contains error type and message, so log stack if available,
// otherwise fall back to just error type + message
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
// Publish `this.backend` only after a successful `connect()`. Otherwise a
// failed connect would leave a half-initialized backend in place, and the
// next `openSession()` would slip past the `!this.backend` guard and
// surface a misleading "backend not implemented" / partial-state error
// instead of the accurate "DBSQLClient: not connected".
try {
await backend.connect(options);
} catch (err) {
// `IBackend.close()` is documented as safe on a partially-initialized
// backend; best-effort cleanup so we don't leak sockets / state.
try {
this.emit('error', error);
} catch (e) {
// EventEmitter will throw unhandled error when emitting 'error' event.
// Since we already logged it few lines above, just suppress this behaviour
await backend.close();
} catch (closeErr) {
// Swallow; the original error is what the caller needs to see.
}
});

thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => {
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
this.emit('reconnecting', params);
});

thriftConnection.on('close', () => {
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
});

thriftConnection.on('timeout', () => {
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
});
throw err;
}
this.backend = backend;

// Initialize telemetry if enabled. The env var DATABRICKS_TELEMETRY_DISABLED
// is a hard kill switch for ops/IT teams who can't redeploy app code.
Expand Down Expand Up @@ -691,6 +682,41 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
return this;
}

private forwardConnectionEvent(event: 'error' | 'reconnecting' | 'close' | 'timeout', payload?: unknown): void {
switch (event) {
case 'error': {
// `payload` is typed `unknown` because the cross-backend
// `IBackend.onConnectionEvent` doesn't constrain the error shape.
// Normalize to `Error` so the stack/name/message access below is safe
// for any backend that emits a non-Error value (e.g. a bare string).
const error = payload instanceof Error ? payload : new Error(String(payload));
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
try {
this.emit('error', error);
} catch (e) {
// EventEmitter throws when 'error' has no listeners; we've already logged it.
}
return;
}
case 'reconnecting':
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(payload)}`);
this.emit('reconnecting', payload);
return;
case 'close':
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
return;
case 'timeout':
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
// Explicit return mirrors the other cases and protects against
// fall-through if a new event is added below.
// eslint-disable-next-line no-useless-return
return;
// no default
}
}

/**
* Starts new session
* @public
Expand All @@ -701,6 +727,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
* const session = await client.openSession();
*/
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
if (!this.backend) {
throw new HiveDriverError('DBSQLClient: not connected');
}

// Track connection open latency
const startTime = Date.now();

Expand All @@ -711,30 +741,11 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
if (this.config.enableMetricViewMetadata) {
configuration['spark.sql.thriftserver.metadata.metricview.enabled'] = 'true';
}

// Serialize queryTags dict and set in configuration; takes precedence over configuration.QUERY_TAGS
if (request.queryTags !== undefined) {
const serialized = serializeQueryTags(request.queryTags);
if (serialized) {
configuration.QUERY_TAGS = serialized;
} else {
delete configuration.QUERY_TAGS;
}
}

const response = await this.driver.openSession({
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8),
...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema),
const sessionBackend = await this.backend.openSession({
...request,
configuration,
canUseMultipleCatalogs: true,
});

Status.assert(response.status);
const session = new DBSQLSession({
handle: definedOrError(response.sessionHandle),
context: this,
serverProtocolVersion: response.serverProtocolVersion,
});
const session = new DBSQLSession({ backend: sessionBackend, context: this });
this.sessions.add(session);

// Emit connection.open telemetry event. The DriverConfiguration blob
Expand Down Expand Up @@ -772,6 +783,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
*/
public async close(): Promise<void> {
await this.sessions.closeAll();
await this.backend?.close();

this.backend = undefined;

// Cleanup telemetry. Releasing our refcount on the shared TelemetryClient
// is awaited because the underlying close() drains the final HTTP POST —
Expand Down
Loading
Loading