-
Notifications
You must be signed in to change notification settings - Fork 113
[PECOBLR-587] Azure Service Principal Credential Provider #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5e2b792
05ab3e8
bef7ac6
0877b6f
83d6a9e
b21d015
6d85d19
f518085
814d1cb
ce4543c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,13 @@ | ||
from enum import Enum | ||
from typing import Optional, List | ||
|
||
from databricks.sql.auth.authenticators import ( | ||
AuthProvider, | ||
AccessTokenAuthProvider, | ||
ExternalAuthProvider, | ||
DatabricksOAuthProvider, | ||
AzureServicePrincipalCredentialProvider, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this coming from sdk? |
||
) | ||
|
||
|
||
class AuthType(Enum): | ||
DATABRICKS_OAUTH = "databricks-oauth" | ||
AZURE_OAUTH = "azure-oauth" | ||
# other supported types (access_token) can be inferred | ||
# we can add more types as needed later | ||
from databricks.sql.auth.common import AuthType | ||
|
||
|
||
class ClientContext: | ||
|
@@ -24,6 +18,9 @@ def __init__( | |
auth_type: Optional[str] = None, | ||
oauth_scopes: Optional[List[str]] = None, | ||
oauth_client_id: Optional[str] = None, | ||
oauth_client_secret: Optional[str] = None, | ||
azure_tenant_id: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a todo to remove the need of passing tenant id ? databricks/databricks-sdk-py#638 Also see : https://databricks.atlassian.net/browse/PECOBLR-212 |
||
azure_workspace_resource_id: Optional[str] = None, | ||
oauth_redirect_port_range: Optional[List[int]] = None, | ||
use_cert_as_auth: Optional[str] = None, | ||
tls_client_cert_file: Optional[str] = None, | ||
|
@@ -35,6 +32,9 @@ def __init__( | |
self.auth_type = auth_type | ||
self.oauth_scopes = oauth_scopes | ||
self.oauth_client_id = oauth_client_id | ||
self.oauth_client_secret = oauth_client_secret | ||
self.azure_tenant_id = azure_tenant_id | ||
self.azure_workspace_resource_id = azure_workspace_resource_id | ||
self.oauth_redirect_port_range = oauth_redirect_port_range | ||
self.use_cert_as_auth = use_cert_as_auth | ||
self.tls_client_cert_file = tls_client_cert_file | ||
|
@@ -45,7 +45,17 @@ def __init__( | |
def get_auth_provider(cfg: ClientContext): | ||
if cfg.credentials_provider: | ||
return ExternalAuthProvider(cfg.credentials_provider) | ||
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: | ||
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: | ||
return ExternalAuthProvider( | ||
AzureServicePrincipalCredentialProvider( | ||
cfg.hostname, | ||
cfg.oauth_client_id, | ||
cfg.oauth_client_secret, | ||
cfg.azure_tenant_id, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were talking about this in JDBC if tenant ID can be derived dynamically and not be a required config |
||
cfg.azure_workspace_resource_id, | ||
) | ||
) | ||
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: | ||
assert cfg.oauth_redirect_port_range is not None | ||
assert cfg.oauth_client_id is not None | ||
assert cfg.oauth_scopes is not None | ||
|
@@ -103,9 +113,15 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): | |
|
||
def get_python_sql_connector_auth_provider(hostname: str, **kwargs): | ||
auth_type = kwargs.get("auth_type") | ||
(client_id, redirect_port_range) = get_client_id_and_redirect_port( | ||
auth_type == AuthType.AZURE_OAUTH.value | ||
) | ||
client_id = kwargs.get("oauth_client_id") | ||
redirect_port_range = kwargs.get("oauth_redirect_port_range") | ||
|
||
if auth_type == AuthType.AZURE_SP_M2M.value: | ||
pass | ||
jprakash-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not just do if |
||
(client_id, redirect_port_range) = get_client_id_and_redirect_port( | ||
auth_type == AuthType.AZURE_OAUTH.value | ||
) | ||
if kwargs.get("username") or kwargs.get("password"): | ||
raise ValueError( | ||
"Username/password authentication is no longer supported. " | ||
|
@@ -119,9 +135,12 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): | |
use_cert_as_auth=kwargs.get("_use_cert_as_auth"), | ||
tls_client_cert_file=kwargs.get("_tls_client_cert_file"), | ||
oauth_scopes=PYSQL_OAUTH_SCOPES, | ||
oauth_client_id=kwargs.get("oauth_client_id") or client_id, | ||
oauth_client_id=client_id, | ||
oauth_client_secret=kwargs.get("oauth_client_secret"), | ||
azure_tenant_id=kwargs.get("azure_tenant_id"), | ||
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"), | ||
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]] | ||
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port") | ||
if client_id and kwargs.get("oauth_redirect_port") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is behaviour change from before i.e. earlier we required client_id to come in from kwargs and now it is ok even if we derive it from get_client_id_and_redirect_port, is this intended? |
||
else redirect_port_range, | ||
oauth_persistence=kwargs.get("experimental_oauth_persistence"), | ||
credentials_provider=kwargs.get("credentials_provider"), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,21 @@ | ||
import abc | ||
import base64 | ||
import logging | ||
from typing import Callable, Dict, List | ||
|
||
from databricks.sql.auth.oauth import OAuthManager | ||
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host | ||
from typing import Callable, Dict, List, Optional | ||
from databricks.sql.common.http import HttpHeader | ||
from databricks.sql.auth.oauth import ( | ||
OAuthManager, | ||
RefreshableTokenSource, | ||
ClientCredentialsTokenSource, | ||
) | ||
from databricks.sql.auth.endpoint import get_oauth_endpoints | ||
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id | ||
|
||
# Private API: this is an evolving interface and it will change in the future. | ||
# Please must not depend on it in your applications. | ||
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence | ||
|
||
logger = logging.getLogger(__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think this is being used, can we add logging? |
||
|
||
|
||
class AuthProvider: | ||
def add_headers(self, request_headers: Dict[str, str]): | ||
|
@@ -146,3 +152,80 @@ def add_headers(self, request_headers: Dict[str, str]): | |
headers = self._header_factory() | ||
for k, v in headers.items(): | ||
request_headers[k] = v | ||
|
||
|
||
class AzureServicePrincipalCredentialProvider(CredentialsProvider): | ||
""" | ||
A credential provider for Azure Service Principal authentication with Databricks. | ||
This class implements the CredentialsProvider protocol to authenticate requests | ||
to Databricks REST APIs using Azure Active Directory (AAD) service principal | ||
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens | ||
from Azure AD and automatically refreshes them when they expire. | ||
Attributes: | ||
hostname (str): The Databricks workspace hostname. | ||
oauth_client_id (str): The Azure service principal's client ID. | ||
oauth_client_secret (str): The Azure service principal's client secret. | ||
azure_tenant_id (str): The Azure AD tenant ID. | ||
azure_workspace_resource_id (str, optional): The Azure workspace resource ID. | ||
""" | ||
|
||
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com" | ||
AZURE_TOKEN_ENDPOINT = "oauth2/token" | ||
|
||
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/" | ||
|
||
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token" | ||
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = ( | ||
"X-Databricks-Azure-Workspace-Resource-Id" | ||
) | ||
|
||
def __init__( | ||
self, | ||
hostname, | ||
oauth_client_id, | ||
oauth_client_secret, | ||
azure_tenant_id, | ||
azure_workspace_resource_id=None, | ||
): | ||
self.hostname = hostname | ||
self.oauth_client_id = oauth_client_id | ||
self.oauth_client_secret = oauth_client_secret | ||
self.azure_tenant_id = azure_tenant_id | ||
self.azure_workspace_resource_id = azure_workspace_resource_id | ||
|
||
def auth_type(self) -> str: | ||
return AuthType.AZURE_SP_M2M.value | ||
|
||
def get_token_source(self, resource: str) -> RefreshableTokenSource: | ||
return ClientCredentialsTokenSource( | ||
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}", | ||
oauth_client_id=self.oauth_client_id, | ||
oauth_client_secret=self.oauth_client_secret, | ||
extra_params={"resource": resource}, | ||
) | ||
|
||
def __call__(self, *args, **kwargs) -> HeaderFactory: | ||
inner = self.get_token_source( | ||
resource=get_effective_azure_login_app_id(self.hostname) | ||
) | ||
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE) | ||
|
||
def header_factory() -> Dict[str, str]: | ||
inner_token = inner.get_token() | ||
cloud_token = cloud.get_token() | ||
|
||
headers = { | ||
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}", | ||
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token, | ||
} | ||
|
||
if self.azure_workspace_resource_id: | ||
headers[ | ||
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER | ||
] = self.azure_workspace_resource_id | ||
|
||
return headers | ||
|
||
return header_factory |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
|
||
class AuthType(Enum): | ||
DATABRICKS_OAUTH = "databricks-oauth" | ||
AZURE_OAUTH = "azure-oauth" | ||
AZURE_SP_M2M = "azure-sp-m2m" | ||
|
||
|
||
def get_effective_azure_login_app_id(hostname) -> str: | ||
""" | ||
Get the effective Azure login app ID for a given hostname. | ||
This function determines the appropriate Azure login app ID based on the hostname. | ||
If the hostname does not match any of these domains, it returns the default Databricks resource ID. | ||
|
||
""" | ||
azure_app_ids = { | ||
".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", | ||
".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68", | ||
} | ||
|
||
for domain, app_id in azure_app_ids.items(): | ||
if domain in hostname: | ||
return app_id | ||
|
||
# default databricks resource id | ||
return "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add these IDs as constants at the top so they're in one place |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,19 +6,63 @@ | |
import webbrowser | ||
from datetime import datetime, timezone | ||
from http.server import HTTPServer | ||
from typing import List | ||
from typing import List, Optional | ||
|
||
import oauthlib.oauth2 | ||
import requests | ||
from oauthlib.oauth2.rfc6749.errors import OAuth2Error | ||
from requests.exceptions import RequestException | ||
|
||
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader | ||
from databricks.sql.common.http import OAuthResponse | ||
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler | ||
from databricks.sql.auth.endpoint import OAuthEndpointCollection | ||
from abc import abstractmethod, ABC | ||
from urllib.parse import urlencode | ||
import jwt | ||
import time | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Token: | ||
""" | ||
A class to represent a token. | ||
Attributes: | ||
access_token (str): The access token string. | ||
token_type (str): The type of token (e.g., "Bearer"). | ||
refresh_token (str): The refresh token string. | ||
""" | ||
|
||
def __init__(self, access_token: str, token_type: str, refresh_token: str): | ||
self.access_token = access_token | ||
self.token_type = token_type | ||
self.refresh_token = refresh_token | ||
|
||
def is_expired(self): | ||
try: | ||
decoded_token = jwt.decode( | ||
self.access_token, options={"verify_signature": False} | ||
) | ||
exp_time = decoded_token.get("exp") | ||
current_time = time.time() | ||
buffer_time = 30 # 30 seconds buffer | ||
return exp_time and (exp_time - buffer_time) <= current_time | ||
except Exception as e: | ||
logger.error("Failed to decode token: %s", e) | ||
return e | ||
jprakash-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class RefreshableTokenSource(ABC): | ||
@abstractmethod | ||
def get_token(self) -> Token: | ||
pass | ||
|
||
@abstractmethod | ||
def refresh(self) -> Token: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a comment here that we have duplicate code here with the sdk (with the code pointer) and in the long term we should try to unify? |
||
|
||
|
||
class IgnoreNetrcAuth(requests.auth.AuthBase): | ||
"""This auth method is a no-op. | ||
|
@@ -258,3 +302,63 @@ def get_tokens(self, hostname: str, scope=None): | |
client, token_request_url, redirect_url, code, verifier | ||
) | ||
return self.__get_tokens_from_response(oauth_response) | ||
|
||
|
||
class ClientCredentialsTokenSource(RefreshableTokenSource): | ||
""" | ||
A token source that uses client credentials to get a token from the token endpoint. | ||
It will refresh the token if it is expired. | ||
Attributes: | ||
token_url (str): The URL of the token endpoint. | ||
oauth_client_id (str): The client ID. | ||
oauth_client_secret (str): The client secret. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
token_url, | ||
oauth_client_id, | ||
oauth_client_secret, | ||
extra_params: dict = {}, | ||
): | ||
self.oauth_client_id = oauth_client_id | ||
self.oauth_client_secret = oauth_client_secret | ||
self.token_url = token_url | ||
self.extra_params = extra_params | ||
self.token: Optional[Token] = None | ||
self._http_client = DatabricksHttpClient.get_instance() | ||
|
||
def get_token(self) -> Token: | ||
if self.token is None or self.token.is_expired(): | ||
self.token = self.refresh() | ||
return self.token | ||
|
||
def refresh(self) -> Token: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is the refresh mechanism being handled for the existing credential providers? is there opportunity to dedup/reuse? |
||
headers = { | ||
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded", | ||
} | ||
data = urlencode( | ||
{ | ||
"grant_type": "client_credentials", | ||
"client_id": self.oauth_client_id, | ||
"client_secret": self.oauth_client_secret, | ||
**self.extra_params, | ||
} | ||
) | ||
|
||
response = self._http_client.execute( | ||
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data | ||
) | ||
|
||
if response.status_code == 200: | ||
oauth_response = OAuthResponse(**response.json()) | ||
return Token( | ||
oauth_response.access_token, | ||
oauth_response.token_type, | ||
oauth_response.refresh_token, | ||
) | ||
else: | ||
raise Exception( | ||
f"Failed to get token: {response.status_code} {response.text}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the poetry update have to go in the same PR?