@@ -427,22 +427,55 @@ fn is_session_overshared_for_user(
427
427
let recipient_device_ids: BTreeSet < & DeviceId > =
428
428
recipient_devices. iter ( ) . map ( |d| d. device_id ( ) ) . collect ( ) ;
429
429
430
+ let mut shared: Vec < & DeviceId > = Vec :: new ( ) ;
431
+
432
+ // This duplicates a conservative subset of the logic in
433
+ // `OutboundGroupSession::is_shared_with`, because we
434
+ // don't have corresponding DeviceData at hand
435
+ fn is_actually_shared ( info : & ShareInfo ) -> bool {
436
+ match info {
437
+ ShareInfo :: Shared ( _) => true ,
438
+ ShareInfo :: Withheld ( _) => false ,
439
+ }
440
+ }
441
+
442
+ // Collect the devices that have definitely received the session already
430
443
let guard = outbound_session. shared_with_set . read ( ) ;
444
+ if let Some ( for_user) = guard. get ( user_id) {
445
+ shared. extend ( for_user. iter ( ) . filter_map ( |( d, info) | {
446
+ if is_actually_shared ( info) {
447
+ Some ( AsRef :: < DeviceId > :: as_ref ( d) )
448
+ } else {
449
+ None
450
+ }
451
+ } ) ) ;
452
+ }
453
+
454
+ // To be conservative, also collect the devices that would still receive the
455
+ // session from a pending to-device request if we don't rotate beforehand
456
+ let guard = outbound_session. to_share_with_set . read ( ) ;
457
+ for ( _txid, share_infos) in guard. values ( ) {
458
+ if let Some ( for_user) = share_infos. get ( user_id) {
459
+ shared. extend ( for_user. iter ( ) . filter_map ( |( d, info) | {
460
+ if is_actually_shared ( info) {
461
+ Some ( AsRef :: < DeviceId > :: as_ref ( d) )
462
+ } else {
463
+ None
464
+ }
465
+ } ) ) ;
466
+ }
467
+ }
431
468
432
- let Some ( shared) = guard . get ( user_id ) else {
469
+ if shared. is_empty ( ) {
433
470
return false ;
434
- } ;
471
+ }
435
472
436
- // Devices that received this session
437
- let shared: BTreeSet < & DeviceId > = shared
438
- . iter ( )
439
- . filter ( |( _, info) | matches ! ( info, ShareInfo :: Shared ( _) ) )
440
- . map ( |( d, _) | d. as_ref ( ) )
441
- . collect ( ) ;
473
+ let shared: BTreeSet < & DeviceId > = shared. into_iter ( ) . collect ( ) ;
442
474
443
475
// The set difference between
444
476
//
445
- // 1. Devices that had previously received the session, and
477
+ // 1. Devices that had previously received (or are queued to receive) the
478
+ // session, and
446
479
// 2. Devices that would now receive the session
447
480
//
448
481
// Represents newly deleted or blacklisted devices. If this
@@ -729,17 +762,21 @@ mod tests {
729
762
} ,
730
763
} ;
731
764
use ruma:: {
732
- device_id, events:: room:: history_visibility:: HistoryVisibility , room_id, TransactionId ,
765
+ device_id,
766
+ events:: { dummy:: ToDeviceDummyEventContent , room:: history_visibility:: HistoryVisibility } ,
767
+ room_id, TransactionId ,
733
768
} ;
734
769
use serde_json:: json;
735
770
736
771
use crate :: {
737
772
error:: SessionRecipientCollectionError ,
738
- olm:: OutboundGroupSession ,
773
+ olm:: { OutboundGroupSession , ShareInfo } ,
739
774
session_manager:: {
740
775
group_sessions:: share_strategy:: collect_session_recipients, CollectStrategy ,
741
776
} ,
777
+ store:: caches:: SequenceNumber ,
742
778
testing:: simulate_key_query_response_for_verification,
779
+ types:: requests:: ToDeviceRequest ,
743
780
CrossSigningKeyExport , EncryptionSettings , LocalTrust , OlmError , OlmMachine ,
744
781
} ;
745
782
@@ -2136,6 +2173,61 @@ mod tests {
2136
2173
assert ! ( share_result. should_rotate) ;
2137
2174
}
2138
2175
2176
+ /// Test that the session is rotated if a devices has a pending
2177
+ /// to-device request that would share the keys with it.
2178
+ #[ async_test]
2179
+ async fn test_should_rotate_based_on_device_with_pending_request_excluded ( ) {
2180
+ let machine = test_machine ( ) . await ;
2181
+ import_known_users_to_test_machine ( & machine) . await ;
2182
+
2183
+ let encryption_settings = all_devices_strategy_settings ( ) ;
2184
+ let group_session = create_test_outbound_group_session ( & machine, & encryption_settings) ;
2185
+ let sender_key = machine. identity_keys ( ) . curve25519 ;
2186
+
2187
+ let dan_user = KeyDistributionTestData :: dan_id ( ) ;
2188
+ let dan_dev1 = KeyDistributionTestData :: dan_signed_device_id ( ) ;
2189
+ let dan_dev2 = KeyDistributionTestData :: dan_unsigned_device_id ( ) ;
2190
+
2191
+ // Share the session with device 1
2192
+ group_session. mark_shared_with ( dan_user, dan_dev1, sender_key) . await ;
2193
+
2194
+ {
2195
+ // Add a pending request to share with device 2
2196
+ let share_infos = BTreeMap :: from ( [ (
2197
+ dan_user. to_owned ( ) ,
2198
+ BTreeMap :: from ( [ (
2199
+ dan_dev2. to_owned ( ) ,
2200
+ ShareInfo :: new_shared ( sender_key, 0 , SequenceNumber :: default ( ) ) ,
2201
+ ) ] ) ,
2202
+ ) ] ) ;
2203
+
2204
+ let txid = TransactionId :: new ( ) ;
2205
+ let req = Arc :: new ( ToDeviceRequest :: for_recipients (
2206
+ dan_user,
2207
+ vec ! [ dan_dev2. to_owned( ) ] ,
2208
+ & ruma:: events:: AnyToDeviceEventContent :: Dummy ( ToDeviceDummyEventContent ) ,
2209
+ txid. clone ( ) ,
2210
+ ) ) ;
2211
+ group_session. add_request ( txid, req, share_infos) ;
2212
+ }
2213
+
2214
+ // Remove device 2
2215
+ let keys_query = KeyDistributionTestData :: dan_keys_query_response_device_loggedout ( ) ;
2216
+ machine. mark_request_as_sent ( & TransactionId :: new ( ) , & keys_query) . await . unwrap ( ) ;
2217
+
2218
+ // Share again
2219
+ let share_result = collect_session_recipients (
2220
+ machine. store ( ) ,
2221
+ vec ! [ KeyDistributionTestData :: dan_id( ) ] . into_iter ( ) ,
2222
+ & encryption_settings,
2223
+ & group_session,
2224
+ )
2225
+ . await
2226
+ . unwrap ( ) ;
2227
+
2228
+ assert ! ( share_result. should_rotate) ;
2229
+ }
2230
+
2139
2231
/// Test that the session is not rotated if a devices is removed
2140
2232
/// but was already withheld from receiving the session.
2141
2233
#[ async_test]
0 commit comments