Skip to content

Commit fb2be47

Browse files
committed
refactor: Adapt to use of SupportedVersions for endpoints
Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
1 parent fd5d889 commit fb2be47

File tree

8 files changed

+71
-33
lines changed

8 files changed

+71
-33
lines changed

crates/matrix-sdk-base/src/store/integration_tests.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,10 @@ impl StateStoreIntegrationTests for DynStateStore {
476476
}
477477

478478
async fn test_server_info_saving(&self) {
479-
let versions = &[MatrixVersion::V1_1, MatrixVersion::V1_2, MatrixVersion::V1_11];
479+
let versions =
480+
BTreeSet::from([MatrixVersion::V1_1, MatrixVersion::V1_2, MatrixVersion::V1_11]);
480481
let server_info = ServerInfo::new(
481-
versions.iter().map(|version| version.to_string()).collect(),
482+
versions.iter().map(|version| version.as_str().unwrap().to_owned()).collect(),
482483
[("org.matrix.experimental".to_owned(), true)].into(),
483484
Some(WellKnownResponse {
484485
homeserver: HomeserverInfo::new("matrix.example.com".to_owned()),

crates/matrix-sdk-base/src/store/traits.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,10 +1008,10 @@ impl ServerInfo {
10081008
///
10091009
/// Note: Matrix versions that Ruma cannot parse, or does not know about,
10101010
/// are discarded.
1011-
pub fn known_versions(&self) -> Vec<MatrixVersion> {
1011+
pub fn known_versions(&self) -> BTreeSet<MatrixVersion> {
10121012
get_supported_versions::Response::new(self.versions.clone())
1013-
.known_versions()
1014-
.collect::<Vec<_>>()
1013+
.as_supported_versions()
1014+
.versions
10151015
}
10161016
}
10171017

crates/matrix-sdk/src/authentication/matrix/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,20 @@ impl MatrixAuth {
105105
idp_id: Option<&str>,
106106
) -> Result<String> {
107107
let homeserver = self.client.homeserver();
108-
let server_versions = self.client.server_versions().await?;
108+
let supported_versions = self.client.supported_versions().await?;
109109

110110
let request = if let Some(id) = idp_id {
111111
sso_login_with_provider::v3::Request::new(id.to_owned(), redirect_url.to_owned())
112112
.try_into_http_request::<Vec<u8>>(
113113
homeserver.as_str(),
114114
SendAccessToken::None,
115-
&server_versions,
115+
&supported_versions,
116116
)
117117
} else {
118118
sso_login::v3::Request::new(redirect_url.to_owned()).try_into_http_request::<Vec<u8>>(
119119
homeserver.as_str(),
120120
SendAccessToken::None,
121-
&server_versions,
121+
&supported_versions,
122122
)
123123
};
124124

crates/matrix-sdk/src/authentication/oauth/qrcode/rendezvous_channel.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,20 @@ impl RendezvousChannel {
114114
client: HttpClient,
115115
rendezvous_server: &Url,
116116
) -> Result<Self, HttpError> {
117-
use ruma::api::client::rendezvous::create_rendezvous_session;
117+
use std::collections::BTreeSet;
118+
119+
use ruma::api::{client::rendezvous::create_rendezvous_session, SupportedVersions};
118120

119121
let request = create_rendezvous_session::unstable::Request::default();
120122
let response = client
121-
.send(request, None, rendezvous_server.to_string(), None, &[], Default::default())
123+
.send(
124+
request,
125+
None,
126+
rendezvous_server.to_string(),
127+
None,
128+
&SupportedVersions { versions: BTreeSet::new(), features: vec![] },
129+
Default::default(),
130+
)
122131
.await?;
123132

124133
let rendezvous_url = response.url;

crates/matrix-sdk/src/client/builder/homeserver_config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
use ruma::{
1616
api::{
1717
client::discovery::{discover_homeserver, get_supported_versions},
18-
MatrixVersion,
18+
MatrixVersion, SupportedVersions,
1919
},
2020
OwnedServerName, ServerName,
2121
};
@@ -185,7 +185,7 @@ async fn discover_homeserver(
185185
Some(RequestConfig::short_retry()),
186186
server.to_string(),
187187
None,
188-
&[MatrixVersion::V1_0],
188+
&SupportedVersions { versions: [MatrixVersion::V1_0].into(), features: vec![] },
189189
Default::default(),
190190
)
191191
.await
@@ -209,7 +209,7 @@ pub(super) async fn get_supported_versions(
209209
Some(RequestConfig::short_retry()),
210210
homeserver_url.to_string(),
211211
None,
212-
&[MatrixVersion::V1_0],
212+
&SupportedVersions { versions: [MatrixVersion::V1_0].into(), features: vec![] },
213213
Default::default(),
214214
)
215215
.await

crates/matrix-sdk/src/client/builder/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod homeserver_config;
1717

1818
#[cfg(feature = "sqlite")]
1919
use std::path::Path;
20-
use std::{fmt, sync::Arc};
20+
use std::{collections::BTreeSet, fmt, sync::Arc};
2121

2222
use homeserver_config::*;
2323
#[cfg(feature = "e2e-encryption")]
@@ -101,7 +101,7 @@ pub struct ClientBuilder {
101101
store_config: BuilderStoreConfig,
102102
request_config: RequestConfig,
103103
respect_login_well_known: bool,
104-
server_versions: Option<Box<[MatrixVersion]>>,
104+
server_versions: Option<BTreeSet<MatrixVersion>>,
105105
handle_refresh_tokens: bool,
106106
base_client: Option<BaseClient>,
107107
#[cfg(feature = "e2e-encryption")]

crates/matrix-sdk/src/client/mod.rs

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
// limitations under the License.
1616

1717
use std::{
18-
collections::{btree_map, BTreeMap},
18+
collections::{btree_map, BTreeMap, BTreeSet},
1919
fmt::{self, Debug},
2020
future::{ready, Future},
2121
pin::Pin,
2222
sync::{Arc, Mutex as StdMutex, RwLock as StdRwLock, Weak},
2323
};
2424

25+
use as_variant::as_variant;
2526
use caches::ClientCaches;
2627
use eyeball::{SharedObservable, Subscriber};
2728
use eyeball_im::{Vector, VectorDiff};
@@ -63,7 +64,7 @@ use ruma::{
6364
user_directory::search_users,
6465
},
6566
error::FromHttpResponseError,
66-
MatrixVersion, OutgoingRequest,
67+
MatrixVersion, OutgoingRequest, SupportedVersions,
6768
},
6869
assign,
6970
push::Ruleset,
@@ -1791,7 +1792,7 @@ impl Client {
17911792
config,
17921793
homeserver,
17931794
access_token.as_deref(),
1794-
&self.server_versions().await?,
1795+
&self.supported_versions().await?,
17951796
send_progress,
17961797
)
17971798
.await
@@ -1818,7 +1819,7 @@ impl Client {
18181819
request_config,
18191820
self.homeserver().to_string(),
18201821
None,
1821-
&[MatrixVersion::V1_0],
1822+
&SupportedVersions { versions: [MatrixVersion::V1_0].into(), features: vec![] },
18221823
Default::default(),
18231824
)
18241825
.await?;
@@ -1845,7 +1846,7 @@ impl Client {
18451846
Some(RequestConfig::short_retry()),
18461847
server_url_string,
18471848
None,
1848-
&[MatrixVersion::V1_0],
1849+
&SupportedVersions { versions: [MatrixVersion::V1_0].into(), features: vec![] },
18491850
Default::default(),
18501851
)
18511852
.await;
@@ -1928,10 +1929,10 @@ impl Client {
19281929
// Fill both unstable features and server versions at once.
19291930
let mut versions = server_info.known_versions();
19301931
if versions.is_empty() {
1931-
versions.push(MatrixVersion::V1_0);
1932+
versions.insert(MatrixVersion::V1_0);
19321933
}
19331934

1934-
guarded_server_info.server_versions = CachedValue::Cached(versions.into());
1935+
guarded_server_info.server_versions = CachedValue::Cached(versions);
19351936
guarded_server_info.unstable_features = CachedValue::Cached(server_info.unstable_features);
19361937
guarded_server_info.well_known = CachedValue::Cached(server_info.well_known);
19371938

@@ -1958,11 +1959,32 @@ impl Client {
19581959
/// println!("The homeserver supports Matrix 1.1: {supports_1_1:?}");
19591960
/// # anyhow::Ok(()) };
19601961
/// ```
1961-
pub async fn server_versions(&self) -> HttpResult<Box<[MatrixVersion]>> {
1962+
pub async fn server_versions(&self) -> HttpResult<BTreeSet<MatrixVersion>> {
19621963
self.get_or_load_and_cache_server_info(|server_info| server_info.server_versions.clone())
19631964
.await
19641965
}
19651966

1967+
pub(crate) async fn supported_versions(&self) -> HttpResult<SupportedVersions> {
1968+
self.get_or_load_and_cache_server_info(|server_info| {
1969+
match server_info
1970+
.server_versions
1971+
.as_cached_value()
1972+
.zip(server_info.unstable_features.as_cached_value())
1973+
{
1974+
Some((versions, features)) => CachedValue::Cached(SupportedVersions {
1975+
versions: versions.iter().copied().collect(),
1976+
features: features
1977+
.iter()
1978+
.filter(|(_, enabled)| **enabled)
1979+
.map(|(feature, _)| feature.clone())
1980+
.collect(),
1981+
}),
1982+
None => CachedValue::NotSet,
1983+
}
1984+
})
1985+
.await
1986+
}
1987+
19661988
/// Get the unstable features supported by the homeserver by fetching them
19671989
/// from the server or the cache.
19681990
///
@@ -2794,7 +2816,7 @@ impl WeakClient {
27942816
#[derive(Clone)]
27952817
struct ClientServerInfo {
27962818
/// The Matrix versions the server supports (known ones only).
2797-
server_versions: CachedValue<Box<[MatrixVersion]>>,
2819+
server_versions: CachedValue<BTreeSet<MatrixVersion>>,
27982820

27992821
/// The unstable features and their on/off state on the server.
28002822
unstable_features: CachedValue<BTreeMap<String, bool>>,
@@ -2815,6 +2837,11 @@ enum CachedValue<Value> {
28152837
}
28162838

28172839
impl<Value> CachedValue<Value> {
2840+
/// Return the cached value, it if it exists.
2841+
fn as_cached_value(&self) -> Option<&Value> {
2842+
as_variant!(self, CachedValue::Cached)
2843+
}
2844+
28182845
/// Unwraps the cached value, returning it if it exists.
28192846
///
28202847
/// # Panics

crates/matrix-sdk/src/http_client/mod.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
use std::{
1616
any::type_name,
17+
borrow::Cow,
1718
fmt::Debug,
1819
num::NonZeroUsize,
1920
sync::{
@@ -29,7 +30,7 @@ use eyeball::SharedObservable;
2930
use http::Method;
3031
use ruma::api::{
3132
error::{FromHttpResponseError, IntoHttpError},
32-
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
33+
AuthScheme, OutgoingRequest, SendAccessToken, SupportedVersions,
3334
};
3435
use tokio::sync::{Semaphore, SemaphorePermit};
3536
use tracing::{debug, field::debug, instrument, trace};
@@ -101,17 +102,17 @@ impl HttpClient {
101102
config: RequestConfig,
102103
homeserver: String,
103104
access_token: Option<&str>,
104-
server_versions: &[MatrixVersion],
105+
supported_versions: &SupportedVersions,
105106
) -> Result<http::Request<Bytes>, IntoHttpError>
106107
where
107108
R: OutgoingRequest + Debug,
108109
{
109110
trace!(request_type = type_name::<R>(), "Serializing request");
110111

111-
let server_versions = if config.force_matrix_version.is_some() {
112-
config.force_matrix_version.as_slice()
112+
let supported_versions = if let Some(matrix_version) = config.force_matrix_version {
113+
Cow::Owned(SupportedVersions { versions: [matrix_version].into(), features: vec![] })
113114
} else {
114-
server_versions
115+
Cow::Borrowed(supported_versions)
115116
};
116117

117118
let send_access_token = match access_token {
@@ -126,15 +127,15 @@ impl HttpClient {
126127
};
127128

128129
let request = request
129-
.try_into_http_request::<BytesMut>(&homeserver, send_access_token, server_versions)?
130+
.try_into_http_request::<BytesMut>(&homeserver, send_access_token, &supported_versions)?
130131
.map(|body| body.freeze());
131132

132133
Ok(request)
133134
}
134135

135136
#[allow(clippy::too_many_arguments)]
136137
#[instrument(
137-
skip(self, request, config, homeserver, access_token, server_versions, send_progress),
138+
skip(self, request, config, homeserver, access_token, supported_versions, send_progress),
138139
fields(uri, method, request_size, request_id, status, response_size, sentry_event_id)
139140
)]
140141
pub async fn send<R>(
@@ -143,7 +144,7 @@ impl HttpClient {
143144
config: Option<RequestConfig>,
144145
homeserver: String,
145146
access_token: Option<&str>,
146-
server_versions: &[MatrixVersion],
147+
supported_versions: &SupportedVersions,
147148
send_progress: SharedObservable<TransmissionProgress>,
148149
) -> Result<R::IncomingResponse, HttpError>
149150
where
@@ -177,7 +178,7 @@ impl HttpClient {
177178
}
178179

179180
let request = self
180-
.serialize_request(request, config, homeserver, access_token, server_versions)
181+
.serialize_request(request, config, homeserver, access_token, supported_versions)
181182
.map_err(HttpError::IntoHttp)?;
182183

183184
let method = request.method();

0 commit comments

Comments
 (0)