Skip to content

Commit 3572889

Browse files
[PECO-909] Automatically renew oauth token when refresh token is available (#156)
* HiveDriver: obtain a thrift client before each request (allows to re-create client if needed) Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Move auth logic to DBSQLClient Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Remove redundant HttpTransport class Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Cache OAuth tokens in memory by default to avoid re-running OAuth flow on every request Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Re-create thrift client when auth credentials (e.g. oauth token) change Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Update tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> --------- Signed-off-by: Levko Kravets <levko.ne@gmail.com>
1 parent 7801813 commit 3572889

File tree

17 files changed

+554
-564
lines changed

17 files changed

+554
-564
lines changed

lib/DBSQLClient.ts

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
import thrift from 'thrift';
1+
import thrift, { HttpHeaders } from 'thrift';
22

33
import { EventEmitter } from 'events';
44
import TCLIService from '../thrift/TCLIService';
55
import { TProtocolVersion } from '../thrift/TCLIService_types';
6-
import IDBSQLClient, { ConnectionOptions, OpenSessionRequest, ClientOptions } from './contracts/IDBSQLClient';
6+
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
77
import HiveDriver from './hive/HiveDriver';
88
import { Int64 } from './hive/Types';
99
import DBSQLSession from './DBSQLSession';
1010
import IDBSQLSession from './contracts/IDBSQLSession';
11-
import IThriftConnection from './connection/contracts/IThriftConnection';
12-
import IConnectionProvider from './connection/contracts/IConnectionProvider';
1311
import IAuthentication from './connection/contracts/IAuthentication';
1412
import HttpConnection from './connection/connections/HttpConnection';
1513
import IConnectionOptions from './connection/contracts/IConnectionOptions';
1614
import Status from './dto/Status';
1715
import HiveDriverError from './errors/HiveDriverError';
18-
import { buildUserAgentString, definedOrError } from './utils';
16+
import { areHeadersEqual, buildUserAgentString, definedOrError } from './utils';
1917
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
2018
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
2119
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
@@ -42,26 +40,25 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
4240
}
4341

4442
export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
45-
private client: TCLIService.Client | null;
43+
private client: TCLIService.Client | null = null;
4644

47-
private connection: IThriftConnection | null;
45+
private authProvider: IAuthentication | null = null;
4846

49-
private connectionProvider: IConnectionProvider;
47+
private connectionOptions: ConnectionOptions | null = null;
48+
49+
private additionalHeaders: HttpHeaders = {};
5050

5151
private readonly logger: IDBSQLLogger;
5252

5353
private readonly thrift = thrift;
5454

5555
constructor(options?: ClientOptions) {
5656
super();
57-
this.connectionProvider = new HttpConnection();
5857
this.logger = options?.logger || new DBSQLLogger();
59-
this.client = null;
60-
this.connection = null;
6158
this.logger.log(LogLevel.info, 'Created DBSQLClient');
6259
}
6360

64-
private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
61+
private getConnectionOptions(options: ConnectionOptions, headers: HttpHeaders): IConnectionOptions {
6562
const {
6663
host,
6764
port,
@@ -85,6 +82,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
8582
https: true,
8683
...otherOptions,
8784
headers: {
85+
...headers,
8886
'User-Agent': buildUserAgentString(options.clientId),
8987
},
9088
},
@@ -126,39 +124,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
126124
* const session = client.connect({host, path, token});
127125
*/
128126
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
129-
authProvider = this.getAuthProvider(options, authProvider);
130-
131-
this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider);
132-
133-
this.client = this.thrift.createClient(TCLIService, this.connection.getConnection());
134-
135-
this.connection.getConnection().on('error', (error: Error) => {
136-
// Error.stack already contains error type and message, so log stack if available,
137-
// otherwise fall back to just error type + message
138-
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
139-
try {
140-
this.emit('error', error);
141-
} catch (e) {
142-
// EventEmitter will throw unhandled error when emitting 'error' event.
143-
// Since we already logged it few lines above, just suppress this behaviour
144-
}
145-
});
146-
147-
this.connection.getConnection().on('reconnecting', (params: { delay: number; attempt: number }) => {
148-
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
149-
this.emit('reconnecting', params);
150-
});
151-
152-
this.connection.getConnection().on('close', () => {
153-
this.logger.log(LogLevel.debug, 'Closing connection.');
154-
this.emit('close');
155-
});
156-
157-
this.connection.getConnection().on('timeout', () => {
158-
this.logger.log(LogLevel.debug, 'Connection timed out.');
159-
this.emit('timeout');
160-
});
161-
127+
this.authProvider = this.getAuthProvider(options, authProvider);
128+
this.connectionOptions = options;
162129
return this;
163130
}
164131

@@ -172,11 +139,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
172139
* const session = await client.openSession();
173140
*/
174141
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
175-
if (!this.connection?.isConnected()) {
176-
throw new HiveDriverError('DBSQLClient: connection is lost');
177-
}
178-
179-
const driver = new HiveDriver(this.getClient());
142+
const driver = new HiveDriver(() => this.getClient());
180143

181144
const response = await driver.openSession({
182145
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6),
@@ -187,23 +150,64 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
187150
return new DBSQLSession(driver, definedOrError(response.sessionHandle), this.logger);
188151
}
189152

190-
public getClient() {
191-
if (!this.client) {
192-
throw new HiveDriverError('DBSQLClient: client is not initialized');
153+
private async getClient() {
154+
if (!this.connectionOptions || !this.authProvider) {
155+
throw new HiveDriverError('DBSQLClient: not connected');
156+
}
157+
158+
const authHeaders = await this.authProvider.authenticate();
159+
// When auth headers change - recreate client. Thrift library does not provide API for updating
160+
// changed options, therefore we have to recreate both connection and client to apply new headers
161+
if (!this.client || !areHeadersEqual(this.additionalHeaders, authHeaders)) {
162+
this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client');
163+
this.additionalHeaders = authHeaders;
164+
const connectionOptions = this.getConnectionOptions(this.connectionOptions, this.additionalHeaders);
165+
166+
const connection = await this.createConnection(connectionOptions);
167+
this.client = this.thrift.createClient(TCLIService, connection.getConnection());
193168
}
194169

195170
return this.client;
196171
}
197172

198-
public async close(): Promise<void> {
199-
if (this.connection) {
200-
const thriftConnection = this.connection.getConnection();
173+
private async createConnection(options: IConnectionOptions) {
174+
const connectionProvider = new HttpConnection();
175+
const connection = await connectionProvider.connect(options);
176+
const thriftConnection = connection.getConnection();
201177

202-
if (typeof thriftConnection.end === 'function') {
203-
this.connection.getConnection().end();
178+
thriftConnection.on('error', (error: Error) => {
179+
// Error.stack already contains error type and message, so log stack if available,
180+
// otherwise fall back to just error type + message
181+
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
182+
try {
183+
this.emit('error', error);
184+
} catch (e) {
185+
// EventEmitter will throw unhandled error when emitting 'error' event.
186+
// Since we already logged it few lines above, just suppress this behaviour
204187
}
188+
});
205189

206-
this.connection = null;
207-
}
190+
thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => {
191+
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
192+
this.emit('reconnecting', params);
193+
});
194+
195+
thriftConnection.on('close', () => {
196+
this.logger.log(LogLevel.debug, 'Closing connection.');
197+
this.emit('close');
198+
});
199+
200+
thriftConnection.on('timeout', () => {
201+
this.logger.log(LogLevel.debug, 'Connection timed out.');
202+
this.emit('timeout');
203+
});
204+
205+
return connection;
206+
}
207+
208+
public async close(): Promise<void> {
209+
this.client = null;
210+
this.authProvider = null;
211+
this.connectionOptions = null;
208212
}
209213
}

lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,15 @@ export default interface OAuthPersistence {
55

66
read(host: string): Promise<OAuthToken | undefined>;
77
}
8+
9+
export class OAuthPersistenceCache implements OAuthPersistence {
10+
private tokens: Record<string, OAuthToken | undefined> = {};
11+
12+
async persist(host: string, token: OAuthToken) {
13+
this.tokens[host] = token;
14+
}
15+
16+
async read(host: string) {
17+
return this.tokens[host];
18+
}
19+
}

lib/connection/auth/DatabricksOAuth/index.ts

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import { HttpHeaders } from 'thrift';
22
import IAuthentication from '../../contracts/IAuthentication';
3-
import HttpTransport from '../../transports/HttpTransport';
43
import IDBSQLLogger from '../../../contracts/IDBSQLLogger';
5-
import OAuthPersistence from './OAuthPersistence';
4+
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
65
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
76
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';
87

@@ -20,26 +19,30 @@ export default class DatabricksOAuth implements IAuthentication {
2019

2120
private readonly manager: OAuthManager;
2221

22+
private readonly defaultPersistence = new OAuthPersistenceCache();
23+
2324
constructor(options: DatabricksOAuthOptions) {
2425
this.options = options;
2526
this.logger = options.logger;
2627
this.manager = OAuthManager.getManager(this.options);
2728
}
2829

29-
public async authenticate(transport: HttpTransport): Promise<void> {
30-
const { host, scopes, headers, persistence } = this.options;
30+
public async authenticate(): Promise<HttpHeaders> {
31+
const { host, scopes, headers } = this.options;
32+
33+
const persistence = this.options.persistence ?? this.defaultPersistence;
3134

32-
let token = await persistence?.read(host);
35+
let token = await persistence.read(host);
3336
if (!token) {
3437
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
3538
}
3639

3740
token = await this.manager.refreshAccessToken(token);
38-
await persistence?.persist(host, token);
41+
await persistence.persist(host, token);
3942

40-
transport.updateHeaders({
43+
return {
4144
...headers,
4245
Authorization: `Bearer ${token.accessToken}`,
43-
});
46+
};
4447
}
4548
}

lib/connection/auth/PlainHttpAuthentication.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { HttpHeaders } from 'thrift';
22
import IAuthentication from '../contracts/IAuthentication';
3-
import HttpTransport from '../transports/HttpTransport';
43

54
interface PlainHttpAuthenticationOptions {
65
username?: string;
@@ -21,10 +20,10 @@ export default class PlainHttpAuthentication implements IAuthentication {
2120
this.headers = options?.headers || {};
2221
}
2322

24-
public async authenticate(transport: HttpTransport): Promise<void> {
25-
transport.updateHeaders({
23+
public async authenticate(): Promise<HttpHeaders> {
24+
return {
2625
...this.headers,
2726
Authorization: `Bearer ${this.password}`,
28-
});
27+
};
2928
}
3029
}

lib/connection/connections/HttpConnection.ts

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import http, { IncomingMessage } from 'http';
55
import IThriftConnection from '../contracts/IThriftConnection';
66
import IConnectionProvider from '../contracts/IConnectionProvider';
77
import IConnectionOptions, { Options } from '../contracts/IConnectionOptions';
8-
import IAuthentication from '../contracts/IAuthentication';
9-
import HttpTransport from '../transports/HttpTransport';
108
import globalConfig from '../../globalConfig';
119

1210
type NodeOptions = {
@@ -21,7 +19,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
2119

2220
private connection: any;
2321

24-
connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection> {
22+
async connect(options: IConnectionOptions): Promise<IThriftConnection> {
2523
const agentOptions: http.AgentOptions = {
2624
keepAlive: true,
2725
maxSockets: 5,
@@ -33,7 +31,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
3331
? new https.Agent({ ...agentOptions, minVersion: 'TLSv1.2' })
3432
: new http.Agent(agentOptions);
3533

36-
const httpTransport = new HttpTransport({
34+
const thriftOptions = {
3735
transport: thrift.TBufferedTransport,
3836
protocol: thrift.TBinaryProtocol,
3937
...options.options,
@@ -43,15 +41,12 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
4341
...(options.options?.nodeOptions || {}),
4442
timeout: options.options?.socketTimeout ?? globalConfig.socketTimeout,
4543
},
46-
});
47-
48-
return authProvider.authenticate(httpTransport).then(() => {
49-
this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions());
44+
};
5045

51-
this.addCookieHandler();
46+
this.connection = this.thrift.createHttpConnection(options.host, options.port, thriftOptions);
47+
this.addCookieHandler();
5248

53-
return this;
54-
});
49+
return this;
5550
}
5651

5752
getConnection() {
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import HttpTransport from '../transports/HttpTransport';
1+
import { HttpHeaders } from 'thrift';
22

33
export default interface IAuthentication {
4-
authenticate(transport: HttpTransport): Promise<void>;
4+
authenticate(): Promise<HttpHeaders>;
55
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import IConnectionOptions from './IConnectionOptions';
2-
import IAuthentication from './IAuthentication';
32
import IThriftConnection from './IThriftConnection';
43

54
export default interface IConnectionProvider {
6-
connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection>;
5+
connect(options: IConnectionOptions): Promise<IThriftConnection>;
76
}

0 commit comments

Comments
 (0)