Skip to content

email login_hint support when login_with_email_allowed is activated #4568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/data-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ rand.workspace = true
regex.workspace = true
woothee.workspace = true
ruma-common.workspace = true
lettre.workspace = true

mas-iana.workspace = true
mas-jose.workspace = true
oauth2-types.workspace = true

70 changes: 63 additions & 7 deletions crates/data-model/src/oauth2/authorization_grant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.

use std::str::FromStr as _;

use chrono::{DateTime, Utc};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{
Expand Down Expand Up @@ -142,6 +144,7 @@ impl AuthorizationGrantStage {

pub enum LoginHint<'a> {
MXID(&'a UserId),
Email(lettre::Address),
None,
}

Expand Down Expand Up @@ -172,14 +175,31 @@ impl std::ops::Deref for AuthorizationGrant {
}

impl AuthorizationGrant {
/// Parse a `login_hint`
///
/// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com'
///
/// Returns `LoginHint::Email` for valid email 'john.doe@example.com' if
/// email supports is enabled
///
/// Otherwise returns `LoginHint::None`
#[must_use]
pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
pub fn parse_login_hint(&self, homeserver: &str, login_with_email_allowed: bool) -> LoginHint {
let Some(login_hint) = &self.login_hint else {
return LoginHint::None;
};

// Return none if the format is incorrect
let Some((prefix, value)) = login_hint.split_once(':') else {
// If email supports for login_hint is enabled
if login_with_email_allowed {
// Validate the email
let Ok(address) = lettre::Address::from_str(login_hint) else {
// Return none if the format is incorrect
return LoginHint::None;
};
return LoginHint::Email(address);
}
// Unknown hint type, treat as none
return LoginHint::None;
};

Expand Down Expand Up @@ -288,7 +308,7 @@ mod tests {
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com");
let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::None));
}
Expand All @@ -306,11 +326,47 @@ mod tests {
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com");
let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
}

#[test]
fn valid_login_hint_with_email() {
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();

#[allow(clippy::disallowed_methods)]
let now = Utc::now();

let grant = AuthorizationGrant {
login_hint: Some(String::from("example@user")),
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com", true);

assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
}

#[test]
fn valid_login_hint_with_email_when_login_with_email_not_allowed() {
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();

#[allow(clippy::disallowed_methods)]
let now = Utc::now();

let grant = AuthorizationGrant {
login_hint: Some(String::from("example@user")),
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::None));
}

#[test]
fn invalid_login_hint() {
#[allow(clippy::disallowed_methods)]
Expand All @@ -324,7 +380,7 @@ mod tests {
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com");
let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::None));
}
Expand All @@ -342,7 +398,7 @@ mod tests {
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com");
let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::None));
}
Expand All @@ -360,7 +416,7 @@ mod tests {
..AuthorizationGrant::sample(now, &mut rng)
};

let hint = grant.parse_login_hint("example.com");
let hint = grant.parse_login_hint("example.com", false);

assert!(matches!(hint, LoginHint::None));
}
Expand Down
20 changes: 16 additions & 4 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub(crate) async fn get(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await
}
Expand Down Expand Up @@ -177,6 +178,7 @@ pub(crate) async fn post(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await;
}
Expand All @@ -187,7 +189,7 @@ pub(crate) async fn post(
.unwrap_or(&form.username);

// First, lookup the user
let Some(user) = get_user_by_email_or_by_username(site_config, &mut repo, username).await?
let Some(user) = get_user_by_email_or_by_username(&site_config, &mut repo, username).await?
else {
let form_state = form_state.with_error_on_form(FormError::InvalidCredentials);
PASSWORD_LOGIN_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
Expand All @@ -201,6 +203,7 @@ pub(crate) async fn post(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await;
};
Expand All @@ -220,6 +223,7 @@ pub(crate) async fn post(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await;
}
Expand All @@ -240,6 +244,7 @@ pub(crate) async fn post(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await;
};
Expand Down Expand Up @@ -283,6 +288,7 @@ pub(crate) async fn post(
&mut rng,
&templates,
&homeserver,
&site_config,
)
.await;
}
Expand Down Expand Up @@ -339,7 +345,7 @@ pub(crate) async fn post(
}

async fn get_user_by_email_or_by_username<R: RepositoryAccess>(
site_config: SiteConfig,
site_config: &SiteConfig,
repo: &mut R,
username_or_email: &str,
) -> Result<Option<mas_data_model::User>, R::Error> {
Expand All @@ -364,6 +370,7 @@ fn handle_login_hint(
mut ctx: LoginContext,
next: &PostAuthContext,
homeserver: &dyn HomeserverConnection,
site_config: &SiteConfig,
) -> LoginContext {
let form_state = ctx.form_state_mut();

Expand All @@ -373,8 +380,12 @@ fn handle_login_hint(
}

if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx {
let value = match grant.parse_login_hint(homeserver.homeserver()) {
let value = match grant.parse_login_hint(
homeserver.homeserver(),
site_config.login_with_email_allowed,
) {
LoginHint::MXID(mxid) => Some(mxid.localpart().to_owned()),
LoginHint::Email(email) => Some(email.to_string()),
LoginHint::None => None,
};
form_state.set_value(LoginFormField::Username, value);
Expand All @@ -393,6 +404,7 @@ async fn render(
rng: impl Rng,
templates: &Templates,
homeserver: &dyn HomeserverConnection,
site_config: &SiteConfig,
) -> Result<Response, InternalError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let providers = repo.upstream_oauth_provider().all_enabled().await?;
Expand All @@ -406,7 +418,7 @@ async fn render(
.await
.map_err(InternalError::from_anyhow)?;
let ctx = if let Some(next) = next {
let ctx = handle_login_hint(ctx, &next, homeserver);
let ctx = handle_login_hint(ctx, &next, homeserver, site_config);
ctx.with_post_action(next)
} else {
ctx
Expand Down
Loading