@@ -321,12 +321,20 @@ func (d *dummyStateSpent) IsTerminal() bool {
321
321
return true
322
322
}
323
323
324
- func assertState [Event any , Env Environment ](t * testing.T ,
325
- m * StateMachine [Event , Env ], expectedState State [Event , Env ]) {
324
+ // assertState asserts that the state machine is currently in the expected
325
+ // state type and returns the state cast to that type.
326
+ func assertState [Event any , Env Environment , S State [Event , Env ]](t * testing.T ,
327
+ m * StateMachine [Event , Env ], expectedState S ) S {
326
328
327
329
state , err := m .CurrentState ()
328
330
require .NoError (t , err )
329
331
require .IsType (t , expectedState , state )
332
+
333
+ // Perform the type assertion to return the concrete type.
334
+ concreteState , ok := state .(S )
335
+ require .True (t , ok , "state type assertion failed" )
336
+
337
+ return concreteState
330
338
}
331
339
332
340
func assertStateTransitions [Event any , Env Environment ](
@@ -639,18 +647,15 @@ func TestStateMachineConfMapper(t *testing.T) {
639
647
assertStateTransitions (t , stateSub , expectedStates )
640
648
641
649
// Final state assertion.
642
- finalState , err := stateMachine .CurrentState ()
643
- require .NoError (t , err )
644
- require .IsType (t , & dummyStateConfirmed {}, finalState )
650
+ finalState := assertState (t , & stateMachine , & dummyStateConfirmed {})
645
651
646
652
// Assert that the details from the confirmation event were correctly
647
653
// propagated to the final state.
648
- finalStateDetails := finalState .(* dummyStateConfirmed )
649
654
require .Equal (t ,
650
- * simulatedConf .BlockHash , finalStateDetails .blockHash ,
655
+ * simulatedConf .BlockHash , finalState .blockHash ,
651
656
)
652
657
require .Equal (t ,
653
- simulatedConf .BlockHeight , finalStateDetails .blockHeight ,
658
+ simulatedConf .BlockHeight , finalState .blockHeight ,
654
659
)
655
660
656
661
adapters .AssertExpectations (t )
@@ -719,18 +724,15 @@ func TestStateMachineSpendMapper(t *testing.T) {
719
724
assertStateTransitions (t , stateSub , expectedStates )
720
725
721
726
// Final state assertion.
722
- finalState , err := stateMachine .CurrentState ()
723
- require .NoError (t , err )
724
- require .IsType (t , & dummyStateSpent {}, finalState )
727
+ finalState := assertState (t , & stateMachine , & dummyStateSpent {})
725
728
726
729
// Assert that the details from the spend event were correctly
727
730
// propagated to the final state.
728
- finalStateDetails := finalState .(* dummyStateSpent )
729
731
require .Equal (t ,
730
- * simulatedSpend .SpenderTxHash , finalStateDetails .spenderTxHash ,
732
+ * simulatedSpend .SpenderTxHash , finalState .spenderTxHash ,
731
733
)
732
734
require .Equal (t ,
733
- simulatedSpend .SpendingHeight , finalStateDetails .spendingHeight ,
735
+ simulatedSpend .SpendingHeight , finalState .spendingHeight ,
734
736
)
735
737
736
738
adapters .AssertExpectations (t )
0 commit comments