@@ -2876,4 +2876,52 @@ mod tests {
2876
2876
// For (None)
2877
2877
assert_eq ! ( filter_addresses( None ) , None ) ;
2878
2878
}
2879
+
2880
+ #[ test]
2881
+ #[ cfg( feature = "std" ) ]
2882
+ fn test_process_events_multithreaded ( ) {
2883
+ use std:: time:: { Duration , Instant } ;
2884
+ // Test that `process_events` getting called on multiple threads doesn't generate too many
2885
+ // loop iterations.
2886
+ // Each time `process_events` goes around the loop we call
2887
+ // `get_and_clear_pending_msg_events`, which we count using the `TestMessageHandler`.
2888
+ // Because the loop should go around once more after a call which fails to take the
2889
+ // single-threaded lock, if we write zero to the counter before calling `process_events` we
2890
+ // should never observe there having been more than 2 loop iterations.
2891
+ // Further, because the last thread to exit will call `process_events` before returning, we
2892
+ // should always have at least one count at the end.
2893
+ let cfg = Arc :: new ( create_peermgr_cfgs ( 1 ) ) ;
2894
+ // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }.
2895
+ let peer = Arc :: new ( create_network ( 1 , unsafe { & * ( & * cfg as * const _ ) as & ' static _ } ) . pop ( ) . unwrap ( ) ) ;
2896
+
2897
+ let exit_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
2898
+ macro_rules! spawn_thread { ( ) => { {
2899
+ let thread_cfg = Arc :: clone( & cfg) ;
2900
+ let thread_peer = Arc :: clone( & peer) ;
2901
+ let thread_exit = Arc :: clone( & exit_flag) ;
2902
+ std:: thread:: spawn( move || {
2903
+ while !thread_exit. load( Ordering :: Acquire ) {
2904
+ thread_cfg[ 0 ] . chan_handler. message_fetch_counter. store( 0 , Ordering :: Release ) ;
2905
+ thread_peer. process_events( ) ;
2906
+ std:: thread:: sleep( Duration :: from_micros( 1 ) ) ;
2907
+ }
2908
+ } )
2909
+ } } }
2910
+
2911
+ let thread_a = spawn_thread ! ( ) ;
2912
+ let thread_b = spawn_thread ! ( ) ;
2913
+ let thread_c = spawn_thread ! ( ) ;
2914
+
2915
+ let start_time = Instant :: now ( ) ;
2916
+ while start_time. elapsed ( ) < Duration :: from_millis ( 100 ) {
2917
+ let val = cfg[ 0 ] . chan_handler . message_fetch_counter . load ( Ordering :: Acquire ) ;
2918
+ assert ! ( val <= 2 ) ;
2919
+ }
2920
+
2921
+ exit_flag. store ( true , Ordering :: Release ) ;
2922
+ thread_a. join ( ) . unwrap ( ) ;
2923
+ thread_b. join ( ) . unwrap ( ) ;
2924
+ thread_c. join ( ) . unwrap ( ) ;
2925
+ assert ! ( cfg[ 0 ] . chan_handler. message_fetch_counter. load( Ordering :: Acquire ) >= 1 ) ;
2926
+ }
2879
2927
}
0 commit comments