Skip to content

Commit 0856f4e

Browse files
authored
refactor(crypto): Properly encapsulate internal OutboundGroupSession state
Previously, the `share_strategy` was breaking the abstraction provided by `OutboundGroupSession` by accessing its internal fields in an inconsistent and adhoc way. Now all fields are private and a proper abstraction was added to access the required state in a consistent API.
1 parent ae4cdda commit 0856f4e

File tree

4 files changed

+134
-138
lines changed

4 files changed

+134
-138
lines changed

crates/matrix-sdk-crypto/src/gossiping/machine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ impl GossipMachine {
645645
// at. For this, we need an outbound session because this
646646
// information is recorded there.
647647
} else if let Some(outbound) = outbound_session {
648-
match outbound.is_shared_with(&device.inner) {
648+
match outbound.sharing_view().get_share_state(&device.inner) {
649649
ShareState::Shared { message_index, olm_wedging_index: _ } => {
650650
Ok(Some(message_index))
651651
}

crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs

Lines changed: 117 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ use std::{
1616
cmp::max,
1717
collections::{BTreeMap, BTreeSet},
1818
fmt,
19+
ops::Bound,
1920
sync::{
2021
atomic::{AtomicBool, AtomicU64, Ordering},
21-
Arc,
22+
Arc, RwLockReadGuard,
2223
},
2324
time::Duration,
2425
};
@@ -70,7 +71,7 @@ const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
7071
const ROTATION_PERIOD: Duration = ONE_WEEK;
7172
const ROTATION_MESSAGES: u64 = 100;
7273

73-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
7475
/// Information about whether a session was shared with a device.
7576
pub(crate) enum ShareState {
7677
/// The session was not shared with the device.
@@ -157,11 +158,8 @@ pub struct OutboundGroupSession {
157158
shared: Arc<AtomicBool>,
158159
invalidated: Arc<AtomicBool>,
159160
settings: Arc<EncryptionSettings>,
160-
pub(crate) shared_with_set:
161-
Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>>>,
162-
#[allow(clippy::type_complexity)]
163-
pub(crate) to_share_with_set:
164-
Arc<StdRwLock<BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>>>,
161+
shared_with_set: Arc<StdRwLock<ShareInfoSet>>,
162+
to_share_with_set: Arc<StdRwLock<ToShareMap>>,
165163
}
166164

167165
/// A a map of userid/device it to a `ShareInfo`.
@@ -170,6 +168,8 @@ pub struct OutboundGroupSession {
170168
/// room key.
171169
pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
172170

171+
type ToShareMap = BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>;
172+
173173
/// Struct holding info about the share state of a outbound group session.
174174
#[derive(Clone, Debug, Serialize, Deserialize)]
175175
pub enum ShareInfo {
@@ -206,6 +206,90 @@ pub struct SharedWith {
206206
pub olm_wedging_index: SequenceNumber,
207207
}
208208

209+
/// A read-only view into the device sharing state of an
210+
/// [`OutboundGroupSession`].
211+
pub(crate) struct SharingView<'a> {
212+
shared_with_set: RwLockReadGuard<'a, ShareInfoSet>,
213+
to_share_with_set: RwLockReadGuard<'a, ToShareMap>,
214+
}
215+
216+
impl SharingView<'_> {
217+
/// Has the session been shared with the given user/device pair (or if not,
218+
/// is there such a request pending).
219+
pub(crate) fn get_share_state(&self, device: &DeviceData) -> ShareState {
220+
self.iter_shares(Some(device.user_id()), Some(device.device_id()))
221+
.map(|(_, _, info)| match info {
222+
ShareInfo::Shared(info) => {
223+
if device.curve25519_key() == Some(info.sender_key) {
224+
ShareState::Shared {
225+
message_index: info.message_index,
226+
olm_wedging_index: info.olm_wedging_index,
227+
}
228+
} else {
229+
ShareState::SharedButChangedSenderKey
230+
}
231+
}
232+
ShareInfo::Withheld(_) => ShareState::NotShared,
233+
})
234+
// Return the most "definitive" ShareState found (in case there
235+
// are multiple entries for the same device).
236+
.max()
237+
.unwrap_or(ShareState::NotShared)
238+
}
239+
240+
/// Has the session been withheld for the given user/device pair (or if not,
241+
/// is there such a request pending).
242+
pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
243+
self.iter_shares(Some(device.user_id()), Some(device.device_id()))
244+
.any(|(_, _, info)| matches!(info, ShareInfo::Withheld(c) if c == code))
245+
}
246+
247+
/// Enumerate all sent or pending sharing requests for the given device (or
248+
/// for all devices if not specified). This can yield the same device
249+
/// multiple times.
250+
pub(crate) fn iter_shares<'b, 'c>(
251+
&self,
252+
user_id: Option<&'b UserId>,
253+
device_id: Option<&'c DeviceId>,
254+
) -> impl Iterator<Item = (&UserId, &DeviceId, &ShareInfo)> + use<'_, 'b, 'c> {
255+
fn iter_share_info_set<'a, 'b, 'c>(
256+
set: &'a ShareInfoSet,
257+
user_ids: (Bound<&'b UserId>, Bound<&'b UserId>),
258+
device_ids: (Bound<&'c DeviceId>, Bound<&'c DeviceId>),
259+
) -> impl Iterator<Item = (&'a UserId, &'a DeviceId, &'a ShareInfo)> + use<'a, 'b, 'c>
260+
{
261+
set.range::<UserId, _>(user_ids).flat_map(move |(uid, d)| {
262+
d.range::<DeviceId, _>(device_ids)
263+
.map(|(id, info)| (uid.as_ref(), id.as_ref(), info))
264+
})
265+
}
266+
267+
let user_ids = user_id
268+
.map(|u| (Bound::Included(u), Bound::Included(u)))
269+
.unwrap_or((Bound::Unbounded, Bound::Unbounded));
270+
let device_ids = device_id
271+
.map(|d| (Bound::Included(d), Bound::Included(d)))
272+
.unwrap_or((Bound::Unbounded, Bound::Unbounded));
273+
274+
let already_shared = iter_share_info_set(&self.shared_with_set, user_ids, device_ids);
275+
let pending = self
276+
.to_share_with_set
277+
.values()
278+
.flat_map(move |(_, set)| iter_share_info_set(set, user_ids, device_ids));
279+
already_shared.chain(pending)
280+
}
281+
282+
/// Enumerate all users that have received the session, or have pending
283+
/// requests to receive it. This can yield the same user multiple times,
284+
/// so you may want to `collect()` the result into a `BTreeSet`.
285+
pub(crate) fn shared_with_users(&self) -> impl Iterator<Item = &UserId> {
286+
self.iter_shares(None, None).filter_map(|(u, _, info)| match info {
287+
ShareInfo::Shared(_) => Some(u),
288+
ShareInfo::Withheld(_) => None,
289+
})
290+
}
291+
}
292+
209293
impl OutboundGroupSession {
210294
pub(super) fn session_config(
211295
algorithm: &EventEncryptionAlgorithm,
@@ -541,78 +625,16 @@ impl OutboundGroupSession {
541625
)
542626
}
543627

544-
/// Has or will the session be shared with the given user/device pair.
545-
pub(crate) fn is_shared_with(&self, device: &DeviceData) -> ShareState {
546-
// Check if we shared the session.
547-
let shared_state = self.shared_with_set.read().get(device.user_id()).and_then(|d| {
548-
d.get(device.device_id()).map(|s| match s {
549-
ShareInfo::Shared(s) => {
550-
if device.curve25519_key() == Some(s.sender_key) {
551-
ShareState::Shared {
552-
message_index: s.message_index,
553-
olm_wedging_index: s.olm_wedging_index,
554-
}
555-
} else {
556-
ShareState::SharedButChangedSenderKey
557-
}
558-
}
559-
ShareInfo::Withheld(_) => ShareState::NotShared,
560-
})
561-
});
562-
563-
if let Some(state) = shared_state {
564-
state
565-
} else {
566-
// If we haven't shared the session, check if we're going to share
567-
// the session.
568-
569-
// Find the first request that contains the given user id and
570-
// device ID.
571-
let shared = self.to_share_with_set.read().values().find_map(|(_, share_info)| {
572-
let d = share_info.get(device.user_id())?;
573-
let info = d.get(device.device_id())?;
574-
Some(match info {
575-
ShareInfo::Shared(info) => {
576-
if device.curve25519_key() == Some(info.sender_key) {
577-
ShareState::Shared {
578-
message_index: info.message_index,
579-
olm_wedging_index: info.olm_wedging_index,
580-
}
581-
} else {
582-
ShareState::SharedButChangedSenderKey
583-
}
584-
}
585-
ShareInfo::Withheld(_) => ShareState::NotShared,
586-
})
587-
});
588-
589-
shared.unwrap_or(ShareState::NotShared)
628+
/// Create a read-only view into the device sharing state of this session.
629+
/// This view includes pending requests, so it is not guaranteed that the
630+
/// represented state has been fully propagated yet.
631+
pub(crate) fn sharing_view(&self) -> SharingView<'_> {
632+
SharingView {
633+
shared_with_set: self.shared_with_set.read(),
634+
to_share_with_set: self.to_share_with_set.read(),
590635
}
591636
}
592637

593-
pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
594-
self.shared_with_set
595-
.read()
596-
.get(device.user_id())
597-
.and_then(|d| {
598-
let info = d.get(device.device_id())?;
599-
Some(matches!(info, ShareInfo::Withheld(c) if c == code))
600-
})
601-
.unwrap_or_else(|| {
602-
// If we haven't yet withheld, check if we're going to withheld
603-
// the session.
604-
605-
// Find the first request that contains the given user id and
606-
// device ID.
607-
self.to_share_with_set.read().values().any(|(_, share_info)| {
608-
share_info
609-
.get(device.user_id())
610-
.and_then(|d| d.get(device.device_id()))
611-
.is_some_and(|info| matches!(info, ShareInfo::Withheld(c) if c == code))
612-
})
613-
})
614-
}
615-
616638
/// Mark the session as shared with the given user/device pair, starting
617639
/// from some message index.
618640
#[cfg(test)]
@@ -782,7 +804,7 @@ mod tests {
782804
uint, EventEncryptionAlgorithm,
783805
};
784806

785-
use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD};
807+
use super::{EncryptionSettings, ShareState, ROTATION_MESSAGES, ROTATION_PERIOD};
786808
use crate::CollectStrategy;
787809

788810
#[test]
@@ -811,6 +833,24 @@ mod tests {
811833
assert_eq!(settings.rotation_period_msgs, 500);
812834
}
813835

836+
/// Ensure that the `ShareState` PartialOrd instance orders according to
837+
/// specificity of the value.
838+
#[test]
839+
fn test_share_state_ordering() {
840+
let values = [
841+
ShareState::NotShared,
842+
ShareState::SharedButChangedSenderKey,
843+
ShareState::Shared { message_index: 1, olm_wedging_index: Default::default() },
844+
];
845+
// Make sure our test case of possible variants is exhaustive
846+
match values[0] {
847+
ShareState::NotShared
848+
| ShareState::SharedButChangedSenderKey
849+
| ShareState::Shared { .. } => {}
850+
}
851+
assert!(values.is_sorted());
852+
}
853+
814854
#[cfg(any(target_os = "linux", target_os = "macos", target_arch = "wasm32"))]
815855
mod expiration {
816856
use std::{sync::atomic::Ordering, time::Duration};

crates/matrix-sdk-crypto/src/session_manager/group_sessions/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl GroupSessionCache {
122122

123123
/// Returns whether any session is withheld with the given device and code.
124124
fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
125-
self.sessions.read().values().any(|s| s.is_withheld_to(device, code))
125+
self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
126126
}
127127

128128
fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
@@ -500,7 +500,7 @@ impl GroupSessionManager {
500500
if code == &WithheldCode::NoOlm {
501501
device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
502502
} else {
503-
group_session.is_withheld_to(device, code)
503+
group_session.sharing_view().is_withheld_to(device, code)
504504
}
505505
}
506506

@@ -696,7 +696,7 @@ impl GroupSessionManager {
696696
let devices: Vec<_> = devices
697697
.into_iter()
698698
.flat_map(|(_, d)| {
699-
d.into_iter().filter(|d| match outbound.is_shared_with(d) {
699+
d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
700700
ShareState::NotShared => true,
701701
ShareState::Shared { message_index: _, olm_wedging_index } => {
702702
// If the recipient device's Olm wedging index is higher

crates/matrix-sdk-crypto/src/session_manager/group_sessions/share_strategy.rs

Lines changed: 13 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
use std::{
1616
collections::{BTreeMap, BTreeSet, HashMap},
1717
default::Default,
18-
ops::Deref,
1918
};
2019

2120
use itertools::{Either, Itertools};
@@ -216,10 +215,8 @@ pub(crate) async fn collect_recipients_for_share_strategy(
216215
// users that should get the session but is in the set of users that
217216
// received the session.
218217
if let Some(outbound) = outbound {
219-
let users_shared_with: BTreeSet<OwnedUserId> =
220-
outbound.shared_with_set.read().keys().cloned().collect();
221-
let users_shared_with: BTreeSet<&UserId> =
222-
users_shared_with.iter().map(Deref::deref).collect();
218+
let view = outbound.sharing_view();
219+
let users_shared_with = view.shared_with_users().collect::<BTreeSet<_>>();
223220
let left_users = users_shared_with.difference(&users).collect::<BTreeSet<_>>();
224221
if !left_users.is_empty() {
225222
trace!(?left_users, "Some users have left the chat: session must be rotated");
@@ -427,61 +424,20 @@ fn is_session_overshared_for_user(
427424
let recipient_device_ids: BTreeSet<&DeviceId> =
428425
recipient_devices.iter().map(|d| d.device_id()).collect();
429426

430-
let mut shared: Vec<&DeviceId> = Vec::new();
431-
432-
// This duplicates a conservative subset of the logic in
433-
// `OutboundGroupSession::is_shared_with`, because we
434-
// don't have corresponding DeviceData at hand
435-
fn is_actually_shared(info: &ShareInfo) -> bool {
436-
match info {
437-
ShareInfo::Shared(_) => true,
438-
ShareInfo::Withheld(_) => false,
439-
}
440-
}
441-
442-
// Collect the devices that have definitely received the session already
443-
let guard = outbound_session.shared_with_set.read();
444-
if let Some(for_user) = guard.get(user_id) {
445-
shared.extend(for_user.iter().filter_map(|(d, info)| {
446-
if is_actually_shared(info) {
447-
Some(AsRef::<DeviceId>::as_ref(d))
427+
let view = outbound_session.sharing_view();
428+
let newly_deleted_or_blacklisted: BTreeSet<&DeviceId> = view
429+
.iter_shares(Some(user_id), None)
430+
.filter_map(|(_user_id, device_id, info)| {
431+
// If a devices who we've shared the session with before is not in the
432+
// list of devices that should receive the session, we need to rotate.
433+
// We also collect all of those device IDs to log them out.
434+
if matches!(info, ShareInfo::Shared(_)) && !recipient_device_ids.contains(device_id) {
435+
Some(device_id)
448436
} else {
449437
None
450438
}
451-
}));
452-
}
453-
454-
// To be conservative, also collect the devices that would still receive the
455-
// session from a pending to-device request if we don't rotate beforehand
456-
let guard = outbound_session.to_share_with_set.read();
457-
for (_txid, share_infos) in guard.values() {
458-
if let Some(for_user) = share_infos.get(user_id) {
459-
shared.extend(for_user.iter().filter_map(|(d, info)| {
460-
if is_actually_shared(info) {
461-
Some(AsRef::<DeviceId>::as_ref(d))
462-
} else {
463-
None
464-
}
465-
}));
466-
}
467-
}
468-
469-
if shared.is_empty() {
470-
return false;
471-
}
472-
473-
let shared: BTreeSet<&DeviceId> = shared.into_iter().collect();
474-
475-
// The set difference between
476-
//
477-
// 1. Devices that had previously received (or are queued to receive) the
478-
// session, and
479-
// 2. Devices that would now receive the session
480-
//
481-
// Represents newly deleted or blacklisted devices. If this
482-
// set is non-empty, we must rotate.
483-
let newly_deleted_or_blacklisted =
484-
shared.difference(&recipient_device_ids).collect::<BTreeSet<_>>();
439+
})
440+
.collect();
485441

486442
let should_rotate = !newly_deleted_or_blacklisted.is_empty();
487443
if should_rotate {

0 commit comments

Comments
 (0)