Skip to content

Commit 86e806c

Browse files
committed
display email login_hint when login_with_email_allowed is activated
1 parent 98f2776 commit 86e806c

File tree

4 files changed

+75
-10
lines changed

4 files changed

+75
-10
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/data-model/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ ruma-common.workspace = true
2828
mas-iana.workspace = true
2929
mas-jose.workspace = true
3030
oauth2-types.workspace = true
31+
# Emails
32+
lettre.workspace = true

crates/data-model/src/oauth2/authorization_grant.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
// SPDX-License-Identifier: AGPL-3.0-only
55
// Please see LICENSE in the repository root for full details.
66

7+
use std::str::FromStr as _;
8+
79
use chrono::{DateTime, Utc};
810
use mas_iana::oauth::PkceCodeChallengeMethod;
911
use oauth2_types::{
@@ -142,6 +144,7 @@ impl AuthorizationGrantStage {
142144

143145
pub enum LoginHint<'a> {
144146
MXID(&'a UserId),
147+
EMAIL(&'a str),
145148
None,
146149
}
147150

@@ -173,7 +176,7 @@ impl std::ops::Deref for AuthorizationGrant {
173176

174177
impl AuthorizationGrant {
175178
#[must_use]
176-
pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
179+
pub fn parse_login_hint(&self, homeserver: &str, login_with_email_allowed: bool) -> LoginHint {
177180
let Some(login_hint) = &self.login_hint else {
178181
return LoginHint::None;
179182
};
@@ -197,6 +200,17 @@ impl AuthorizationGrant {
197200

198201
LoginHint::MXID(mxid)
199202
}
203+
"email" => {
204+
if !login_with_email_allowed {
205+
return LoginHint::None;
206+
}
207+
// Validate the email
208+
if lettre::Address::from_str(value).is_err() {
209+
return LoginHint::None;
210+
}
211+
212+
LoginHint::EMAIL(value)
213+
}
200214
// Unknown hint type, treat as none
201215
_ => LoginHint::None,
202216
}
@@ -288,7 +302,7 @@ mod tests {
288302
..AuthorizationGrant::sample(now, &mut rng)
289303
};
290304

291-
let hint = grant.parse_login_hint("example.com");
305+
let hint = grant.parse_login_hint("example.com", false);
292306

293307
assert!(matches!(hint, LoginHint::None));
294308
}
@@ -306,11 +320,47 @@ mod tests {
306320
..AuthorizationGrant::sample(now, &mut rng)
307321
};
308322

309-
let hint = grant.parse_login_hint("example.com");
323+
let hint = grant.parse_login_hint("example.com", false);
310324

311325
assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
312326
}
313327

328+
#[test]
329+
fn valid_login_hint_with_email() {
330+
#[allow(clippy::disallowed_methods)]
331+
let mut rng = thread_rng();
332+
333+
#[allow(clippy::disallowed_methods)]
334+
let now = Utc::now();
335+
336+
let grant = AuthorizationGrant {
337+
login_hint: Some(String::from("email:example@user")),
338+
..AuthorizationGrant::sample(now, &mut rng)
339+
};
340+
341+
let hint = grant.parse_login_hint("example.com", true);
342+
343+
assert!(matches!(hint, LoginHint::EMAIL(email) if email == "example@user"));
344+
}
345+
346+
#[test]
347+
fn valid_login_hint_with_email_when_login_with_email_not_allowed() {
348+
#[allow(clippy::disallowed_methods)]
349+
let mut rng = thread_rng();
350+
351+
#[allow(clippy::disallowed_methods)]
352+
let now = Utc::now();
353+
354+
let grant = AuthorizationGrant {
355+
login_hint: Some(String::from("email:example@user")),
356+
..AuthorizationGrant::sample(now, &mut rng)
357+
};
358+
359+
let hint = grant.parse_login_hint("example.com", false);
360+
361+
assert!(matches!(hint, LoginHint::None));
362+
}
363+
314364
#[test]
315365
fn invalid_login_hint() {
316366
#[allow(clippy::disallowed_methods)]
@@ -324,7 +374,7 @@ mod tests {
324374
..AuthorizationGrant::sample(now, &mut rng)
325375
};
326376

327-
let hint = grant.parse_login_hint("example.com");
377+
let hint = grant.parse_login_hint("example.com", false);
328378

329379
assert!(matches!(hint, LoginHint::None));
330380
}
@@ -342,7 +392,7 @@ mod tests {
342392
..AuthorizationGrant::sample(now, &mut rng)
343393
};
344394

345-
let hint = grant.parse_login_hint("example.com");
395+
let hint = grant.parse_login_hint("example.com", false);
346396

347397
assert!(matches!(hint, LoginHint::None));
348398
}
@@ -360,7 +410,7 @@ mod tests {
360410
..AuthorizationGrant::sample(now, &mut rng)
361411
};
362412

363-
let hint = grant.parse_login_hint("example.com");
413+
let hint = grant.parse_login_hint("example.com", false);
364414

365415
assert!(matches!(hint, LoginHint::None));
366416
}

crates/handlers/src/views/login.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ pub(crate) async fn get(
123123
&mut rng,
124124
&templates,
125125
&homeserver,
126+
&site_config,
126127
)
127128
.await
128129
}
@@ -177,6 +178,7 @@ pub(crate) async fn post(
177178
&mut rng,
178179
&templates,
179180
&homeserver,
181+
&site_config,
180182
)
181183
.await;
182184
}
@@ -187,7 +189,7 @@ pub(crate) async fn post(
187189
.unwrap_or(&form.username);
188190

189191
// First, lookup the user
190-
let Some(user) = get_user_by_email_or_by_username(site_config, &mut repo, username).await?
192+
let Some(user) = get_user_by_email_or_by_username(&site_config, &mut repo, username).await?
191193
else {
192194
let form_state = form_state.with_error_on_form(FormError::InvalidCredentials);
193195
PASSWORD_LOGIN_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
@@ -201,6 +203,7 @@ pub(crate) async fn post(
201203
&mut rng,
202204
&templates,
203205
&homeserver,
206+
&site_config,
204207
)
205208
.await;
206209
};
@@ -220,6 +223,7 @@ pub(crate) async fn post(
220223
&mut rng,
221224
&templates,
222225
&homeserver,
226+
&site_config,
223227
)
224228
.await;
225229
}
@@ -240,6 +244,7 @@ pub(crate) async fn post(
240244
&mut rng,
241245
&templates,
242246
&homeserver,
247+
&site_config,
243248
)
244249
.await;
245250
};
@@ -283,6 +288,7 @@ pub(crate) async fn post(
283288
&mut rng,
284289
&templates,
285290
&homeserver,
291+
&site_config,
286292
)
287293
.await;
288294
}
@@ -339,7 +345,7 @@ pub(crate) async fn post(
339345
}
340346

341347
async fn get_user_by_email_or_by_username<R: RepositoryAccess>(
342-
site_config: SiteConfig,
348+
site_config: &SiteConfig,
343349
repo: &mut R,
344350
username_or_email: &str,
345351
) -> Result<Option<mas_data_model::User>, R::Error> {
@@ -364,6 +370,7 @@ fn handle_login_hint(
364370
mut ctx: LoginContext,
365371
next: &PostAuthContext,
366372
homeserver: &dyn HomeserverConnection,
373+
site_config: &SiteConfig,
367374
) -> LoginContext {
368375
let form_state = ctx.form_state_mut();
369376

@@ -373,8 +380,12 @@ fn handle_login_hint(
373380
}
374381

375382
if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx {
376-
let value = match grant.parse_login_hint(homeserver.homeserver()) {
383+
let value = match grant.parse_login_hint(
384+
homeserver.homeserver(),
385+
site_config.login_with_email_allowed,
386+
) {
377387
LoginHint::MXID(mxid) => Some(mxid.localpart().to_owned()),
388+
LoginHint::EMAIL(email) => Some(email.to_owned()),
378389
LoginHint::None => None,
379390
};
380391
form_state.set_value(LoginFormField::Username, value);
@@ -393,6 +404,7 @@ async fn render(
393404
rng: impl Rng,
394405
templates: &Templates,
395406
homeserver: &dyn HomeserverConnection,
407+
site_config: &SiteConfig,
396408
) -> Result<Response, InternalError> {
397409
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
398410
let providers = repo.upstream_oauth_provider().all_enabled().await?;
@@ -406,7 +418,7 @@ async fn render(
406418
.await
407419
.map_err(InternalError::from_anyhow)?;
408420
let ctx = if let Some(next) = next {
409-
let ctx = handle_login_hint(ctx, &next, homeserver);
421+
let ctx = handle_login_hint(ctx, &next, homeserver, site_config);
410422
ctx.with_post_action(next)
411423
} else {
412424
ctx

0 commit comments

Comments
 (0)