Skip to content

Commit e50e86d

Browse files
authored
[PECOBLR-587] Azure Service Principal Credential Provider (#621)
* basic setup * Nit * working * moved pyjwt to code dependency * nit * nit * nit * nit * nit * nit * testing sdk * Refractor * logging * nit * nit * nit * nit * nit
1 parent 90f0ac1 commit e50e86d

File tree

9 files changed

+643
-80
lines changed

9 files changed

+643
-80
lines changed

poetry.lock

Lines changed: 100 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ requests = "^2.18.1"
2020
oauthlib = "^3.1.0"
2121
openpyxl = "^3.0.10"
2222
urllib3 = ">=1.26"
23+
python-dateutil = "^2.8.0"
2324
pyarrow = [
2425
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
2526
{ version = ">=18.0.0", python = ">=3.13", optional=true }
2627
]
27-
python-dateutil = "^2.8.0"
28+
pyjwt = "^2.0.0"
29+
2830

2931
[tool.poetry.extras]
3032
pyarrow = ["pyarrow"]
3133

32-
[tool.poetry.dev-dependencies]
34+
[tool.poetry.group.dev.dependencies]
3335
pytest = "^7.1.2"
3436
mypy = "^1.10.1"
3537
pylint = ">=2.12.0"

src/databricks/sql/auth/auth.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,29 @@
1-
from enum import Enum
21
from typing import Optional, List
32

43
from databricks.sql.auth.authenticators import (
54
AuthProvider,
65
AccessTokenAuthProvider,
76
ExternalAuthProvider,
87
DatabricksOAuthProvider,
8+
AzureServicePrincipalCredentialProvider,
99
)
10-
11-
12-
class AuthType(Enum):
13-
DATABRICKS_OAUTH = "databricks-oauth"
14-
AZURE_OAUTH = "azure-oauth"
15-
# other supported types (access_token) can be inferred
16-
# we can add more types as needed later
17-
18-
19-
class ClientContext:
20-
def __init__(
21-
self,
22-
hostname: str,
23-
access_token: Optional[str] = None,
24-
auth_type: Optional[str] = None,
25-
oauth_scopes: Optional[List[str]] = None,
26-
oauth_client_id: Optional[str] = None,
27-
oauth_redirect_port_range: Optional[List[int]] = None,
28-
use_cert_as_auth: Optional[str] = None,
29-
tls_client_cert_file: Optional[str] = None,
30-
oauth_persistence=None,
31-
credentials_provider=None,
32-
):
33-
self.hostname = hostname
34-
self.access_token = access_token
35-
self.auth_type = auth_type
36-
self.oauth_scopes = oauth_scopes
37-
self.oauth_client_id = oauth_client_id
38-
self.oauth_redirect_port_range = oauth_redirect_port_range
39-
self.use_cert_as_auth = use_cert_as_auth
40-
self.tls_client_cert_file = tls_client_cert_file
41-
self.oauth_persistence = oauth_persistence
42-
self.credentials_provider = credentials_provider
10+
from databricks.sql.auth.common import AuthType, ClientContext
4311

4412

4513
def get_auth_provider(cfg: ClientContext):
4614
if cfg.credentials_provider:
4715
return ExternalAuthProvider(cfg.credentials_provider)
48-
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
16+
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17+
return ExternalAuthProvider(
18+
AzureServicePrincipalCredentialProvider(
19+
cfg.hostname,
20+
cfg.azure_client_id,
21+
cfg.azure_client_secret,
22+
cfg.azure_tenant_id,
23+
cfg.azure_workspace_resource_id,
24+
)
25+
)
26+
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
4927
assert cfg.oauth_redirect_port_range is not None
5028
assert cfg.oauth_client_id is not None
5129
assert cfg.oauth_scopes is not None
@@ -102,10 +80,13 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
10280

10381

10482
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
83+
# TODO : unify all the auth mechanisms with the Python SDK
84+
10585
auth_type = kwargs.get("auth_type")
10686
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
10787
auth_type == AuthType.AZURE_OAUTH.value
10888
)
89+
10990
if kwargs.get("username") or kwargs.get("password"):
11091
raise ValueError(
11192
"Username/password authentication is no longer supported. "
@@ -120,6 +101,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
120101
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121102
oauth_scopes=PYSQL_OAUTH_SCOPES,
122103
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
104+
azure_client_id=kwargs.get("azure_client_id"),
105+
azure_client_secret=kwargs.get("azure_client_secret"),
106+
azure_tenant_id=kwargs.get("azure_tenant_id"),
107+
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123108
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124109
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
125110
else redirect_port_range,

src/databricks/sql/auth/authenticators.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import abc
2-
import base64
32
import logging
43
from typing import Callable, Dict, List
5-
6-
from databricks.sql.auth.oauth import OAuthManager
7-
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
4+
from databricks.sql.common.http import HttpHeader
5+
from databricks.sql.auth.oauth import (
6+
OAuthManager,
7+
RefreshableTokenSource,
8+
ClientCredentialsTokenSource,
9+
)
10+
from databricks.sql.auth.endpoint import get_oauth_endpoints
11+
from databricks.sql.auth.common import (
12+
AuthType,
13+
get_effective_azure_login_app_id,
14+
get_azure_tenant_id_from_host,
15+
)
816

917
# Private API: this is an evolving interface and it will change in the future.
1018
# Please must not depend on it in your applications.
@@ -146,3 +154,82 @@ def add_headers(self, request_headers: Dict[str, str]):
146154
headers = self._header_factory()
147155
for k, v in headers.items():
148156
request_headers[k] = v
157+
158+
159+
class AzureServicePrincipalCredentialProvider(CredentialsProvider):
160+
"""
161+
A credential provider for Azure Service Principal authentication with Databricks.
162+
163+
This class implements the CredentialsProvider protocol to authenticate requests
164+
to Databricks REST APIs using Azure Active Directory (AAD) service principal
165+
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
166+
from Azure AD and automatically refreshes them when they expire.
167+
168+
Attributes:
169+
hostname (str): The Databricks workspace hostname.
170+
azure_client_id (str): The Azure service principal's client ID.
171+
azure_client_secret (str): The Azure service principal's client secret.
172+
azure_tenant_id (str): The Azure AD tenant ID.
173+
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
174+
"""
175+
176+
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
177+
AZURE_TOKEN_ENDPOINT = "oauth2/token"
178+
179+
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
180+
181+
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
182+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
183+
"X-Databricks-Azure-Workspace-Resource-Id"
184+
)
185+
186+
def __init__(
187+
self,
188+
hostname,
189+
azure_client_id,
190+
azure_client_secret,
191+
azure_tenant_id=None,
192+
azure_workspace_resource_id=None,
193+
):
194+
self.hostname = hostname
195+
self.azure_client_id = azure_client_id
196+
self.azure_client_secret = azure_client_secret
197+
self.azure_workspace_resource_id = azure_workspace_resource_id
198+
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
199+
hostname
200+
)
201+
202+
def auth_type(self) -> str:
203+
return AuthType.AZURE_SP_M2M.value
204+
205+
def get_token_source(self, resource: str) -> RefreshableTokenSource:
206+
return ClientCredentialsTokenSource(
207+
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
208+
client_id=self.azure_client_id,
209+
client_secret=self.azure_client_secret,
210+
extra_params={"resource": resource},
211+
)
212+
213+
def __call__(self, *args, **kwargs) -> HeaderFactory:
214+
inner = self.get_token_source(
215+
resource=get_effective_azure_login_app_id(self.hostname)
216+
)
217+
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
218+
219+
def header_factory() -> Dict[str, str]:
220+
inner_token = inner.get_token()
221+
cloud_token = cloud.get_token()
222+
223+
headers = {
224+
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
225+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
226+
}
227+
228+
if self.azure_workspace_resource_id:
229+
headers[
230+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
231+
] = self.azure_workspace_resource_id
232+
233+
return headers
234+
235+
return header_factory

src/databricks/sql/auth/common.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from enum import Enum
2+
import logging
3+
from typing import Optional, List
4+
from urllib.parse import urlparse
5+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class AuthType(Enum):
11+
DATABRICKS_OAUTH = "databricks-oauth"
12+
AZURE_OAUTH = "azure-oauth"
13+
AZURE_SP_M2M = "azure-sp-m2m"
14+
15+
16+
class AzureAppId(Enum):
17+
DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
18+
STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
19+
PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
20+
21+
22+
class ClientContext:
23+
def __init__(
24+
self,
25+
hostname: str,
26+
access_token: Optional[str] = None,
27+
auth_type: Optional[str] = None,
28+
oauth_scopes: Optional[List[str]] = None,
29+
oauth_client_id: Optional[str] = None,
30+
azure_client_id: Optional[str] = None,
31+
azure_client_secret: Optional[str] = None,
32+
azure_tenant_id: Optional[str] = None,
33+
azure_workspace_resource_id: Optional[str] = None,
34+
oauth_redirect_port_range: Optional[List[int]] = None,
35+
use_cert_as_auth: Optional[str] = None,
36+
tls_client_cert_file: Optional[str] = None,
37+
oauth_persistence=None,
38+
credentials_provider=None,
39+
):
40+
self.hostname = hostname
41+
self.access_token = access_token
42+
self.auth_type = auth_type
43+
self.oauth_scopes = oauth_scopes
44+
self.oauth_client_id = oauth_client_id
45+
self.azure_client_id = azure_client_id
46+
self.azure_client_secret = azure_client_secret
47+
self.azure_tenant_id = azure_tenant_id
48+
self.azure_workspace_resource_id = azure_workspace_resource_id
49+
self.oauth_redirect_port_range = oauth_redirect_port_range
50+
self.use_cert_as_auth = use_cert_as_auth
51+
self.tls_client_cert_file = tls_client_cert_file
52+
self.oauth_persistence = oauth_persistence
53+
self.credentials_provider = credentials_provider
54+
55+
56+
def get_effective_azure_login_app_id(hostname) -> str:
57+
"""
58+
Get the effective Azure login app ID for a given hostname.
59+
This function determines the appropriate Azure login app ID based on the hostname.
60+
If the hostname does not match any of these domains, it returns the default Databricks resource ID.
61+
62+
"""
63+
for azure_app_id in AzureAppId:
64+
domain, app_id = azure_app_id.value
65+
if domain in hostname:
66+
return app_id
67+
68+
# default databricks resource id
69+
return AzureAppId.PROD.value[1]
70+
71+
72+
def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
73+
"""
74+
Load the Azure tenant ID from the Azure Databricks login page.
75+
76+
This function retrieves the Azure tenant ID by making a request to the Databricks
77+
Azure Active Directory (AAD) authentication endpoint. The endpoint redirects to
78+
the Azure login page, and the tenant ID is extracted from the redirect URL.
79+
"""
80+
81+
if http_client is None:
82+
http_client = DatabricksHttpClient.get_instance()
83+
84+
login_url = f"{host}/aad/auth"
85+
logger.debug("Loading tenant ID from %s", login_url)
86+
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
87+
if resp.status_code // 100 != 3:
88+
raise ValueError(
89+
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
90+
)
91+
entra_id_endpoint = resp.headers.get("Location")
92+
if entra_id_endpoint is None:
93+
raise ValueError(f"No Location header in response from {login_url}")
94+
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
95+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
96+
url = urlparse(entra_id_endpoint)
97+
path_segments = url.path.split("/")
98+
if len(path_segments) < 2:
99+
raise ValueError(f"Invalid path in Location header: {url.path}")
100+
return path_segments[1]

0 commit comments

Comments
 (0)