@@ -26,6 +26,8 @@ use lightning::util::ser::{Readable, ReadableArgs, Writeable, Writer};
26
26
use bitcoin:: blockdata:: locktime:: absolute:: LockTime ;
27
27
use bitcoin:: secp256k1:: PublicKey ;
28
28
use bitcoin:: OutPoint ;
29
+ use core:: future:: Future ;
30
+ use core:: task:: { Poll , Waker } ;
29
31
use rand:: { thread_rng, Rng } ;
30
32
use std:: collections:: VecDeque ;
31
33
use std:: ops:: Deref ;
@@ -125,7 +127,8 @@ pub struct EventQueue<K: KVStore + Sync + Send, L: Deref>
125
127
where
126
128
L :: Target : Logger ,
127
129
{
128
- queue : Mutex < VecDeque < Event > > ,
130
+ queue : Arc < Mutex < VecDeque < Event > > > ,
131
+ waker : Arc < Mutex < Option < Waker > > > ,
129
132
notifier : Condvar ,
130
133
kv_store : Arc < K > ,
131
134
logger : L ,
@@ -136,9 +139,10 @@ where
136
139
L :: Target : Logger ,
137
140
{
138
141
pub ( crate ) fn new ( kv_store : Arc < K > , logger : L ) -> Self {
139
- let queue: Mutex < VecDeque < Event > > = Mutex :: new ( VecDeque :: new ( ) ) ;
142
+ let queue = Arc :: new ( Mutex :: new ( VecDeque :: new ( ) ) ) ;
143
+ let waker = Arc :: new ( Mutex :: new ( None ) ) ;
140
144
let notifier = Condvar :: new ( ) ;
141
- Self { queue, notifier, kv_store, logger }
145
+ Self { queue, waker , notifier, kv_store, logger }
142
146
}
143
147
144
148
pub ( crate ) fn add_event ( & self , event : Event ) -> Result < ( ) , Error > {
@@ -149,6 +153,10 @@ where
149
153
}
150
154
151
155
self . notifier . notify_one ( ) ;
156
+
157
+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
158
+ waker. wake ( ) ;
159
+ }
152
160
Ok ( ( ) )
153
161
}
154
162
@@ -157,6 +165,10 @@ where
157
165
locked_queue. front ( ) . map ( |e| e. clone ( ) )
158
166
}
159
167
168
+ pub ( crate ) async fn next_event_async ( & self ) -> Event {
169
+ EventFuture { event_queue : Arc :: clone ( & self . queue ) , waker : Arc :: clone ( & self . waker ) } . await
170
+ }
171
+
160
172
pub ( crate ) fn wait_next_event ( & self ) -> Event {
161
173
let locked_queue =
162
174
self . notifier . wait_while ( self . queue . lock ( ) . unwrap ( ) , |queue| queue. is_empty ( ) ) . unwrap ( ) ;
@@ -170,6 +182,10 @@ where
170
182
self . persist_queue ( & locked_queue) ?;
171
183
}
172
184
self . notifier . notify_one ( ) ;
185
+
186
+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
187
+ waker. wake ( ) ;
188
+ }
173
189
Ok ( ( ) )
174
190
}
175
191
@@ -207,9 +223,10 @@ where
207
223
) -> Result < Self , lightning:: ln:: msgs:: DecodeError > {
208
224
let ( kv_store, logger) = args;
209
225
let read_queue: EventQueueDeserWrapper = Readable :: read ( reader) ?;
210
- let queue: Mutex < VecDeque < Event > > = Mutex :: new ( read_queue. 0 ) ;
226
+ let queue = Arc :: new ( Mutex :: new ( read_queue. 0 ) ) ;
227
+ let waker = Arc :: new ( Mutex :: new ( None ) ) ;
211
228
let notifier = Condvar :: new ( ) ;
212
- Ok ( Self { queue, notifier, kv_store, logger } )
229
+ Ok ( Self { queue, waker , notifier, kv_store, logger } )
213
230
}
214
231
}
215
232
@@ -240,6 +257,26 @@ impl Writeable for EventQueueSerWrapper<'_> {
240
257
}
241
258
}
242
259
260
+ struct EventFuture {
261
+ event_queue : Arc < Mutex < VecDeque < Event > > > ,
262
+ waker : Arc < Mutex < Option < Waker > > > ,
263
+ }
264
+
265
+ impl Future for EventFuture {
266
+ type Output = Event ;
267
+
268
+ fn poll (
269
+ self : core:: pin:: Pin < & mut Self > , cx : & mut core:: task:: Context < ' _ > ,
270
+ ) -> core:: task:: Poll < Self :: Output > {
271
+ if let Some ( event) = self . event_queue . lock ( ) . unwrap ( ) . front ( ) {
272
+ Poll :: Ready ( event. clone ( ) )
273
+ } else {
274
+ * self . waker . lock ( ) . unwrap ( ) = Some ( cx. waker ( ) . clone ( ) ) ;
275
+ Poll :: Pending
276
+ }
277
+ }
278
+ }
279
+
243
280
pub ( crate ) struct EventHandler < K : KVStore + Sync + Send , L : Deref >
244
281
where
245
282
L :: Target : Logger ,
@@ -796,12 +833,14 @@ where
796
833
mod tests {
797
834
use super :: * ;
798
835
use lightning:: util:: test_utils:: { TestLogger , TestStore } ;
836
+ use std:: sync:: atomic:: { AtomicU16 , Ordering } ;
837
+ use std:: time:: Duration ;
799
838
800
- #[ test]
801
- fn event_queue_persistence ( ) {
839
+ #[ tokio :: test]
840
+ async fn event_queue_persistence ( ) {
802
841
let store = Arc :: new ( TestStore :: new ( false ) ) ;
803
842
let logger = Arc :: new ( TestLogger :: new ( ) ) ;
804
- let event_queue = EventQueue :: new ( Arc :: clone ( & store) , Arc :: clone ( & logger) ) ;
843
+ let event_queue = Arc :: new ( EventQueue :: new ( Arc :: clone ( & store) , Arc :: clone ( & logger) ) ) ;
805
844
assert_eq ! ( event_queue. next_event( ) , None ) ;
806
845
807
846
let expected_event = Event :: ChannelReady {
@@ -814,6 +853,7 @@ mod tests {
814
853
// Check we get the expected event and that it is returned until we mark it handled.
815
854
for _ in 0 ..5 {
816
855
assert_eq ! ( event_queue. wait_next_event( ) , expected_event) ;
856
+ assert_eq ! ( event_queue. next_event_async( ) . await , expected_event) ;
817
857
assert_eq ! ( event_queue. next_event( ) , Some ( expected_event. clone( ) ) ) ;
818
858
}
819
859
@@ -832,4 +872,96 @@ mod tests {
832
872
event_queue. event_handled ( ) . unwrap ( ) ;
833
873
assert_eq ! ( event_queue. next_event( ) , None ) ;
834
874
}
875
+
876
+ #[ tokio:: test]
877
+ async fn event_queue_concurrency ( ) {
878
+ let store = Arc :: new ( TestStore :: new ( false ) ) ;
879
+ let logger = Arc :: new ( TestLogger :: new ( ) ) ;
880
+ let event_queue = Arc :: new ( EventQueue :: new ( Arc :: clone ( & store) , Arc :: clone ( & logger) ) ) ;
881
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
882
+
883
+ let expected_event = Event :: ChannelReady {
884
+ channel_id : ChannelId ( [ 23u8 ; 32 ] ) ,
885
+ user_channel_id : UserChannelId ( 2323 ) ,
886
+ counterparty_node_id : None ,
887
+ } ;
888
+
889
+ // Check `next_event_async` won't return if the queue is empty and always rather timeout.
890
+ tokio:: select! {
891
+ _ = tokio:: time:: sleep( Duration :: from_secs( 1 ) ) => {
892
+ // Timeout
893
+ }
894
+ _ = event_queue. next_event_async( ) => {
895
+ panic!( ) ;
896
+ }
897
+ }
898
+
899
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
900
+ // Check we get the expected number of events when polling/enqueuing concurrently.
901
+ let enqueued_events = AtomicU16 :: new ( 0 ) ;
902
+ let received_events = AtomicU16 :: new ( 0 ) ;
903
+ let mut delayed_enqueue = false ;
904
+
905
+ for _ in 0 ..25 {
906
+ event_queue. add_event ( expected_event. clone ( ) ) . unwrap ( ) ;
907
+ enqueued_events. fetch_add ( 1 , Ordering :: SeqCst ) ;
908
+ }
909
+
910
+ loop {
911
+ tokio:: select! {
912
+ _ = tokio:: time:: sleep( Duration :: from_millis( 10 ) ) , if !delayed_enqueue => {
913
+ event_queue. add_event( expected_event. clone( ) ) . unwrap( ) ;
914
+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
915
+ delayed_enqueue = true ;
916
+ }
917
+ e = event_queue. next_event_async( ) => {
918
+ assert_eq!( e, expected_event) ;
919
+ event_queue. event_handled( ) . unwrap( ) ;
920
+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
921
+
922
+ event_queue. add_event( expected_event. clone( ) ) . unwrap( ) ;
923
+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
924
+ }
925
+ e = event_queue. next_event_async( ) => {
926
+ assert_eq!( e, expected_event) ;
927
+ event_queue. event_handled( ) . unwrap( ) ;
928
+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
929
+ }
930
+ }
931
+
932
+ if delayed_enqueue
933
+ && received_events. load ( Ordering :: SeqCst ) == enqueued_events. load ( Ordering :: SeqCst )
934
+ {
935
+ break ;
936
+ }
937
+ }
938
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
939
+
940
+ // Check we operate correctly, even when mixing and matching blocking and async API calls.
941
+ let ( tx, mut rx) = tokio:: sync:: watch:: channel ( ( ) ) ;
942
+ let thread_queue = Arc :: clone ( & event_queue) ;
943
+ let thread_event = expected_event. clone ( ) ;
944
+ std:: thread:: spawn ( move || {
945
+ let e = thread_queue. wait_next_event ( ) ;
946
+ assert_eq ! ( e, thread_event) ;
947
+ thread_queue. event_handled ( ) . unwrap ( ) ;
948
+ tx. send ( ( ) ) . unwrap ( ) ;
949
+ } ) ;
950
+
951
+ let thread_queue = Arc :: clone ( & event_queue) ;
952
+ let thread_event = expected_event. clone ( ) ;
953
+ std:: thread:: spawn ( move || {
954
+ // Sleep a bit before we enqueue the events everybody is waiting for.
955
+ std:: thread:: sleep ( Duration :: from_millis ( 20 ) ) ;
956
+ thread_queue. add_event ( thread_event. clone ( ) ) . unwrap ( ) ;
957
+ thread_queue. add_event ( thread_event. clone ( ) ) . unwrap ( ) ;
958
+ } ) ;
959
+
960
+ let e = event_queue. next_event_async ( ) . await ;
961
+ assert_eq ! ( e, expected_event. clone( ) ) ;
962
+ event_queue. event_handled ( ) . unwrap ( ) ;
963
+
964
+ rx. changed ( ) . await . unwrap ( ) ;
965
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
966
+ }
835
967
}
0 commit comments