Skip to content

[PECOBLR-587] Azure Service Principal Credential Provider #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 100 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ requests = "^2.18.1"
oauthlib = "^3.1.0"
openpyxl = "^3.0.10"
urllib3 = ">=1.26"
python-dateutil = "^2.8.0"
pyarrow = [
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
python-dateutil = "^2.8.0"
pyjwt = "^2.0.0"


[tool.poetry.extras]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the poetry update have to go in the same PR?

pytest = "^7.1.2"
mypy = "^1.10.1"
pylint = ">=2.12.0"
Expand Down
47 changes: 33 additions & 14 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from enum import Enum
from typing import Optional, List

from databricks.sql.auth.authenticators import (
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
AzureServicePrincipalCredentialProvider,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this coming from sdk?



class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# other supported types (access_token) can be inferred
# we can add more types as needed later
from databricks.sql.auth.common import AuthType


class ClientContext:
Expand All @@ -24,6 +18,9 @@ def __init__(
auth_type: Optional[str] = None,
oauth_scopes: Optional[List[str]] = None,
oauth_client_id: Optional[str] = None,
oauth_client_secret: Optional[str] = None,
azure_tenant_id: Optional[str] = None,
azure_workspace_resource_id: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a todo to remove the need of passing tenant id ? databricks/databricks-sdk-py#638

Also see : https://databricks.atlassian.net/browse/PECOBLR-212

oauth_redirect_port_range: Optional[List[int]] = None,
use_cert_as_auth: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
Expand All @@ -35,6 +32,9 @@ def __init__(
self.auth_type = auth_type
self.oauth_scopes = oauth_scopes
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.azure_tenant_id = azure_tenant_id
self.azure_workspace_resource_id = azure_workspace_resource_id
self.oauth_redirect_port_range = oauth_redirect_port_range
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
Expand All @@ -45,7 +45,17 @@ def __init__(
def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
return ExternalAuthProvider(
AzureServicePrincipalCredentialProvider(
cfg.hostname,
cfg.oauth_client_id,
cfg.oauth_client_secret,
cfg.azure_tenant_id,
cfg.azure_workspace_resource_id,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were talking about this in JDBC if tenant ID can be derived dynamically and not be a required config

)
)
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
Expand Down Expand Up @@ -103,9 +113,15 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):

def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
client_id = kwargs.get("oauth_client_id")
redirect_port_range = kwargs.get("oauth_redirect_port_range")

if auth_type == AuthType.AZURE_SP_M2M.value:
pass
else:
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just do if if auth_type != AuthType.AZURE_SP_M2M.value:

auth_type == AuthType.AZURE_OAUTH.value
)
if kwargs.get("username") or kwargs.get("password"):
raise ValueError(
"Username/password authentication is no longer supported. "
Expand All @@ -119,9 +135,12 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
oauth_scopes=PYSQL_OAUTH_SCOPES,
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
oauth_client_id=client_id,
oauth_client_secret=kwargs.get("oauth_client_secret"),
azure_tenant_id=kwargs.get("azure_tenant_id"),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
if client_id and kwargs.get("oauth_redirect_port")
else redirect_port_range,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is behaviour change from before i.e. earlier we required client_id to come in from kwargs and now it is ok even if we derive it from get_client_id_and_redirect_port, is this intended?

oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
Expand Down
93 changes: 88 additions & 5 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import abc
import base64
import logging
from typing import Callable, Dict, List

from databricks.sql.auth.oauth import OAuthManager
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
from typing import Callable, Dict, List, Optional
from databricks.sql.common.http import HttpHeader
from databricks.sql.auth.oauth import (
OAuthManager,
RefreshableTokenSource,
ClientCredentialsTokenSource,
)
from databricks.sql.auth.endpoint import get_oauth_endpoints
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id

# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence

logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this is being used, can we add logging?



class AuthProvider:
def add_headers(self, request_headers: Dict[str, str]):
Expand Down Expand Up @@ -146,3 +152,80 @@ def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v


class AzureServicePrincipalCredentialProvider(CredentialsProvider):
"""
A credential provider for Azure Service Principal authentication with Databricks.
This class implements the CredentialsProvider protocol to authenticate requests
to Databricks REST APIs using Azure Active Directory (AAD) service principal
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
from Azure AD and automatically refreshes them when they expire.
Attributes:
hostname (str): The Databricks workspace hostname.
oauth_client_id (str): The Azure service principal's client ID.
oauth_client_secret (str): The Azure service principal's client secret.
azure_tenant_id (str): The Azure AD tenant ID.
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
"""

AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_ENDPOINT = "oauth2/token"

AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"

DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
"X-Databricks-Azure-Workspace-Resource-Id"
)

def __init__(
self,
hostname,
oauth_client_id,
oauth_client_secret,
azure_tenant_id,
azure_workspace_resource_id=None,
):
self.hostname = hostname
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.azure_tenant_id = azure_tenant_id
self.azure_workspace_resource_id = azure_workspace_resource_id

def auth_type(self) -> str:
return AuthType.AZURE_SP_M2M.value

def get_token_source(self, resource: str) -> RefreshableTokenSource:
return ClientCredentialsTokenSource(
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
oauth_client_id=self.oauth_client_id,
oauth_client_secret=self.oauth_client_secret,
extra_params={"resource": resource},
)

def __call__(self, *args, **kwargs) -> HeaderFactory:
inner = self.get_token_source(
resource=get_effective_azure_login_app_id(self.hostname)
)
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)

def header_factory() -> Dict[str, str]:
inner_token = inner.get_token()
cloud_token = cloud.get_token()

headers = {
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
}

if self.azure_workspace_resource_id:
headers[
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
] = self.azure_workspace_resource_id

return headers

return header_factory
28 changes: 28 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import Enum
from typing import Optional


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
AZURE_SP_M2M = "azure-sp-m2m"


def get_effective_azure_login_app_id(hostname) -> str:
"""
Get the effective Azure login app ID for a given hostname.
This function determines the appropriate Azure login app ID based on the hostname.
If the hostname does not match any of these domains, it returns the default Databricks resource ID.

"""
azure_app_ids = {
".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68",
}

for domain, app_id in azure_app_ids.items():
if domain in hostname:
return app_id

# default databricks resource id
return "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add these IDs as constants at the top so they're in one place

108 changes: 106 additions & 2 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,63 @@
import webbrowser
from datetime import datetime, timezone
from http.server import HTTPServer
from typing import List
from typing import List, Optional

import oauthlib.oauth2
import requests
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from requests.exceptions import RequestException

from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
from abc import abstractmethod, ABC
from urllib.parse import urlencode
import jwt
import time

logger = logging.getLogger(__name__)


class Token:
"""
A class to represent a token.
Attributes:
access_token (str): The access token string.
token_type (str): The type of token (e.g., "Bearer").
refresh_token (str): The refresh token string.
"""

def __init__(self, access_token: str, token_type: str, refresh_token: str):
self.access_token = access_token
self.token_type = token_type
self.refresh_token = refresh_token

def is_expired(self):
try:
decoded_token = jwt.decode(
self.access_token, options={"verify_signature": False}
)
exp_time = decoded_token.get("exp")
current_time = time.time()
buffer_time = 30 # 30 seconds buffer
return exp_time and (exp_time - buffer_time) <= current_time
except Exception as e:
logger.error("Failed to decode token: %s", e)
return e


class RefreshableTokenSource(ABC):
@abstractmethod
def get_token(self) -> Token:
pass

@abstractmethod
def refresh(self) -> Token:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment here that we have duplicate code here with the sdk (with the code pointer) and in the long term we should try to unify?



class IgnoreNetrcAuth(requests.auth.AuthBase):
"""This auth method is a no-op.
Expand Down Expand Up @@ -258,3 +302,63 @@ def get_tokens(self, hostname: str, scope=None):
client, token_request_url, redirect_url, code, verifier
)
return self.__get_tokens_from_response(oauth_response)


class ClientCredentialsTokenSource(RefreshableTokenSource):
"""
A token source that uses client credentials to get a token from the token endpoint.
It will refresh the token if it is expired.
Attributes:
token_url (str): The URL of the token endpoint.
oauth_client_id (str): The client ID.
oauth_client_secret (str): The client secret.
"""

def __init__(
self,
token_url,
oauth_client_id,
oauth_client_secret,
extra_params: dict = {},
):
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.token_url = token_url
self.extra_params = extra_params
self.token: Optional[Token] = None
self._http_client = DatabricksHttpClient.get_instance()

def get_token(self) -> Token:
if self.token is None or self.token.is_expired():
self.token = self.refresh()
return self.token

def refresh(self) -> Token:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the refresh mechanism being handled for the existing credential providers? is there opportunity to dedup/reuse?

headers = {
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
}
data = urlencode(
{
"grant_type": "client_credentials",
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
**self.extra_params,
}
)

response = self._http_client.execute(
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
)

if response.status_code == 200:
oauth_response = OAuthResponse(**response.json())
return Token(
oauth_response.access_token,
oauth_response.token_type,
oauth_response.refresh_token,
)
else:
raise Exception(
f"Failed to get token: {response.status_code} {response.text}"
)
Loading
Loading