@@ -251,7 +251,15 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
251
251
252
252
// Once the accumulator reaches a complete state for a specific slot
253
253
// we can build the message states
254
- build_message_states ( state, accumulator_messages, wormhole_merkle_state) . await ?;
254
+ let message_states = build_message_states ( accumulator_messages, wormhole_merkle_state) ?;
255
+
256
+ let message_state_keys = message_states
257
+ . iter ( )
258
+ . map ( |message_state| message_state. key ( ) )
259
+ . collect :: < HashSet < _ > > ( ) ;
260
+
261
+ tracing:: info!( len = message_states. len( ) , "Storing Message States." ) ;
262
+ state. store_message_states ( message_states) . await ?;
255
263
256
264
// Update the aggregate state
257
265
let mut aggregate_state = state. aggregate_state . write ( ) . await ;
@@ -266,6 +274,7 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
266
274
. await ?;
267
275
}
268
276
Some ( latest) if slot > latest => {
277
+ state. prune_removed_keys ( message_state_keys) . await ;
269
278
aggregate_state. latest_completed_slot . replace ( slot) ;
270
279
state
271
280
. api_update_tx
@@ -296,18 +305,17 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
296
305
Ok ( ( ) )
297
306
}
298
307
299
- #[ tracing:: instrument( skip( state, accumulator_messages, wormhole_merkle_state) ) ]
300
- async fn build_message_states (
301
- state : & State ,
308
+ #[ tracing:: instrument( skip( accumulator_messages, wormhole_merkle_state) ) ]
309
+ fn build_message_states (
302
310
accumulator_messages : AccumulatorMessages ,
303
311
wormhole_merkle_state : WormholeMerkleState ,
304
- ) -> Result < ( ) > {
312
+ ) -> Result < Vec < MessageState > > {
305
313
let wormhole_merkle_message_states_proofs =
306
314
construct_message_states_proofs ( & accumulator_messages, & wormhole_merkle_state) ?;
307
315
308
316
let current_time: UnixTimestamp = SystemTime :: now ( ) . duration_since ( UNIX_EPOCH ) ?. as_secs ( ) as _ ;
309
317
310
- let message_states = accumulator_messages
318
+ accumulator_messages
311
319
. raw_messages
312
320
. into_iter ( )
313
321
. enumerate ( )
@@ -326,13 +334,7 @@ async fn build_message_states(
326
334
current_time,
327
335
) )
328
336
} )
329
- . collect :: < Result < Vec < _ > > > ( ) ?;
330
-
331
- tracing:: info!( len = message_states. len( ) , "Storing Message States." ) ;
332
-
333
- state. store_message_states ( message_states) . await ?;
334
-
335
- Ok ( ( ) )
337
+ . collect :: < Result < Vec < _ > > > ( )
336
338
}
337
339
338
340
async fn get_verified_price_feeds < S > (
@@ -677,6 +679,87 @@ mod test {
677
679
}
678
680
}
679
681
682
+ /// On this test we will initially have two price feeds. Then we will send an update with only
683
+ /// price feed 1 (without price feed 2) and make sure that price feed 2 is not stored anymore.
684
+ #[ tokio:: test]
685
+ pub async fn test_getting_price_ids_works_fine_after_price_removal ( ) {
686
+ let ( state, mut update_rx) = setup_state ( 10 ) . await ;
687
+
688
+ let price_feed_1 = create_dummy_price_feed_message ( 100 , 10 , 9 ) ;
689
+ let price_feed_2 = create_dummy_price_feed_message ( 200 , 10 , 9 ) ;
690
+
691
+ // Populate the state
692
+ store_multiple_concurrent_valid_updates (
693
+ state. clone ( ) ,
694
+ generate_update (
695
+ vec ! [
696
+ Message :: PriceFeedMessage ( price_feed_1) ,
697
+ Message :: PriceFeedMessage ( price_feed_2) ,
698
+ ] ,
699
+ 10 ,
700
+ 20 ,
701
+ ) ,
702
+ )
703
+ . await ;
704
+
705
+ // Check that the update_rx channel has received a message
706
+ assert_eq ! (
707
+ update_rx. recv( ) . await ,
708
+ Some ( AggregationEvent :: New { slot: 10 } )
709
+ ) ;
710
+
711
+ // Check the price ids are stored correctly
712
+ assert_eq ! (
713
+ get_price_feed_ids( & * state) . await ,
714
+ vec![
715
+ PriceIdentifier :: new( [ 100 ; 32 ] ) ,
716
+ PriceIdentifier :: new( [ 200 ; 32 ] )
717
+ ]
718
+ . into_iter( )
719
+ . collect( )
720
+ ) ;
721
+
722
+ // Check that price feed 2 exists
723
+ assert ! ( get_price_feeds_with_update_data(
724
+ & * state,
725
+ & [ PriceIdentifier :: new( [ 200 ; 32 ] ) ] ,
726
+ RequestTime :: Latest ,
727
+ )
728
+ . await
729
+ . is_ok( ) ) ;
730
+
731
+ // Now send an update with only price feed 1 (without price feed 2)
732
+ // and make sure that price feed 2 is not stored anymore.
733
+ let price_feed_1 = create_dummy_price_feed_message ( 100 , 12 , 10 ) ;
734
+
735
+ // Populate the state
736
+ store_multiple_concurrent_valid_updates (
737
+ state. clone ( ) ,
738
+ generate_update ( vec ! [ Message :: PriceFeedMessage ( price_feed_1) ] , 15 , 30 ) ,
739
+ )
740
+ . await ;
741
+
742
+ // Check that the update_rx channel has received a message
743
+ assert_eq ! (
744
+ update_rx. recv( ) . await ,
745
+ Some ( AggregationEvent :: New { slot: 15 } )
746
+ ) ;
747
+
748
+ // Check that price feed 2 does not exist anymore
749
+ assert_eq ! (
750
+ get_price_feed_ids( & * state) . await ,
751
+ vec![ PriceIdentifier :: new( [ 100 ; 32 ] ) , ] . into_iter( ) . collect( )
752
+ ) ;
753
+
754
+ assert ! ( get_price_feeds_with_update_data(
755
+ & * state,
756
+ & [ PriceIdentifier :: new( [ 200 ; 32 ] ) ] ,
757
+ RequestTime :: Latest ,
758
+ )
759
+ . await
760
+ . is_err( ) ) ;
761
+ }
762
+
680
763
#[ tokio:: test]
681
764
pub async fn test_metadata_times_and_readiness_work ( ) {
682
765
// The receiver channel should stay open for the state to work
0 commit comments