Skip to content

Commit 1ebcc38

Browse files
committed
f Expose PaymentInfoStorage::lock and impl ..Guard
1 parent fbb7c97 commit 1ebcc38

File tree

3 files changed

+116
-21
lines changed

3 files changed

+116
-21
lines changed

src/event.rs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ impl Writeable for EventQueueSerWrapper<'_> {
182182
}
183183
}
184184

185-
pub(crate) struct EventHandler<K: Deref, L: Deref>
185+
pub(crate) struct EventHandler<K: Deref + Clone, L: Deref>
186186
where
187187
K::Target: KVStorePersister + KVStoreUnpersister,
188188
L::Target: Logger,
@@ -198,7 +198,7 @@ where
198198
_config: Arc<Config>,
199199
}
200200

201-
impl<K: Deref, L: Deref> EventHandler<K, L>
201+
impl<K: Deref + Clone, L: Deref> EventHandler<K, L>
202202
where
203203
K::Target: KVStorePersister + KVStoreUnpersister,
204204
L::Target: Logger,
@@ -223,7 +223,7 @@ where
223223
}
224224
}
225225

226-
impl<K: Deref, L: Deref> LdkEventHandler for EventHandler<K, L>
226+
impl<K: Deref + Clone, L: Deref> LdkEventHandler for EventHandler<K, L>
227227
where
228228
K::Target: KVStorePersister + KVStoreUnpersister,
229229
L::Target: Logger,
@@ -364,25 +364,24 @@ where
364364
PaymentPurpose::SpontaneousPayment(preimage) => (Some(preimage), None),
365365
};
366366

367-
let payment_info =
368-
if let Some(mut payment_info) = self.payment_store.get(&payment_hash) {
367+
let mut locked_store = self.payment_store.lock().unwrap();
368+
locked_store
369+
.entry(payment_hash)
370+
.and_modify(|payment_info| {
369371
payment_info.status = PaymentStatus::Succeeded;
370372
payment_info.preimage = payment_preimage;
371373
payment_info.secret = payment_secret;
372374
payment_info.amount_msat = Some(amount_msat);
373-
payment_info
374-
} else {
375-
PaymentInfo {
376-
preimage: payment_preimage,
377-
payment_hash,
378-
secret: payment_secret,
379-
amount_msat: Some(amount_msat),
380-
direction: PaymentDirection::Inbound,
381-
status: PaymentStatus::Succeeded,
382-
}
383-
};
375+
})
376+
.or_insert(PaymentInfo {
377+
preimage: payment_preimage,
378+
payment_hash,
379+
secret: payment_secret,
380+
amount_msat: Some(amount_msat),
381+
direction: PaymentDirection::Inbound,
382+
status: PaymentStatus::Succeeded,
383+
});
384384

385-
self.payment_store.insert(payment_info).expect("Failed to access payment store");
386385
self.event_queue
387386
.add_event(Event::PaymentReceived { payment_hash, amount_msat })
388387
.expect("Failed to push to event queue");

src/payment_store.rs

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
66
use lightning::util::persist::KVStorePersister;
77
use lightning::{impl_writeable_tlv_based, impl_writeable_tlv_based_enum};
88

9-
use std::collections::HashMap;
9+
use std::collections::hash_map;
10+
use std::collections::{HashMap, HashSet};
1011
use std::iter::FromIterator;
1112
use std::ops::Deref;
12-
use std::sync::Mutex;
13+
use std::sync::{Mutex, MutexGuard};
1314

1415
/// Represents a payment.
1516
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -71,18 +72,23 @@ impl_writeable_tlv_based_enum!(PaymentStatus,
7172
/// The payment information will be persisted under this prefix.
7273
pub(crate) const PAYMENT_INFO_PERSISTENCE_PREFIX: &str = "payments";
7374

74-
pub(crate) struct PaymentInfoStorage<K: Deref>
75+
pub(crate) struct PaymentInfoStorage<K: Deref + Clone>
7576
where
7677
K::Target: KVStorePersister + KVStoreUnpersister,
7778
{
7879
payments: Mutex<HashMap<PaymentHash, PaymentInfo>>,
7980
persister: K,
8081
}
8182

82-
impl<K: Deref> PaymentInfoStorage<K>
83+
impl<K: Deref + Clone> PaymentInfoStorage<K>
8384
where
8485
K::Target: KVStorePersister + KVStoreUnpersister,
8586
{
87+
pub(crate) fn new(persister: K) -> Self {
88+
let payments = Mutex::new(HashMap::new());
89+
Self { payments, persister }
90+
}
91+
8692
pub(crate) fn from_payments(payments: Vec<PaymentInfo>, persister: K) -> Self {
8793
let payments = Mutex::new(HashMap::from_iter(
8894
payments.into_iter().map(|payment_info| (payment_info.payment_hash, payment_info)),
@@ -107,6 +113,11 @@ where
107113
return Ok(());
108114
}
109115

116+
pub(crate) fn lock(&self) -> Result<PaymentInfoGuard<K>, ()> {
117+
let locked_store = self.payments.lock().map_err(|_| ())?;
118+
Ok(PaymentInfoGuard::new(locked_store, self.persister.clone()))
119+
}
120+
110121
pub(crate) fn remove(&self, payment_hash: &PaymentHash) -> Result<(), Error> {
111122
let key = format!(
112123
"{}/{}",
@@ -143,3 +154,79 @@ where
143154
Ok(())
144155
}
145156
}
157+
158+
pub(crate) struct PaymentInfoGuard<'a, K: Deref>
159+
where
160+
K::Target: KVStorePersister + KVStoreUnpersister,
161+
{
162+
inner: MutexGuard<'a, HashMap<PaymentHash, PaymentInfo>>,
163+
touched_keys: HashSet<PaymentHash>,
164+
persister: K,
165+
}
166+
167+
impl<'a, K: Deref> PaymentInfoGuard<'a, K>
168+
where
169+
K::Target: KVStorePersister + KVStoreUnpersister,
170+
{
171+
pub fn new(inner: MutexGuard<'a, HashMap<PaymentHash, PaymentInfo>>, persister: K) -> Self {
172+
let touched_keys = HashSet::new();
173+
Self { inner, touched_keys, persister }
174+
}
175+
176+
pub fn entry(
177+
&mut self, payment_hash: PaymentHash,
178+
) -> hash_map::Entry<PaymentHash, PaymentInfo> {
179+
self.touched_keys.insert(payment_hash);
180+
self.inner.entry(payment_hash)
181+
}
182+
}
183+
184+
impl<'a, K: Deref> Drop for PaymentInfoGuard<'a, K>
185+
where
186+
K::Target: KVStorePersister + KVStoreUnpersister,
187+
{
188+
fn drop(&mut self) {
189+
for key in self.touched_keys.iter() {
190+
let store_key =
191+
format!("{}/{}", PAYMENT_INFO_PERSISTENCE_PREFIX, hex_utils::to_string(&key.0));
192+
193+
match self.inner.entry(*key) {
194+
hash_map::Entry::Vacant(_) => {
195+
self.persister.unpersist(&store_key).expect("Persistence failed");
196+
}
197+
hash_map::Entry::Occupied(e) => {
198+
self.persister.persist(&store_key, e.get()).expect("Persistence failed");
199+
}
200+
};
201+
}
202+
}
203+
}
204+
205+
#[cfg(test)]
206+
mod tests {
207+
use super::*;
208+
use crate::tests::test_utils::TestPersister;
209+
use std::sync::Arc;
210+
211+
#[test]
212+
fn persistence_guard_persists_on_drop() {
213+
let persister = Arc::new(TestPersister::new());
214+
let payment_info_store = PaymentInfoStorage::new(Arc::clone(&persister));
215+
216+
let payment_hash = PaymentHash([42u8; 32]);
217+
assert!(!payment_info_store.contains(&payment_hash));
218+
219+
let payment_info = PaymentInfo {
220+
payment_hash,
221+
preimage: None,
222+
secret: None,
223+
amount_msat: None,
224+
direction: PaymentDirection::Inbound,
225+
status: PaymentStatus::Pending,
226+
};
227+
228+
assert!(!persister.get_and_clear_did_persist());
229+
payment_info_store.lock().unwrap().entry(payment_hash).or_insert(payment_info);
230+
assert!(persister.get_and_clear_did_persist());
231+
}
232+
}

src/tests/test_utils.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::io_utils::KVStoreUnpersister;
12
use lightning::util::persist::KVStorePersister;
23
use lightning::util::ser::Writeable;
34

@@ -53,3 +54,11 @@ impl KVStorePersister for TestPersister {
5354
Ok(())
5455
}
5556
}
57+
58+
impl KVStoreUnpersister for TestPersister {
59+
fn unpersist(&self, key: &str) -> std::io::Result<bool> {
60+
let mut persisted_bytes_lock = self.persisted_bytes.lock().unwrap();
61+
self.did_persist.store(true, Ordering::SeqCst);
62+
Ok(persisted_bytes_lock.remove(key).is_some())
63+
}
64+
}

0 commit comments

Comments
 (0)