Skip to content

Commit 957791b

Browse files
Use correct scopes for OAuth U2M and M2M flows (#228)
* Refactor OAuthManager: explicitly define which flow to use Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Refactoring: when refreshing OAuth token, use same scopes as when getting the one Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Use correct scopes for U2M and M2M flows Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Tests Signed-off-by: Levko Kravets <levko.ne@gmail.com> --------- Signed-off-by: Levko Kravets <levko.ne@gmail.com>
1 parent 3953e5d commit 957791b

File tree

8 files changed

+164
-54
lines changed

8 files changed

+164
-54
lines changed

lib/DBSQLClient.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Status from './dto/Status';
1717
import HiveDriverError from './errors/HiveDriverError';
1818
import { buildUserAgentString, definedOrError } from './utils';
1919
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
20-
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
20+
import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth';
2121
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
2222
import DBSQLLogger from './DBSQLLogger';
2323
import CloseableCollection from './utils/CloseableCollection';
@@ -125,6 +125,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
125125
});
126126
case 'databricks-oauth':
127127
return new DatabricksOAuth({
128+
flow: options.oauthClientSecret === undefined ? OAuthFlow.U2M : OAuthFlow.M2M,
128129
host: options.host,
129130
persistence: options.persistence,
130131
azureTenantId: options.azureTenantId,

lib/connection/auth/DatabricksOAuth/OAuthManager.ts

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ import HiveDriverError from '../../../errors/HiveDriverError';
44
import { LogLevel } from '../../../contracts/IDBSQLLogger';
55
import OAuthToken from './OAuthToken';
66
import AuthorizationCode from './AuthorizationCode';
7-
import { OAuthScope, OAuthScopes } from './OAuthScope';
7+
import { OAuthScope, OAuthScopes, scopeDelimiter } from './OAuthScope';
88
import IClientContext from '../../../contracts/IClientContext';
99

10+
export enum OAuthFlow {
11+
U2M = 'U2M',
12+
M2M = 'M2M',
13+
}
14+
1015
export interface OAuthManagerOptions {
16+
flow: OAuthFlow;
1117
host: string;
1218
callbackPorts?: Array<number>;
1319
clientId?: string;
@@ -47,9 +53,7 @@ export default abstract class OAuthManager {
4753

4854
protected abstract getCallbackPorts(): Array<number>;
4955

50-
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
51-
return requestedScopes;
52-
}
56+
protected abstract getScopes(requestedScopes: OAuthScopes): OAuthScopes;
5357

5458
protected async getClient(): Promise<BaseClient> {
5559
// Obtain http agent each time when we need an OAuth client
@@ -113,17 +117,11 @@ export default abstract class OAuthManager {
113117
if (!accessToken || !refreshToken) {
114118
throw new Error('Failed to refresh token: invalid response');
115119
}
116-
return new OAuthToken(accessToken, refreshToken);
120+
return new OAuthToken(accessToken, refreshToken, token.scopes);
117121
}
118122

119-
private async refreshAccessTokenM2M(): Promise<OAuthToken> {
120-
const { access_token: accessToken, refresh_token: refreshToken } = await this.getTokenM2M();
121-
122-
if (!accessToken) {
123-
throw new Error('Failed to fetch access token');
124-
}
125-
126-
return new OAuthToken(accessToken, refreshToken);
123+
private async refreshAccessTokenM2M(token: OAuthToken): Promise<OAuthToken> {
124+
return this.getTokenM2M(token.scopes ?? []);
127125
}
128126

129127
public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
@@ -137,10 +135,16 @@ export default abstract class OAuthManager {
137135
throw error;
138136
}
139137

140-
return this.options.clientSecret === undefined ? this.refreshAccessTokenU2M(token) : this.refreshAccessTokenM2M();
138+
switch (this.options.flow) {
139+
case OAuthFlow.U2M:
140+
return this.refreshAccessTokenU2M(token);
141+
case OAuthFlow.M2M:
142+
return this.refreshAccessTokenM2M(token);
143+
// no default
144+
}
141145
}
142146

143-
private async getTokenU2M(scopes: OAuthScopes) {
147+
private async getTokenU2M(scopes: OAuthScopes): Promise<OAuthToken> {
144148
const client = await this.getClient();
145149

146150
const authCode = new AuthorizationCode({
@@ -153,37 +157,47 @@ export default abstract class OAuthManager {
153157

154158
const { code, verifier, redirectUri } = await authCode.fetch(mappedScopes);
155159

156-
return client.grant({
160+
const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
157161
grant_type: 'authorization_code',
158162
code,
159163
code_verifier: verifier,
160164
redirect_uri: redirectUri,
161165
});
166+
167+
if (!accessToken) {
168+
throw new Error('Failed to fetch access token');
169+
}
170+
return new OAuthToken(accessToken, refreshToken, mappedScopes);
162171
}
163172

164-
private async getTokenM2M() {
173+
private async getTokenM2M(scopes: OAuthScopes): Promise<OAuthToken> {
165174
const client = await this.getClient();
166175

176+
const mappedScopes = this.getScopes(scopes);
177+
167178
// M2M flow doesn't really support token refreshing, and refresh should not be available
168179
// in response. Each time access token expires, client can just acquire a new one using
169180
// client secret. Here we explicitly return access token only as a sign that we're not going
170181
// to use refresh token for M2M flow anywhere later
171182
const { access_token: accessToken } = await client.grant({
172183
grant_type: 'client_credentials',
173-
scope: 'all-apis', // this is the only allowed scope for M2M flow
184+
scope: mappedScopes.join(scopeDelimiter),
174185
});
175-
return { access_token: accessToken, refresh_token: undefined };
176-
}
177-
178-
public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
179-
const { access_token: accessToken, refresh_token: refreshToken } =
180-
this.options.clientSecret === undefined ? await this.getTokenU2M(scopes) : await this.getTokenM2M();
181186

182187
if (!accessToken) {
183188
throw new Error('Failed to fetch access token');
184189
}
190+
return new OAuthToken(accessToken, undefined, mappedScopes);
191+
}
185192

186-
return new OAuthToken(accessToken, refreshToken);
193+
public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
194+
switch (this.options.flow) {
195+
case OAuthFlow.U2M:
196+
return this.getTokenU2M(scopes);
197+
case OAuthFlow.M2M:
198+
return this.getTokenM2M(scopes);
199+
// no default
200+
}
187201
}
188202

189203
public static getManager(options: OAuthManagerOptions): OAuthManager {
@@ -245,6 +259,14 @@ export class DatabricksOAuthManager extends OAuthManager {
245259
protected getCallbackPorts(): Array<number> {
246260
return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts;
247261
}
262+
263+
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
264+
if (this.options.flow === OAuthFlow.M2M) {
265+
// this is the only allowed scope for M2M flow
266+
return [OAuthScope.allAPIs];
267+
}
268+
return requestedScopes;
269+
}
248270
}
249271

250272
export class AzureOAuthManager extends OAuthManager {
@@ -273,7 +295,18 @@ export class AzureOAuthManager extends OAuthManager {
273295
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
274296
// There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
275297
const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp;
276-
const azureScopes = [`${tenantId}/user_impersonation`];
298+
299+
const azureScopes = [];
300+
301+
switch (this.options.flow) {
302+
case OAuthFlow.U2M:
303+
azureScopes.push(`${tenantId}/user_impersonation`);
304+
break;
305+
case OAuthFlow.M2M:
306+
azureScopes.push(`${tenantId}/.default`);
307+
break;
308+
// no default
309+
}
277310

278311
if (requestedScopes.includes(OAuthScope.offlineAccess)) {
279312
azureScopes.push(OAuthScope.offlineAccess);

lib/connection/auth/DatabricksOAuth/OAuthScope.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
export enum OAuthScope {
22
offlineAccess = 'offline_access',
33
SQL = 'sql',
4+
allAPIs = 'all-apis',
45
}
56

67
export type OAuthScopes = Array<string>;

lib/connection/auth/DatabricksOAuth/OAuthToken.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import { OAuthScopes } from './OAuthScope';
2+
13
export default class OAuthToken {
24
private readonly _accessToken: string;
35

46
private readonly _refreshToken?: string;
57

8+
private readonly _scopes?: OAuthScopes;
9+
610
private _expirationTime?: number;
711

8-
constructor(accessToken: string, refreshToken?: string) {
12+
constructor(accessToken: string, refreshToken?: string, scopes?: OAuthScopes) {
913
this._accessToken = accessToken;
1014
this._refreshToken = refreshToken;
15+
this._scopes = scopes;
1116
}
1217

1318
get accessToken(): string {
@@ -18,6 +23,10 @@ export default class OAuthToken {
1823
return this._refreshToken;
1924
}
2025

26+
get scopes(): OAuthScopes | undefined {
27+
return this._scopes;
28+
}
29+
2130
get expirationTime(): number {
2231
// This token has already been verified, and we are just parsing it.
2332
// If it has been tampered with, it will be rejected on the server side.

lib/connection/auth/DatabricksOAuth/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import { HeadersInit } from 'node-fetch';
22
import IAuthentication from '../../contracts/IAuthentication';
33
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
4-
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
4+
import OAuthManager, { OAuthManagerOptions, OAuthFlow } from './OAuthManager';
55
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';
66
import IClientContext from '../../../contracts/IClientContext';
77

8+
export { OAuthFlow };
9+
810
interface DatabricksOAuthOptions extends OAuthManagerOptions {
911
scopes?: OAuthScopes;
1012
persistence?: OAuthPersistence;

tests/unit/DBSQLClient.test.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ describe('DBSQLClient.initAuthProvider', () => {
344344
authType: 'databricks-oauth',
345345
// host is used when creating OAuth manager, so make it look like a real AWS instance
346346
host: 'example.dev.databricks.com',
347+
oauthClientSecret: 'test-secret',
347348
});
348349

349350
expect(provider).to.be.instanceOf(DatabricksOAuth);

0 commit comments

Comments
 (0)