diff --git a/app/schema.graphql b/app/schema.graphql index 92d5aaf3dd..cc4e6d3a0a 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -465,9 +465,10 @@ type CreateUserApiKeyMutationPayload { input CreateUserInput { email: String! username: String! - password: String! + password: String = null role: UserRoleInput! sendWelcomeEmail: Boolean = false + authMethod: AuthMethod = LOCAL } scalar CronExpression diff --git a/app/src/pages/settings/__generated__/NewUserDialogMutation.graphql.ts b/app/src/pages/settings/__generated__/NewUserDialogMutation.graphql.ts index d316207110..48e2d0c002 100644 --- a/app/src/pages/settings/__generated__/NewUserDialogMutation.graphql.ts +++ b/app/src/pages/settings/__generated__/NewUserDialogMutation.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<2ac4cdb4b8001c4f40165b494129ff58>> + * @generated SignedSource<> * @lightSyntaxTransform * @nogrep */ @@ -9,10 +9,12 @@ // @ts-nocheck import { ConcreteRequest } from 'relay-runtime'; +export type AuthMethod = "LOCAL" | "OAUTH2"; export type UserRoleInput = "ADMIN" | "MEMBER"; export type CreateUserInput = { + authMethod?: AuthMethod | null; email: string; - password: string; + password?: string | null; role: UserRoleInput; sendWelcomeEmail?: boolean | null; username: string; diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 1147535042..b45e46b20a 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import re @@ -839,6 +841,7 @@ class OAuth2ClientConfig: client_id: str client_secret: str oidc_config_url: str + allow_sign_up: bool @classmethod def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": @@ -870,6 +873,7 @@ def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": f"An OpenID Connect configuration URL must be set for the {idp_name} OAuth2 IDP " f"via the {oidc_config_url_env_var} environment variable" ) + allow_sign_up = get_env_oauth2_allow_sign_up(idp_name) parsed_oidc_config_url = urlparse(oidc_config_url) is_local_oidc_config_url = parsed_oidc_config_url.hostname in ("localhost", "127.0.0.1") if parsed_oidc_config_url.scheme != "https" and not is_local_oidc_config_url: @@ -886,17 +890,51 @@ def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": client_id=client_id, client_secret=client_secret, oidc_config_url=oidc_config_url, + allow_sign_up=allow_sign_up, ) def get_env_oauth2_settings() -> list[OAuth2ClientConfig]: """ - Get OAuth2 settings from environment variables. - """ + Retrieves and validates OAuth2/OpenID Connect (OIDC) identity provider configurations from environment variables. + + This function scans the environment for OAuth2 configuration variables and returns a list of + configured identity providers. It supports multiple identity providers simultaneously. + + Environment Variable Pattern: + PHOENIX_OAUTH2_{IDP_NAME}_{CONFIG_TYPE} + + Required Environment Variables for each IDP: + - PHOENIX_OAUTH2_{IDP_NAME}_CLIENT_ID: The OAuth2 client ID issued by the identity provider + - PHOENIX_OAUTH2_{IDP_NAME}_CLIENT_SECRET: The OAuth2 client secret issued by the identity provider + - PHOENIX_OAUTH2_{IDP_NAME}_OIDC_CONFIG_URL: The OpenID Connect configuration URL (must be HTTPS) + Optional Environment Variables: + - PHOENIX_OAUTH2_{IDP_NAME}_DISPLAY_NAME: A user-friendly name for the identity provider + - PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP: Whether to allow new user registration (defaults to True) + When set to False, the system will check if the user exists in the database by their email address. + If the user does not exist or has a password set, they will be redirected to the login page with + an error message. + + Returns: + list[OAuth2ClientConfig]: A list of configured OAuth2 identity providers, sorted alphabetically by IDP name. + Each OAuth2ClientConfig contains the validated configuration for one identity provider. + + Raises: + ValueError: If required environment variables are missing or invalid. + Specifically, if the OIDC configuration URL is not HTTPS (except for localhost). + + Example: + To configure Google as an identity provider, set these environment variables: + PHOENIX_OAUTH2_GOOGLE_CLIENT_ID=your_client_id + PHOENIX_OAUTH2_GOOGLE_CLIENT_SECRET=your_client_secret + PHOENIX_OAUTH2_GOOGLE_OIDC_CONFIG_URL=https://accounts.google.com/.well-known/openid-configuration + PHOENIX_OAUTH2_GOOGLE_DISPLAY_NAME=Google (optional) + PHOENIX_OAUTH2_GOOGLE_ALLOW_SIGN_UP=true (optional, defaults to true) + """ # noqa: E501 idp_names = set() pattern = re.compile( - r"^PHOENIX_OAUTH2_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|OIDC_CONFIG_URL)$" + r"^PHOENIX_OAUTH2_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|OIDC_CONFIG_URL|ALLOW_SIGN_UP)$" ) for env_var in os.environ: if (match := pattern.match(env_var)) is not None and (idp_name := match.group(1).lower()): @@ -904,6 +942,27 @@ def get_env_oauth2_settings() -> list[OAuth2ClientConfig]: return [OAuth2ClientConfig.from_env(idp_name) for idp_name in sorted(idp_names)] +def get_env_oauth2_allow_sign_up(idp_name: str) -> bool: + """Retrieves the allow_sign_up setting for a specific OAuth2 identity provider. + + This function determines whether new user registration is allowed for the specified identity provider. + When set to False, the system will check if the user exists in the database by their email address. + If the user does not exist or has a password set, they will be redirected to the login page with + an error message. + + Parameters: + idp_name (str): The name of the identity provider (e.g., 'google', 'aws_cognito', 'microsoft_entra_id') + + Returns: + bool: True if new user registration is allowed (default), False otherwise + + Environment Variable: + PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP: Controls whether new user registration is allowed (defaults to True if not set) + """ # noqa: E501 + env_var = f"PHOENIX_OAUTH2_{idp_name}_ALLOW_SIGN_UP".upper() + return _bool_val(env_var, True) + + PHOENIX_DIR = Path(__file__).resolve().parent # Server config SERVER_DIR = PHOENIX_DIR / "server" diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index 943165c202..aa7f05857b 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -4,9 +4,8 @@ from sqlalchemy.orm import InstrumentedAttribute from phoenix.db import models -from phoenix.db.models import AuthMethod -__all__ = ["AuthMethod", "UserRole", "COLUMN_ENUMS"] +__all__ = ["UserRole", "COLUMN_ENUMS"] class UserRole(Enum): diff --git a/src/phoenix/db/facilitator.py b/src/phoenix/db/facilitator.py index 2196f2b1be..bec41633c6 100644 --- a/src/phoenix/db/facilitator.py +++ b/src/phoenix/db/facilitator.py @@ -196,7 +196,7 @@ async def _ensure_admins( ) assert admin_role_id is not None, "Admin role not found in database" for email, username in admins.items(): - values = dict( + user = models.User( user_role_id=admin_role_id, username=username, email=email, @@ -204,7 +204,8 @@ async def _ensure_admins( password_hash=secrets.token_bytes(DEFAULT_SECRET_LENGTH), reset_password=True, ) - await session.execute(insert(models.User).values(values)) + session.add(user) + await session.flush() if email_sender is None: return for exc in await gather( diff --git a/src/phoenix/db/migrations/versions/6a88424799fe_update_users_with_auth_method.py b/src/phoenix/db/migrations/versions/6a88424799fe_update_users_with_auth_method.py new file mode 100644 index 0000000000..16aab43305 --- /dev/null +++ b/src/phoenix/db/migrations/versions/6a88424799fe_update_users_with_auth_method.py @@ -0,0 +1,168 @@ +"""Add auth_method column to users table and migrate existing authentication data. + +This migration: +1. Adds a new 'auth_method' column to the users table that indicates whether a user + authenticates via local password ('LOCAL') or external OAuth2 ('OAUTH2') +2. Migrates existing authentication data to populate the new column: + - Sets 'LOCAL' for users with password_hash + - Sets 'OAUTH2' for users with OAuth2 credentials +3. Adds appropriate constraints to ensure data integrity: + - NOT NULL constraint on auth_method + - 'valid_auth_method': ensures only 'LOCAL' or 'OAUTH2' values + - 'local_auth_no_oauth': ensures LOCAL users do not have OAuth2 credentials + - 'oauth2_auth_no_password': ensures OAUTH2 users do not have password credentials +4. Removes legacy constraints that are replaced by the new column: + - 'exactly_one_auth_method': replaced by auth_method column and its constraints + - 'oauth2_client_id_and_user_id': replaced by auth_method column and its constraints +5. Drops redundant single column indices: + - 'ix_users_oauth2_client_id' and 'ix_users_oauth2_user_id' are removed as they are + redundant with the unique constraint 'uq_users_oauth2_client_id_oauth2_user_id', + which already provides the necessary composite index for lookups + +The migration uses batch_alter_table to ensure compatibility with both SQLite and PostgreSQL. +This approach allows us to: +- Add the column as nullable initially +- Update the values based on existing authentication data +- Make the column NOT NULL after populating +- Add appropriate constraints +- Remove legacy constraints +- Drop redundant indices + +The downgrade path: +1. Recreates the legacy constraints: + - 'exactly_one_auth_method': ensures exactly one auth method is set + - 'oauth2_client_id_and_user_id': ensures OAuth2 credentials are consistent +2. Removes the auth_method column and its associated constraints +3. Recreates the single column indices to maintain backward compatibility: + - 'ix_users_oauth2_client_id' + - 'ix_users_oauth2_user_id' + +Revision ID: 6a88424799fe +Revises: 8a3764fe7f1a +Create Date: 2025-05-01 08:08:22.700715 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "6a88424799fe" +down_revision: Union[str, None] = "8a3764fe7f1a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade the database schema to include the auth_method column. + + This function: + 1. Adds the auth_method column as nullable + 2. Populates the column based on existing authentication data: + - 'LOCAL' for users with password_hash + - 'OAUTH2' for users with OAuth2 credentials + 3. Makes the column NOT NULL after populating + 4. Adds CHECK constraints to ensure data integrity: + - 'valid_auth_method': ensures only 'LOCAL' or 'OAUTH2' values + - 'local_auth_no_oauth': ensures auth_method matches credentials + - 'oauth2_auth_no_password': ensures auth_method matches credentials + 5. Removes legacy constraints that are replaced by the new column: + - 'exactly_one_auth_method' + - 'oauth2_client_id_and_user_id' + 6. Drops redundant single column indices: + - 'ix_users_oauth2_client_id' and 'ix_users_oauth2_user_id' are removed as they are + redundant with the unique constraint 'uq_users_oauth2_client_id_oauth2_user_id', + which already provides the necessary composite index for lookups + + The implementation uses batch_alter_table for compatibility with both + SQLite and PostgreSQL databases. + + Raises: + sqlalchemy.exc.SQLAlchemyError: If database operations fail + """ + with op.batch_alter_table("users") as batch_op: + # For SQLite, first add the column as nullable + batch_op.add_column(sa.Column("auth_method", sa.String, nullable=True)) + + with op.batch_alter_table("users") as batch_op: + batch_op.execute(""" + UPDATE users + SET auth_method = CASE + WHEN password_hash IS NOT NULL THEN 'LOCAL' ELSE 'OAUTH2' END + """) + # Make the column non-nullable + batch_op.alter_column("auth_method", nullable=False, existing_nullable=True) + + # Drop both old constraints as they're now redundant + # exactly_one_auth_method is covered by the new auth_method constraints + # oauth2_client_id_and_user_id is covered by the new auth_method constraints + batch_op.drop_constraint("exactly_one_auth_method", type_="check") + batch_op.drop_constraint("oauth2_client_id_and_user_id", type_="check") + + # Drop redundant single column indices, because a composite index already + # exists in the uniqueness constraint for (client_id, user_id) + batch_op.drop_index("ix_users_oauth2_client_id") + batch_op.drop_index("ix_users_oauth2_user_id") + + # Add CHECK constraint to ensure only valid values are allowed + batch_op.create_check_constraint( + "valid_auth_method", + "auth_method IN ('LOCAL', 'OAUTH2')", + ) + batch_op.create_check_constraint( + "local_auth_no_oauth", + "auth_method != 'LOCAL' OR oauth2_client_id IS NULL", + ) + batch_op.create_check_constraint( + "oauth2_auth_no_password", + "auth_method != 'OAUTH2' OR password_hash IS NULL", + ) + + +def downgrade() -> None: + """Downgrade the database schema by removing the auth_method column. + + This function: + 1. Recreates the legacy constraints that were removed in the upgrade: + - 'oauth2_client_id_and_user_id': ensures OAuth2 credentials are consistent + - 'exactly_one_auth_method': ensures exactly one auth method is set + 2. Removes the auth_method column and its associated CHECK constraints: + - 'oauth2_auth_no_password' + - 'local_auth_no_oauth' + - 'valid_auth_method' + 3. Recreates the single column indices to maintain backward compatibility: + - 'ix_users_oauth2_client_id' + - 'ix_users_oauth2_user_id' + + The implementation uses batch_alter_table to ensure compatibility with both + SQLite and PostgreSQL databases. + + Raises: + sqlalchemy.exc.SQLAlchemyError: If database operations fail + """ + # Use batch_alter_table for SQLite compatibility + # This ensures the downgrade works on both SQLite and PostgreSQL + with op.batch_alter_table("users") as batch_op: + # Drop the CHECK constraint and column + batch_op.drop_constraint("oauth2_auth_no_password", type_="check") + batch_op.drop_constraint("local_auth_no_oauth", type_="check") + batch_op.drop_constraint("valid_auth_method", type_="check") + + # Recreate single column indices + batch_op.create_index("ix_users_oauth2_user_id", ["oauth2_user_id"]) + batch_op.create_index("ix_users_oauth2_client_id", ["oauth2_client_id"]) + + # Recreate both old constraints that were dropped in upgrade + batch_op.create_check_constraint( + "oauth2_client_id_and_user_id", + "(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)", + ) + batch_op.create_check_constraint( + "exactly_one_auth_method", + "(password_hash IS NULL) != (oauth2_client_id IS NULL)", + ) + + # Remove added column + batch_op.drop_column("auth_method") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index b5d5b0f9ff..c7430c3a65 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from enum import Enum from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict, cast import sqlalchemy.sql as sql @@ -23,7 +22,6 @@ case, func, insert, - not_, select, text, ) @@ -42,6 +40,7 @@ from sqlalchemy.sql import Values, column, compiler, expression, literal, roles, union_all from sqlalchemy.sql.compiler import SQLCompiler from sqlalchemy.sql.functions import coalesce +from typing_extensions import TypeAlias from phoenix.config import get_env_database_schema from phoenix.datetime_utils import normalize_datetime @@ -147,9 +146,7 @@ def render_values_w_union( return compiler.process(subquery, from_linter=from_linter, **kw) -class AuthMethod(Enum): - LOCAL = "LOCAL" - OAUTH2 = "OAUTH2" +AuthMethod: TypeAlias = Literal["LOCAL", "OAUTH2"] class JSONB(JSON): @@ -1152,8 +1149,11 @@ class User(Base): password_hash: Mapped[Optional[bytes]] password_salt: Mapped[Optional[bytes]] reset_password: Mapped[bool] - oauth2_client_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) - oauth2_user_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) + oauth2_client_id: Mapped[Optional[str]] + oauth2_user_id: Mapped[Optional[str]] + auth_method: Mapped[AuthMethod] = mapped_column( + CheckConstraint("auth_method IN ('LOCAL', 'OAUTH2')", name="valid_auth_method") + ) created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( UtcTimeStamp, server_default=func.now(), onupdate=func.now() @@ -1169,28 +1169,10 @@ class User(Base): ) api_keys: Mapped[list["ApiKey"]] = relationship("ApiKey", back_populates="user") - @hybrid_property - def auth_method(self) -> Optional[str]: - if self.password_hash is not None: - return AuthMethod.LOCAL.value - elif self.oauth2_client_id is not None: - return AuthMethod.OAUTH2.value - return None - - @auth_method.inplace.expression - @classmethod - def _auth_method_expression(cls) -> ColumnElement[Optional[str]]: - return case( - ( - not_(cls.password_hash.is_(None)), - AuthMethod.LOCAL.value, - ), - ( - not_(cls.oauth2_client_id.is_(None)), - AuthMethod.OAUTH2.value, - ), - else_=None, - ) + __mapper_args__ = { + "polymorphic_on": "auth_method", + "polymorphic_identity": None, # Base class is abstract + } __table_args__ = ( CheckConstraint( @@ -1198,12 +1180,12 @@ def _auth_method_expression(cls) -> ColumnElement[Optional[str]]: name="password_hash_and_salt", ), CheckConstraint( - "(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)", - name="oauth2_client_id_and_user_id", + "auth_method != 'LOCAL' OR oauth2_client_id IS NULL", + name="local_auth_no_oauth", ), CheckConstraint( - "(password_hash IS NULL) != (oauth2_client_id IS NULL)", - name="exactly_one_auth_method", + "auth_method != 'OAUTH2' OR password_hash IS NULL", + name="oauth2_auth_no_password", ), UniqueConstraint( "oauth2_client_id", @@ -1212,6 +1194,59 @@ def _auth_method_expression(cls) -> ColumnElement[Optional[str]]: dict(sqlite_autoincrement=True), ) + def __init__(self, **kwargs: Any) -> None: + if "auth_method" not in kwargs: + if kwargs.get("password_hash") and kwargs.get("password_salt"): + kwargs["auth_method"] = "LOCAL" + else: + kwargs["auth_method"] = "OAUTH2" + super().__init__(**kwargs) + + +class LocalUser(User): + __mapper_args__ = { + "polymorphic_identity": "LOCAL", + } + + def __init__( + self, + *, + email: str, + username: str, + password_hash: bytes, + password_salt: bytes, + reset_password: bool = True, + ) -> None: + if not password_hash or not password_salt: + raise ValueError("password_hash and password_salt are required for LocalUser") + super().__init__( + email=email, + username=username, + password_hash=password_hash, + password_salt=password_salt, + reset_password=reset_password, + auth_method="LOCAL", + ) + + +class OAuth2User(User): + __mapper_args__ = { + "polymorphic_identity": "OAUTH2", + } + + def __init__( + self, + *, + email: str, + username: str, + ) -> None: + super().__init__( + email=email, + username=username, + reset_password=False, + auth_method="OAUTH2", + ) + class PasswordResetToken(Base): __tablename__ = "password_reset_tokens" diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 87dc737e5f..00b9fc12af 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -26,8 +26,9 @@ from phoenix.db import enums, models from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly from phoenix.server.api.context import Context -from phoenix.server.api.exceptions import Conflict, NotFound, Unauthorized +from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput +from phoenix.server.api.types.AuthMethod import AuthMethod from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.api.types.User import User, to_gql_user from phoenix.server.bearer_auth import PhoenixUser @@ -40,9 +41,17 @@ class CreateUserInput: email: str username: str - password: str + password: Optional[str] = None role: UserRoleInput send_welcome_email: Optional[bool] = False + auth_method: Optional[AuthMethod] = AuthMethod.LOCAL + + def __post_init__(self) -> None: + if self.auth_method is AuthMethod.OAUTH2: + if self.password: + raise BadRequest("Password is not allowed for OAuth2 authentication") + elif not self.password: + raise BadRequest("Password is required for local authentication") @strawberry.input @@ -53,9 +62,9 @@ class PatchViewerInput: def __post_init__(self) -> None: if not self.new_username and not self.new_password: - raise ValueError("At least one field must be set") + raise BadRequest("At least one field must be set") if self.new_password and not self.current_password: - raise ValueError("current_password is required when modifying password") + raise BadRequest("current_password is required when modifying password") if self.new_password: PASSWORD_REQUIREMENTS.validate(self.new_password) @@ -69,7 +78,7 @@ class PatchUserInput: def __post_init__(self) -> None: if not self.new_role and not self.new_username and not self.new_password: - raise ValueError("At least one field must be set") + raise BadRequest("At least one field must be set") if self.new_password: PASSWORD_REQUIREMENTS.validate(self.new_password) @@ -92,17 +101,25 @@ async def create_user( info: Info[Context, None], input: CreateUserInput, ) -> UserMutationPayload: - validate_email_format(email := input.email) - validate_password_format(password := input.password) - salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) - password_hash = await info.context.hash_password(password, salt) - user = models.User( - reset_password=True, - username=input.username, - email=email, - password_hash=password_hash, - password_salt=salt, - ) + user: models.User + if input.auth_method is AuthMethod.OAUTH2: + user = models.OAuth2User( + username=input.username, + email=input.email, + ) + else: + assert input.password + validate_email_format(input.email) + validate_password_format(input.password) + salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) + password_hash = await info.context.hash_password(input.password, salt) + user = models.LocalUser( + reset_password=True, + username=input.username, + email=input.email, + password_hash=password_hash, + password_salt=salt, + ) async with AsyncExitStack() as stack: session = await stack.enter_async_context(info.context.db()) user_role_id = await session.scalar(_select_role_id_by_name(input.role.value)) @@ -150,7 +167,7 @@ async def patch_user( raise NotFound(f"Role {input.new_role.value} not found") user.user_role_id = user_role_id if password := input.new_password: - if user.auth_method != enums.AuthMethod.LOCAL.value: + if user.auth_method != "LOCAL": raise Conflict("Cannot modify password for non-local user") validate_password_format(password) user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) @@ -183,7 +200,7 @@ async def patch_viewer( raise NotFound("User not found") stack.enter_context(session.no_autoflush) if password := input.new_password: - if user.auth_method != enums.AuthMethod.LOCAL.value: + if user.auth_method != "LOCAL": raise Conflict("Cannot modify password for non-local user") if not ( current_password := input.current_password diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index f32e83b565..9fd403a2c0 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -32,7 +32,7 @@ validate_password_format, ) from phoenix.config import get_base_url, get_env_disable_rate_limit, get_env_host_root_path -from phoenix.db import enums, models +from phoenix.db import models from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens from phoenix.server.email.types import EmailSender from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_ip_rate_limiter @@ -207,7 +207,7 @@ async def initiate_password_reset(request: Request) -> Response: joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id) ) ) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or user.auth_method != "LOCAL": # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) token_store: TokenStore = request.app.state.get_token_store() @@ -244,7 +244,7 @@ async def reset_password(request: Request) -> Response: assert (user_id := claims.subject) async with request.app.state.db() as session: user = await session.scalar(select(models.User).filter_by(id=int(user_id))) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or user.auth_method != "LOCAL": # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) validate_password_format(password) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 08399b785c..91d63fd40f 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -169,12 +169,13 @@ async def create_tokens( user_info = _parse_user_info(user_info) try: async with request.app.state.db() as session: - user = await _ensure_user_exists_and_is_up_to_date( + user = await _process_oauth2_user( session, oauth2_client_id=str(oauth2_client.client_id), user_info=user_info, + allow_sign_up=oauth2_client.allow_sign_up, ) - except EmailAlreadyInUse as error: + except (EmailAlreadyInUse, SignInNotAllowed) as error: return _redirect_to_login(request=request, error=str(error)) access_token, refresh_token = await create_access_and_refresh_tokens( user=user, @@ -198,12 +199,24 @@ async def create_tokens( return response -@dataclass +@dataclass(frozen=True) class UserInfo: idp_user_id: str email: str - username: Optional[str] - profile_picture_url: Optional[str] + username: Optional[str] = None + profile_picture_url: Optional[str] = None + + def __post_init__(self) -> None: + if not (idp_user_id := (self.idp_user_id or "").strip()): + raise ValueError("idp_user_id cannot be empty") + object.__setattr__(self, "idp_user_id", idp_user_id) + if not (email := (self.email or "").strip()): + raise ValueError("email cannot be empty") + object.__setattr__(self, "email", email) + if username := (self.username or "").strip(): + object.__setattr__(self, "username", username) + if profile_picture_url := (self.profile_picture_url or "").strip(): + object.__setattr__(self, "profile_picture_url", profile_picture_url) def _validate_token_data(token_data: dict[str, Any]) -> None: @@ -235,9 +248,148 @@ def _parse_user_info(user_info: dict[str, Any]) -> UserInfo: ) -async def _ensure_user_exists_and_is_up_to_date( - session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo +async def _process_oauth2_user( + session: AsyncSession, + /, + *, + oauth2_client_id: str, + user_info: UserInfo, + allow_sign_up: bool, +) -> models.User: + """ + Processes an OAuth2 user, either signing in an existing user or creating/updating one. + + This function handles two main scenarios based on the allow_sign_up parameter: + 1. When sign-up is not allowed (allow_sign_up=False): + - Checks if the user exists and can sign in with the given OAuth2 credentials + - Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs) + - If the user doesn't exist or has a password set, raises SignInNotAllowed + 2. When sign-up is allowed (allow_sign_up=True): + - Finds the user by OAuth2 credentials (client_id and user_id) + - Creates a new user if one doesn't exist, with default member role + - Updates the user's email if it has changed + - Handles username conflicts by adding a random suffix if needed + + The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP + environment variable for the specific identity provider. + + Args: + session: The database session + oauth2_client_id: The ID of the OAuth2 client + user_info: User information from the OAuth2 provider + allow_sign_up: Whether to allow creating new users + + Returns: + The user object + + Raises: + SignInNotAllowed: When sign-in is not allowed for the user (user doesn't exist or has a password) + EmailAlreadyInUse: When the email is already in use by another account + """ # noqa: E501 + if not allow_sign_up: + return await _get_existing_oauth2_user( + session, + oauth2_client_id=oauth2_client_id, + user_info=user_info, + ) + return await _create_or_update_user( + session, + oauth2_client_id=oauth2_client_id, + user_info=user_info, + ) + + +async def _get_existing_oauth2_user( + session: AsyncSession, + /, + *, + oauth2_client_id: str, + user_info: UserInfo, ) -> models.User: + """ + Signs in an existing user with OAuth2 credentials. + + This function attempts to find a user in two ways: + 1. First by OAuth2 credentials (client_id and user_id) + 2. If not found, then by email + + When found by OAuth2 credentials: + - Updates the user's email if it has changed from the IDP info + + When found by email: + - Verifies the user is an OAuth2 user (no password set) + - Verifies either: + a) The user has no OAuth2 credentials set yet, or + b) The user's OAuth2 credentials match the provided ones + - Updates the user's OAuth2 credentials if they were not set + + Args: + session: The database session + oauth2_client_id: The ID of the OAuth2 client + user_info: User information from the OAuth2 provider + + Returns: + The signed-in user + + Raises: + SignInNotAllowed: When sign-in is not allowed for the user (user doesn't exist, has a + password, or has mismatched OAuth2 credentials) + """ # noqa: E501 + if not (email := (user_info.email or "").strip()): + raise ValueError("Email is required.") + if not (oauth2_user_id := (user_info.idp_user_id or "").strip()): + raise ValueError("OAuth2 user ID is required.") + if not (oauth2_client_id := (oauth2_client_id or "").strip()): + raise ValueError("OAuth2 client ID is required.") + username = (user_info.username or "").strip() + profile_picture_url = (user_info.profile_picture_url or "").strip() + stmt = select(models.User).options(joinedload(models.User.role)) + if user := await session.scalar( + stmt.filter_by(oauth2_client_id=oauth2_client_id, oauth2_user_id=oauth2_user_id) + ): + if email and email != user.email: + user.email = email + else: + user = await session.scalar(stmt.filter_by(email=email)) + if user is None or not isinstance(user, models.OAuth2User): + raise SignInNotAllowed("Sign in is not allowed.") + if oauth2_client_id != user.oauth2_client_id: + user.oauth2_client_id = oauth2_client_id + user.oauth2_user_id = oauth2_user_id + elif not user.oauth2_user_id: + user.oauth2_user_id = oauth2_user_id + elif oauth2_user_id != user.oauth2_user_id: + raise SignInNotAllowed("Sign in is not allowed.") + if username and username != user.username: + user.username = username + if profile_picture_url != user.profile_picture_url: + user.profile_picture_url = profile_picture_url + if user in session.dirty: + await session.flush() + return user + + +async def _create_or_update_user( + session: AsyncSession, + /, + *, + oauth2_client_id: str, + user_info: UserInfo, +) -> models.User: + """ + Creates a new user or updates an existing one with OAuth2 credentials. + + Args: + session: The database session + oauth2_client_id: The ID of the OAuth2 client + user_info: User information from the OAuth2 provider + + Returns: + The created or updated user + + Raises: + EmailAlreadyInUse: When the email is already in use by another account + """ user = await _get_user( session, oauth2_client_id=oauth2_client_id, @@ -303,6 +455,7 @@ async def _create_user( email=email, profile_picture_url=user_info.profile_picture_url, reset_password=False, + auth_method="OAUTH2", ) ) assert isinstance(user_id, int) @@ -366,6 +519,10 @@ class EmailAlreadyInUse(Exception): pass +class SignInNotAllowed(Exception): + pass + + def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse: """ Creates a RedirectResponse to the login page to display an error message. diff --git a/src/phoenix/server/oauth2.py b/src/phoenix/server/oauth2.py index 25ee2c4286..65699e12bf 100644 --- a/src/phoenix/server/oauth2.py +++ b/src/phoenix/server/oauth2.py @@ -19,8 +19,13 @@ class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[ client_cls = AsyncHttpxOAuth2Client - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, allow_sign_up: bool, **kwargs: Any) -> None: super().__init__(framework=None, *args, **kwargs) + self._allow_sign_up = allow_sign_up + + @property + def allow_sign_up(self) -> bool: + return self._allow_sign_up class OAuth2Clients: @@ -35,6 +40,7 @@ def add_client(self, config: OAuth2ClientConfig) -> None: client_secret=config.client_secret, server_metadata_url=config.oidc_config_url, client_kwargs={"scope": "openid email profile"}, + allow_sign_up=config.allow_sign_up, ) assert isinstance(client, OAuth2Client) self._clients[config.idp_name] = client diff --git a/tests/integration/_helpers.py b/tests/integration/_helpers.py index 236205cadc..849b6b5bae 100644 --- a/tests/integration/_helpers.py +++ b/tests/integration/_helpers.py @@ -167,8 +167,15 @@ def create_user( *, profile: _Profile, send_welcome_email: bool = False, + local: bool = True, ) -> _User: - return _create_user(self, role=role, profile=profile, send_welcome_email=send_welcome_email) + return _create_user( + self, + role=role, + profile=profile, + send_welcome_email=send_welcome_email, + local=local, + ) def delete_users(self, *users: Union[_GqlId, _User]) -> None: return _delete_users(self, users=users) @@ -821,6 +828,7 @@ def _create_user( role: UserRoleInput, profile: _Profile, send_welcome_email: bool = False, + local: bool = True, ) -> _User: email = profile.email password = profile.password @@ -828,6 +836,8 @@ def _create_user( args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] if username: args.append(f'username:"{username}"') + if not local: + args.append("authMethod:OAUTH2") args.append(f"sendWelcomeEmail:{str(send_welcome_email).lower()}") out = "user{id email role{name}}" query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" diff --git a/tests/integration/auth/conftest.py b/tests/integration/auth/conftest.py index 2725e19324..4ce01b5846 100644 --- a/tests/integration/auth/conftest.py +++ b/tests/integration/auth/conftest.py @@ -84,6 +84,19 @@ def _app( f"PHOENIX_OAUTH2_{_oidc_server}_OIDC_CONFIG_URL".upper(), f"{_oidc_server.base_url}/.well-known/openid-configuration", ), + ( + f"PHOENIX_OAUTH2_{_oidc_server}_NO_SIGN_UP_CLIENT_ID".upper(), + _oidc_server.client_id, + ), + ( + f"PHOENIX_OAUTH2_{_oidc_server}_NO_SIGN_UP_CLIENT_SECRET".upper(), + _oidc_server.client_secret, + ), + ( + f"PHOENIX_OAUTH2_{_oidc_server}_NO_SIGN_UP_OIDC_CONFIG_URL".upper(), + f"{_oidc_server.base_url}/.well-known/openid-configuration", + ), + (f"PHOENIX_OAUTH2_{_oidc_server}_NO_SIGN_UP_ALLOW_SIGN_UP".upper(), "false"), ) with ExitStack() as stack: stack.enter_context(mock.patch.dict(os.environ, values)) diff --git a/tests/integration/auth/test_auth.py b/tests/integration/auth/test_auth.py index 28b19dd512..67daf9180c 100644 --- a/tests/integration/auth/test_auth.py +++ b/tests/integration/auth/test_auth.py @@ -83,26 +83,43 @@ class TestOIDC: """Tests for OpenID Connect (OIDC) authentication flow. - This class tests the OIDC signup process, including user creation, - token generation, and handling of conflicts with existing users. - """ - - async def test_signup( + This class tests the OIDC sign-in and sign-up processes, including: + - User authentication via OIDC + - New user creation during OIDC sign-in + - Token generation and validation + - Handling of conflicts with existing users + - Configuration options like allow_sign_up + - Error handling for invalid credentials + """ # noqa: E501 + + @pytest.mark.parametrize("allow_sign_up", [True, False]) + async def test_sign_in( self, + allow_sign_up: bool, _oidc_server: _OIDCServer, ) -> None: - """Test the complete OIDC signup flow. - - Verifies that: - 1. The OAuth2 flow redirects correctly - 2. A new user is created with MEMBER role - 3. Access and refresh tokens are generated - 4. Subsequent OIDC flows generate new tokens - """ + """Test the complete OIDC sign-in flow with different allow_sign_up settings. + + This test verifies: + 1. The OAuth2 flow redirects correctly to the OIDC provider + 2. When allow_sign_up is True: + - A new user is created with MEMBER role + - Access and refresh tokens are generated + - Subsequent OIDC flows generate new tokens for the same user + 3. When allow_sign_up is False: + - Users are redirected to login with an error message + - No access tokens are granted + - If a user without a password exists, they can still sign in + """ # noqa: E501 client = _httpx_client() + url = ( + f"oauth2/{_oidc_server}/login" + if allow_sign_up + else f"oauth2/{_oidc_server}_no_sign_up/login" + ) # Start the OAuth2 flow - response = client.post(f"oauth2/{_oidc_server}/login") + response = client.post(url) assert response.status_code == 302 auth_url = response.headers["location"] cookies = dict(response.cookies) # Save cookies for reuse @@ -120,9 +137,45 @@ async def test_signup( # Complete the flow by calling the token endpoint response = client.get(callback_url) - assert response.status_code == 302 + + if not allow_sign_up: + # Verify that user is redirected to /login + assert response.status_code == 307 + assert "/login" in response.headers["location"] + + # Verify no access is granted + assert not response.cookies.get("phoenix-access-token") + assert not response.cookies.get("phoenix-refresh-token") + + # Create the user without password + admin.create_user(profile=_Profile(email, "", token_hex(8)), local=False) + + # If user go through OIDC flow again, access should be granted + response = _httpx_client(cookies=cookies).get(callback_url) + + # Verify that user is redirected not to /login + assert response.status_code == 302 + assert "/login" not in response.headers["location"] + + # Verify we got access + assert (access_token := response.cookies.get("phoenix-access-token")) + assert (refresh_token := response.cookies.get("phoenix-refresh-token")) + + # Verify that the user was created + users = {u.profile.email: u for u in admin.list_users()} + assert email in users + assert users[email].role is UserRoleInput.MEMBER + + # If user go through OIDC flow again, new access token should be created + response = _httpx_client(cookies=cookies).get(callback_url) + assert (new_access_token := response.cookies.get("phoenix-access-token")) + assert (new_refresh_token := response.cookies.get("phoenix-refresh-token")) + assert new_access_token != access_token + assert new_refresh_token != refresh_token + return # Verify we got access + assert response.status_code == 302 assert (access_token := response.cookies.get("phoenix-access-token")) assert (refresh_token := response.cookies.get("phoenix-refresh-token")) @@ -138,21 +191,30 @@ async def test_signup( assert new_access_token != access_token assert new_refresh_token != refresh_token - async def test_signup_conflict_for_local_user_with_password( + @pytest.mark.parametrize("allow_sign_up", [True, False]) + async def test_sign_in_conflict_for_local_user_with_password( self, + allow_sign_up: bool, _oidc_server: _OIDCServer, ) -> None: - """Test OIDC signup when a user with the same email already exists. - - Verifies that: - 1. The system detects the email conflict - 2. The user is redirected to the login page - 3. No access tokens are granted - """ + """Test OIDC sign-in when a user with the same email already exists with password authentication. + + This test verifies: + 1. The system detects the email conflict with an existing user + 2. The user is redirected to the login page with an appropriate error message + 3. No access tokens are granted to the OIDC user + 4. The existing user's credentials remain unchanged + 5. This behavior is consistent regardless of the allow_sign_up setting + """ # noqa: E501 client = _httpx_client() + url = ( + f"oauth2/{_oidc_server}/login" + if allow_sign_up + else f"oauth2/{_oidc_server}_no_sign_up/login" + ) # Start the OAuth2 flow - response = client.post(f"oauth2/{_oidc_server}/login") + response = client.post(url) assert response.status_code == 302 auth_url = response.headers["location"] @@ -173,7 +235,7 @@ async def test_signup_conflict_for_local_user_with_password( # Verify that user is redirected to /login response = client.get(callback_url) assert response.status_code == 307 - assert response.headers["location"].startswith("/login") + assert "/login" in response.headers["location"] # Verify no access is granted assert not response.cookies.get("phoenix-access-token") diff --git a/tests/integration/db_migrations/test_data_migration_6a88424799fe_update_users_with_auth_method.py b/tests/integration/db_migrations/test_data_migration_6a88424799fe_update_users_with_auth_method.py new file mode 100644 index 0000000000..e8f318f1df --- /dev/null +++ b/tests/integration/db_migrations/test_data_migration_6a88424799fe_update_users_with_auth_method.py @@ -0,0 +1,363 @@ +from secrets import token_hex +from typing import Literal + +import pytest +from alembic.config import Config +from sqlalchemy import Connection, Engine, text + +from . import _down, _up, _version_num + + +def test_user_auth_method_migration( + _engine: Engine, + _alembic_config: Config, + _db_backend: Literal["sqlite", "postgresql"], +) -> None: + """Test the migration that adds the auth_method column to the users table. + + This test verifies the complete migration process for adding the auth_method column, + including both upgrade and downgrade paths. It ensures data integrity throughout + the migration process and verifies that the schema changes are correctly applied. + + The test process: + 1. Initial Setup: + - Verifies clean state + - Runs initial migration + - Creates test users with different auth methods (local and OAuth2) + + 2. Migration Testing: + - Verifies pre-migration state + - Runs auth_method migration + - Verifies post-migration state and schema + - Tests new user creation with auth_method + - Tests constraint enforcement for invalid auth_method values + - Tests NOT NULL constraint for auth_method + + 3. Downgrade Testing: + - Runs downgrade migration + - Verifies schema returns to initial state + - Verifies data integrity of existing users + - Verifies original constraints are restored correctly + + Args: + _engine: Database engine fixture + _alembic_config: Alembic configuration fixture + _db_backend: Database backend type ('sqlite' or 'postgresql') + + Raises: + AssertionError: If any verification checks fail + sqlalchemy.exc.SQLAlchemyError: If database operations fail + """ + # no migrations applied yet + with pytest.raises(BaseException, match="alembic_version"): + _version_num(_engine) + + # apply migrations up to right before auth method migration + _up(_engine, _alembic_config, "8a3764fe7f1a") + + # Create test users + with _engine.connect() as conn: + # Create a user role + role_id = conn.execute( + text( + """ + INSERT INTO user_roles (name) + VALUES ('MEMBER') + RETURNING id + """ + ) + ).scalar() + assert isinstance(role_id, int) + + # Create a local auth user + local_user_id = conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, password_hash, password_salt, + reset_password, oauth2_client_id, oauth2_user_id + ) + VALUES ( + :role_id, :username, :email, + :password_hash, :password_salt, false, NULL, NULL + ) + RETURNING id + """ + ), + { + "role_id": role_id, + "username": f"local_user_{token_hex(4)}", + "email": f"local_{token_hex(4)}@example.com", + "password_hash": b"test_hash", + "password_salt": b"test_salt", + }, + ).scalar() + assert isinstance(local_user_id, int) + + # Create an OAuth2 user + oauth_user_id = conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, password_hash, password_salt, + reset_password, oauth2_client_id, oauth2_user_id + ) + VALUES ( + :role_id, :username, :email, + NULL, NULL, false, :client_id, :user_id + ) + RETURNING id + """ + ), + { + "role_id": role_id, + "username": f"oauth_user_{token_hex(4)}", + "email": f"oauth_{token_hex(4)}@example.com", + "client_id": f"test_client_{token_hex(4)}", + "user_id": f"test_user_{token_hex(4)}", + }, + ).scalar() + assert isinstance(oauth_user_id, int) + conn.commit() + + # Run the auth method migration + _up(_engine, _alembic_config, "6a88424799fe") + + # Test post-migration constraints + with _engine.connect() as conn: + # Test invalid auth_method value + with pytest.raises(Exception) as exc_info: + conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, auth_method, + password_hash, password_salt, reset_password + ) + VALUES ( + :role_id, :username, :email, 'INVALID', + :password_hash, :password_salt, false + ) + """ + ), + { + "role_id": role_id, + "username": f"invalid_auth_{token_hex(4)}", + "email": f"invalid_auth_{token_hex(4)}@example.com", + "password_hash": b"test_hash", + "password_salt": b"test_salt", + }, + ) + conn.commit() + error_message = str(exc_info.value) + assert ( + "valid_auth_method" in error_message + ), "Expected valid_auth_method constraint violation" + conn.rollback() + + with _engine.connect() as conn: + # Test LOCAL auth with OAuth2 credentials + with pytest.raises(Exception) as exc_info: + conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, auth_method, + password_hash, password_salt, reset_password, + oauth2_client_id, oauth2_user_id + ) + VALUES ( + :role_id, :username, :email, 'LOCAL', + :password_hash, :password_salt, false, + :client_id, :user_id + ) + """ + ), + { + "role_id": role_id, + "username": f"local_with_oauth_{token_hex(4)}", + "email": f"local_with_oauth_{token_hex(4)}@example.com", + "password_hash": b"test_hash", + "password_salt": b"test_salt", + "client_id": f"test_client_{token_hex(4)}", + "user_id": f"test_user_{token_hex(4)}", + }, + ) + conn.commit() + error_message = str(exc_info.value) + assert ( + "local_auth_no_oauth" in error_message + ), "Expected local_auth_no_oauth constraint violation" + conn.rollback() + + with _engine.connect() as conn: + # Test OAUTH2 auth with password credentials + with pytest.raises(Exception) as exc_info: + conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, auth_method, + password_hash, password_salt, reset_password, + oauth2_client_id, oauth2_user_id + ) + VALUES ( + :role_id, :username, :email, 'OAUTH2', + :password_hash, :password_salt, false, + :client_id, :user_id + ) + """ + ), + { + "role_id": role_id, + "username": f"oauth_with_password_{token_hex(4)}", + "email": f"oauth_with_password_{token_hex(4)}@example.com", + "password_hash": b"test_hash", + "password_salt": b"test_salt", + "client_id": f"test_client_{token_hex(4)}", + "user_id": f"test_user_{token_hex(4)}", + }, + ) + conn.commit() + error_message = str(exc_info.value) + assert ( + "oauth2_auth_no_password" in error_message + ), "Expected oauth2_auth_no_password constraint violation" + conn.rollback() + + # Test downgrade + _down(_engine, _alembic_config, "8a3764fe7f1a") + + # Verify downgrade state + with _engine.connect() as conn: + # Verify users still exist and have correct data + local_user = conn.execute( + text( + """ + SELECT password_hash IS NOT NULL as has_password, + oauth2_client_id IS NOT NULL as has_oauth + FROM users + WHERE id = :id + """ + ), + {"id": local_user_id}, + ).first() + assert local_user is not None + assert bool(local_user[0]), "Local user should still have password_hash" + assert not bool(local_user[1]), "Local user should still not have oauth2_client_id" + + with _engine.connect() as conn: + oauth_user = conn.execute( + text( + """ + SELECT password_hash IS NOT NULL as has_password, + oauth2_client_id IS NOT NULL as has_oauth + FROM users + WHERE id = :id + """ + ), + {"id": oauth_user_id}, + ).first() + assert oauth_user is not None + assert not bool(oauth_user[0]), "OAuth user should still not have password_hash" + assert bool(oauth_user[1]), "OAuth user should still have oauth2_client_id" + + +def _create_local_user( + conn: Connection, + role_id: int, +) -> int: + """Create a new local authentication user. + + Creates a user with: + - Local authentication method ('LOCAL') + - Password credentials (hash and salt) + - Randomly generated username and email + - Assigned user role + - reset_password set to false + + Args: + conn: Database connection to use + role_id: ID of the user role to assign + + Returns: + int: ID of the created user + + Raises: + sqlalchemy.exc.SQLAlchemyError: If database operations fail + AssertionError: If user creation fails or returned ID is not an integer + """ + result = conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, password_hash, password_salt, + reset_password, auth_method + ) + VALUES ( + :role_id, :username, :email, + :password_hash, :password_salt, false, 'LOCAL' + ) + RETURNING id + """ + ), + { + "role_id": role_id, + "username": f"new_local_user_{token_hex(4)}", + "email": f"new_local_{token_hex(4)}@example.com", + "password_hash": b"new_hash", + "password_salt": b"new_salt", + }, + ).scalar_one() + assert isinstance(result, int) + return result + + +def _create_oauth_user( + conn: Connection, + role_id: int, +) -> int: + """Create a new OAuth2 user. + + Creates a user with: + - External authentication method ('OAUTH2') + - OAuth2 credentials (client_id and user_id) + - Randomly generated username, email, and OAuth IDs + - Assigned user role + - reset_password set to false + + Args: + conn: Database connection to use + role_id: ID of the user role to assign + + Returns: + int: ID of the created user + + Raises: + sqlalchemy.exc.SQLAlchemyError: If database operations fail + AssertionError: If user creation fails or returned ID is not an integer + """ + result = conn.execute( + text( + """ + INSERT INTO users ( + user_role_id, username, email, oauth2_client_id, oauth2_user_id, + reset_password, auth_method + ) + VALUES ( + :role_id, :username, :email, + :client_id, :user_id, false, 'OAUTH2' + ) + RETURNING id + """ + ), + { + "role_id": role_id, + "username": f"new_oauth_user_{token_hex(4)}", + "email": f"new_oauth_{token_hex(4)}@example.com", + "client_id": f"new_client_{token_hex(4)}", + "user_id": f"new_user_{token_hex(4)}", + }, + ).scalar_one() + assert isinstance(result, int) + return result diff --git a/tests/integration/db_migrations/test_db_schema_6a88424799fe_update_users_with_auth_method.py b/tests/integration/db_migrations/test_db_schema_6a88424799fe_update_users_with_auth_method.py new file mode 100644 index 0000000000..3df8ef74c5 --- /dev/null +++ b/tests/integration/db_migrations/test_db_schema_6a88424799fe_update_users_with_auth_method.py @@ -0,0 +1,187 @@ +from abc import ABC, abstractmethod + +from alembic.config import Config +from sqlalchemy import Engine +from typing_extensions import assert_never + +from . import _DBBackend, _down, _get_table_schema_info, _TableSchemaInfo, _up, _verify_clean_state + +_DOWN = "8a3764fe7f1a" +_UP = "6a88424799fe" + + +class DBSchemaComparisonTest(ABC): + table_name: str + + @classmethod + @abstractmethod + def _get_current_schema_info( + cls, + db_backend: _DBBackend, + ) -> _TableSchemaInfo: ... + + @classmethod + @abstractmethod + def _get_upgraded_schema_info( + cls, + db_backend: _DBBackend, + ) -> _TableSchemaInfo: ... + + def _test_db_schema( + self, + _engine: Engine, + _alembic_config: Config, + _db_backend: _DBBackend, + ) -> None: + _verify_clean_state(_engine) + + _up(_engine, _alembic_config, _DOWN) + + current_info = self._get_current_schema_info(_db_backend) + upgraded_info = self._get_upgraded_schema_info(_db_backend) + + with _engine.connect() as conn: + initial_info = _get_table_schema_info(conn, self.table_name, _db_backend) + assert ( + initial_info == current_info + ), "Initial schema info does not match expected current schema info" + + _up(_engine, _alembic_config, _UP) + + with _engine.connect() as conn: + final_info = _get_table_schema_info(conn, self.table_name, _db_backend) + assert ( + final_info == upgraded_info + ), "Final schema info does not match expected upgraded schema info" + + _down(_engine, _alembic_config, _DOWN) + + with _engine.connect() as conn: + downgraded_info = _get_table_schema_info(conn, self.table_name, _db_backend) + assert ( + downgraded_info == current_info + ), "Downgraded schema info does not match expected current schema info" + + +class TestUsers(DBSchemaComparisonTest): + table_name = "users" + + @classmethod + def _get_current_schema_info( + cls, + db_backend: _DBBackend, + ) -> _TableSchemaInfo: + column_names = { + "id", + "user_role_id", + "username", + "email", + "password_hash", + "password_salt", + "reset_password", + "oauth2_client_id", + "oauth2_user_id", + "created_at", + "updated_at", + "profile_picture_url", + } + index_names = { + "ix_users_username", + "ix_users_email", + "ix_users_oauth2_client_id", + "ix_users_oauth2_user_id", + "ix_users_user_role_id", + } + constraint_names = { + "ck_users_`exactly_one_auth_method`", + "ck_users_`oauth2_client_id_and_user_id`", + "ck_users_`password_hash_and_salt`", + "uq_users_oauth2_client_id_oauth2_user_id", + "pk_users", + "fk_users_user_role_id_user_roles", + } + if db_backend == "postgresql": + index_names.update( + { + "pk_users", + "uq_users_oauth2_client_id_oauth2_user_id", + } + ) + elif db_backend == "sqlite": + index_names.update( + { + "sqlite_autoindex_users_1", + } + ) + else: + assert_never(db_backend) + return _TableSchemaInfo( + table_name="users", + column_names=frozenset(column_names), + index_names=frozenset(index_names), + constraint_names=frozenset(constraint_names), + ) + + @classmethod + def _get_upgraded_schema_info( + cls, + db_backend: _DBBackend, + ) -> _TableSchemaInfo: + column_names = { + "id", + "user_role_id", + "username", + "email", + "password_hash", + "password_salt", + "reset_password", + "oauth2_client_id", + "oauth2_user_id", + "created_at", + "updated_at", + "profile_picture_url", + "auth_method", + } + index_names = { + "ix_users_username", + "ix_users_email", + "ix_users_user_role_id", + } + constraint_names = { + "ck_users_`valid_auth_method`", + "ck_users_`local_auth_no_oauth`", + "ck_users_`oauth2_auth_no_password`", + "ck_users_`password_hash_and_salt`", + "uq_users_oauth2_client_id_oauth2_user_id", + "pk_users", + "fk_users_user_role_id_user_roles", + } + if db_backend == "postgresql": + index_names.update( + { + "pk_users", + "uq_users_oauth2_client_id_oauth2_user_id", + } + ) + elif db_backend == "sqlite": + index_names.update( + { + "sqlite_autoindex_users_1", + } + ) + else: + assert_never(db_backend) + return _TableSchemaInfo( + table_name="users", + column_names=frozenset(column_names), + index_names=frozenset(index_names), + constraint_names=frozenset(constraint_names), + ) + + def test_db_schema( + self, + _engine: Engine, + _alembic_config: Config, + _db_backend: _DBBackend, + ) -> None: + self._test_db_schema(_engine, _alembic_config, _db_backend) diff --git a/tests/integration/db_migrations/test_up_and_down_migrations.py b/tests/integration/db_migrations/test_up_and_down_migrations.py index 9879cea3c9..bd2e6b0d6f 100644 --- a/tests/integration/db_migrations/test_up_and_down_migrations.py +++ b/tests/integration/db_migrations/test_up_and_down_migrations.py @@ -308,3 +308,8 @@ def test_up_and_down_migrations( _up(_engine, _alembic_config, "8a3764fe7f1a") _down(_engine, _alembic_config, "bb8139330879") _up(_engine, _alembic_config, "8a3764fe7f1a") + + for _ in range(2): + _up(_engine, _alembic_config, "6a88424799fe") + _down(_engine, _alembic_config, "8a3764fe7f1a") + _up(_engine, _alembic_config, "6a88424799fe") diff --git a/tests/unit/server/api/routers/test_oauth2.py b/tests/unit/server/api/routers/test_oauth2.py new file mode 100644 index 0000000000..6aa4e7ce7e --- /dev/null +++ b/tests/unit/server/api/routers/test_oauth2.py @@ -0,0 +1,507 @@ +from secrets import token_hex +from typing import Optional + +import pytest +from sqlalchemy import insert, select +from starlette.types import ASGIApp + +from phoenix.db import models +from phoenix.server.api.routers.oauth2 import ( + SignInNotAllowed, + UserInfo, + _get_existing_oauth2_user, +) +from phoenix.server.types import DbSessionFactory + + +@pytest.mark.parametrize( + "user,oauth2_client_id,user_info,allowed", + [ + # Test Case: User with password hash cannot sign in with OAuth2 + # Verifies that users who have set a password must use password authentication + # and cannot switch to OAuth2 authentication. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=b"password_hash", + password_salt=b"password_salt", + reset_password=False, + oauth2_client_id=None, + oauth2_user_id=None, + auth_method="LOCAL", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + False, + id="user_with_password_hash", + ), + # Test Case: User with matching OAuth2 credentials can sign in + # Verifies that users with matching OAuth2 credentials can successfully sign in + # without any credential updates needed. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_matching_oauth2_credentials", + ), + # Test Case: User with different OAuth2 client ID can sign in + # Verifies that users found by email can have their OAuth2 client ID updated + # when signing in with a different client ID. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="987654321098-xyzdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_different_oauth2_client_id", + ), + # Test Case: User with different OAuth2 user ID cannot sign in + # Verifies that users cannot sign in when their OAuth2 user ID doesn't match, + # even if the client ID matches. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890987654321", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + False, + id="user_with_different_oauth2_user_id", + ), + # Test Case: User with missing OAuth2 client ID can sign in + # Verifies that users found by email can have their OAuth2 client ID set + # when signing in for the first time with OAuth2. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id=None, + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_missing_oauth2_client_id", + ), + # Test Case: User with missing OAuth2 user ID can sign in + # Verifies that users can sign in when their OAuth2 user ID is missing, + # if the client ID matches. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id=None, + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_missing_oauth2_user_id", + ), + # Test Case: User with missing OAuth2 client ID but different user ID can sign in + # Verifies that users found by email can have both their OAuth2 client ID and user ID + # updated when signing in for the first time with OAuth2. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id=None, + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890987654321", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_missing_oauth2_client_id_and_different_user_id", + ), + # Test Case: User with missing OAuth2 user ID but different client ID can sign in + # Verifies that users found by email can have both their OAuth2 client ID and user ID + # updated when signing in with different credentials. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="987654321098-xyzdef.apps.googleusercontent.com", + oauth2_user_id=None, + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_with_missing_oauth2_user_id_and_different_client_id", + ), + # Test Case: User found by email with no OAuth2 credentials can sign in + # Verifies that users found by email can have their OAuth2 credentials set + # when signing in for the first time with OAuth2. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id=None, + oauth2_user_id=None, + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_found_by_email_no_oauth2_credentials", + ), + # Test Case: User found by email with matching OAuth2 credentials can sign in + # Verifies that users found by email with matching OAuth2 credentials + # can sign in without any credential updates. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_found_by_email_matching_oauth2_credentials", + ), + # Test Case: User found by email with different OAuth2 credentials can sign in + # Verifies that users found by email can have their OAuth2 credentials updated + # when signing in with different credentials. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="987654321098-xyzdef.apps.googleusercontent.com", + oauth2_user_id="118234567890987654321", + auth_method="OAUTH2", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + True, + id="user_found_by_email_different_oauth2_credentials", + ), + # Test Case: User with updated profile picture can sign in + # Verifies that users can have their profile picture URL updated + # when signing in with a new URL. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + profile_picture_url="https://old-picture.com/avatar.jpg", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url="https://new-picture.com/avatar.jpg", + ), + True, + id="user_with_updated_profile_picture", + ), + # Test Case: User with changed OAuth2 client ID and profile picture can sign in + # Verifies that users can have both their OAuth2 client ID and profile picture + # updated simultaneously when signing in. + pytest.param( + models.User( + user_role_id=1, + username=token_hex(8), + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="987654321098-xyzdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + profile_picture_url="https://old-picture.com/avatar.jpg", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url="https://new-picture.com/avatar.jpg", + ), + True, + id="user_with_changed_client_id_and_profile_picture", + ), + # Test Case: User with updated username can sign in + # Verifies that users can have their username updated + # when signing in with a new username. + pytest.param( + models.User( + user_role_id=1, + username=f"old_username{token_hex(8)}", + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + profile_picture_url="https://example.com/avatar.jpg", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=f"new_username{token_hex(8)}", + profile_picture_url="https://example.com/avatar.jpg", + ), + True, + id="user_with_updated_username", + ), + # Test Case: User with removed username can sign in + # Verifies that users cannot have their username removed (set to None) + # when signing in with no username provided. + pytest.param( + models.User( + user_role_id=1, + username=f"old_username{token_hex(8)}", + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + profile_picture_url="https://example.com/avatar.jpg", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url="https://example.com/avatar.jpg", + ), + True, + id="user_with_removed_username", + ), + # Test Case: User with removed profile picture can sign in + # Verifies that users can have their profile picture URL removed (set to None) + # when signing in with no profile picture URL provided. + pytest.param( + models.User( + user_role_id=1, + username=f"test_username{token_hex(8)}", + password_hash=None, + password_salt=None, + reset_password=False, + oauth2_client_id="123456789012-abcdef.apps.googleusercontent.com", + oauth2_user_id="118234567890123456789", + auth_method="OAUTH2", + profile_picture_url="https://example.com/avatar.jpg", + ), + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=f"test_username{token_hex(8)}", + profile_picture_url=None, + ), + True, + id="user_with_removed_profile_picture", + ), + # Test Case: Non-existent user cannot sign in + # Verifies that sign-in is rejected when no user is found + # by either OAuth2 credentials or email. + pytest.param( + None, + "123456789012-abcdef.apps.googleusercontent.com", + UserInfo( + idp_user_id="118234567890123456789", + email=f"{token_hex(8)}@example.com", + username=None, + profile_picture_url=None, + ), + False, + id="non_existent_user", + ), + ], +) +async def test_get_existing_oauth2_user( + app: ASGIApp, + db: DbSessionFactory, + user: Optional[models.User], + oauth2_client_id: str, + user_info: UserInfo, + allowed: bool, +) -> None: + """Test the OAuth2 user sign-in and update process. + + This test verifies the behavior of _get_existing_oauth2_user function, which handles: + 1. OAuth2 user authentication + 2. User profile updates (username, profile picture) + 3. OAuth2 credential updates (client_id, user_id) + 4. Error cases and rejections + + The test covers various scenarios: + - Users with/without passwords + - Users with matching/different OAuth2 credentials + - Users found by OAuth2 credentials or email + - Profile information updates + - Error conditions + + Args: + app: The FastAPI application instance + db: Database session factory + user: Optional existing user in the database + oauth2_client_id: OAuth2 client ID to use for sign-in + user_info: User information from the OAuth2 provider + allowed: Whether the sign-in should be allowed + """ + if user: + async with db() as session: + # For some strange reason PostgreSQL insists on UPDATE instead of + # INSERT when using session.add(user), so we INSERT manually. + await session.execute( + insert(models.User).values( + email=user_info.email, + user_role_id=user.user_role_id, + username=user.username, + password_hash=user.password_hash, + password_salt=user.password_salt, + reset_password=user.reset_password, + oauth2_client_id=user.oauth2_client_id, + oauth2_user_id=user.oauth2_user_id, + auth_method=user.auth_method, + ) + ) + async with db() as session: + if not user or not allowed: + with pytest.raises(SignInNotAllowed): + await _get_existing_oauth2_user( + session, + oauth2_client_id=oauth2_client_id, + user_info=user_info, + ) + return + oauth2_user = await _get_existing_oauth2_user( + session, + oauth2_client_id=oauth2_client_id, + user_info=user_info, + ) + # Verify the returned user object has the correct OAuth2 credentials + assert oauth2_user + assert oauth2_user.oauth2_client_id == oauth2_client_id + assert oauth2_user.oauth2_user_id == user_info.idp_user_id + + # Verify the database state after the update + async with db() as session: + db_user = await session.scalar(select(models.User).filter_by(email=user_info.email)) + assert db_user + # Verify OAuth2 credentials are correctly updated + assert db_user.oauth2_client_id == oauth2_client_id + assert db_user.oauth2_user_id == user_info.idp_user_id + # Verify profile picture URL is updated if provided in user_info + if user_info.profile_picture_url is not None: + assert db_user.profile_picture_url == user_info.profile_picture_url + # Verify username is updated if provided in user_info, otherwise remains unchanged + if user_info.username is not None: + assert db_user.username == user_info.username + elif user is not None: # If user_info.username is None, username should remain unchanged + assert user.username