Skip to content

Commit 71d45a2

Browse files
[PECO-728] Add OAuth support (#147)
* [PECO-728] Add OAuth support Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Cleanup DBSQLClient code; remove redundant and no longer needed NoSaslAuthentication Signed-off-by: Levko Kravets <levko.ne@gmail.com> * DBSQLClient: options for auth types Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Fix: move comment to appropriate place Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Improve tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Use proper client ID; improve callback handling Signed-off-by: Levko Kravets <levko.ne@gmail.com> --------- Signed-off-by: Levko Kravets <levko.ne@gmail.com>
1 parent 0a2bdb4 commit 71d45a2

File tree

15 files changed

+1389
-32
lines changed

15 files changed

+1389
-32
lines changed

lib/DBSQLClient.ts

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Status from './dto/Status';
1717
import HiveDriverError from './errors/HiveDriverError';
1818
import { buildUserAgentString, definedOrError } from './utils';
1919
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
20+
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
2021
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
2122
import DBSQLLogger from './DBSQLLogger';
2223

@@ -61,7 +62,21 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
6162
}
6263

6364
private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
64-
const { host, port, path, token, clientId, ...otherOptions } = options;
65+
const {
66+
host,
67+
port,
68+
path,
69+
clientId,
70+
authType,
71+
// @ts-expect-error TS2339: Property 'token' does not exist on type 'ConnectionOptions'
72+
token,
73+
// @ts-expect-error TS2339: Property 'persistence' does not exist on type 'ConnectionOptions'
74+
persistence,
75+
// @ts-expect-error TS2339: Property 'provider' does not exist on type 'ConnectionOptions'
76+
provider,
77+
...otherOptions
78+
} = options;
79+
6580
return {
6681
host,
6782
port: port || 443,
@@ -76,22 +91,41 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
7691
};
7792
}
7893

94+
private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
95+
if (authProvider) {
96+
return authProvider;
97+
}
98+
99+
switch (options.authType) {
100+
case undefined:
101+
case 'access-token':
102+
return new PlainHttpAuthentication({
103+
username: 'token',
104+
password: options.token,
105+
});
106+
case 'databricks-oauth':
107+
return new DatabricksOAuth({
108+
host: options.host,
109+
logger: this.logger,
110+
persistence: options.persistence,
111+
});
112+
case 'custom':
113+
return options.provider;
114+
// no default
115+
}
116+
}
117+
79118
/**
80119
* Connects DBSQLClient to endpoint
81120
* @public
82121
* @param options - host, path, and token are required
83-
* @param authProvider - Optional custom authentication provider
122+
* @param authProvider - [DEPRECATED - use `authType: 'custom'] Optional custom authentication provider
84123
* @returns Session object that can be used to execute statements
85124
* @example
86125
* const session = client.connect({host, path, token});
87126
*/
88127
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
89-
authProvider =
90-
authProvider ||
91-
new PlainHttpAuthentication({
92-
username: 'token',
93-
password: options.token,
94-
});
128+
authProvider = this.getAuthProvider(options, authProvider);
95129

96130
this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider);
97131

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import http, { IncomingMessage, Server, ServerResponse } from 'http';
2+
import { BaseClient, CallbackParamsType, generators } from 'openid-client';
3+
import open from 'open';
4+
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
5+
6+
export interface AuthorizationCodeOptions {
7+
client: BaseClient;
8+
ports: Array<number>;
9+
logger?: IDBSQLLogger;
10+
}
11+
12+
const scopeDelimiter = ' ';
13+
14+
async function startServer(
15+
host: string,
16+
port: number,
17+
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
18+
): Promise<Server> {
19+
const server = http.createServer(requestHandler);
20+
21+
return new Promise((resolve, reject) => {
22+
const errorListener = (error: Error) => {
23+
server.off('error', errorListener);
24+
reject(error);
25+
};
26+
27+
server.on('error', errorListener);
28+
server.listen(port, host, () => {
29+
server.off('error', errorListener);
30+
resolve(server);
31+
});
32+
});
33+
}
34+
35+
async function stopServer(server: Server): Promise<void> {
36+
if (!server.listening) {
37+
return;
38+
}
39+
40+
return new Promise((resolve, reject) => {
41+
const errorListener = (error: Error) => {
42+
server.off('error', errorListener);
43+
reject(error);
44+
};
45+
46+
server.on('error', errorListener);
47+
server.close(() => {
48+
server.off('error', errorListener);
49+
resolve();
50+
});
51+
});
52+
}
53+
54+
export interface AuthorizationCodeFetchResult {
55+
code: string;
56+
verifier: string;
57+
redirectUri: string;
58+
}
59+
60+
export default class AuthorizationCode {
61+
private readonly client: BaseClient;
62+
63+
private readonly host: string = 'localhost';
64+
65+
private readonly ports: Array<number>;
66+
67+
private readonly logger?: IDBSQLLogger;
68+
69+
constructor(options: AuthorizationCodeOptions) {
70+
this.client = options.client;
71+
this.ports = options.ports;
72+
this.logger = options.logger;
73+
}
74+
75+
private async openUrl(url: string) {
76+
return open(url);
77+
}
78+
79+
public async fetch(scopes: Array<string>): Promise<AuthorizationCodeFetchResult> {
80+
const verifierString = generators.codeVerifier(32);
81+
const challengeString = generators.codeChallenge(verifierString);
82+
const state = generators.state(16);
83+
84+
let receivedParams: CallbackParamsType | undefined;
85+
86+
const server = await this.startServer((req, res) => {
87+
const params = this.client.callbackParams(req);
88+
if (params.state === state) {
89+
receivedParams = params;
90+
res.writeHead(200);
91+
res.end(this.renderCallbackResponse());
92+
server.stop();
93+
} else {
94+
res.writeHead(404);
95+
res.end();
96+
}
97+
});
98+
99+
const redirectUri = `http://${server.host}:${server.port}/`;
100+
const authUrl = this.client.authorizationUrl({
101+
response_type: 'code',
102+
response_mode: 'query',
103+
scope: scopes.join(scopeDelimiter),
104+
code_challenge: challengeString,
105+
code_challenge_method: 'S256',
106+
state,
107+
redirect_uri: redirectUri,
108+
});
109+
110+
await this.openUrl(authUrl);
111+
await server.stopped();
112+
113+
if (!receivedParams || !receivedParams.code) {
114+
if (receivedParams?.error) {
115+
const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`;
116+
throw new Error(errorMessage);
117+
}
118+
throw new Error(`No path parameters were returned to the callback at ${redirectUri}`);
119+
}
120+
121+
return { code: receivedParams.code, verifier: verifierString, redirectUri };
122+
}
123+
124+
private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
125+
for (const port of this.ports) {
126+
const host = this.host; // eslint-disable-line prefer-destructuring
127+
try {
128+
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
129+
this.logger?.log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);
130+
131+
let resolveStopped: () => void;
132+
let rejectStopped: (reason?: any) => void;
133+
const stoppedPromise = new Promise<void>((resolve, reject) => {
134+
resolveStopped = resolve;
135+
rejectStopped = reject;
136+
});
137+
138+
return {
139+
host,
140+
port,
141+
server,
142+
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
143+
stopped: () => stoppedPromise,
144+
};
145+
} catch (error) {
146+
// if port already in use - try another one, otherwise re-throw an exception
147+
if (error instanceof Error && 'code' in error && error.code === 'EADDRINUSE') {
148+
this.logger?.log(LogLevel.debug, `Failed to start server at ${host}:${port}: ${error.code}`);
149+
} else {
150+
throw error;
151+
}
152+
}
153+
}
154+
155+
throw new Error('Failed to start server: all ports are in use');
156+
}
157+
158+
private renderCallbackResponse(): string {
159+
const applicationName = 'Databricks Sql Connector';
160+
161+
return `<html lang="en">
162+
<head>
163+
<title>Close this Tab</title>
164+
<style>
165+
body {
166+
font-family: "Barlow", Helvetica, Arial, sans-serif;
167+
padding: 20px;
168+
background-color: #f3f3f3;
169+
}
170+
</style>
171+
</head>
172+
<body>
173+
<h1>Please close this tab.</h1>
174+
<p>
175+
The ${applicationName} received a response. You may close this tab.
176+
</p>
177+
</body>
178+
</html>`;
179+
}
180+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import { Issuer, BaseClient } from 'openid-client';
2+
import HiveDriverError from '../../../errors/HiveDriverError';
3+
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
4+
import OAuthToken from './OAuthToken';
5+
import AuthorizationCode from './AuthorizationCode';
6+
7+
const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server';
8+
9+
export interface OAuthManagerOptions {
10+
host: string;
11+
callbackPorts: Array<number>;
12+
clientId: string;
13+
logger?: IDBSQLLogger;
14+
}
15+
16+
export default class OAuthManager {
17+
private readonly options: OAuthManagerOptions;
18+
19+
private readonly logger?: IDBSQLLogger;
20+
21+
private issuer?: Issuer;
22+
23+
private client?: BaseClient;
24+
25+
constructor(options: OAuthManagerOptions) {
26+
this.options = options;
27+
this.logger = options.logger;
28+
}
29+
30+
private async getClient(): Promise<BaseClient> {
31+
if (!this.issuer) {
32+
const { host } = this.options;
33+
const schema = host.startsWith('https://') ? '' : 'https://';
34+
const trailingSlash = host.endsWith('/') ? '' : '/';
35+
this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`);
36+
}
37+
38+
if (!this.client) {
39+
this.client = new this.issuer.Client({
40+
client_id: this.options.clientId,
41+
token_endpoint_auth_method: 'none',
42+
});
43+
}
44+
45+
return this.client;
46+
}
47+
48+
public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
49+
try {
50+
if (!token.hasExpired) {
51+
// The access token is fine. Just return it.
52+
return token;
53+
}
54+
} catch (error) {
55+
this.logger?.log(LogLevel.error, `${error}`);
56+
throw error;
57+
}
58+
59+
if (!token.refreshToken) {
60+
const message = `OAuth access token expired on ${token.expirationTime}.`;
61+
this.logger?.log(LogLevel.error, message);
62+
throw new HiveDriverError(message);
63+
}
64+
65+
// Try to refresh using the refresh token
66+
this.logger?.log(
67+
LogLevel.debug,
68+
`Attempting to refresh OAuth access token that expired on ${token.expirationTime}`,
69+
);
70+
71+
const client = await this.getClient();
72+
const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken);
73+
if (!accessToken || !refreshToken) {
74+
throw new Error('Failed to refresh token: invalid response');
75+
}
76+
return new OAuthToken(accessToken, refreshToken);
77+
}
78+
79+
public async getToken(scopes: Array<string>): Promise<OAuthToken> {
80+
const client = await this.getClient();
81+
const authCode = new AuthorizationCode({
82+
client,
83+
ports: this.options.callbackPorts,
84+
logger: this.logger,
85+
});
86+
87+
const { code, verifier, redirectUri } = await authCode.fetch(scopes);
88+
89+
const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
90+
grant_type: 'authorization_code',
91+
code,
92+
code_verifier: verifier,
93+
redirect_uri: redirectUri,
94+
});
95+
96+
if (!accessToken) {
97+
throw new Error('Failed to fetch access token');
98+
}
99+
100+
return new OAuthToken(accessToken, refreshToken);
101+
}
102+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import OAuthToken from './OAuthToken';
2+
3+
export default interface OAuthPersistence {
4+
persist(host: string, token: OAuthToken): Promise<void>;
5+
6+
read(host: string): Promise<OAuthToken | undefined>;
7+
}

0 commit comments

Comments
 (0)