Skip to content

Support for the stable scopes #4686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions crates/data-model/src/compat/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;

static GENERATED_DEVICE_ID_LENGTH: usize = 10;
static DEVICE_SCOPE_PREFIX: &str = "urn:matrix:org.matrix.msc2967.client:device:";
static UNSTABLE_DEVICE_SCOPE_PREFIX: &str = "urn:matrix:org.matrix.msc2967.client:device:";
static STABLE_DEVICE_SCOPE_PREFIX: &str = "urn:matrix:client:device:";

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
Expand All @@ -28,24 +29,31 @@ pub enum ToScopeTokenError {
}

impl Device {
/// Get the corresponding [`ScopeToken`] for that device
/// Get the corresponding stable and unstable [`ScopeToken`] for that device
///
/// # Errors
///
/// Returns an error if the device ID contains characters that can't be
/// encoded in a scope
pub fn to_scope_token(&self) -> Result<ScopeToken, ToScopeTokenError> {
format!("{DEVICE_SCOPE_PREFIX}{}", self.id)
.parse()
.map_err(|_| ToScopeTokenError::InvalidCharacters)
pub fn to_scope_token(&self) -> Result<[ScopeToken; 2], ToScopeTokenError> {
Ok([
format!("{STABLE_DEVICE_SCOPE_PREFIX}{}", self.id)
.parse()
.map_err(|_| ToScopeTokenError::InvalidCharacters)?,
format!("{UNSTABLE_DEVICE_SCOPE_PREFIX}{}", self.id)
.parse()
.map_err(|_| ToScopeTokenError::InvalidCharacters)?,
])
}

/// Get the corresponding [`Device`] from a [`ScopeToken`]
///
/// Returns `None` if the [`ScopeToken`] is not a device scope
#[must_use]
pub fn from_scope_token(token: &ScopeToken) -> Option<Self> {
let id = token.as_str().strip_prefix(DEVICE_SCOPE_PREFIX)?;
let stable = token.as_str().strip_prefix(STABLE_DEVICE_SCOPE_PREFIX);
let unstable = token.as_str().strip_prefix(UNSTABLE_DEVICE_SCOPE_PREFIX);
let id = stable.or(unstable)?;
Some(Device::from(id.to_owned()))
}

Expand Down Expand Up @@ -89,12 +97,23 @@ mod test {
#[test]
fn test_device_id_to_from_scope_token() {
let device = Device::from("AABBCCDDEE".to_owned());
let scope_token = device.to_scope_token().unwrap();
let [stable_scope_token, unstable_scope_token] = device.to_scope_token().unwrap();
assert_eq!(
scope_token.as_str(),
stable_scope_token.as_str(),
"urn:matrix:client:device:AABBCCDDEE"
);
assert_eq!(
unstable_scope_token.as_str(),
"urn:matrix:org.matrix.msc2967.client:device:AABBCCDDEE"
);
assert_eq!(Device::from_scope_token(&scope_token), Some(device));
assert_eq!(
Device::from_scope_token(&unstable_scope_token).as_ref(),
Some(&device)
);
assert_eq!(
Device::from_scope_token(&stable_scope_token).as_ref(),
Some(&device)
);
assert_eq!(Device::from_scope_token(&OPENID), None);
}
}
2 changes: 1 addition & 1 deletion crates/handlers/src/admin/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl OAuth2Session {
user_id: Some(Ulid::from_bytes([0x04; 16])),
user_session_id: Some(Ulid::from_bytes([0x05; 16])),
client_id: Ulid::from_bytes([0x06; 16]),
scope: "urn:matrix:org.matrix.msc2967.client:api:*".to_owned(),
scope: "urn:matrix:client:api:*".to_owned(),
user_agent: Some("Mozilla/5.0".to_owned()),
last_active_at: Some(DateTime::default()),
last_active_ip: Some("127.0.0.1".parse().unwrap()),
Expand Down
14 changes: 2 additions & 12 deletions crates/handlers/src/graphql/query/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use mas_storage::{
compat::{CompatSessionFilter, CompatSessionRepository},
oauth2::OAuth2SessionFilter,
};
use oauth2_types::scope::Scope;

use crate::graphql::{
UserId,
Expand Down Expand Up @@ -77,20 +76,11 @@ impl SessionQuery {
))));
}

// Then, try to find an OAuth 2.0 session. Because we don't have any dedicated
// device column, we're looking up using the device scope.
// All device IDs can't necessarily be encoded as a scope. If it's not the case,
// we'll skip looking for OAuth 2.0 sessions.
let Ok(scope_token) = device.to_scope_token() else {
repo.cancel().await?;

return Ok(None);
};
let scope = Scope::from_iter([scope_token]);
// Then, try to find an OAuth 2.0 session.
let filter = OAuth2SessionFilter::new()
.for_user(&user)
.active_only()
.with_scope(&scope);
.for_device(&device);
let sessions = repo.oauth2_session().list(filter, pagination).await?;

// It's possible to have multiple active OAuth 2.0 sessions. For now, we just
Expand Down
50 changes: 39 additions & 11 deletions crates/handlers/src/oauth2/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::LazyLock;
use std::{collections::BTreeSet, sync::LazyLock};

use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
use hyper::{HeaderMap, StatusCode};
Expand All @@ -24,7 +24,7 @@ use mas_storage::{
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
requests::{IntrospectionRequest, IntrospectionResponse},
scope::ScopeToken,
scope::{Scope, ScopeToken},
};
use opentelemetry::{Key, KeyValue, metrics::Counter};
use thiserror::Error;
Expand Down Expand Up @@ -190,9 +190,33 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
device_id: None,
};

const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc2967.client:api:*");
const UNSTABLE_API_SCOPE: ScopeToken =
ScopeToken::from_static("urn:matrix:org.matrix.msc2967.client:api:*");
const STABLE_API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:client:api:*");
const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:admin:*");

/// Normalize a scope by adding the stable and unstable API scopes equivalents
/// if missing
fn normalize_scope(mut scope: Scope) -> Scope {
// Here we abuse the fact that the scope is a BTreeSet to not care about
// duplicates
let mut to_add = BTreeSet::new();
for token in &*scope {
if token == &STABLE_API_SCOPE {
to_add.insert(UNSTABLE_API_SCOPE);
} else if token == &UNSTABLE_API_SCOPE {
to_add.insert(STABLE_API_SCOPE);
} else if let Some(device) = Device::from_scope_token(token) {
let tokens = device
.to_scope_token()
.expect("from/to scope token rountrip should never fail");
to_add.extend(tokens);
}
}
scope.append(&mut to_add);
scope
}

#[tracing::instrument(
name = "handlers.oauth2.introspection.post",
fields(client.id = client_authorization.client_id()),
Expand Down Expand Up @@ -311,9 +335,11 @@ pub(crate) async fn post(
],
);

let scope = normalize_scope(session.scope);

IntrospectionResponse {
active: true,
scope: Some(session.scope),
scope: Some(scope),
client_id: Some(session.client_id.to_string()),
username,
token_type: Some(OAuthTokenTypeHint::AccessToken),
Expand Down Expand Up @@ -382,9 +408,11 @@ pub(crate) async fn post(
],
);

let scope = normalize_scope(session.scope);

IntrospectionResponse {
active: true,
scope: Some(session.scope),
scope: Some(scope),
client_id: Some(session.client_id.to_string()),
username,
token_type: Some(OAuthTokenTypeHint::RefreshToken),
Expand Down Expand Up @@ -446,9 +474,9 @@ pub(crate) async fn post(
.transpose()?
};

let scope = [API_SCOPE]
let scope = [STABLE_API_SCOPE, UNSTABLE_API_SCOPE]
.into_iter()
.chain(device_scope_opt)
.chain(device_scope_opt.into_iter().flatten())
.chain(synapse_admin_scope_opt)
.collect();

Expand Down Expand Up @@ -530,9 +558,9 @@ pub(crate) async fn post(
.transpose()?
};

let scope = [API_SCOPE]
let scope = [STABLE_API_SCOPE, UNSTABLE_API_SCOPE]
.into_iter()
.chain(device_scope_opt)
.chain(device_scope_opt.into_iter().flatten())
.chain(synapse_admin_scope_opt)
.collect();

Expand Down Expand Up @@ -879,7 +907,7 @@ mod tests {
let refresh_token = response["refresh_token"].as_str().unwrap();
let device_id = response["device_id"].as_str().unwrap();
let expected_scope: Scope =
format!("urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id}")
format!("urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id} urn:matrix:client:api:* urn:matrix:client:device:{device_id}")
.parse()
.unwrap();

Expand Down Expand Up @@ -912,7 +940,7 @@ mod tests {
assert_eq!(response.token_type, Some(OAuthTokenTypeHint::AccessToken));
assert_eq!(
response.scope.map(|s| s.to_string()),
Some("urn:matrix:org.matrix.msc2967.client:api:*".to_owned())
Some("urn:matrix:client:api:* urn:matrix:org.matrix.msc2967.client:api:*".to_owned())
);
assert_eq!(response.device_id.as_deref(), Some(device_id));

Expand Down
17 changes: 15 additions & 2 deletions crates/oauth2-types/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

#![allow(clippy::module_name_repetitions)]

use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
use std::{
borrow::Cow,
collections::BTreeSet,
iter::FromIterator,
ops::{Deref, DerefMut},
str::FromStr,
};

use serde::{Deserialize, Serialize};
use thiserror::Error;
Expand Down Expand Up @@ -121,6 +127,12 @@ impl Deref for Scope {
}
}

impl DerefMut for Scope {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl FromStr for Scope {
type Err = InvalidScope;

Expand Down Expand Up @@ -248,6 +260,7 @@ mod tests {
);

assert!(Scope::from_str("http://example.com").is_ok());
assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:*").is_ok());
assert!(Scope::from_str("urn:matrix:client:api:*").is_ok());
assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:api:*").is_ok());
}
}

This file was deleted.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 14 additions & 4 deletions crates/storage-pg/src/app_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,17 +499,24 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
.instrument(span)
.await?;

if let Ok(device_as_scope_token) = device.to_scope_token() {
if let Ok([stable_device_as_scope_token, unstable_device_as_scope_token]) =
device.to_scope_token()
{
let span = tracing::info_span!(
"db.app_session.finish_sessions_to_replace_device.oauth2_sessions",
{ DB_QUERY_TEXT } = tracing::field::Empty,
);
sqlx::query!(
"
UPDATE oauth2_sessions SET finished_at = $3 WHERE user_id = $1 AND $2 = ANY(scope_list) AND finished_at IS NULL
UPDATE oauth2_sessions
SET finished_at = $4
WHERE user_id = $1
AND ($2 = ANY(scope_list) OR $3 = ANY(scope_list))
AND finished_at IS NULL
",
Uuid::from(user.id),
device_as_scope_token.as_str(),
stable_device_as_scope_token.as_str(),
unstable_device_as_scope_token.as_str(),
finished_at
)
.record(&span)
Expand Down Expand Up @@ -652,7 +659,10 @@ mod tests {
.unwrap();

let device2 = Device::generate(&mut rng);
let scope = Scope::from_iter([OPENID, device2.to_scope_token().unwrap()]);
let scope: Scope = [OPENID]
.into_iter()
.chain(device2.to_scope_token().unwrap().into_iter())
.collect();

// We're moving the clock forward by 1 minute between each session to ensure
// we're getting consistent ordering in lists.
Expand Down
24 changes: 17 additions & 7 deletions crates/storage-pg/src/oauth2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use mas_storage::{
};
use oauth2_types::scope::{Scope, ScopeToken};
use rand::RngCore;
use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
use sea_query::{
Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
extension::postgres::PgExpr,
};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
Expand Down Expand Up @@ -126,12 +129,19 @@ impl Filter for OAuth2SessionFilter<'_> {
.ne(Expr::all(static_clients))
}
}))
.add_option(self.device().map(|device| {
if let Ok(scope_token) = device.to_scope_token() {
Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
.add_option(self.device().map(|device| -> SimpleExpr {
if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
Condition::any()
.add(
Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
)))),
)
.add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
)))
.into()
} else {
// If the device ID can't be encoded as a scope token, match no rows
Expr::val(false).into()
Expand Down
2 changes: 1 addition & 1 deletion docs/api/spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@
"user_id": "040G2081040G2081040G208104",
"user_session_id": "050M2GA1850M2GA1850M2GA185",
"client_id": "060R30C1G60R30C1G60R30C1G6",
"scope": "urn:matrix:org.matrix.msc2967.client:api:*",
"scope": "urn:matrix:client:api:*",
"user_agent": "Mozilla/5.0",
"last_active_at": "1970-01-01T00:00:00Z",
"last_active_ip": "127.0.0.1",
Expand Down
Loading
Loading