1
+ // Copyright 2022 Alibaba Cloud. All rights reserved.
1
2
// Copyright (c) 2020 Ant Financial
2
3
//
3
4
// SPDX-License-Identifier: Apache-2.0
6
7
use std:: collections:: HashMap ;
7
8
use std:: convert:: TryInto ;
8
9
use std:: os:: unix:: io:: RawFd ;
10
+ use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
9
11
use std:: sync:: { Arc , Mutex } ;
10
12
11
13
use async_trait:: async_trait;
@@ -14,19 +16,23 @@ use tokio::{self, sync::mpsc, task};
14
16
15
17
use crate :: common:: client_connect;
16
18
use crate :: error:: { Error , Result } ;
17
- use crate :: proto:: { Code , Codec , GenMessage , Message , Request , Response , MESSAGE_TYPE_RESPONSE } ;
19
+ use crate :: proto:: {
20
+ Code , Codec , GenMessage , Message , Request , Response , FLAG_REMOTE_CLOSED , FLAG_REMOTE_OPEN ,
21
+ MESSAGE_TYPE_DATA , MESSAGE_TYPE_RESPONSE ,
22
+ } ;
18
23
use crate :: r#async:: connection:: * ;
19
24
use crate :: r#async:: shutdown;
20
- use crate :: r#async:: stream:: { ResultReceiver , ResultSender } ;
25
+ use crate :: r#async:: stream:: {
26
+ Kind , MessageReceiver , MessageSender , ResultReceiver , ResultSender , StreamInner ,
27
+ } ;
21
28
use crate :: r#async:: utils;
22
29
23
- type RequestSender = mpsc:: Sender < ( GenMessage , ResultSender ) > ;
24
- type RequestReceiver = mpsc:: Receiver < ( GenMessage , ResultSender ) > ;
25
-
26
30
/// A ttrpc Client (async).
27
31
#[ derive( Clone ) ]
28
32
pub struct Client {
29
- req_tx : RequestSender ,
33
+ req_tx : MessageSender ,
34
+ next_stream_id : Arc < AtomicU32 > ,
35
+ streams : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
30
36
}
31
37
32
38
impl Client {
@@ -39,26 +45,40 @@ impl Client {
39
45
pub fn new ( fd : RawFd ) -> Client {
40
46
let stream = utils:: new_unix_stream_from_raw_fd ( fd) ;
41
47
42
- let ( req_tx, rx) : ( RequestSender , RequestReceiver ) = mpsc:: channel ( 100 ) ;
48
+ let ( req_tx, rx) : ( MessageSender , MessageReceiver ) = mpsc:: channel ( 100 ) ;
43
49
44
- let delegate = ClientBuilder { rx : Some ( rx) } ;
50
+ let req_map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
51
+ let delegate = ClientBuilder {
52
+ rx : Some ( rx) ,
53
+ streams : req_map. clone ( ) ,
54
+ } ;
45
55
46
56
let conn = Connection :: new ( stream, delegate) ;
47
57
tokio:: spawn ( async move { conn. run ( ) . await } ) ;
48
58
49
- Client { req_tx }
59
+ Client {
60
+ req_tx,
61
+ next_stream_id : Arc :: new ( AtomicU32 :: new ( 1 ) ) ,
62
+ streams : req_map,
63
+ }
50
64
}
51
65
52
66
/// Requsts a unary request and returns with response.
53
67
pub async fn request ( & self , req : Request ) -> Result < Response > {
54
68
let timeout_nano = req. timeout_nano ;
55
- let msg: GenMessage = Message :: new_request ( 0 , req)
69
+ let stream_id = self . next_stream_id . fetch_add ( 2 , Ordering :: Relaxed ) ;
70
+
71
+ let msg: GenMessage = Message :: new_request ( stream_id, req)
56
72
. try_into ( )
57
73
. map_err ( |e : protobuf:: error:: ProtobufError | Error :: Others ( e. to_string ( ) ) ) ?;
58
74
59
75
let ( tx, mut rx) : ( ResultSender , ResultReceiver ) = mpsc:: channel ( 100 ) ;
76
+
77
+ // TODO: check return.
78
+ self . streams . lock ( ) . unwrap ( ) . insert ( stream_id, tx) ;
79
+
60
80
self . req_tx
61
- . send ( ( msg, tx ) )
81
+ . send ( msg)
62
82
. await
63
83
. map_err ( |e| Error :: Others ( format ! ( "Send packet to sender error {:?}" , e) ) ) ?;
64
84
@@ -87,6 +107,44 @@ impl Client {
87
107
88
108
Ok ( res)
89
109
}
110
+
111
+ /// Creates a StreamInner instance.
112
+ pub async fn new_stream (
113
+ & self ,
114
+ req : Request ,
115
+ streaming_client : bool ,
116
+ streaming_server : bool ,
117
+ ) -> Result < StreamInner > {
118
+ let stream_id = self . next_stream_id . fetch_add ( 2 , Ordering :: Relaxed ) ;
119
+
120
+ let mut msg: GenMessage = Message :: new_request ( stream_id, req)
121
+ . try_into ( )
122
+ . map_err ( |e : protobuf:: error:: ProtobufError | Error :: Others ( e. to_string ( ) ) ) ?;
123
+
124
+ if streaming_client {
125
+ msg. header . add_flags ( FLAG_REMOTE_OPEN ) ;
126
+ } else {
127
+ msg. header . add_flags ( FLAG_REMOTE_CLOSED ) ;
128
+ }
129
+
130
+ let ( tx, rx) : ( ResultSender , ResultReceiver ) = mpsc:: channel ( 100 ) ;
131
+ // TODO: check return
132
+ self . streams . lock ( ) . unwrap ( ) . insert ( stream_id, tx) ;
133
+ self . req_tx
134
+ . send ( msg)
135
+ . await
136
+ . map_err ( |e| Error :: Others ( format ! ( "Send packet to sender error {:?}" , e) ) ) ?;
137
+
138
+ Ok ( StreamInner :: new (
139
+ stream_id,
140
+ self . req_tx . clone ( ) ,
141
+ rx,
142
+ streaming_client,
143
+ streaming_server,
144
+ Kind :: Client ,
145
+ self . streams . clone ( ) ,
146
+ ) )
147
+ }
90
148
}
91
149
92
150
struct ClientClose {
@@ -104,7 +162,8 @@ impl Drop for ClientClose {
104
162
105
163
#[ derive( Debug ) ]
106
164
struct ClientBuilder {
107
- rx : Option < RequestReceiver > ,
165
+ rx : Option < MessageReceiver > ,
166
+ streams : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
108
167
}
109
168
110
169
impl Builder for ClientBuilder {
@@ -113,52 +172,43 @@ impl Builder for ClientBuilder {
113
172
114
173
fn build ( & mut self ) -> ( Self :: Reader , Self :: Writer ) {
115
174
let ( notifier, waiter) = shutdown:: new ( ) ;
116
- let req_map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
117
175
(
118
176
ClientReader {
119
177
shutdown_waiter : waiter,
120
- req_map : req_map . clone ( ) ,
178
+ streams : self . streams . clone ( ) ,
121
179
} ,
122
180
ClientWriter {
123
- stream_id : 1 ,
124
181
rx : self . rx . take ( ) . unwrap ( ) ,
125
182
shutdown_notifier : notifier,
126
- req_map,
183
+
184
+ streams : self . streams . clone ( ) ,
127
185
} ,
128
186
)
129
187
}
130
188
}
131
189
132
190
struct ClientWriter {
133
- stream_id : u32 ,
134
- rx : RequestReceiver ,
191
+ rx : MessageReceiver ,
135
192
shutdown_notifier : shutdown:: Notifier ,
136
- req_map : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
193
+
194
+ streams : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
137
195
}
138
196
139
197
#[ async_trait]
140
198
impl WriterDelegate for ClientWriter {
141
199
async fn recv ( & mut self ) -> Option < GenMessage > {
142
- if let Some ( ( mut msg, resp_tx) ) = self . rx . recv ( ) . await {
143
- let current_stream_id = self . stream_id ;
144
- msg. header . set_stream_id ( current_stream_id) ;
145
- self . stream_id += 2 ;
146
- {
147
- let mut map = self . req_map . lock ( ) . unwrap ( ) ;
148
- map. insert ( current_stream_id, resp_tx) ;
149
- }
150
- return Some ( msg) ;
151
- } else {
152
- return None ;
153
- }
200
+ self . rx . recv ( ) . await
154
201
}
155
202
156
203
async fn disconnect ( & self , msg : & GenMessage , e : Error ) {
204
+ // TODO:
205
+ // At this point, a new request may have been received.
157
206
let resp_tx = {
158
- let mut map = self . req_map . lock ( ) . unwrap ( ) ;
207
+ let mut map = self . streams . lock ( ) . unwrap ( ) ;
159
208
map. remove ( & msg. header . stream_id )
160
209
} ;
161
210
211
+ // TODO: if None
162
212
if let Some ( resp_tx) = resp_tx {
163
213
let e = Error :: Socket ( format ! ( "{:?}" , e) ) ;
164
214
resp_tx
@@ -174,8 +224,8 @@ impl WriterDelegate for ClientWriter {
174
224
}
175
225
176
226
struct ClientReader {
227
+ streams : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
177
228
shutdown_waiter : shutdown:: Waiter ,
178
- req_map : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
179
229
}
180
230
181
231
#[ async_trait]
@@ -191,8 +241,8 @@ impl ReaderDelegate for ClientReader {
191
241
let _ = sender. await ;
192
242
193
243
// Take all items out of `req_map`.
194
- let mut map = std:: mem:: take ( & mut * self . req_map . lock ( ) . unwrap ( ) ) ;
195
- // Terminate outstanding RPC requests with the error.
244
+ let mut map = std:: mem:: take ( & mut * self . streams . lock ( ) . unwrap ( ) ) ;
245
+ // Terminate undone RPC requests with the error.
196
246
for ( _stream_id, resp_tx) in map. drain ( ) {
197
247
if let Err ( _e) = resp_tx. send ( Err ( e. clone ( ) ) ) . await {
198
248
warn ! ( "Failed to terminate pending RPC: the request has returned" ) ;
@@ -203,35 +253,56 @@ impl ReaderDelegate for ClientReader {
203
253
async fn exit ( & self ) { }
204
254
205
255
async fn handle_msg ( & self , msg : GenMessage ) {
206
- let req_map = self . req_map . clone ( ) ;
256
+ let req_map = self . streams . clone ( ) ;
207
257
tokio:: spawn ( async move {
208
- let resp_tx2 ;
209
- {
210
- let mut map = req_map. lock ( ) . unwrap ( ) ;
211
- let resp_tx = match map . get ( & msg . header . stream_id ) {
212
- Some ( tx ) => tx ,
213
- None => {
214
- debug ! ( "Receiver got unknown packet {:?}" , msg ) ;
215
- return ;
258
+ let resp_tx = match msg . header . type_ {
259
+ MESSAGE_TYPE_RESPONSE => {
260
+ match req_map. lock ( ) . unwrap ( ) . remove ( & msg . header . stream_id ) {
261
+ Some ( tx ) => tx ,
262
+ None => {
263
+ debug ! ( "Receiver got unknown response packet {:?}" , msg ) ;
264
+ return ;
265
+ }
216
266
}
217
- } ;
218
-
219
- resp_tx2 = resp_tx. clone ( ) ;
220
- map. remove ( & msg. header . stream_id ) ; // Forget the result, just remove.
221
- }
222
-
223
- if msg. header . type_ != MESSAGE_TYPE_RESPONSE {
224
- resp_tx2
225
- . send ( Err ( Error :: Others ( format ! (
226
- "Recver got malformed packet {:?}" ,
227
- msg
228
- ) ) ) )
229
- . await
230
- . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
231
- return ;
232
- }
233
-
234
- resp_tx2
267
+ }
268
+ MESSAGE_TYPE_DATA => {
269
+ if ( msg. header . flags & FLAG_REMOTE_CLOSED ) == FLAG_REMOTE_CLOSED {
270
+ match req_map. lock ( ) . unwrap ( ) . remove ( & msg. header . stream_id ) {
271
+ Some ( tx) => tx. clone ( ) ,
272
+ None => {
273
+ debug ! ( "Receiver got unknown data packet {:?}" , msg) ;
274
+ return ;
275
+ }
276
+ }
277
+ } else {
278
+ match req_map. lock ( ) . unwrap ( ) . get ( & msg. header . stream_id ) {
279
+ Some ( tx) => tx. clone ( ) ,
280
+ None => {
281
+ debug ! ( "Receiver got unknown data packet {:?}" , msg) ;
282
+ return ;
283
+ }
284
+ }
285
+ }
286
+ }
287
+ _ => {
288
+ let resp_tx = match req_map. lock ( ) . unwrap ( ) . remove ( & msg. header . stream_id ) {
289
+ Some ( tx) => tx,
290
+ None => {
291
+ debug ! ( "Receiver got unknown packet {:?}" , msg) ;
292
+ return ;
293
+ }
294
+ } ;
295
+ resp_tx
296
+ . send ( Err ( Error :: Others ( format ! (
297
+ "Recver got malformed packet {:?}" ,
298
+ msg
299
+ ) ) ) )
300
+ . await
301
+ . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
302
+ return ;
303
+ }
304
+ } ;
305
+ resp_tx
235
306
. send ( Ok ( msg) )
236
307
. await
237
308
. unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
0 commit comments