@@ -2,7 +2,7 @@ use anyhow::Context as _;
2
2
use assert_matches:: assert_matches;
3
3
use matrix_sdk_base:: store:: RoomLoadSettings ;
4
4
use matrix_sdk_test:: async_test;
5
- use oauth2:: { ClientId , CsrfToken , PkceCodeChallenge , RedirectUrl } ;
5
+ use oauth2:: { ClientId , CsrfToken , PkceCodeChallenge , RedirectUrl , Scope } ;
6
6
use ruma:: {
7
7
api:: client:: discovery:: get_authorization_server_metadata:: v1:: Prompt , device_id,
8
8
owned_device_id, user_id, DeviceId , ServerName ,
@@ -56,6 +56,7 @@ async fn check_authorization_url(
56
56
device_id : Option < & DeviceId > ,
57
57
expected_prompt : Option < & str > ,
58
58
expected_login_hint : Option < & str > ,
59
+ additional_scopes : Option < Vec < Scope > > ,
59
60
) {
60
61
tracing:: debug!( "authorization data URL = {}" , authorization_data. url) ;
61
62
@@ -85,15 +86,45 @@ async fn check_authorization_url(
85
86
num_expected -= 1 ;
86
87
}
87
88
"scope" => {
88
- let expected_start = "urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:" ;
89
- assert ! ( val. starts_with( expected_start) ) ;
90
- assert ! ( val. len( ) > expected_start. len( ) ) ;
89
+ let actual_scopes: Vec < String > = val. split ( ' ' ) . map ( String :: from) . collect ( ) ;
90
+
91
+ assert ! ( actual_scopes. len( ) >= 2 , "Expected at least two scopes" ) ;
92
+
93
+ assert ! (
94
+ actual_scopes
95
+ . contains( & "urn:matrix:org.matrix.msc2967.client:api:*" . to_owned( ) ) ,
96
+ "Expected Matrix API scope not found in scopes"
97
+ ) ;
91
98
92
99
// Only check the device ID if we know it. If it's generated randomly we don't
93
100
// know it.
94
101
if let Some ( device_id) = device_id {
95
- assert ! ( val. ends_with( device_id. as_str( ) ) ) ;
96
- assert_eq ! ( val. len( ) , expected_start. len( ) + device_id. as_str( ) . len( ) ) ;
102
+ let device_id_scope =
103
+ format ! ( "urn:matrix:org.matrix.msc2967.client:device:{device_id}" ) ;
104
+ assert ! (
105
+ actual_scopes. contains( & device_id_scope) ,
106
+ "Expected device ID scope not found in scopes"
107
+ )
108
+ } else {
109
+ assert ! (
110
+ actual_scopes
111
+ . iter( )
112
+ . any( |s| s. starts_with( "urn:matrix:org.matrix.msc2967.client:device:" ) ) ,
113
+ "Expected device ID scope not found in scopes"
114
+ ) ;
115
+ }
116
+
117
+ if let Some ( additional_scopes) = & additional_scopes {
118
+ // Check if the additional scopes are present in the actual scopes.
119
+ let expected_len = 2 + additional_scopes. len ( ) ;
120
+ assert_eq ! ( actual_scopes. len( ) , expected_len, "Expected {expected_len} scopes" , ) ;
121
+
122
+ for scope in additional_scopes {
123
+ assert ! (
124
+ actual_scopes. contains( scope) ,
125
+ "Expected additional scope not found in scopes: {scope:?}" ,
126
+ ) ;
127
+ }
97
128
}
98
129
99
130
num_expected -= 1 ;
@@ -146,7 +177,7 @@ async fn test_high_level_login() -> anyhow::Result<()> {
146
177
147
178
// When getting the OIDC login URL.
148
179
let authorization_data = oauth
149
- . login ( redirect_uri. clone ( ) , None , Some ( registration_data) )
180
+ . login ( redirect_uri. clone ( ) , None , Some ( registration_data) , None )
150
181
. prompt ( vec ! [ Prompt :: Create ] )
151
182
. build ( )
152
183
. await
@@ -169,12 +200,16 @@ async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
169
200
// Given a client ready to complete login.
170
201
let ( oauth, server, mut redirect_uri, registration_data) = mock_environment ( ) . await . unwrap ( ) ;
171
202
172
- let authorization_data =
173
- oauth. login ( redirect_uri. clone ( ) , None , Some ( registration_data) ) . build ( ) . await . unwrap ( ) ;
203
+ let authorization_data = oauth
204
+ . login ( redirect_uri. clone ( ) , None , Some ( registration_data) , None )
205
+ . build ( )
206
+ . await
207
+ . unwrap ( ) ;
174
208
175
209
assert_eq ! ( oauth. client_id( ) . map( |id| id. as_str( ) ) , Some ( "test_client_id" ) ) ;
176
210
177
- check_authorization_url ( & authorization_data, & oauth, & server. uri ( ) , None , None , None ) . await ;
211
+ check_authorization_url ( & authorization_data, & oauth, & server. uri ( ) , None , None , None , None )
212
+ . await ;
178
213
179
214
// When completing login with a cancellation callback.
180
215
redirect_uri. set_query ( Some ( & format ! (
@@ -200,12 +235,16 @@ async fn test_high_level_login_invalid_state() -> anyhow::Result<()> {
200
235
// Given a client ready to complete login.
201
236
let ( oauth, server, mut redirect_uri, registration_data) = mock_environment ( ) . await . unwrap ( ) ;
202
237
203
- let authorization_data =
204
- oauth. login ( redirect_uri. clone ( ) , None , Some ( registration_data) ) . build ( ) . await . unwrap ( ) ;
238
+ let authorization_data = oauth
239
+ . login ( redirect_uri. clone ( ) , None , Some ( registration_data) , None )
240
+ . build ( )
241
+ . await
242
+ . unwrap ( ) ;
205
243
206
244
assert_eq ! ( oauth. client_id( ) . map( |id| id. as_str( ) ) , Some ( "test_client_id" ) ) ;
207
245
208
- check_authorization_url ( & authorization_data, & oauth, & server. uri ( ) , None , None , None ) . await ;
246
+ check_authorization_url ( & authorization_data, & oauth, & server. uri ( ) , None , None , None , None )
247
+ . await ;
209
248
210
249
// When completing login with an old/tampered state.
211
250
redirect_uri. set_query ( Some ( "code=42&state=imposter_alert" ) ) ;
@@ -229,7 +268,7 @@ async fn test_login_url() -> anyhow::Result<()> {
229
268
let server_uri = server. uri ( ) ;
230
269
231
270
let oauth_server = server. oauth ( ) ;
232
- oauth_server. mock_server_metadata ( ) . ok ( ) . expect ( 3 ) . mount ( ) . await ;
271
+ oauth_server. mock_server_metadata ( ) . ok ( ) . expect ( 4 ) . mount ( ) . await ;
233
272
234
273
let client = server. client_builder ( ) . registered_with_oauth ( ) . build ( ) . await ;
235
274
let oauth = client. oauth ( ) ;
@@ -239,15 +278,26 @@ async fn test_login_url() -> anyhow::Result<()> {
239
278
let redirect_uri_str = REDIRECT_URI_STRING ;
240
279
let redirect_uri = Url :: parse ( redirect_uri_str) ?;
241
280
281
+ let additional_scopes =
282
+ vec ! [ Scope :: new( "urn:test:scope1" . to_owned( ) ) , Scope :: new( "urn:test:scope2" . to_owned( ) ) ] ;
283
+
242
284
// No extra parameters.
243
285
let authorization_data =
244
- oauth. login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None ) . build ( ) . await ?;
245
- check_authorization_url ( & authorization_data, & oauth, & server_uri, Some ( & device_id) , None , None )
246
- . await ;
286
+ oauth. login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None , None ) . build ( ) . await ?;
287
+ check_authorization_url (
288
+ & authorization_data,
289
+ & oauth,
290
+ & server_uri,
291
+ Some ( & device_id) ,
292
+ None ,
293
+ None ,
294
+ None ,
295
+ )
296
+ . await ;
247
297
248
298
// With prompt parameter.
249
299
let authorization_data = oauth
250
- . login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None )
300
+ . login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None , None )
251
301
. prompt ( vec ! [ Prompt :: Create ] )
252
302
. build ( )
253
303
. await ?;
@@ -258,12 +308,13 @@ async fn test_login_url() -> anyhow::Result<()> {
258
308
Some ( & device_id) ,
259
309
Some ( "create" ) ,
260
310
None ,
311
+ None ,
261
312
)
262
313
. await ;
263
314
264
315
// With user_id_hint parameter.
265
316
let authorization_data = oauth
266
- . login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None )
317
+ . login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None , None )
267
318
. user_id_hint ( user_id ! ( "@joe:example.org" ) )
268
319
. build ( )
269
320
. await ?;
@@ -274,6 +325,23 @@ async fn test_login_url() -> anyhow::Result<()> {
274
325
Some ( & device_id) ,
275
326
None ,
276
327
Some ( "mxid:@joe:example.org" ) ,
328
+ None ,
329
+ )
330
+ . await ;
331
+
332
+ // With additional scopes.
333
+ let authorization_data = oauth
334
+ . login ( redirect_uri. clone ( ) , Some ( device_id. clone ( ) ) , None , Some ( additional_scopes. clone ( ) ) )
335
+ . build ( )
336
+ . await ?;
337
+ check_authorization_url (
338
+ & authorization_data,
339
+ & oauth,
340
+ & server_uri,
341
+ Some ( & device_id) ,
342
+ None ,
343
+ None ,
344
+ Some ( additional_scopes) ,
277
345
)
278
346
. await ;
279
347
0 commit comments