Skip to content

Commit 78ffcb8

Browse files
committed
async: add streaming support for server
Added streaming support for server-side. Signed-off-by: wanglei01 <wllenyj@linux.alibaba.com>
1 parent 1126e43 commit 78ffcb8

File tree

3 files changed

+182
-32
lines changed

3 files changed

+182
-32
lines changed

src/asynchronous/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ mod connection;
1515
pub mod shutdown;
1616
mod unix_incoming;
1717

18+
pub use self::stream::{Kind, StreamInner};
1819
#[doc(inline)]
1920
pub use crate::r#async::client::Client;
2021
#[doc(inline)]
21-
pub use crate::r#async::server::Server;
22+
pub use crate::r#async::server::{Server, Service};
2223
#[doc(inline)]
23-
pub use utils::{MethodHandler, TtrpcContext};
24+
pub use utils::{MethodHandler, StreamHandler, TtrpcContext};

src/asynchronous/server.rs

Lines changed: 169 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Copyright 2022 Alibaba Cloud. All rights reserved.
12
// Copyright (c) 2020 Ant Financial
23
//
34
// SPDX-License-Identifier: Apache-2.0
@@ -10,7 +11,7 @@ use std::os::unix::io::RawFd;
1011
use std::os::unix::io::{AsRawFd, FromRawFd};
1112
use std::os::unix::net::UnixListener as SysUnixListener;
1213
use std::result::Result as StdResult;
13-
use std::sync::Arc;
14+
use std::sync::{Arc, Mutex};
1415
use std::time::Duration;
1516

1617
use async_trait::async_trait;
@@ -34,22 +35,39 @@ use crate::common::{self, Domain};
3435
use crate::context;
3536
use crate::error::{get_status, Error, Result};
3637
use crate::proto::{
37-
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status,
38-
MESSAGE_TYPE_REQUEST,
38+
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA,
39+
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
3940
};
4041
use crate::r#async::connection::*;
4142
use crate::r#async::shutdown;
42-
use crate::r#async::stream::{MessageReceiver, MessageSender};
43+
use crate::r#async::stream::{
44+
Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner,
45+
};
4346
use crate::r#async::utils;
44-
use crate::r#async::{MethodHandler, TtrpcContext};
47+
use crate::r#async::{MethodHandler, StreamHandler, TtrpcContext};
4548

4649
const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000);
4750
const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000);
4851

52+
pub struct Service {
53+
pub methods: HashMap<String, Box<dyn MethodHandler + Send + Sync>>,
54+
pub streams: HashMap<String, Arc<dyn StreamHandler + Send + Sync>>,
55+
}
56+
57+
impl Service {
58+
pub(crate) fn get_method(&self, name: &str) -> Option<&(dyn MethodHandler + Send + Sync)> {
59+
self.methods.get(name).map(|b| b.as_ref())
60+
}
61+
62+
pub(crate) fn get_stream(&self, name: &str) -> Option<Arc<dyn StreamHandler + Send + Sync>> {
63+
self.streams.get(name).cloned()
64+
}
65+
}
66+
4967
/// A ttrpc Server (async).
5068
pub struct Server {
5169
listeners: Vec<RawFd>,
52-
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
70+
services: Arc<HashMap<String, Service>>,
5371
domain: Option<Domain>,
5472

5573
shutdown: shutdown::Notifier,
@@ -60,7 +78,7 @@ impl Default for Server {
6078
fn default() -> Self {
6179
Server {
6280
listeners: Vec::with_capacity(1),
63-
methods: Arc::new(HashMap::new()),
81+
services: Arc::new(HashMap::new()),
6482
domain: None,
6583
shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0,
6684
stop_listen_tx: None,
@@ -105,12 +123,9 @@ impl Server {
105123
Ok(self)
106124
}
107125

108-
pub fn register_service(
109-
mut self,
110-
methods: HashMap<String, Box<dyn MethodHandler + Send + Sync>>,
111-
) -> Server {
112-
let mut_methods = Arc::get_mut(&mut self.methods).unwrap();
113-
mut_methods.extend(methods);
126+
pub fn register_service(mut self, new: HashMap<String, Service>) -> Server {
127+
let services = Arc::get_mut(&mut self.services).unwrap();
128+
services.extend(new);
114129
self
115130
}
116131

@@ -158,7 +173,7 @@ impl Server {
158173
I: Stream<Item = std::io::Result<S>> + Unpin + Send + 'static + AsRawFd,
159174
S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static,
160175
{
161-
let methods = self.methods.clone();
176+
let services = self.services.clone();
162177

163178
let shutdown_waiter = self.shutdown.subscribe();
164179

@@ -172,13 +187,13 @@ impl Server {
172187
if let Some(conn) = conn {
173188
// Accept a new connection
174189
match conn {
175-
Ok(stream) => {
176-
let fd = stream.as_raw_fd();
190+
Ok(conn) => {
191+
let fd = conn.as_raw_fd();
177192
// spawn a connection handler, would not block
178193
spawn_connection_handler(
179194
fd,
180-
stream,
181-
methods.clone(),
195+
conn,
196+
services.clone(),
182197
shutdown_waiter.clone(),
183198
).await;
184199
}
@@ -244,14 +259,15 @@ impl Server {
244259
async fn spawn_connection_handler<C>(
245260
fd: RawFd,
246261
conn: C,
247-
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
262+
services: Arc<HashMap<String, Service>>,
248263
shutdown_waiter: shutdown::Waiter,
249264
) where
250265
C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static,
251266
{
252267
let delegate = ServerBuilder {
253268
fd,
254-
methods,
269+
services,
270+
streams: Arc::new(Mutex::new(HashMap::new())),
255271
shutdown_waiter,
256272
};
257273
let conn = Connection::new(conn, delegate);
@@ -279,7 +295,8 @@ impl AsRawFd for Server {
279295

280296
struct ServerBuilder {
281297
fd: RawFd,
282-
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
298+
services: Arc<HashMap<String, Service>>,
299+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
283300
shutdown_waiter: shutdown::Waiter,
284301
}
285302

@@ -296,7 +313,8 @@ impl Builder for ServerBuilder {
296313
ServerReader {
297314
fd: self.fd,
298315
tx,
299-
methods: self.methods.clone(),
316+
services: self.services.clone(),
317+
streams: self.streams.clone(),
300318
server_shutdown: self.shutdown_waiter.clone(),
301319
handler_shutdown: disconnect_notifier,
302320
},
@@ -321,7 +339,8 @@ impl WriterDelegate for ServerWriter {
321339
struct ServerReader {
322340
fd: RawFd,
323341
tx: MessageSender,
324-
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
342+
services: Arc<HashMap<String, Service>>,
343+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
325344
server_shutdown: shutdown::Waiter,
326345
handler_shutdown: shutdown::Notifier,
327346
}
@@ -366,7 +385,8 @@ impl ServerReader {
366385
HandlerContext {
367386
fd: self.fd,
368387
tx: self.tx.clone(),
369-
methods: self.methods.clone(),
388+
services: self.services.clone(),
389+
streams: self.streams.clone(),
370390
_handler_shutdown_waiter: self.handler_shutdown.subscribe(),
371391
}
372392
}
@@ -375,7 +395,8 @@ impl ServerReader {
375395
struct HandlerContext {
376396
fd: RawFd,
377397
tx: MessageSender,
378-
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
398+
services: Arc<HashMap<String, Service>>,
399+
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
379400
// Used for waiting handler exit.
380401
_handler_shutdown_waiter: shutdown::Waiter,
381402
}
@@ -406,11 +427,63 @@ impl HandlerContext {
406427
.ok();
407428
}
408429
None => {
409-
unimplemented!();
430+
let mut header = MessageHeader::new_data(stream_id, 0);
431+
header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA);
432+
let msg = GenMessage {
433+
header,
434+
payload: Vec::new(),
435+
};
436+
437+
self.tx
438+
.send(msg)
439+
.await
440+
.map_err(err_to_others_err!(e, "Send packet to sender error "))
441+
.ok();
410442
}
411443
},
412444
Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await,
413445
},
446+
MESSAGE_TYPE_DATA => {
447+
// TODO(wllenyj): Compatible with golang behavior.
448+
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED
449+
&& !msg.payload.is_empty()
450+
{
451+
Self::respond_with_status(
452+
self.tx.clone(),
453+
stream_id,
454+
get_status(
455+
Code::INVALID_ARGUMENT,
456+
format!(
457+
"Stream id {}: data close message connot include data",
458+
stream_id
459+
),
460+
),
461+
)
462+
.await;
463+
return;
464+
}
465+
let stream_tx = self.streams.lock().unwrap().get(&stream_id).cloned();
466+
if let Some(stream_tx) = stream_tx {
467+
if let Err(e) = stream_tx.send(Ok(msg)).await {
468+
Self::respond_with_status(
469+
self.tx.clone(),
470+
stream_id,
471+
get_status(
472+
Code::INVALID_ARGUMENT,
473+
format!("Stream id {}: handling data error: {}", stream_id, e),
474+
),
475+
)
476+
.await;
477+
}
478+
} else {
479+
Self::respond_with_status(
480+
self.tx.clone(),
481+
stream_id,
482+
get_status(Code::INVALID_ARGUMENT, "Stream is no longer active"),
483+
)
484+
.await;
485+
}
486+
}
414487
_ => {
415488
// TODO: else we must ignore this for future compat. log this?
416489
// TODO(wllenyj): Compatible with golang behavior.
@@ -432,12 +505,23 @@ impl HandlerContext {
432505
let req = &req_msg.payload;
433506
trace!("Got Message request {} {}", req.service, req.method);
434507

435-
let path = utils::get_path(&req.service, &req.method);
436-
let method = self.methods.get(&path).ok_or_else(|| {
437-
get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path))
508+
let srv = self.services.get(&req.service).ok_or_else(|| {
509+
get_status(
510+
Code::INVALID_ARGUMENT,
511+
format!("{} service does not exist", &req.service),
512+
)
438513
})?;
439514

440-
return self.handle_method(method.as_ref(), req_msg).await;
515+
if let Some(method) = srv.get_method(&req.method) {
516+
return self.handle_method(method, req_msg).await;
517+
}
518+
if let Some(stream) = srv.get_stream(&req.method) {
519+
return self.handle_stream(stream, req_msg).await;
520+
}
521+
Err(get_status(
522+
Code::UNIMPLEMENTED,
523+
format!("{} method", &req.method),
524+
))
441525
}
442526

443527
async fn handle_method(
@@ -484,6 +568,61 @@ impl HandlerContext {
484568
}
485569
}
486570

571+
async fn handle_stream(
572+
&self,
573+
stream: Arc<dyn StreamHandler + Send + Sync>,
574+
req_msg: Message<Request>,
575+
) -> StdResult<Option<Response>, Status> {
576+
let stream_id = req_msg.header.stream_id;
577+
let req = req_msg.payload;
578+
let path = utils::get_path(&req.service, &req.method);
579+
580+
let (tx, rx): (ResultSender, ResultReceiver) = channel(100);
581+
let stream_tx = tx.clone();
582+
self.streams.lock().unwrap().insert(stream_id, tx);
583+
584+
let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED;
585+
let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN;
586+
let si = StreamInner::new(
587+
stream_id,
588+
self.tx.clone(),
589+
rx,
590+
true, // TODO
591+
true,
592+
Kind::Server,
593+
self.streams.clone(),
594+
);
595+
596+
let ctx = TtrpcContext {
597+
fd: self.fd,
598+
mh: req_msg.header,
599+
metadata: context::from_pb(&req.metadata),
600+
timeout_nano: req.timeout_nano,
601+
};
602+
603+
let task = spawn(async move { stream.handler(ctx, si).await });
604+
605+
if !req.payload.is_empty() {
606+
// Fake the first data message.
607+
let msg = GenMessage {
608+
header: MessageHeader::new_data(stream_id, req.payload.len() as u32),
609+
payload: req.payload,
610+
};
611+
stream_tx.send(Ok(msg)).await.map_err(|e| {
612+
error!("send stream data {} got error {:?}", path, &e);
613+
get_status(Code::UNKNOWN, e)
614+
})?;
615+
}
616+
task.await
617+
.unwrap_or_else(|e| {
618+
Err(Error::Others(format!(
619+
"stream {} task got error {:?}",
620+
path, e
621+
)))
622+
})
623+
.map_err(|e| get_status(Code::UNKNOWN, e))
624+
}
625+
487626
async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> {
488627
let payload = resp
489628
.encode()

src/asynchronous/utils.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ pub trait MethodHandler {
8484
async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<Response>;
8585
}
8686

87+
/// Trait that implements handler which is a proxy to the stream (async).
88+
#[async_trait]
89+
pub trait StreamHandler {
90+
async fn handler(
91+
&self,
92+
ctx: TtrpcContext,
93+
stream: crate::r#async::StreamInner,
94+
) -> Result<Option<Response>>;
95+
}
96+
8797
/// The context of ttrpc (async).
8898
#[derive(Debug)]
8999
pub struct TtrpcContext {

0 commit comments

Comments
 (0)