1
+ use std:: fmt;
2
+
1
3
use chrono:: { DateTime , Utc } ;
4
+ use futures:: future:: { join, join_all} ;
2
5
use primitives:: {
3
6
balances:: { Balances , CheckedState } ,
4
7
Address , ChannelId , UnifiedNum ,
@@ -11,7 +14,7 @@ use tokio_postgres::{
11
14
use super :: { DbPool , PoolError } ;
12
15
use thiserror:: Error ;
13
16
14
- static UPDATE_ACCOUNTING_STATEMENT : & str = "INSERT INTO accounting(channel_id, side, address, amount, updated, created) VALUES($1, $2, $3, $4, $5, $6) ON CONFLICT ON CONSTRAINT accounting_pkey DO UPDATE SET amount = accounting.amount + $4 , updated = $6 WHERE accounting.channel_id = $1 AND accounting.side = $2 AND accounting.address = $3 RETURNING channel_id, side, address, amount, updated, created" ;
17
+ static UPDATE_ACCOUNTING_STATEMENT : & str = "INSERT INTO accounting(channel_id, side, address, amount, updated, created) VALUES($1, $2, $3, $4, NULL, NOW()) ON CONFLICT ON CONSTRAINT accounting_pkey DO UPDATE SET amount = accounting.amount + EXCLUDED.amount , updated = NOW() WHERE accounting.channel_id = $1 AND accounting.side = $2 AND accounting.address = $3 RETURNING channel_id, side, address, amount, updated, created" ;
15
18
16
19
#[ derive( Debug , Error ) ]
17
20
pub enum Error {
@@ -57,12 +60,16 @@ pub enum Side {
57
60
Spender ,
58
61
}
59
62
60
- pub enum SpendError {
61
- Pool ( PoolError ) ,
62
- NoRecordsUpdated ,
63
+ impl fmt:: Display for Side {
64
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
65
+ match self {
66
+ Side :: Earner => write ! ( f, "Earner" ) ,
67
+ Side :: Spender => write ! ( f, "Spender" ) ,
68
+ }
69
+ }
63
70
}
64
71
65
- /// ```text
72
+ /// ```sql
66
73
/// SELECT channel_id, side, address, amount, updated, created FROM accounting WHERE channel_id = $1 AND address = $2 AND side = $3
67
74
/// ```
68
75
pub async fn get_accounting (
@@ -110,14 +117,8 @@ pub async fn update_accounting(
110
117
let client = pool. get ( ) . await ?;
111
118
let statement = client. prepare ( UPDATE_ACCOUNTING_STATEMENT ) . await ?;
112
119
113
- let now = Utc :: now ( ) ;
114
- let updated: Option < DateTime < Utc > > = None ;
115
-
116
120
let row = client
117
- . query_one (
118
- & statement,
119
- & [ & channel_id, & side, & address, & amount, & updated, & now] ,
120
- )
121
+ . query_one ( & statement, & [ & channel_id, & side, & address, & amount] )
121
122
. await ?;
122
123
123
124
Ok ( Accounting :: from ( & row) )
@@ -126,52 +127,63 @@ pub async fn update_accounting(
126
127
/// `delta_balances` defines the Balances that need to be added to the spending or earnings of the `Accounting`s.
127
128
/// It will **not** override the whole `Accounting` value
128
129
/// Returns a tuple of `(Vec<Earners Accounting>, Vec<Spenders Accounting>)`
130
+ ///
131
+ /// # Error
132
+ ///
133
+ /// It will return an error if any of the updates fails but it would have updated the rest of them.
134
+ ///
135
+ /// This way we ensure that even if a single or multiple Accounting updates fail,
136
+ /// we will still pay out the rest of them.
129
137
pub async fn spend_amount (
130
138
pool : DbPool ,
131
139
channel_id : ChannelId ,
132
140
delta_balances : Balances < CheckedState > ,
133
141
) -> Result < ( Vec < Accounting > , Vec < Accounting > ) , PoolError > {
134
- let client = pool. get ( ) . await ?;
135
-
136
- let statement = client. prepare ( UPDATE_ACCOUNTING_STATEMENT ) . await ?;
142
+ let client = & pool. get ( ) . await ?;
137
143
138
- let now = Utc :: now ( ) ;
139
- let updated: Option < DateTime < Utc > > = None ;
140
-
141
- let ( mut earners, mut spenders) = ( vec ! [ ] , vec ! [ ] ) ;
144
+ let statement = client. prepare_cached ( UPDATE_ACCOUNTING_STATEMENT ) . await ?;
142
145
143
146
// Earners
144
- for ( earner, amount) in delta_balances. earners {
145
- let row = client
146
- . query_one (
147
- & statement,
148
- & [ & channel_id, & Side :: Earner , & earner, & amount, & updated, & now] ,
149
- )
150
- . await ?;
147
+ let earners_futures = delta_balances. earners . into_iter ( ) . map ( |( earner, amount) | {
148
+ let statement = statement. clone ( ) ;
151
149
152
- earners. push ( Accounting :: from ( & row) )
153
- }
150
+ async move {
151
+ client
152
+ . query_one ( & statement, & [ & channel_id, & Side :: Earner , & earner, & amount] )
153
+ . await
154
+ . map ( |row| Accounting :: from ( & row) )
155
+ }
156
+ } ) ;
154
157
155
158
// Spenders
156
- for ( spender, amount) in delta_balances. spenders {
157
- let row = client
158
- . query_one (
159
- & statement,
160
- & [
161
- & channel_id,
162
- & Side :: Spender ,
163
- & spender,
164
- & amount,
165
- & updated,
166
- & now,
167
- ] ,
168
- )
169
- . await ?;
159
+ let spenders_futures = delta_balances
160
+ . spenders
161
+ . into_iter ( )
162
+ . map ( |( spender, amount) | {
163
+ let statement = statement. clone ( ) ;
164
+
165
+ async move {
166
+ client
167
+ . query_one (
168
+ & statement,
169
+ & [ & channel_id, & Side :: Spender , & spender, & amount] ,
170
+ )
171
+ . await
172
+ . map ( |row| Accounting :: from ( & row) )
173
+ }
174
+ } ) ;
170
175
171
- spenders. push ( Accounting :: from ( & row) )
172
- }
176
+ let earners = join_all ( earners_futures) ;
177
+ let spenders = join_all ( spenders_futures) ;
178
+
179
+ // collect all the Accounting updates into Vectors
180
+ let ( earners, spenders) = join ( earners, spenders) . await ;
173
181
174
- Ok ( ( earners, spenders) )
182
+ // Return an error if any of the Accounting updates failed
183
+ Ok ( (
184
+ earners. into_iter ( ) . collect :: < Result < _ , _ > > ( ) ?,
185
+ spenders. into_iter ( ) . collect :: < Result < _ , _ > > ( ) ?,
186
+ ) )
175
187
}
176
188
177
189
#[ cfg( test) ]
@@ -531,7 +543,7 @@ mod test {
531
543
earners_acc
532
544
. iter ( )
533
545
. find ( |a| a. address == earner)
534
- . unwrap ( )
546
+ . expect ( "Should find Accounting" )
535
547
. clone ( ) ,
536
548
false ,
537
549
) ;
@@ -540,7 +552,7 @@ mod test {
540
552
earners_acc
541
553
. iter ( )
542
554
. find ( |a| a. address == other_earner)
543
- . unwrap ( )
555
+ . expect ( "Should find Accounting" )
544
556
. clone ( ) ,
545
557
false ,
546
558
) ;
@@ -551,7 +563,7 @@ mod test {
551
563
spenders_acc
552
564
. iter ( )
553
565
. find ( |a| a. address == spender)
554
- . unwrap ( )
566
+ . expect ( "Should find Accounting" )
555
567
. clone ( ) ,
556
568
false ,
557
569
) ;
@@ -560,7 +572,7 @@ mod test {
560
572
spenders_acc
561
573
. iter ( )
562
574
. find ( |a| a. address == other_spender)
563
- . unwrap ( )
575
+ . expect ( "Should find Accounting" )
564
576
. clone ( ) ,
565
577
false ,
566
578
) ;
0 commit comments