@@ -16,9 +16,10 @@ use std::{
16
16
cmp:: max,
17
17
collections:: { BTreeMap , BTreeSet } ,
18
18
fmt,
19
+ ops:: Bound ,
19
20
sync:: {
20
21
atomic:: { AtomicBool , AtomicU64 , Ordering } ,
21
- Arc ,
22
+ Arc , RwLockReadGuard ,
22
23
} ,
23
24
time:: Duration ,
24
25
} ;
@@ -70,7 +71,7 @@ const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
70
71
const ROTATION_PERIOD : Duration = ONE_WEEK ;
71
72
const ROTATION_MESSAGES : u64 = 100 ;
72
73
73
- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
74
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , PartialOrd , Ord ) ]
74
75
/// Information about whether a session was shared with a device.
75
76
pub ( crate ) enum ShareState {
76
77
/// The session was not shared with the device.
@@ -157,11 +158,8 @@ pub struct OutboundGroupSession {
157
158
shared : Arc < AtomicBool > ,
158
159
invalidated : Arc < AtomicBool > ,
159
160
settings : Arc < EncryptionSettings > ,
160
- pub ( crate ) shared_with_set :
161
- Arc < StdRwLock < BTreeMap < OwnedUserId , BTreeMap < OwnedDeviceId , ShareInfo > > > > ,
162
- #[ allow( clippy:: type_complexity) ]
163
- pub ( crate ) to_share_with_set :
164
- Arc < StdRwLock < BTreeMap < OwnedTransactionId , ( Arc < ToDeviceRequest > , ShareInfoSet ) > > > ,
161
+ shared_with_set : Arc < StdRwLock < ShareInfoSet > > ,
162
+ to_share_with_set : Arc < StdRwLock < ToShareMap > > ,
165
163
}
166
164
167
165
/// A a map of userid/device it to a `ShareInfo`.
@@ -170,6 +168,8 @@ pub struct OutboundGroupSession {
170
168
/// room key.
171
169
pub type ShareInfoSet = BTreeMap < OwnedUserId , BTreeMap < OwnedDeviceId , ShareInfo > > ;
172
170
171
+ type ToShareMap = BTreeMap < OwnedTransactionId , ( Arc < ToDeviceRequest > , ShareInfoSet ) > ;
172
+
173
173
/// Struct holding info about the share state of a outbound group session.
174
174
#[ derive( Clone , Debug , Serialize , Deserialize ) ]
175
175
pub enum ShareInfo {
@@ -206,6 +206,90 @@ pub struct SharedWith {
206
206
pub olm_wedging_index : SequenceNumber ,
207
207
}
208
208
209
+ /// A read-only view into the device sharing state of an
210
+ /// [`OutboundGroupSession`].
211
+ pub ( crate ) struct SharingView < ' a > {
212
+ shared_with_set : RwLockReadGuard < ' a , ShareInfoSet > ,
213
+ to_share_with_set : RwLockReadGuard < ' a , ToShareMap > ,
214
+ }
215
+
216
+ impl SharingView < ' _ > {
217
+ /// Has the session been shared with the given user/device pair (or if not,
218
+ /// is there such a request pending).
219
+ pub ( crate ) fn get_share_state ( & self , device : & DeviceData ) -> ShareState {
220
+ self . iter_shares ( Some ( device. user_id ( ) ) , Some ( device. device_id ( ) ) )
221
+ . map ( |( _, _, info) | match info {
222
+ ShareInfo :: Shared ( info) => {
223
+ if device. curve25519_key ( ) == Some ( info. sender_key ) {
224
+ ShareState :: Shared {
225
+ message_index : info. message_index ,
226
+ olm_wedging_index : info. olm_wedging_index ,
227
+ }
228
+ } else {
229
+ ShareState :: SharedButChangedSenderKey
230
+ }
231
+ }
232
+ ShareInfo :: Withheld ( _) => ShareState :: NotShared ,
233
+ } )
234
+ // Return the most "definitive" ShareState found (in case there
235
+ // are multiple entries for the same device).
236
+ . max ( )
237
+ . unwrap_or ( ShareState :: NotShared )
238
+ }
239
+
240
+ /// Has the session been withheld for the given user/device pair (or if not,
241
+ /// is there such a request pending).
242
+ pub ( crate ) fn is_withheld_to ( & self , device : & DeviceData , code : & WithheldCode ) -> bool {
243
+ self . iter_shares ( Some ( device. user_id ( ) ) , Some ( device. device_id ( ) ) )
244
+ . any ( |( _, _, info) | matches ! ( info, ShareInfo :: Withheld ( c) if c == code) )
245
+ }
246
+
247
+ /// Enumerate all sent or pending sharing requests for the given device (or
248
+ /// for all devices if not specified). This can yield the same device
249
+ /// multiple times.
250
+ pub ( crate ) fn iter_shares < ' b , ' c > (
251
+ & self ,
252
+ user_id : Option < & ' b UserId > ,
253
+ device_id : Option < & ' c DeviceId > ,
254
+ ) -> impl Iterator < Item = ( & UserId , & DeviceId , & ShareInfo ) > + use < ' _ , ' b , ' c > {
255
+ fn iter_share_info_set < ' a , ' b , ' c > (
256
+ set : & ' a ShareInfoSet ,
257
+ user_ids : ( Bound < & ' b UserId > , Bound < & ' b UserId > ) ,
258
+ device_ids : ( Bound < & ' c DeviceId > , Bound < & ' c DeviceId > ) ,
259
+ ) -> impl Iterator < Item = ( & ' a UserId , & ' a DeviceId , & ' a ShareInfo ) > + use < ' a , ' b , ' c >
260
+ {
261
+ set. range :: < UserId , _ > ( user_ids) . flat_map ( move |( uid, d) | {
262
+ d. range :: < DeviceId , _ > ( device_ids)
263
+ . map ( |( id, info) | ( uid. as_ref ( ) , id. as_ref ( ) , info) )
264
+ } )
265
+ }
266
+
267
+ let user_ids = user_id
268
+ . map ( |u| ( Bound :: Included ( u) , Bound :: Included ( u) ) )
269
+ . unwrap_or ( ( Bound :: Unbounded , Bound :: Unbounded ) ) ;
270
+ let device_ids = device_id
271
+ . map ( |d| ( Bound :: Included ( d) , Bound :: Included ( d) ) )
272
+ . unwrap_or ( ( Bound :: Unbounded , Bound :: Unbounded ) ) ;
273
+
274
+ let already_shared = iter_share_info_set ( & self . shared_with_set , user_ids, device_ids) ;
275
+ let pending = self
276
+ . to_share_with_set
277
+ . values ( )
278
+ . flat_map ( move |( _, set) | iter_share_info_set ( set, user_ids, device_ids) ) ;
279
+ already_shared. chain ( pending)
280
+ }
281
+
282
+ /// Enumerate all users that have received the session, or have pending
283
+ /// requests to receive it. This can yield the same user multiple times,
284
+ /// so you may want to `collect()` the result into a `BTreeSet`.
285
+ pub ( crate ) fn shared_with_users ( & self ) -> impl Iterator < Item = & UserId > {
286
+ self . iter_shares ( None , None ) . filter_map ( |( u, _, info) | match info {
287
+ ShareInfo :: Shared ( _) => Some ( u) ,
288
+ ShareInfo :: Withheld ( _) => None ,
289
+ } )
290
+ }
291
+ }
292
+
209
293
impl OutboundGroupSession {
210
294
pub ( super ) fn session_config (
211
295
algorithm : & EventEncryptionAlgorithm ,
@@ -541,78 +625,16 @@ impl OutboundGroupSession {
541
625
)
542
626
}
543
627
544
- /// Has or will the session be shared with the given user/device pair.
545
- pub ( crate ) fn is_shared_with ( & self , device : & DeviceData ) -> ShareState {
546
- // Check if we shared the session.
547
- let shared_state = self . shared_with_set . read ( ) . get ( device. user_id ( ) ) . and_then ( |d| {
548
- d. get ( device. device_id ( ) ) . map ( |s| match s {
549
- ShareInfo :: Shared ( s) => {
550
- if device. curve25519_key ( ) == Some ( s. sender_key ) {
551
- ShareState :: Shared {
552
- message_index : s. message_index ,
553
- olm_wedging_index : s. olm_wedging_index ,
554
- }
555
- } else {
556
- ShareState :: SharedButChangedSenderKey
557
- }
558
- }
559
- ShareInfo :: Withheld ( _) => ShareState :: NotShared ,
560
- } )
561
- } ) ;
562
-
563
- if let Some ( state) = shared_state {
564
- state
565
- } else {
566
- // If we haven't shared the session, check if we're going to share
567
- // the session.
568
-
569
- // Find the first request that contains the given user id and
570
- // device ID.
571
- let shared = self . to_share_with_set . read ( ) . values ( ) . find_map ( |( _, share_info) | {
572
- let d = share_info. get ( device. user_id ( ) ) ?;
573
- let info = d. get ( device. device_id ( ) ) ?;
574
- Some ( match info {
575
- ShareInfo :: Shared ( info) => {
576
- if device. curve25519_key ( ) == Some ( info. sender_key ) {
577
- ShareState :: Shared {
578
- message_index : info. message_index ,
579
- olm_wedging_index : info. olm_wedging_index ,
580
- }
581
- } else {
582
- ShareState :: SharedButChangedSenderKey
583
- }
584
- }
585
- ShareInfo :: Withheld ( _) => ShareState :: NotShared ,
586
- } )
587
- } ) ;
588
-
589
- shared. unwrap_or ( ShareState :: NotShared )
628
+ /// Create a read-only view into the device sharing state of this session.
629
+ /// This view includes pending requests, so it is not guaranteed that the
630
+ /// represented state has been fully propagated yet.
631
+ pub ( crate ) fn sharing_view ( & self ) -> SharingView < ' _ > {
632
+ SharingView {
633
+ shared_with_set : self . shared_with_set . read ( ) ,
634
+ to_share_with_set : self . to_share_with_set . read ( ) ,
590
635
}
591
636
}
592
637
593
- pub ( crate ) fn is_withheld_to ( & self , device : & DeviceData , code : & WithheldCode ) -> bool {
594
- self . shared_with_set
595
- . read ( )
596
- . get ( device. user_id ( ) )
597
- . and_then ( |d| {
598
- let info = d. get ( device. device_id ( ) ) ?;
599
- Some ( matches ! ( info, ShareInfo :: Withheld ( c) if c == code) )
600
- } )
601
- . unwrap_or_else ( || {
602
- // If we haven't yet withheld, check if we're going to withheld
603
- // the session.
604
-
605
- // Find the first request that contains the given user id and
606
- // device ID.
607
- self . to_share_with_set . read ( ) . values ( ) . any ( |( _, share_info) | {
608
- share_info
609
- . get ( device. user_id ( ) )
610
- . and_then ( |d| d. get ( device. device_id ( ) ) )
611
- . is_some_and ( |info| matches ! ( info, ShareInfo :: Withheld ( c) if c == code) )
612
- } )
613
- } )
614
- }
615
-
616
638
/// Mark the session as shared with the given user/device pair, starting
617
639
/// from some message index.
618
640
#[ cfg( test) ]
@@ -782,7 +804,7 @@ mod tests {
782
804
uint, EventEncryptionAlgorithm ,
783
805
} ;
784
806
785
- use super :: { EncryptionSettings , ROTATION_MESSAGES , ROTATION_PERIOD } ;
807
+ use super :: { EncryptionSettings , ShareState , ROTATION_MESSAGES , ROTATION_PERIOD } ;
786
808
use crate :: CollectStrategy ;
787
809
788
810
#[ test]
@@ -811,6 +833,24 @@ mod tests {
811
833
assert_eq ! ( settings. rotation_period_msgs, 500 ) ;
812
834
}
813
835
836
+ /// Ensure that the `ShareState` PartialOrd instance orders according to
837
+ /// specificity of the value.
838
+ #[ test]
839
+ fn test_share_state_ordering ( ) {
840
+ let values = [
841
+ ShareState :: NotShared ,
842
+ ShareState :: SharedButChangedSenderKey ,
843
+ ShareState :: Shared { message_index : 1 , olm_wedging_index : Default :: default ( ) } ,
844
+ ] ;
845
+ // Make sure our test case of possible variants is exhaustive
846
+ match values[ 0 ] {
847
+ ShareState :: NotShared
848
+ | ShareState :: SharedButChangedSenderKey
849
+ | ShareState :: Shared { .. } => { }
850
+ }
851
+ assert ! ( values. is_sorted( ) ) ;
852
+ }
853
+
814
854
#[ cfg( any( target_os = "linux" , target_os = "macos" , target_arch = "wasm32" ) ) ]
815
855
mod expiration {
816
856
use std:: { sync:: atomic:: Ordering , time:: Duration } ;
0 commit comments