|
4 | 4 | // SPDX-License-Identifier: AGPL-3.0-only
|
5 | 5 | // Please see LICENSE in the repository root for full details.
|
6 | 6 |
|
| 7 | +use std::sync::Arc; |
| 8 | + |
7 | 9 | use axum::{
|
8 | 10 | extract::{Path, Query, State},
|
9 | 11 | response::{IntoResponse, Redirect},
|
10 | 12 | };
|
11 | 13 | use hyper::StatusCode;
|
12 | 14 | use mas_axum_utils::{cookies::CookieJar, record_error};
|
13 |
| -use mas_data_model::UpstreamOAuthProvider; |
| 15 | +use mas_data_model::{UpstreamOAuthProvider, oauth2::LoginHint}; |
| 16 | +use mas_matrix::HomeserverConnection; |
14 | 17 | use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
15 | 18 | use mas_router::{PostAuthAction, UrlBuilder};
|
16 | 19 | use mas_storage::{
|
@@ -66,6 +69,7 @@ pub(crate) async fn get(
|
66 | 69 | cookie_jar: CookieJar,
|
67 | 70 | Path(provider_id): Path<Ulid>,
|
68 | 71 | Query(query): Query<OptionalPostAuthAction>,
|
| 72 | + State(homeserver): State<Arc<dyn HomeserverConnection>>, |
69 | 73 | ) -> Result<impl IntoResponse, RouteError> {
|
70 | 74 | let provider = repo
|
71 | 75 | .upstream_oauth_provider()
|
@@ -96,13 +100,11 @@ pub(crate) async fn get(
|
96 | 100 | // sees fit
|
97 | 101 | if provider.forward_login_hint {
|
98 | 102 | if let Some(PostAuthAction::ContinueAuthorizationGrant { id }) = &query.post_auth_action {
|
99 |
| - if let Some(login_hint) = repo |
100 |
| - .oauth2_authorization_grant() |
101 |
| - .lookup(*id) |
102 |
| - .await? |
103 |
| - .and_then(|grant| grant.login_hint) |
104 |
| - { |
105 |
| - data = data.with_login_hint(login_hint); |
| 103 | + if let Some(grant) = repo.oauth2_authorization_grant().lookup(*id).await? { |
| 104 | + match grant.parse_login_hint(homeserver.homeserver()) { |
| 105 | + LoginHint::MXID(mxid) => data = data.with_login_hint(mxid.to_string()), |
| 106 | + LoginHint::None => (), |
| 107 | + } |
106 | 108 | }
|
107 | 109 | }
|
108 | 110 | }
|
|
0 commit comments