1313// limitations under the License.
1414
1515use futures:: { Future , FutureExt } ;
16- use std:: { fmt:: Debug , panic :: AssertUnwindSafe , pin:: Pin } ;
16+ use std:: { fmt:: Debug , pin:: Pin } ;
1717use thiserror:: Error ;
1818use tokio:: task:: { JoinError , JoinSet } ;
1919
2020/// Boxed async operation returning Result<R, E>.
2121pub type Operation < R , E > =
2222 Box < dyn FnOnce ( ) -> Pin < Box < dyn Future < Output = Result < R , E > > + Send > > + Send > ;
2323
24- type TaskResult < R , E > = ( usize , Result < Result < R , E > , ( ) > ) ;
25-
2624/// Helper to box an operation ergonomically.
2725/// usage: op(|| async { /* ... */ -> Result<R, E> })
2826pub fn op < Fut , F , R , E > ( f : F ) -> Operation < R , E >
@@ -47,19 +45,21 @@ pub enum TxnError<E: std::error::Error + Send + Sync + 'static> {
4745 #[ source]
4846 source : JoinError ,
4947 } ,
50- #[ error( "task {index} panicked" ) ]
48+ #[ error( "task {index} panicked during execution " ) ]
5149 Panic { index : usize } ,
50+ #[ error( "internal error: missing result for operation {index}" ) ]
51+ MissingResult { index : usize } ,
5252}
5353
54- /// Run all operations concurrently and wait for all to complete.
54+ /// Run N-1 operations concurrently, then run the last operation after they complete.
5555/// If any operation fails:
5656/// 1) wait for all operations to finish
5757/// 2) call `rollback` with all indices that completed successfully
5858/// 3) return the first failure encountered
5959///
6060/// If all succeed, returns results in the original order.
6161pub async fn run_fail_end < R , E , Rollback , RFut > (
62- mut operations : Vec < Operation < R , E > > ,
62+ operations : Vec < Operation < R , E > > ,
6363 rollback : Rollback ,
6464) -> Result < Vec < R > , TxnError < E > >
6565where
@@ -69,23 +69,59 @@ where
6969 RFut : Future < Output = ( ) > + Send + ' static ,
7070{
7171 let n = operations. len ( ) ;
72- assert ! ( n >= 1 , "need at least one operation (the final)" ) ;
7372
74- let final_op = operations. pop ( ) . expect ( "len>=1 ensured" ) ;
73+ if n == 0 {
74+ return Ok ( Vec :: new ( ) ) ;
75+ }
76+ if n == 1 {
77+ if let Some ( op) = operations. into_iter ( ) . next ( ) {
78+ match op ( ) . await {
79+ Ok ( result) => return Ok ( vec ! [ result] ) ,
80+ Err ( e) => {
81+ rollback ( vec ! [ ] ) . await ;
82+ return Err ( TxnError :: Operation {
83+ index : 0 ,
84+ source : e,
85+ } ) ;
86+ }
87+ }
88+ } else {
89+ return Ok ( Vec :: new ( ) ) ;
90+ }
91+ }
92+
93+ // Split operations: first N-1 and the last one
94+ let mut operations = operations;
95+ let last_operation = match operations. pop ( ) {
96+ Some ( op) => op,
97+ None => {
98+ return Ok ( Vec :: new ( ) ) ;
99+ }
100+ } ;
101+ let last_index = n - 1 ;
75102
76- let mut set: JoinSet < TaskResult < R , E > > = JoinSet :: new ( ) ;
77- let mut successes: Vec < usize > = Vec :: with_capacity ( n - 1 ) ;
78- let mut results: Vec < Option < R > > = std :: iter :: repeat_with ( || None ) . take ( n ) . collect ( ) ;
103+ let mut set = JoinSet :: new ( ) ;
104+ let mut successes: Vec < usize > = Vec :: with_capacity ( n) ;
105+ let mut results: Vec < Option < R > > = Vec :: with_capacity ( n ) ;
79106 let mut first_error: Option < TxnError < E > > = None ;
80107
108+ for _ in 0 ..n {
109+ results. push ( None ) ;
110+ }
111+
112+ // Phase 1: Run first N-1 operations in parallel
81113 for ( i, op) in operations. into_iter ( ) . enumerate ( ) {
82114 set. spawn ( async move {
83- let fut = op ( ) ;
84- let caught = AssertUnwindSafe ( fut) . catch_unwind ( ) . await ;
85- ( i, caught. map_err ( |_| ( ) ) )
115+ let result = std:: panic:: AssertUnwindSafe ( op ( ) ) . catch_unwind ( ) . await ;
116+
117+ match result {
118+ Ok ( operation_result) => ( i, Ok ( operation_result) ) ,
119+ Err ( _panic_payload) => ( i, Err ( ( ) ) ) ,
120+ }
86121 } ) ;
87122 }
88123
124+ // Wait for all N-1 operations to complete
89125 while let Some ( joined) = set. join_next ( ) . await {
90126 match joined {
91127 Ok ( ( i, Ok ( Ok ( val) ) ) ) => {
@@ -116,40 +152,37 @@ where
116152 }
117153 }
118154
119- // If any parallel op failed, rollback successes and return error. Final op won't run.
120155 if let Some ( err) = first_error {
121- successes. sort_unstable ( ) ;
122- successes. reverse ( ) ;
123156 rollback ( successes) . await ;
124157 return Err ( err) ;
125158 }
126159
127- // All pre-final succeeded -> run final op now
128- let final_idx = n - 1 ;
129- let final_outcome = AssertUnwindSafe ( ( final_op) ( ) ) . catch_unwind ( ) . await ;
160+ // Phase 2: All N-1 operations succeeded, now run the last operation
161+ match last_operation ( ) . await {
162+ Ok ( last_result) => {
163+ results[ last_index] = Some ( last_result) ;
164+ successes. push ( last_index) ;
130165
131- match final_outcome {
132- Ok ( Ok ( val) ) => {
133- results[ final_idx] = Some ( val) ;
134- Ok ( results
135- . into_iter ( )
136- . map ( |o| o. expect ( "missing result despite success" ) )
137- . collect ( ) )
166+ // All succeeded
167+ let mut final_results = Vec :: with_capacity ( n) ;
168+ for ( i, result_opt) in results. into_iter ( ) . enumerate ( ) {
169+ match result_opt {
170+ Some ( result) => final_results. push ( result) ,
171+ None => {
172+ rollback ( successes) . await ;
173+ return Err ( TxnError :: MissingResult { index : i } ) ;
174+ }
175+ }
176+ }
177+ Ok ( final_results)
138178 }
139- Ok ( Err ( e) ) => {
140- successes. sort_unstable ( ) ;
141- successes. reverse ( ) ;
179+ Err ( e) => {
180+ // Last operation failed, rollback all successful operations
142181 rollback ( successes) . await ;
143182 Err ( TxnError :: Operation {
144- index : final_idx ,
183+ index : last_index ,
145184 source : e,
146185 } )
147186 }
148- Err ( _) => {
149- successes. sort_unstable ( ) ;
150- successes. reverse ( ) ;
151- rollback ( successes) . await ;
152- Err ( TxnError :: Panic { index : final_idx } )
153- }
154187 }
155188}
0 commit comments