-
Notifications
You must be signed in to change notification settings - Fork 429
feat(auth): add allow_sign_up environment variable for OpenID Connect #7267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
885fa46
7f775a0
bc2aecd
e31c5ac
6266b8f
a1a3a79
8e2338a
7f46097
5d741d3
9447844
f4f05f9
61a76e3
6a29926
6d6c32f
3c5af88
958b189
38764ad
adc2904
4779933
2224356
0e5d9f4
ab6a888
c6b9d02
d768c66
00f6a0c
e3fe39f
3a8df07
6623e1e
5f47e3f
a634637
bfc1ad9
ef1de94
aa367a4
c30432c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
TBD_OAUTH2_CLIENT_ID_PREFIX = "TBD_OAUTH2_CLIENT_ID_" | ||
TBD_OAUTH2_USER_ID_PREFIX = "TBD_OAUTH2_USER_ID_" | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
validate_password_format, | ||
) | ||
from phoenix.db import enums, models | ||
from phoenix.db.constants import TBD_OAUTH2_CLIENT_ID_PREFIX, TBD_OAUTH2_USER_ID_PREFIX | ||
from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly | ||
from phoenix.server.api.context import Context | ||
from phoenix.server.api.exceptions import Conflict, NotFound, Unauthorized | ||
|
@@ -93,16 +94,25 @@ async def create_user( | |
input: CreateUserInput, | ||
) -> UserMutationPayload: | ||
validate_email_format(email := input.email) | ||
validate_password_format(password := input.password) | ||
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) | ||
password_hash = await info.context.hash_password(password, salt) | ||
user = models.User( | ||
reset_password=True, | ||
username=input.username, | ||
email=email, | ||
password_hash=password_hash, | ||
password_salt=salt, | ||
) | ||
if input.password: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of relying on the password being blank, can we make add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we need to make this more declarative. The code block below is not self evident and requires tribal knowledge. The UI is also going to have to build affordances for this for it to actually work so using an enum makes the most sense here to me. |
||
validate_password_format(password := input.password) | ||
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) | ||
password_hash = await info.context.hash_password(password, salt) | ||
user = models.User( | ||
reset_password=True, | ||
username=input.username, | ||
email=email, | ||
password_hash=password_hash, | ||
password_salt=salt, | ||
) | ||
else: | ||
user = models.User( | ||
reset_password=False, | ||
username=input.username, | ||
email=email, | ||
oauth2_client_id=f"{TBD_OAUTH2_CLIENT_ID_PREFIX}{secrets.token_hex(4)}", | ||
oauth2_user_id=f"{TBD_OAUTH2_USER_ID_PREFIX}{secrets.token_hex(4)}", | ||
) | ||
async with AsyncExitStack() as stack: | ||
session = await stack.enter_async_context(info.context.db()) | ||
user_role_id = await session.scalar(_select_role_id_by_name(input.role.value)) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -34,6 +34,7 @@ | |||||||||||||||||
) | ||||||||||||||||||
from phoenix.config import get_env_disable_rate_limit | ||||||||||||||||||
from phoenix.db import models | ||||||||||||||||||
from phoenix.db.constants import TBD_OAUTH2_CLIENT_ID_PREFIX, TBD_OAUTH2_USER_ID_PREFIX | ||||||||||||||||||
from phoenix.db.enums import UserRole | ||||||||||||||||||
from phoenix.server.bearer_auth import create_access_and_refresh_tokens | ||||||||||||||||||
from phoenix.server.oauth2 import OAuth2Client | ||||||||||||||||||
|
@@ -169,12 +170,13 @@ async def create_tokens( | |||||||||||||||||
user_info = _parse_user_info(user_info) | ||||||||||||||||||
try: | ||||||||||||||||||
async with request.app.state.db() as session: | ||||||||||||||||||
user = await _ensure_user_exists_and_is_up_to_date( | ||||||||||||||||||
user = await _process_oauth2_user( | ||||||||||||||||||
session, | ||||||||||||||||||
oauth2_client_id=str(oauth2_client.client_id), | ||||||||||||||||||
user_info=user_info, | ||||||||||||||||||
allow_sign_up=oauth2_client.allow_sign_up, | ||||||||||||||||||
) | ||||||||||||||||||
except EmailAlreadyInUse as error: | ||||||||||||||||||
except (EmailAlreadyInUse, SignInNotAllowed) as error: | ||||||||||||||||||
return _redirect_to_login(request=request, error=str(error)) | ||||||||||||||||||
access_token, refresh_token = await create_access_and_refresh_tokens( | ||||||||||||||||||
user=user, | ||||||||||||||||||
|
@@ -198,13 +200,21 @@ async def create_tokens( | |||||||||||||||||
return response | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@dataclass | ||||||||||||||||||
@dataclass(frozen=True) | ||||||||||||||||||
class UserInfo: | ||||||||||||||||||
idp_user_id: str | ||||||||||||||||||
email: str | ||||||||||||||||||
username: Optional[str] | ||||||||||||||||||
profile_picture_url: Optional[str] | ||||||||||||||||||
|
||||||||||||||||||
def __post_init__(self) -> None: | ||||||||||||||||||
object.__setattr__(self, "idp_user_id", self.idp_user_id.strip()) | ||||||||||||||||||
object.__setattr__(self, "email", self.email.strip()) | ||||||||||||||||||
if username := self.username: | ||||||||||||||||||
object.__setattr__(self, "username", username.strip()) | ||||||||||||||||||
if profile_picture_url := self.profile_picture_url: | ||||||||||||||||||
object.__setattr__(self, "profile_picture_url", profile_picture_url.strip()) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not clear to me why these local variables are more clear, since they are just a rearrangement of letters
|
||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _validate_token_data(token_data: dict[str, Any]) -> None: | ||||||||||||||||||
""" | ||||||||||||||||||
|
@@ -235,17 +245,144 @@ def _parse_user_info(user_info: dict[str, Any]) -> UserInfo: | |||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
async def _ensure_user_exists_and_is_up_to_date( | ||||||||||||||||||
session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo | ||||||||||||||||||
async def _process_oauth2_user( | ||||||||||||||||||
session: AsyncSession, | ||||||||||||||||||
/, | ||||||||||||||||||
*, | ||||||||||||||||||
oauth2_client_id: str, | ||||||||||||||||||
user_info: UserInfo, | ||||||||||||||||||
allow_sign_up: bool, | ||||||||||||||||||
) -> models.User: | ||||||||||||||||||
""" | ||||||||||||||||||
Processes an OAuth2 user, either signing in an existing user or creating/updating one. | ||||||||||||||||||
|
||||||||||||||||||
This function handles two main scenarios based on the allow_sign_up parameter: | ||||||||||||||||||
1. When sign-up is not allowed (allow_sign_up=False): | ||||||||||||||||||
- Checks if the user exists and can sign in with the given OAuth2 credentials | ||||||||||||||||||
- Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs) | ||||||||||||||||||
- If the user doesn't exist or has a password set, raises SignInNotAllowed | ||||||||||||||||||
2. When sign-up is allowed (allow_sign_up=True): | ||||||||||||||||||
- Finds the user by OAuth2 credentials (client_id and user_id) | ||||||||||||||||||
- Creates a new user if one doesn't exist, with default member role | ||||||||||||||||||
- Updates the user's email if it has changed | ||||||||||||||||||
- Handles username conflicts by adding a random suffix if needed | ||||||||||||||||||
|
||||||||||||||||||
The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP | ||||||||||||||||||
environment variable for the specific identity provider. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
session: The database session | ||||||||||||||||||
oauth2_client_id: The ID of the OAuth2 client | ||||||||||||||||||
user_info: User information from the OAuth2 provider | ||||||||||||||||||
allow_sign_up: Whether to allow creating new users | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
The user object | ||||||||||||||||||
|
||||||||||||||||||
Raises: | ||||||||||||||||||
SignInNotAllowed: When sign-in is not allowed for the user (user doesn't exist or has a password) | ||||||||||||||||||
EmailAlreadyInUse: When the email is already in use by another account | ||||||||||||||||||
""" # noqa: E501 | ||||||||||||||||||
if not allow_sign_up: | ||||||||||||||||||
return await _get_existing_oauth2_user( | ||||||||||||||||||
session, | ||||||||||||||||||
oauth2_client_id=oauth2_client_id, | ||||||||||||||||||
user_info=user_info, | ||||||||||||||||||
) | ||||||||||||||||||
return await _create_or_update_user( | ||||||||||||||||||
session, | ||||||||||||||||||
oauth2_client_id=oauth2_client_id, | ||||||||||||||||||
user_info=user_info, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
async def _get_existing_oauth2_user( | ||||||||||||||||||
session: AsyncSession, | ||||||||||||||||||
/, | ||||||||||||||||||
*, | ||||||||||||||||||
oauth2_client_id: str, | ||||||||||||||||||
user_info: UserInfo, | ||||||||||||||||||
) -> models.User: | ||||||||||||||||||
""" | ||||||||||||||||||
Signs in an existing user with OAuth2 credentials by looking up the user by email. | ||||||||||||||||||
|
||||||||||||||||||
This function attempts to find a user with the provided email and verifies that: | ||||||||||||||||||
1. The user exists | ||||||||||||||||||
2. The user does not have a password set (password_hash is None) | ||||||||||||||||||
3. The user has OAuth2 credentials set | ||||||||||||||||||
4. The user's OAuth2 credentials match the provided ones, or are temporary placeholders | ||||||||||||||||||
|
||||||||||||||||||
If the user has temporary OAuth2 credentials (prefixed with TBD_OAUTH2_CLIENT_ID_ or | ||||||||||||||||||
TBD_OAUTH2_USER_ID_), these are updated with the actual credentials from the OAuth2 provider. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These comments look out of date. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're right. very sharp eyes |
||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
session: The database session | ||||||||||||||||||
oauth2_client_id: The ID of the OAuth2 client | ||||||||||||||||||
user_info: User information from the OAuth2 provider | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
The signed-in user | ||||||||||||||||||
|
||||||||||||||||||
Raises: | ||||||||||||||||||
SignInNotAllowed: When sign-in is not allowed for the user (user doesn't exist, has a | ||||||||||||||||||
password, or has mismatched OAuth2 credentials) | ||||||||||||||||||
""" # noqa: E501 | ||||||||||||||||||
email = user_info.email | ||||||||||||||||||
stmt = select(models.User).filter_by(email=email).options(joinedload(models.User.role)) | ||||||||||||||||||
user = await session.scalar(stmt) | ||||||||||||||||||
if ( | ||||||||||||||||||
user is None | ||||||||||||||||||
or user.password_hash is not None | ||||||||||||||||||
or user.oauth2_client_id is None | ||||||||||||||||||
or user.oauth2_user_id is None | ||||||||||||||||||
or ( | ||||||||||||||||||
user.oauth2_user_id != user_info.idp_user_id | ||||||||||||||||||
and not user.oauth2_user_id.startswith(TBD_OAUTH2_USER_ID_PREFIX) | ||||||||||||||||||
) | ||||||||||||||||||
or ( | ||||||||||||||||||
user.oauth2_client_id != oauth2_client_id | ||||||||||||||||||
and not user.oauth2_client_id.startswith(TBD_OAUTH2_CLIENT_ID_PREFIX) | ||||||||||||||||||
) | ||||||||||||||||||
): | ||||||||||||||||||
raise SignInNotAllowed(f"Sign in is not allowed for {email}.") | ||||||||||||||||||
if user.oauth2_client_id.startswith(TBD_OAUTH2_CLIENT_ID_PREFIX): | ||||||||||||||||||
user.oauth2_client_id = oauth2_client_id | ||||||||||||||||||
if user.oauth2_user_id.startswith(TBD_OAUTH2_USER_ID_PREFIX): | ||||||||||||||||||
user.oauth2_user_id = user_info.idp_user_id | ||||||||||||||||||
if user in session.dirty: | ||||||||||||||||||
await session.flush() | ||||||||||||||||||
return user | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
async def _create_or_update_user( | ||||||||||||||||||
session: AsyncSession, | ||||||||||||||||||
/, | ||||||||||||||||||
*, | ||||||||||||||||||
oauth2_client_id: str, | ||||||||||||||||||
user_info: UserInfo, | ||||||||||||||||||
) -> models.User: | ||||||||||||||||||
""" | ||||||||||||||||||
Creates a new user or updates an existing one with OAuth2 credentials. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
session: The database session | ||||||||||||||||||
oauth2_client_id: The ID of the OAuth2 client | ||||||||||||||||||
user_info: User information from the OAuth2 provider | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
The created or updated user | ||||||||||||||||||
|
||||||||||||||||||
Raises: | ||||||||||||||||||
EmailAlreadyInUse: When the email is already in use by another account | ||||||||||||||||||
""" | ||||||||||||||||||
user = await _get_user( | ||||||||||||||||||
session, | ||||||||||||||||||
oauth2_client_id=oauth2_client_id, | ||||||||||||||||||
idp_user_id=user_info.idp_user_id, | ||||||||||||||||||
) | ||||||||||||||||||
if user is None: | ||||||||||||||||||
user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info) | ||||||||||||||||||
elif user.email != user_info.email: | ||||||||||||||||||
if user.email != user_info.email: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change looks like it's intended to handle a case where the user is newly created, but there is a mismatch between the newly created user's email and the email from the OIDC token. But that shouldn't be possible since we use that exact email when creating the user. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, this one is unintentional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure how it got changed. good catch |
||||||||||||||||||
user = await _update_user_email(session, user_id=user.id, email=user_info.email) | ||||||||||||||||||
return user | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -366,6 +503,10 @@ class EmailAlreadyInUse(Exception): | |||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class SignInNotAllowed(Exception): | ||||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse: | ||||||||||||||||||
""" | ||||||||||||||||||
Creates a RedirectResponse to the login page to display an error message. | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow what these are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I get it moving down - but I think you should add a docstring here sincde as raw constants it's hard to tell what they do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, these are dummy values that are being added to satisfy the
exactly_one_auth_method
check constraint:phoenix/src/phoenix/db/models.py
Line 1085 in f882faf
I think it makes sense to relax that constraint since it too restrictive in a world where an OAuth2 user can be added before their OAuth2 client and user IDs have been recorded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then the dummy values become unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like you would prefer to keep the constraint and use dummy values @mikeldking?