@@ -20,24 +20,29 @@ use crate::lsps0;
20
20
use crate :: lsps1;
21
21
use crate :: lsps2;
22
22
use crate :: prelude:: { Vec , VecDeque } ;
23
- use crate :: sync:: Mutex ;
23
+ use crate :: sync:: { Arc , Mutex } ;
24
+
25
+ use core:: future:: Future ;
26
+ use core:: task:: { Poll , Waker } ;
24
27
25
28
pub ( crate ) struct EventQueue {
26
- queue : Mutex < VecDeque < Event > > ,
29
+ queue : Arc < Mutex < VecDeque < Event > > > ,
30
+ waker : Arc < Mutex < Option < Waker > > > ,
27
31
#[ cfg( feature = "std" ) ]
28
32
condvar : std:: sync:: Condvar ,
29
33
}
30
34
31
35
impl EventQueue {
32
36
pub fn new ( ) -> Self {
33
- let queue = Mutex :: new ( VecDeque :: new ( ) ) ;
37
+ let queue = Arc :: new ( Mutex :: new ( VecDeque :: new ( ) ) ) ;
38
+ let waker = Arc :: new ( Mutex :: new ( None ) ) ;
34
39
#[ cfg( feature = "std" ) ]
35
40
{
36
41
let condvar = std:: sync:: Condvar :: new ( ) ;
37
- Self { queue, condvar }
42
+ Self { queue, waker , condvar }
38
43
}
39
44
#[ cfg( not( feature = "std" ) ) ]
40
- Self { queue }
45
+ Self { queue, waker }
41
46
}
42
47
43
48
pub fn enqueue ( & self , event : Event ) {
@@ -46,6 +51,9 @@ impl EventQueue {
46
51
queue. push_back ( event) ;
47
52
}
48
53
54
+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
55
+ waker. wake ( ) ;
56
+ }
49
57
#[ cfg( feature = "std" ) ]
50
58
self . condvar . notify_one ( ) ;
51
59
}
@@ -54,6 +62,10 @@ impl EventQueue {
54
62
self . queue . lock ( ) . unwrap ( ) . pop_front ( )
55
63
}
56
64
65
+ pub async fn next_event_async ( & self ) -> Event {
66
+ EventFuture { event_queue : Arc :: clone ( & self . queue ) , waker : Arc :: clone ( & self . waker ) } . await
67
+ }
68
+
57
69
#[ cfg( feature = "std" ) ]
58
70
pub fn wait_next_event ( & self ) -> Event {
59
71
let mut queue =
@@ -65,6 +77,10 @@ impl EventQueue {
65
77
drop ( queue) ;
66
78
67
79
if should_notify {
80
+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
81
+ waker. wake ( ) ;
82
+ }
83
+
68
84
self . condvar . notify_one ( ) ;
69
85
}
70
86
@@ -92,3 +108,132 @@ pub enum Event {
92
108
/// An LSPS2 (JIT Channel) server event.
93
109
LSPS2Service ( lsps2:: event:: LSPS2ServiceEvent ) ,
94
110
}
111
+
112
+ struct EventFuture {
113
+ event_queue : Arc < Mutex < VecDeque < Event > > > ,
114
+ waker : Arc < Mutex < Option < Waker > > > ,
115
+ }
116
+
117
+ impl Future for EventFuture {
118
+ type Output = Event ;
119
+
120
+ fn poll (
121
+ self : core:: pin:: Pin < & mut Self > , cx : & mut core:: task:: Context < ' _ > ,
122
+ ) -> core:: task:: Poll < Self :: Output > {
123
+ if let Some ( event) = self . event_queue . lock ( ) . unwrap ( ) . pop_front ( ) {
124
+ Poll :: Ready ( event)
125
+ } else {
126
+ * self . waker . lock ( ) . unwrap ( ) = Some ( cx. waker ( ) . clone ( ) ) ;
127
+ Poll :: Pending
128
+ }
129
+ }
130
+ }
131
+
132
+ #[ cfg( test) ]
133
+ mod tests {
134
+ use super :: * ;
135
+ use crate :: lsps0:: event:: LSPS0ClientEvent ;
136
+ use bitcoin:: secp256k1:: { PublicKey , Secp256k1 , SecretKey } ;
137
+
138
+ #[ tokio:: test]
139
+ #[ cfg( feature = "std" ) ]
140
+ async fn event_queue_works ( ) {
141
+ use core:: sync:: atomic:: { AtomicU16 , Ordering } ;
142
+ use std:: sync:: Arc ;
143
+ use std:: time:: Duration ;
144
+
145
+ let event_queue = Arc :: new ( EventQueue :: new ( ) ) ;
146
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
147
+
148
+ let secp_ctx = Secp256k1 :: new ( ) ;
149
+ let counterparty_node_id =
150
+ PublicKey :: from_secret_key ( & secp_ctx, & SecretKey :: from_slice ( & [ 42 ; 32 ] ) . unwrap ( ) ) ;
151
+ let expected_event = Event :: LSPS0Client ( LSPS0ClientEvent :: ListProtocolsResponse {
152
+ counterparty_node_id,
153
+ protocols : Vec :: new ( ) ,
154
+ } ) ;
155
+
156
+ for _ in 0 ..3 {
157
+ event_queue. enqueue ( expected_event. clone ( ) ) ;
158
+ }
159
+
160
+ assert_eq ! ( event_queue. wait_next_event( ) , expected_event) ;
161
+ assert_eq ! ( event_queue. next_event_async( ) . await , expected_event) ;
162
+ assert_eq ! ( event_queue. next_event( ) , Some ( expected_event. clone( ) ) ) ;
163
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
164
+
165
+ // Check `next_event_async` won't return if the queue is empty and always rather timeout.
166
+ tokio:: select! {
167
+ _ = tokio:: time:: sleep( Duration :: from_millis( 10 ) ) => {
168
+ // Timeout
169
+ }
170
+ _ = event_queue. next_event_async( ) => {
171
+ panic!( ) ;
172
+ }
173
+ }
174
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
175
+
176
+ // Check we get the expected number of events when polling/enqueuing concurrently.
177
+ let enqueued_events = AtomicU16 :: new ( 0 ) ;
178
+ let received_events = AtomicU16 :: new ( 0 ) ;
179
+ let mut delayed_enqueue = false ;
180
+
181
+ for _ in 0 ..25 {
182
+ event_queue. enqueue ( expected_event. clone ( ) ) ;
183
+ enqueued_events. fetch_add ( 1 , Ordering :: SeqCst ) ;
184
+ }
185
+
186
+ loop {
187
+ tokio:: select! {
188
+ _ = tokio:: time:: sleep( Duration :: from_millis( 10 ) ) , if !delayed_enqueue => {
189
+ event_queue. enqueue( expected_event. clone( ) ) ;
190
+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
191
+ delayed_enqueue = true ;
192
+ }
193
+ e = event_queue. next_event_async( ) => {
194
+ assert_eq!( e, expected_event) ;
195
+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
196
+
197
+ event_queue. enqueue( expected_event. clone( ) ) ;
198
+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
199
+ }
200
+ e = event_queue. next_event_async( ) => {
201
+ assert_eq!( e, expected_event) ;
202
+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
203
+ }
204
+ }
205
+
206
+ if delayed_enqueue
207
+ && received_events. load ( Ordering :: SeqCst ) == enqueued_events. load ( Ordering :: SeqCst )
208
+ {
209
+ break ;
210
+ }
211
+ }
212
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
213
+
214
+ // Check we operate correctly, even when mixing and matching blocking and async API calls.
215
+ let ( tx, mut rx) = tokio:: sync:: watch:: channel ( ( ) ) ;
216
+ let thread_queue = Arc :: clone ( & event_queue) ;
217
+ let thread_event = expected_event. clone ( ) ;
218
+ std:: thread:: spawn ( move || {
219
+ let e = thread_queue. wait_next_event ( ) ;
220
+ assert_eq ! ( e, thread_event) ;
221
+ tx. send ( ( ) ) . unwrap ( ) ;
222
+ } ) ;
223
+
224
+ let thread_queue = Arc :: clone ( & event_queue) ;
225
+ let thread_event = expected_event. clone ( ) ;
226
+ std:: thread:: spawn ( move || {
227
+ // Sleep a bit before we enqueue the events everybody is waiting for.
228
+ std:: thread:: sleep ( Duration :: from_millis ( 20 ) ) ;
229
+ thread_queue. enqueue ( thread_event. clone ( ) ) ;
230
+ thread_queue. enqueue ( thread_event. clone ( ) ) ;
231
+ } ) ;
232
+
233
+ let e = event_queue. next_event_async ( ) . await ;
234
+ assert_eq ! ( e, expected_event. clone( ) ) ;
235
+
236
+ rx. changed ( ) . await . unwrap ( ) ;
237
+ assert_eq ! ( event_queue. next_event( ) , None ) ;
238
+ }
239
+ }
0 commit comments