Skip to content

Commit 1126e43

Browse files
committed
async: add streaming support for client.
Added streaming support for client-side. Signed-off-by: wanglei01 <wllenyj@linux.alibaba.com>
1 parent 314ec95 commit 1126e43

File tree

1 file changed

+133
-62
lines changed

1 file changed

+133
-62
lines changed

src/asynchronous/client.rs

Lines changed: 133 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Copyright 2022 Alibaba Cloud. All rights reserved.
12
// Copyright (c) 2020 Ant Financial
23
//
34
// SPDX-License-Identifier: Apache-2.0
@@ -6,6 +7,7 @@
67
use std::collections::HashMap;
78
use std::convert::TryInto;
89
use std::os::unix::io::RawFd;
10+
use std::sync::atomic::{AtomicU32, Ordering};
911
use std::sync::{Arc, Mutex};
1012

1113
use async_trait::async_trait;
@@ -14,19 +16,23 @@ use tokio::{self, sync::mpsc, task};
1416

1517
use crate::common::client_connect;
1618
use crate::error::{Error, Result};
17-
use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE};
19+
use crate::proto::{
20+
Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN,
21+
MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
22+
};
1823
use crate::r#async::connection::*;
1924
use crate::r#async::shutdown;
20-
use crate::r#async::stream::{ResultReceiver, ResultSender};
25+
use crate::r#async::stream::{
26+
Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner,
27+
};
2128
use crate::r#async::utils;
2229

23-
type RequestSender = mpsc::Sender<(GenMessage, ResultSender)>;
24-
type RequestReceiver = mpsc::Receiver<(GenMessage, ResultSender)>;
25-
2630
/// A ttrpc Client (async).
2731
#[derive(Clone)]
2832
pub struct Client {
29-
req_tx: RequestSender,
33+
req_tx: MessageSender,
34+
next_stream_id: Arc<AtomicU32>,
35+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
3036
}
3137

3238
impl Client {
@@ -39,26 +45,40 @@ impl Client {
3945
pub fn new(fd: RawFd) -> Client {
4046
let stream = utils::new_unix_stream_from_raw_fd(fd);
4147

42-
let (req_tx, rx): (RequestSender, RequestReceiver) = mpsc::channel(100);
48+
let (req_tx, rx): (MessageSender, MessageReceiver) = mpsc::channel(100);
4349

44-
let delegate = ClientBuilder { rx: Some(rx) };
50+
let req_map = Arc::new(Mutex::new(HashMap::new()));
51+
let delegate = ClientBuilder {
52+
rx: Some(rx),
53+
streams: req_map.clone(),
54+
};
4555

4656
let conn = Connection::new(stream, delegate);
4757
tokio::spawn(async move { conn.run().await });
4858

49-
Client { req_tx }
59+
Client {
60+
req_tx,
61+
next_stream_id: Arc::new(AtomicU32::new(1)),
62+
streams: req_map,
63+
}
5064
}
5165

5266
/// Requsts a unary request and returns with response.
5367
pub async fn request(&self, req: Request) -> Result<Response> {
5468
let timeout_nano = req.timeout_nano;
55-
let msg: GenMessage = Message::new_request(0, req)
69+
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
70+
71+
let msg: GenMessage = Message::new_request(stream_id, req)
5672
.try_into()
5773
.map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?;
5874

5975
let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
76+
77+
// TODO: check return.
78+
self.streams.lock().unwrap().insert(stream_id, tx);
79+
6080
self.req_tx
61-
.send((msg, tx))
81+
.send(msg)
6282
.await
6383
.map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?;
6484

@@ -87,6 +107,44 @@ impl Client {
87107

88108
Ok(res)
89109
}
110+
111+
/// Creates a StreamInner instance.
112+
pub async fn new_stream(
113+
&self,
114+
req: Request,
115+
streaming_client: bool,
116+
streaming_server: bool,
117+
) -> Result<StreamInner> {
118+
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
119+
120+
let mut msg: GenMessage = Message::new_request(stream_id, req)
121+
.try_into()
122+
.map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?;
123+
124+
if streaming_client {
125+
msg.header.add_flags(FLAG_REMOTE_OPEN);
126+
} else {
127+
msg.header.add_flags(FLAG_REMOTE_CLOSED);
128+
}
129+
130+
let (tx, rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
131+
// TODO: check return
132+
self.streams.lock().unwrap().insert(stream_id, tx);
133+
self.req_tx
134+
.send(msg)
135+
.await
136+
.map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?;
137+
138+
Ok(StreamInner::new(
139+
stream_id,
140+
self.req_tx.clone(),
141+
rx,
142+
streaming_client,
143+
streaming_server,
144+
Kind::Client,
145+
self.streams.clone(),
146+
))
147+
}
90148
}
91149

92150
struct ClientClose {
@@ -104,7 +162,8 @@ impl Drop for ClientClose {
104162

105163
#[derive(Debug)]
106164
struct ClientBuilder {
107-
rx: Option<RequestReceiver>,
165+
rx: Option<MessageReceiver>,
166+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
108167
}
109168

110169
impl Builder for ClientBuilder {
@@ -113,52 +172,43 @@ impl Builder for ClientBuilder {
113172

114173
fn build(&mut self) -> (Self::Reader, Self::Writer) {
115174
let (notifier, waiter) = shutdown::new();
116-
let req_map = Arc::new(Mutex::new(HashMap::new()));
117175
(
118176
ClientReader {
119177
shutdown_waiter: waiter,
120-
req_map: req_map.clone(),
178+
streams: self.streams.clone(),
121179
},
122180
ClientWriter {
123-
stream_id: 1,
124181
rx: self.rx.take().unwrap(),
125182
shutdown_notifier: notifier,
126-
req_map,
183+
184+
streams: self.streams.clone(),
127185
},
128186
)
129187
}
130188
}
131189

132190
struct ClientWriter {
133-
stream_id: u32,
134-
rx: RequestReceiver,
191+
rx: MessageReceiver,
135192
shutdown_notifier: shutdown::Notifier,
136-
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
193+
194+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
137195
}
138196

139197
#[async_trait]
140198
impl WriterDelegate for ClientWriter {
141199
async fn recv(&mut self) -> Option<GenMessage> {
142-
if let Some((mut msg, resp_tx)) = self.rx.recv().await {
143-
let current_stream_id = self.stream_id;
144-
msg.header.set_stream_id(current_stream_id);
145-
self.stream_id += 2;
146-
{
147-
let mut map = self.req_map.lock().unwrap();
148-
map.insert(current_stream_id, resp_tx);
149-
}
150-
return Some(msg);
151-
} else {
152-
return None;
153-
}
200+
self.rx.recv().await
154201
}
155202

156203
async fn disconnect(&self, msg: &GenMessage, e: Error) {
204+
// TODO:
205+
// At this point, a new request may have been received.
157206
let resp_tx = {
158-
let mut map = self.req_map.lock().unwrap();
207+
let mut map = self.streams.lock().unwrap();
159208
map.remove(&msg.header.stream_id)
160209
};
161210

211+
// TODO: if None
162212
if let Some(resp_tx) = resp_tx {
163213
let e = Error::Socket(format!("{:?}", e));
164214
resp_tx
@@ -174,8 +224,8 @@ impl WriterDelegate for ClientWriter {
174224
}
175225

176226
struct ClientReader {
227+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
177228
shutdown_waiter: shutdown::Waiter,
178-
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
179229
}
180230

181231
#[async_trait]
@@ -191,8 +241,8 @@ impl ReaderDelegate for ClientReader {
191241
let _ = sender.await;
192242

193243
// Take all items out of `req_map`.
194-
let mut map = std::mem::take(&mut *self.req_map.lock().unwrap());
195-
// Terminate outstanding RPC requests with the error.
244+
let mut map = std::mem::take(&mut *self.streams.lock().unwrap());
245+
// Terminate undone RPC requests with the error.
196246
for (_stream_id, resp_tx) in map.drain() {
197247
if let Err(_e) = resp_tx.send(Err(e.clone())).await {
198248
warn!("Failed to terminate pending RPC: the request has returned");
@@ -203,35 +253,56 @@ impl ReaderDelegate for ClientReader {
203253
async fn exit(&self) {}
204254

205255
async fn handle_msg(&self, msg: GenMessage) {
206-
let req_map = self.req_map.clone();
256+
let req_map = self.streams.clone();
207257
tokio::spawn(async move {
208-
let resp_tx2;
209-
{
210-
let mut map = req_map.lock().unwrap();
211-
let resp_tx = match map.get(&msg.header.stream_id) {
212-
Some(tx) => tx,
213-
None => {
214-
debug!("Receiver got unknown packet {:?}", msg);
215-
return;
258+
let resp_tx = match msg.header.type_ {
259+
MESSAGE_TYPE_RESPONSE => {
260+
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
261+
Some(tx) => tx,
262+
None => {
263+
debug!("Receiver got unknown response packet {:?}", msg);
264+
return;
265+
}
216266
}
217-
};
218-
219-
resp_tx2 = resp_tx.clone();
220-
map.remove(&msg.header.stream_id); // Forget the result, just remove.
221-
}
222-
223-
if msg.header.type_ != MESSAGE_TYPE_RESPONSE {
224-
resp_tx2
225-
.send(Err(Error::Others(format!(
226-
"Recver got malformed packet {:?}",
227-
msg
228-
))))
229-
.await
230-
.unwrap_or_else(|_e| error!("The request has returned"));
231-
return;
232-
}
233-
234-
resp_tx2
267+
}
268+
MESSAGE_TYPE_DATA => {
269+
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
270+
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
271+
Some(tx) => tx.clone(),
272+
None => {
273+
debug!("Receiver got unknown data packet {:?}", msg);
274+
return;
275+
}
276+
}
277+
} else {
278+
match req_map.lock().unwrap().get(&msg.header.stream_id) {
279+
Some(tx) => tx.clone(),
280+
None => {
281+
debug!("Receiver got unknown data packet {:?}", msg);
282+
return;
283+
}
284+
}
285+
}
286+
}
287+
_ => {
288+
let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) {
289+
Some(tx) => tx,
290+
None => {
291+
debug!("Receiver got unknown packet {:?}", msg);
292+
return;
293+
}
294+
};
295+
resp_tx
296+
.send(Err(Error::Others(format!(
297+
"Recver got malformed packet {:?}",
298+
msg
299+
))))
300+
.await
301+
.unwrap_or_else(|_e| error!("The request has returned"));
302+
return;
303+
}
304+
};
305+
resp_tx
235306
.send(Ok(msg))
236307
.await
237308
.unwrap_or_else(|_e| error!("The request has returned"));

0 commit comments

Comments
 (0)