@@ -6,10 +6,11 @@ use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
6
6
use lightning:: util:: persist:: KVStorePersister ;
7
7
use lightning:: { impl_writeable_tlv_based, impl_writeable_tlv_based_enum} ;
8
8
9
- use std:: collections:: HashMap ;
9
+ use std:: collections:: hash_map;
10
+ use std:: collections:: { HashMap , HashSet } ;
10
11
use std:: iter:: FromIterator ;
11
12
use std:: ops:: Deref ;
12
- use std:: sync:: Mutex ;
13
+ use std:: sync:: { Mutex , MutexGuard } ;
13
14
14
15
/// Represents a payment.
15
16
#[ derive( Clone , Debug , PartialEq , Eq ) ]
@@ -71,18 +72,23 @@ impl_writeable_tlv_based_enum!(PaymentStatus,
71
72
/// The payment information will be persisted under this prefix.
72
73
pub ( crate ) const PAYMENT_INFO_PERSISTENCE_PREFIX : & str = "payments" ;
73
74
74
- pub ( crate ) struct PaymentInfoStorage < K : Deref >
75
+ pub ( crate ) struct PaymentInfoStorage < K : Deref + Clone >
75
76
where
76
77
K :: Target : KVStorePersister + KVStoreUnpersister ,
77
78
{
78
79
payments : Mutex < HashMap < PaymentHash , PaymentInfo > > ,
79
80
persister : K ,
80
81
}
81
82
82
- impl < K : Deref > PaymentInfoStorage < K >
83
+ impl < K : Deref + Clone > PaymentInfoStorage < K >
83
84
where
84
85
K :: Target : KVStorePersister + KVStoreUnpersister ,
85
86
{
87
+ pub ( crate ) fn new ( persister : K ) -> Self {
88
+ let payments = Mutex :: new ( HashMap :: new ( ) ) ;
89
+ Self { payments, persister }
90
+ }
91
+
86
92
pub ( crate ) fn from_payments ( payments : Vec < PaymentInfo > , persister : K ) -> Self {
87
93
let payments = Mutex :: new ( HashMap :: from_iter (
88
94
payments. into_iter ( ) . map ( |payment_info| ( payment_info. payment_hash , payment_info) ) ,
@@ -107,6 +113,11 @@ where
107
113
return Ok ( ( ) ) ;
108
114
}
109
115
116
+ pub ( crate ) fn lock ( & self ) -> Result < PaymentInfoGuard < K > , ( ) > {
117
+ let locked_store = self . payments . lock ( ) . map_err ( |_| ( ) ) ?;
118
+ Ok ( PaymentInfoGuard :: new ( locked_store, self . persister . clone ( ) ) )
119
+ }
120
+
110
121
pub ( crate ) fn remove ( & self , payment_hash : & PaymentHash ) -> Result < ( ) , Error > {
111
122
let key = format ! (
112
123
"{}/{}" ,
@@ -143,3 +154,79 @@ where
143
154
Ok ( ( ) )
144
155
}
145
156
}
157
+
158
+ pub ( crate ) struct PaymentInfoGuard < ' a , K : Deref >
159
+ where
160
+ K :: Target : KVStorePersister + KVStoreUnpersister ,
161
+ {
162
+ inner : MutexGuard < ' a , HashMap < PaymentHash , PaymentInfo > > ,
163
+ touched_keys : HashSet < PaymentHash > ,
164
+ persister : K ,
165
+ }
166
+
167
+ impl < ' a , K : Deref > PaymentInfoGuard < ' a , K >
168
+ where
169
+ K :: Target : KVStorePersister + KVStoreUnpersister ,
170
+ {
171
+ pub fn new ( inner : MutexGuard < ' a , HashMap < PaymentHash , PaymentInfo > > , persister : K ) -> Self {
172
+ let touched_keys = HashSet :: new ( ) ;
173
+ Self { inner, touched_keys, persister }
174
+ }
175
+
176
+ pub fn entry (
177
+ & mut self , payment_hash : PaymentHash ,
178
+ ) -> hash_map:: Entry < PaymentHash , PaymentInfo > {
179
+ self . touched_keys . insert ( payment_hash) ;
180
+ self . inner . entry ( payment_hash)
181
+ }
182
+ }
183
+
184
+ impl < ' a , K : Deref > Drop for PaymentInfoGuard < ' a , K >
185
+ where
186
+ K :: Target : KVStorePersister + KVStoreUnpersister ,
187
+ {
188
+ fn drop ( & mut self ) {
189
+ for key in self . touched_keys . iter ( ) {
190
+ let store_key =
191
+ format ! ( "{}/{}" , PAYMENT_INFO_PERSISTENCE_PREFIX , hex_utils:: to_string( & key. 0 ) ) ;
192
+
193
+ match self . inner . entry ( * key) {
194
+ hash_map:: Entry :: Vacant ( _) => {
195
+ self . persister . unpersist ( & store_key) . expect ( "Persistence failed" ) ;
196
+ }
197
+ hash_map:: Entry :: Occupied ( e) => {
198
+ self . persister . persist ( & store_key, e. get ( ) ) . expect ( "Persistence failed" ) ;
199
+ }
200
+ } ;
201
+ }
202
+ }
203
+ }
204
+
205
+ #[ cfg( test) ]
206
+ mod tests {
207
+ use super :: * ;
208
+ use crate :: tests:: test_utils:: TestPersister ;
209
+ use std:: sync:: Arc ;
210
+
211
+ #[ test]
212
+ fn persistence_guard_persists_on_drop ( ) {
213
+ let persister = Arc :: new ( TestPersister :: new ( ) ) ;
214
+ let payment_info_store = PaymentInfoStorage :: new ( Arc :: clone ( & persister) ) ;
215
+
216
+ let payment_hash = PaymentHash ( [ 42u8 ; 32 ] ) ;
217
+ assert ! ( !payment_info_store. contains( & payment_hash) ) ;
218
+
219
+ let payment_info = PaymentInfo {
220
+ payment_hash,
221
+ preimage : None ,
222
+ secret : None ,
223
+ amount_msat : None ,
224
+ direction : PaymentDirection :: Inbound ,
225
+ status : PaymentStatus :: Pending ,
226
+ } ;
227
+
228
+ assert ! ( !persister. get_and_clear_did_persist( ) ) ;
229
+ payment_info_store. lock ( ) . unwrap ( ) . entry ( payment_hash) . or_insert ( payment_info) ;
230
+ assert ! ( persister. get_and_clear_did_persist( ) ) ;
231
+ }
232
+ }
0 commit comments