Skip to content

Commit 253ab0f

Browse files
Support OAuth flow for Databricks Azure (#154)
* [PECO-833] Support Azure OAuth flow Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Fix tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Add tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Use custom authorization endpoint for Azure OAuth flow Signed-off-by: Levko Kravets <levko.ne@gmail.com> --------- Signed-off-by: Levko Kravets <levko.ne@gmail.com>
1 parent 71d45a2 commit 253ab0f

File tree

10 files changed

+338
-205
lines changed

10 files changed

+338
-205
lines changed

.eslintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"no-bitwise": "off",
1515
"@typescript-eslint/no-throw-literal": "off",
1616
"no-restricted-syntax": "off",
17-
"no-case-declarations": "off"
17+
"no-case-declarations": "off",
18+
"max-classes-per-file": "off"
1819
}
1920
}
2021
]

lib/DBSQLClient.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
108108
host: options.host,
109109
logger: this.logger,
110110
persistence: options.persistence,
111+
azureTenantId: options.azureTenantId,
111112
});
112113
case 'custom':
113114
return options.provider;

lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@ import http, { IncomingMessage, Server, ServerResponse } from 'http';
22
import { BaseClient, CallbackParamsType, generators } from 'openid-client';
33
import open from 'open';
44
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
5+
import { OAuthScopes, scopeDelimiter } from './OAuthScope';
56

67
export interface AuthorizationCodeOptions {
78
client: BaseClient;
89
ports: Array<number>;
910
logger?: IDBSQLLogger;
1011
}
1112

12-
const scopeDelimiter = ' ';
13-
1413
async function startServer(
1514
host: string,
1615
port: number,
@@ -76,7 +75,7 @@ export default class AuthorizationCode {
7675
return open(url);
7776
}
7877

79-
public async fetch(scopes: Array<string>): Promise<AuthorizationCodeFetchResult> {
78+
public async fetch(scopes: OAuthScopes): Promise<AuthorizationCodeFetchResult> {
8079
const verifierString = generators.codeVerifier(32);
8180
const challengeString = generators.codeChallenge(verifierString);
8281
const state = generators.state(16);
@@ -96,7 +95,7 @@ export default class AuthorizationCode {
9695
}
9796
});
9897

99-
const redirectUri = `http://${server.host}:${server.port}/`;
98+
const redirectUri = `http://${server.host}:${server.port}`;
10099
const authUrl = this.client.authorizationUrl({
101100
response_type: 'code',
102101
response_mode: 'query',

lib/connection/auth/DatabricksOAuth/OAuthManager.ts

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,62 @@ import HiveDriverError from '../../../errors/HiveDriverError';
33
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
44
import OAuthToken from './OAuthToken';
55
import AuthorizationCode from './AuthorizationCode';
6-
7-
const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server';
6+
import { OAuthScope, OAuthScopes } from './OAuthScope';
87

98
export interface OAuthManagerOptions {
109
host: string;
11-
callbackPorts: Array<number>;
12-
clientId: string;
10+
callbackPorts?: Array<number>;
11+
clientId?: string;
12+
azureTenantId?: string;
1313
logger?: IDBSQLLogger;
1414
}
1515

16-
export default class OAuthManager {
17-
private readonly options: OAuthManagerOptions;
16+
function getDatabricksOIDCUrl(host: string): string {
17+
const schema = host.startsWith('https://') ? '' : 'https://';
18+
const trailingSlash = host.endsWith('/') ? '' : '/';
19+
return `${schema}${host}${trailingSlash}oidc`;
20+
}
21+
22+
export default abstract class OAuthManager {
23+
protected readonly options: OAuthManagerOptions;
1824

19-
private readonly logger?: IDBSQLLogger;
25+
protected readonly logger?: IDBSQLLogger;
2026

21-
private issuer?: Issuer;
27+
protected issuer?: Issuer;
2228

23-
private client?: BaseClient;
29+
protected client?: BaseClient;
2430

2531
constructor(options: OAuthManagerOptions) {
2632
this.options = options;
2733
this.logger = options.logger;
2834
}
2935

30-
private async getClient(): Promise<BaseClient> {
36+
protected abstract getOIDCConfigUrl(): string;
37+
38+
protected abstract getAuthorizationUrl(): string;
39+
40+
protected abstract getClientId(): string;
41+
42+
protected abstract getCallbackPorts(): Array<number>;
43+
44+
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
45+
return requestedScopes;
46+
}
47+
48+
protected async getClient(): Promise<BaseClient> {
3149
if (!this.issuer) {
32-
const { host } = this.options;
33-
const schema = host.startsWith('https://') ? '' : 'https://';
34-
const trailingSlash = host.endsWith('/') ? '' : '/';
35-
this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`);
50+
const issuer = await Issuer.discover(this.getOIDCConfigUrl());
51+
// Overwrite `authorization_endpoint` in default config (specifically needed for Azure flow
52+
// where this URL has to be different)
53+
this.issuer = new Issuer({
54+
...issuer.metadata,
55+
authorization_endpoint: this.getAuthorizationUrl(),
56+
});
3657
}
3758

3859
if (!this.client) {
3960
this.client = new this.issuer.Client({
40-
client_id: this.options.clientId,
61+
client_id: this.getClientId(),
4162
token_endpoint_auth_method: 'none',
4263
});
4364
}
@@ -76,15 +97,17 @@ export default class OAuthManager {
7697
return new OAuthToken(accessToken, refreshToken);
7798
}
7899

79-
public async getToken(scopes: Array<string>): Promise<OAuthToken> {
100+
public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
80101
const client = await this.getClient();
81102
const authCode = new AuthorizationCode({
82103
client,
83-
ports: this.options.callbackPorts,
104+
ports: this.getCallbackPorts(),
84105
logger: this.logger,
85106
});
86107

87-
const { code, verifier, redirectUri } = await authCode.fetch(scopes);
108+
const mappedScopes = this.getScopes(scopes);
109+
110+
const { code, verifier, redirectUri } = await authCode.fetch(mappedScopes);
88111

89112
const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
90113
grant_type: 'authorization_code',
@@ -99,4 +122,84 @@ export default class OAuthManager {
99122

100123
return new OAuthToken(accessToken, refreshToken);
101124
}
125+
126+
public static getManager(options: OAuthManagerOptions): OAuthManager {
127+
// normalize
128+
const host = options.host.toLowerCase().replace('https://', '').split('/')[0];
129+
130+
// eslint-disable-next-line @typescript-eslint/no-use-before-define
131+
const managers = [AWSOAuthManager, AzureOAuthManager];
132+
133+
for (const OAuthManagerClass of managers) {
134+
for (const domain of OAuthManagerClass.domains) {
135+
if (host.endsWith(domain)) {
136+
return new OAuthManagerClass(options);
137+
}
138+
}
139+
}
140+
141+
throw new Error(`OAuth is not supported for ${options.host}`);
142+
}
143+
}
144+
145+
export class AWSOAuthManager extends OAuthManager {
146+
public static domains = ['.cloud.databricks.com', '.dev.databricks.com'];
147+
148+
public static defaultClientId = 'databricks-sql-connector';
149+
150+
public static defaultCallbackPorts = [8030];
151+
152+
protected getOIDCConfigUrl(): string {
153+
return `${getDatabricksOIDCUrl(this.options.host)}/.well-known/oauth-authorization-server`;
154+
}
155+
156+
protected getAuthorizationUrl(): string {
157+
return `${getDatabricksOIDCUrl(this.options.host)}/oauth2/v2.0/authorize`;
158+
}
159+
160+
protected getClientId(): string {
161+
return this.options.clientId ?? AWSOAuthManager.defaultClientId;
162+
}
163+
164+
protected getCallbackPorts(): Array<number> {
165+
return this.options.callbackPorts ?? AWSOAuthManager.defaultCallbackPorts;
166+
}
167+
}
168+
169+
export class AzureOAuthManager extends OAuthManager {
170+
public static domains = ['.azuredatabricks.net', '.databricks.azure.cn', '.databricks.azure.us'];
171+
172+
public static defaultClientId = '96eecda7-19ea-49cc-abb5-240097d554f5';
173+
174+
public static defaultCallbackPorts = [8030];
175+
176+
public static datatricksAzureApp = '2ff814a6-3304-4ab8-85cb-cd0e6f879c1d';
177+
178+
protected getOIDCConfigUrl(): string {
179+
return 'https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration';
180+
}
181+
182+
protected getAuthorizationUrl(): string {
183+
return `${getDatabricksOIDCUrl(this.options.host)}/oauth2/v2.0/authorize`;
184+
}
185+
186+
protected getClientId(): string {
187+
return this.options.clientId ?? AzureOAuthManager.defaultClientId;
188+
}
189+
190+
protected getCallbackPorts(): Array<number> {
191+
return this.options.callbackPorts ?? AzureOAuthManager.defaultCallbackPorts;
192+
}
193+
194+
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
195+
// There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
196+
const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp;
197+
const azureScopes = [`${tenantId}/user_impersonation`];
198+
199+
if (requestedScopes.includes(OAuthScope.offlineAccess)) {
200+
azureScopes.push(OAuthScope.offlineAccess);
201+
}
202+
203+
return azureScopes;
204+
}
102205
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
export enum OAuthScope {
2+
offlineAccess = 'offline_access',
3+
SQL = 'sql',
4+
}
5+
6+
export type OAuthScopes = Array<string>;
7+
8+
export const defaultOAuthScopes: OAuthScopes = [OAuthScope.SQL, OAuthScope.offlineAccess];
9+
10+
export const scopeDelimiter = ' ';

lib/connection/auth/DatabricksOAuth/index.ts

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,42 @@ import IAuthentication from '../../contracts/IAuthentication';
33
import HttpTransport from '../../transports/HttpTransport';
44
import IDBSQLLogger from '../../../contracts/IDBSQLLogger';
55
import OAuthPersistence from './OAuthPersistence';
6-
import OAuthManager from './OAuthManager';
6+
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
7+
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';
78

8-
interface DatabricksOAuthOptions {
9-
host: string;
10-
redirectPorts?: Array<number>;
11-
clientId?: string;
12-
scopes?: Array<string>;
9+
interface DatabricksOAuthOptions extends OAuthManagerOptions {
10+
scopes?: OAuthScopes;
1311
logger?: IDBSQLLogger;
1412
persistence?: OAuthPersistence;
1513
headers?: HttpHeaders;
1614
}
1715

18-
const defaultOAuthOptions = {
19-
clientId: 'databricks-sql-connector',
20-
redirectPorts: [8030],
21-
scopes: ['sql', 'offline_access'],
22-
} satisfies Partial<DatabricksOAuthOptions>;
23-
2416
export default class DatabricksOAuth implements IAuthentication {
25-
private readonly host: string;
26-
27-
private readonly redirectPorts: Array<number>;
28-
29-
private readonly clientId: string;
30-
31-
private readonly scopes: Array<string>;
17+
private readonly options: DatabricksOAuthOptions;
3218

3319
private readonly logger?: IDBSQLLogger;
3420

35-
private readonly persistence?: OAuthPersistence;
36-
37-
private readonly headers?: HttpHeaders;
38-
3921
private readonly manager: OAuthManager;
4022

4123
constructor(options: DatabricksOAuthOptions) {
42-
this.host = options.host;
43-
this.redirectPorts = options.redirectPorts || defaultOAuthOptions.redirectPorts;
44-
this.clientId = options.clientId || defaultOAuthOptions.clientId;
45-
this.scopes = options.scopes || defaultOAuthOptions.scopes;
24+
this.options = options;
4625
this.logger = options.logger;
47-
this.persistence = options.persistence;
48-
this.headers = options.headers;
49-
50-
this.manager = new OAuthManager({
51-
host: this.host,
52-
callbackPorts: this.redirectPorts,
53-
clientId: this.clientId,
54-
logger: this.logger,
55-
});
26+
this.manager = OAuthManager.getManager(this.options);
5627
}
5728

5829
public async authenticate(transport: HttpTransport): Promise<void> {
59-
let token = await this.persistence?.read(this.host);
30+
const { host, scopes, headers, persistence } = this.options;
31+
32+
let token = await persistence?.read(host);
6033
if (!token) {
61-
token = await this.manager.getToken(this.scopes);
34+
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
6235
}
6336

6437
token = await this.manager.refreshAccessToken(token);
65-
await this.persistence?.persist(this.host, token);
38+
await persistence?.persist(host, token);
6639

6740
transport.updateHeaders({
68-
...this.headers,
41+
...headers,
6942
Authorization: `Bearer ${token.accessToken}`,
7043
});
7144
}

lib/contracts/IDBSQLClient.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ type AuthOptions =
1515
| {
1616
authType: 'databricks-oauth';
1717
persistence?: OAuthPersistence;
18+
azureTenantId?: string;
1819
}
1920
| {
2021
authType: 'custom';

tests/unit/DBSQLClient.test.js

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const DBSQLSession = require('../../dist/DBSQLSession').default;
55

66
const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAuthentication').default;
77
const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default;
8+
const { AWSOAuthManager, AzureOAuthManager } = require('../../dist/connection/auth/DatabricksOAuth/OAuthManager');
89
const HttpConnection = require('../../dist/connection/connections/HttpConnection').default;
910

1011
const ConnectionProviderMock = (connection) => ({
@@ -255,14 +256,42 @@ describe('DBSQLClient.getAuthProvider', () => {
255256
expect(provider.password).to.be.equal(testAccessToken);
256257
});
257258

258-
it('should use Databricks OAuth method', () => {
259+
it('should use Databricks OAuth method (AWS)', () => {
259260
const client = new DBSQLClient();
260261

261262
const provider = client.getAuthProvider({
262263
authType: 'databricks-oauth',
264+
// host is used when creating OAuth manager, so make it look like a real AWS instance
265+
host: 'example.dev.databricks.com',
263266
});
264267

265268
expect(provider).to.be.instanceOf(DatabricksOAuth);
269+
expect(provider.manager).to.be.instanceOf(AWSOAuthManager);
270+
});
271+
272+
it('should use Databricks OAuth method (Azure)', () => {
273+
const client = new DBSQLClient();
274+
275+
const provider = client.getAuthProvider({
276+
authType: 'databricks-oauth',
277+
// host is used when creating OAuth manager, so make it look like a real Azure instance
278+
host: 'example.databricks.azure.us',
279+
});
280+
281+
expect(provider).to.be.instanceOf(DatabricksOAuth);
282+
expect(provider.manager).to.be.instanceOf(AzureOAuthManager);
283+
});
284+
285+
it('should throw error when OAuth not supported for host', () => {
286+
const client = new DBSQLClient();
287+
288+
expect(() => {
289+
client.getAuthProvider({
290+
authType: 'databricks-oauth',
291+
// use host which is not supported for sure
292+
host: 'example.com',
293+
});
294+
}).to.throw();
266295
});
267296

268297
it('should use custom auth method', () => {

0 commit comments

Comments
 (0)