Skip to content

[Identity] Enable brokered auth in DAC #40335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class InteractiveBrowserBrokerCredential(_InteractiveBrowserCredential):
are required to also provide its window handle, so that the sign in UI window will properly pop up on top
of your window.
:keyword bool use_default_broker_account: Enables automatically using the default broker account for
authentication instead of prompting the user with an account picker. Defaults to False.
authentication instead of prompting the user with an account picker. This is currently only supported on Windows
and WSL. Defaults to False.
:keyword bool enable_msa_passthrough: Determines whether Microsoft Account (MSA) passthrough is enabled. Note, this
is only needed for select legacy first-party applications. Defaults to False.
:keyword bool disable_instance_discovery: Determines whether or not instance discovery is performed when attempting
Expand All @@ -78,6 +79,7 @@ def __init__(self, **kwargs: Any) -> None:
self._parent_window_handle = kwargs.pop("parent_window_handle", None)
self._enable_msa_passthrough = kwargs.pop("enable_msa_passthrough", False)
self._use_default_broker_account = kwargs.pop("use_default_broker_account", False)
self._disable_interactive_fallback = kwargs.pop("disable_interactive_fallback", False)
super().__init__(**kwargs)

@wrap_exceptions
Expand All @@ -93,6 +95,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
http_method=pop["resource_request_method"], url=pop["resource_request_url"], nonce=pop["nonce"]
)
if sys.platform.startswith("win") or is_wsl():
result = {}
if self._use_default_broker_account:
try:
result = app.acquire_token_interactive(
Expand All @@ -110,6 +113,10 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
return result
except socket.error:
pass

if self._disable_interactive_fallback:
self._check_result(result)

try:
result = app.acquire_token_interactive(
scopes=scopes_list,
Expand All @@ -124,14 +131,8 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
)
except socket.error as ex:
raise CredentialUnavailableError(message="Couldn't start an HTTP server.") from ex
if "access_token" not in result and "error_description" in result:
if within_dac.get():
raise CredentialUnavailableError(message=result["error_description"])
raise ClientAuthenticationError(message=result.get("error_description"))
if "access_token" not in result:
if within_dac.get():
raise CredentialUnavailableError(message="Failed to authenticate user")
raise ClientAuthenticationError(message="Failed to authenticate user")

self._check_result(result)
else:
try:
result = app.acquire_token_interactive(
Expand All @@ -157,16 +158,19 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
parent_window_handle=self._parent_window_handle,
enable_msa_passthrough=self._enable_msa_passthrough,
)
if "access_token" in result:
return result
if "error_description" in result:
if within_dac.get():
# pylint: disable=raise-missing-from
raise CredentialUnavailableError(message=result["error_description"])
# pylint: disable=raise-missing-from
raise ClientAuthenticationError(message=result.get("error_description"))
self._check_result(result)
return result

def _check_result(self, result: Dict[str, Any]) -> None:
if "access_token" not in result and "error_description" in result:
if within_dac.get():
raise CredentialUnavailableError(message=result["error_description"])
raise ClientAuthenticationError(message=result.get("error_description"))
if "access_token" not in result:
if within_dac.get():
raise CredentialUnavailableError(message="Failed to authenticate user")
raise ClientAuthenticationError(message="Failed to authenticate user")

def _get_app(self, **kwargs: Any) -> msal.ClientApplication:
tenant_id = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Valid values are `EnvironmentCredential`, `WorkloadIdentityCredential`, `ManagedIdentityCredential`, `AzureCliCredential`, `AzurePowershellCredential`, `AzureDeveloperCliCredential`, and `InteractiveBrowserCredential`. ([#41709](https://github.com/Azure/azure-sdk-for-python/pull/41709))
- Re-enabled `VisualStudioCodeCredential` - Previously deprecated `VisualStudioCodeCredential` has been re-implemented to work with the VS Code Azure Resources extension instead of the deprecated Azure Account extension. This requires the `azure-identity-broker` package to be installed for authentication. ([#41822](https://github.com/Azure/azure-sdk-for-python/pull/41822))
- `VisualStudioCodeCredential` is now included in the `DefaultAzureCredential` token chain by default.

- `DefaultAzureCredential` now supports authentication with the currently signed-in Windows account, provided the `azure-identity-broker` package is installed. This auth mechanism is added at the end of the `DefaultAzureCredential` credential chain. ([#40335](https://github.com/Azure/azure-sdk-for-python/pull/40335))

### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import sys
from typing import Any

import msal
from azure.core.credentials import AccessToken, AccessTokenInfo, SupportsTokenInfo
from .._exceptions import CredentialUnavailableError
from .._internal.utils import get_broker_credential, is_wsl


class BrokerCredential(SupportsTokenInfo):
"""A broker credential that handles prerequisite checking and falls back appropriately.

This credential checks if the azure-identity-broker package is available and the platform
is supported. If both conditions are met, it uses the real broker credential. Otherwise,
it raises CredentialUnavailableError with an appropriate message.
"""

def __init__(self, **kwargs: Any) -> None:

self._tenant_id = kwargs.pop("tenant_id", None)
self._client_id = kwargs.pop("client_id", None)
self._broker_credential = None
self._unavailable_message = None

# Check prerequisites and initialize the appropriate credential
broker_credential_class = get_broker_credential()
if broker_credential_class and (sys.platform.startswith("win") or is_wsl()):
# The silent auth flow for brokered auth is available on Windows/WSL with the broker package
try:
broker_credential_args = {
"tenant_id": self._tenant_id,
"parent_window_handle": msal.PublicClientApplication.CONSOLE_WINDOW_HANDLE,
"use_default_broker_account": True,
"disable_interactive_fallback": True,
**kwargs,
}
if self._client_id:
broker_credential_args["client_id"] = self._client_id
self._broker_credential = broker_credential_class(**broker_credential_args)
except Exception as ex: # pylint: disable=broad-except
self._unavailable_message = f"InteractiveBrowserBrokerCredential initialization failed: {ex}"
else:
# Determine the specific reason for unavailability
if broker_credential_class is None:
self._unavailable_message = (
"InteractiveBrowserBrokerCredential unavailable. "
"The 'azure-identity-broker' package is required to use brokered authentication."
)
else:
self._unavailable_message = (
"InteractiveBrowserBrokerCredential unavailable. "
"Brokered authentication is only supported on Windows and WSL platforms."
)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
if self._broker_credential:
return self._broker_credential.get_token(*scopes, **kwargs)
raise CredentialUnavailableError(message=self._unavailable_message)

def get_token_info(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
if self._broker_credential:
return self._broker_credential.get_token_info(*scopes, **kwargs)
raise CredentialUnavailableError(message=self._unavailable_message)

def __enter__(self) -> "BrokerCredential":
if self._broker_credential:
self._broker_credential.__enter__()
return self

def __exit__(self, *args):
if self._broker_credential:
self._broker_credential.__exit__(*args)

def close(self) -> None:
self.__exit__()
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo, TokenCredential
from .._constants import EnvironmentVariables
from .._internal import get_default_authority, normalize_authority, within_dac, process_credential_exclusions
from .._internal.utils import get_default_authority, normalize_authority, within_dac, process_credential_exclusions
from .azure_powershell import AzurePowerShellCredential
from .broker import BrokerCredential
from .browser import InteractiveBrowserCredential
from .chained import ChainedTokenCredential
from .environment import EnvironmentCredential
Expand Down Expand Up @@ -42,6 +43,9 @@ class DefaultAzureCredential(ChainedTokenCredential):
5. The identity currently logged in to the Azure CLI.
6. The identity currently logged in to Azure PowerShell.
7. The identity currently logged in to the Azure Developer CLI.
8. Brokered authentication. On Windows and WSL, this uses an authentication broker such as
Web Account Manager (WAM). On other platforms or when the azure-identity-broker package is not installed,
this credential will raise CredentialUnavailableError.

This default behavior is configurable with keyword arguments.

Expand All @@ -64,9 +68,14 @@ class DefaultAzureCredential(ChainedTokenCredential):
**False**.
:keyword bool exclude_interactive_browser_credential: Whether to exclude interactive browser authentication (see
:class:`~azure.identity.InteractiveBrowserCredential`). Defaults to **True**.
:keyword bool exclude_broker_credential: Whether to exclude the broker credential from the credential chain.
Defaults to **False**. When False, the broker credential is always included in the chain but will raise
CredentialUnavailableError if the azure-identity-broker package is not installed.
:keyword str interactive_browser_tenant_id: Tenant ID to use when authenticating a user through
:class:`~azure.identity.InteractiveBrowserCredential`. Defaults to the value of environment variable
AZURE_TENANT_ID, if any. If unspecified, users will authenticate in their home tenants.
:keyword str broker_tenant_id: The tenant ID to use when using brokered authentication. Defaults to the value of
environment variable AZURE_TENANT_ID, if any. If unspecified, users will authenticate in their home tenants.
:keyword str managed_identity_client_id: The client ID of a user-assigned managed identity. Defaults to the value
of the environment variable AZURE_CLIENT_ID, if any. If not specified, a system-assigned identity will be used.
:keyword str workload_identity_client_id: The client ID of an identity assigned to the pod. Defaults to the value
Expand All @@ -75,6 +84,8 @@ class DefaultAzureCredential(ChainedTokenCredential):
Defaults to the value of environment variable AZURE_TENANT_ID, if any.
:keyword str interactive_browser_client_id: The client ID to be used in interactive browser credential. If not
specified, users will authenticate to an Azure development application.
:keyword str broker_client_id: The client ID to be used in brokered authentication. If not specified, users will
authenticate to an Azure development application.
:keyword str shared_cache_username: Preferred username for :class:`~azure.identity.SharedTokenCacheCredential`.
Defaults to the value of environment variable AZURE_USERNAME, if any.
:keyword str shared_cache_tenant_id: Preferred tenant for :class:`~azure.identity.SharedTokenCacheCredential`.
Expand Down Expand Up @@ -117,6 +128,9 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
)
interactive_browser_client_id = kwargs.pop("interactive_browser_client_id", None)

broker_tenant_id = kwargs.pop("broker_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID))
broker_client_id = kwargs.pop("broker_client_id", None)

shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME))
shared_cache_tenant_id = kwargs.pop(
"shared_cache_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID)
Expand Down Expand Up @@ -147,6 +161,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
},
"visual_studio_code": {
"exclude_param": "exclude_visual_studio_code_credential",
"env_name": "visualstudiocodecredential",
"default_exclude": False,
},
"cli": {
Expand All @@ -169,6 +184,11 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
"env_name": "interactivebrowsercredential",
"default_exclude": True,
},
"broker": {
"exclude_param": "exclude_broker_credential",
"env_name": "interactivebrowserbrokercredential",
"default_exclude": False,
},
}

# Extract user-provided exclude flags and set defaults
Expand All @@ -192,6 +212,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
exclude_developer_cli_credential = exclude_flags["developer_cli"]
exclude_powershell_credential = exclude_flags["powershell"]
exclude_interactive_browser_credential = exclude_flags["interactive_browser"]
exclude_broker_credential = exclude_flags["broker"]

credentials: List[SupportsTokenInfo] = []
within_dac.set(True)
Expand Down Expand Up @@ -242,6 +263,12 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
)
else:
credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, **kwargs))
if not exclude_broker_credential:
broker_credential_args = {"tenant_id": broker_tenant_id, **kwargs}
if broker_client_id:
broker_credential_args["client_id"] = broker_client_id
credentials.append(BrokerCredential(**broker_credential_args))

within_dac.set(False)
super(DefaultAzureCredential, self).__init__(*credentials)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import logging
from contextvars import ContextVar
from string import ascii_letters, digits
Expand Down Expand Up @@ -209,3 +210,11 @@ def get_broker_credential() -> Optional[type]:
return InteractiveBrowserBrokerCredential
except ImportError:
return None


def is_wsl() -> bool:
# This is how MSAL checks for WSL.
uname = platform.uname()
platform_name = getattr(uname, "system", uname[0]).lower()
release = getattr(uname, "release", uname[2]).lower()
return platform_name == "linux" and "microsoft" in release
Loading