Skip to content

Commit 2bad764

Browse files
try prost
ref: #173 Co-authored-by: Xuewei Niu <niuxuewei.nxw@antgroup.com> Signed-off-by: jokemanfire <hu.dingyang@zte.com.cn>
1 parent 56a5a0a commit 2bad764

File tree

10 files changed

+115
-43
lines changed

10 files changed

+115
-43
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ default = ["sync"]
4141
async = ["dep:async-trait", "dep:tokio", "dep:futures", "dep:tokio-vsock"]
4242
sync = []
4343
prost = ["dep:prost", "dep:prost-build"]
44+
rustprotobuf = []
4445

4546
[package.metadata.docs.rs]
4647
all-features = true

src/asynchronous/client.rs

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,34 @@ impl Client {
6868
pub async fn request(&self, req: Request) -> Result<Response> {
6969
let timeout_nano = req.timeout_nano;
7070
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
71-
72-
let msg: GenMessage = Message::new_request(stream_id, req)?
73-
.try_into()
74-
.map_err(|err: std::io::Error| Error::Others(err.to_string()))?;
71+
let msg: GenMessage;
72+
#[cfg(not(feature = "prost"))]
73+
{
74+
msg = Message::new_request(stream_id, req)?
75+
.try_into()
76+
.map_err(|err: protobuf::Error| Error::Others(err.to_string()))?;
77+
}
78+
79+
#[cfg(feature = "prost")]
80+
{
81+
msg = Message::new_request(stream_id, req)?
82+
.try_into()
83+
.map_err(|err: std::io::Error| Error::Others(err.to_string()))?;
84+
}
7585

7686
let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
77-
87+
7888
self.streams
7989
.lock()
8090
.map_err(|_| Error::Others("Failed to acquire lock on streams".to_string()))?
8191
.insert(stream_id, tx);
82-
92+
8393
self.req_tx
8494
.send(SendingMessage::new(msg))
8595
.await
8696
.map_err(|_| Error::LocalClosed)?;
87-
97+
98+
#[allow(clippy::unnecessary_lazy_evaluations)]
8899
let result = if timeout_nano == 0 {
89100
rx.recv()
90101
.await
@@ -134,20 +145,28 @@ impl Client {
134145
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
135146
let is_req_payload_empty = req.payload.is_empty();
136147

148+
#[cfg(not(feature = "prost"))]
137149
let mut msg: GenMessage = Message::new_request(stream_id, req)?
138150
.try_into()
139-
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;
151+
.map_err(|err: protobuf::Error| Error::Others(err.to_string()))?;
152+
140153
#[cfg(feature = "prost")]
141-
let mut msg: GenMessage = Message::new_request(stream_id, req)
154+
let mut msg: GenMessage = Message::new_request(stream_id, req)?
142155
.try_into()
143156
.map_err(|err: std::io::Error| Error::Others(err.to_string()))?;
144157

145158
if streaming_client {
146159
if !is_req_payload_empty {
160+
#[cfg(not(feature = "prost"))]
147161
return Err(get_rpc_status(
148162
Code::INVALID_ARGUMENT,
149163
"Creating a ClientStream and sending payload at the same time is not allowed",
150164
));
165+
#[cfg(feature = "prost")]
166+
return Err(get_rpc_status(
167+
Code::Unknown,
168+
"Creating a ClientStream and sending payload at the same time is not allowed",
169+
));
151170
}
152171
msg.header.add_flags(FLAG_REMOTE_OPEN | FLAG_NO_DATA);
153172
} else {

src/asynchronous/server.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use async_trait::async_trait;
1818
use futures::stream::Stream;
1919
use futures::StreamExt as _;
2020
use nix::unistd;
21+
#[cfg(not(feature = "prost"))]
2122
use protobuf::Message as _;
2223
use tokio::{
2324
self,
@@ -455,9 +456,14 @@ impl HandlerContext {
455456
Ok(opt_msg) => match opt_msg {
456457
Some(mut resp) => {
457458
// Server: check size before sending to client
459+
#[cfg(not(feature = "prost"))]
458460
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
459461
resp = e.into();
460462
}
463+
#[cfg(feature = "prost")]
464+
if let Err(e) = check_oversize(resp.size() as usize, true) {
465+
resp = e.into();
466+
}
461467

462468
Self::respond(self.tx.clone(), stream_id, resp)
463469
.await
@@ -744,20 +750,21 @@ impl HandlerContext {
744750
get_status(Code::Unknown, e)
745751
})?;
746752
}
747-
#[cfg(not(feature = "prost"))]
748-
task.await
749-
.unwrap_or_else(|e| Err(Error::Others(format!("stream {path} task got error {e:?}"))))
750-
.map_err(|e| get_status(Code::UNKNOWN, e))
751753

752754
#[cfg(feature = "prost")]
753-
task.await
754-
.unwrap_or_else(|e| {
755-
Err(Error::Others(format!(
756-
"stream {} task got error {:?}",
757-
path, e
758-
)))
759-
})
760-
.map_err(|e| get_status(Code::Unknown, e))
755+
return task.await
756+
.unwrap_or_else(|e| {
757+
Err(Error::Others(format!(
758+
"stream {} task got error {:?}",
759+
path, e
760+
)))
761+
})
762+
.map_err(|e| get_status(Code::Unknown, e));
763+
764+
#[cfg(not(feature = "prost"))]
765+
return task.await
766+
.unwrap_or_else(|e| Err(Error::Others(format!("stream {path} task got error {e:?}"))))
767+
.map_err(|e| get_status(Code::UNKNOWN, e));
761768
}
762769

763770
async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> {

src/error.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
// limitations under the License.
1414

1515
//! Error and Result of ttrpc and relevant functions, macros.
16-
17-
use crate::proto::{Code, Response, Status};
16+
#[allow(unused_imports)]
17+
use crate::proto::{self, Code, Response, Status};
1818
use std::result;
1919
use thiserror::Error;
2020

@@ -53,12 +53,30 @@ impl From<Error> for Response {
5353
let status = if let Error::RpcStatus(stat) = e {
5454
stat
5555
} else {
56-
get_status(Code::UNKNOWN, e)
56+
#[cfg(not(feature = "prost"))]
57+
{
58+
get_status(Code::UNKNOWN, e)
59+
}
60+
61+
#[cfg(feature = "prost")]
62+
{
63+
get_status(Code::Unknown, e)
64+
}
5765
};
5866

59-
let mut res = Response::new();
60-
res.set_status(status);
61-
res
67+
#[cfg(not(feature = "prost"))]
68+
{
69+
let mut res = Response::new();
70+
res.set_status(status);
71+
res
72+
}
73+
#[cfg(feature = "prost")]
74+
{
75+
Response {
76+
status: Some(status),
77+
..Default::default()
78+
}
79+
}
6280
}
6381
}
6482

src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ pub use crate::error::{get_status, Error, Result};
6262

6363
cfg_sync! {
6464
pub mod sync;
65-
#[doc(hidden)]
66-
pub use sync::response_to_channel;
6765
#[doc(inline)]
6866
pub use sync::{MethodHandler, TtrpcContext};
6967
pub use sync::Client;

src/proto.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@ pub(crate) fn check_oversize(len: usize, return_rpc_error: bool) -> TtResult<()>
3535
len, MESSAGE_LENGTH_MAX
3636
);
3737
let e = if return_rpc_error {
38-
get_rpc_status(Code::INVALID_ARGUMENT, msg)
38+
#[cfg(not(feature = "prost"))]
39+
{
40+
get_rpc_status(Code::INVALID_ARGUMENT, msg)
41+
}
42+
#[cfg(feature = "prost")]
43+
{
44+
get_rpc_status(Code::Unknown, msg)
45+
}
3946
} else {
4047
Error::Others(msg)
4148
};
@@ -517,6 +524,7 @@ mod tests {
517524
}
518525

519526
#[cfg(feature = "async")]
527+
#[cfg(not(feature = "prost"))]
520528
#[tokio::test]
521529
async fn async_gen_message() {
522530
// Test packet which exceeds maximum message size
@@ -557,6 +565,7 @@ mod tests {
557565
}
558566

559567
#[cfg(feature = "async")]
568+
#[cfg(not(feature = "prost"))]
560569
#[tokio::test]
561570
async fn async_message() {
562571
// Test packet which exceeds maximum message size

src/sync/client.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#[cfg(unix)]
1818
use std::os::unix::io::RawFd;
1919

20-
use protobuf::Message;
2120
use std::collections::HashMap;
2221
use std::sync::mpsc;
2322
use std::sync::{Arc, Mutex};
@@ -157,7 +156,11 @@ impl Client {
157156
})
158157
}
159158
pub fn request(&self, req: Request) -> Result<Response> {
160-
check_oversize(req.compute_size() as usize, false)?;
159+
#[cfg(feature = "prost")]
160+
check_oversize(req.payload.len(), false)?;
161+
162+
#[cfg(not(feature = "prost"))]
163+
check_oversize(req.payload.len(), false)?;
161164

162165
let buf = req.encode().map_err(err_to_others_err!(e, ""))?;
163166
// Notice: pure client problem can't be rpc error

src/sync/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ pub use client::Client;
1717
pub use server::Server;
1818

1919
#[doc(hidden)]
20-
pub use utils::response_to_channel;
2120
pub use utils::{MethodHandler, TtrpcContext};

src/sync/server.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
2020
use std::time::Duration;
2121

22-
use nix::sys::socket::{self, *};
23-
use nix::unistd::*;
2422
#[cfg(feature = "prost")]
2523
use prost::Message;
2624
#[cfg(not(feature = "prost"))]
@@ -32,8 +30,9 @@ use std::sync::{Arc, Mutex};
3230
use std::thread;
3331
use std::thread::JoinHandle;
3432

35-
use super::utils::{response_error_to_channel, response_to_channel};
36-
use crate::context;
33+
use super::utils::response_error_to_channel;
34+
use crate::sync::utils::response_to_channel;
35+
use crate::{context, Status};
3736
use crate::error::{get_status, Error, Result};
3837
use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST};
3938
use crate::sync::channel::{read_message, write_message};
@@ -167,7 +166,8 @@ fn start_method_handler_thread(
167166
if mh.type_ != MESSAGE_TYPE_REQUEST {
168167
continue;
169168
}
170-
let mut req;
169+
#[allow(unused_mut)]
170+
let mut req: Request = Request::default();
171171
#[cfg(not(feature = "prost"))]
172172
{
173173
let mut s = CodedInputStream::from_bytes(&buf);
@@ -186,7 +186,6 @@ fn start_method_handler_thread(
186186
}
187187
#[cfg(feature = "prost")]
188188
{
189-
req = Request::default();
190189
if let Err(x) = req.merge(&buf as &[u8]) {
191190
let status = get_status(Code::InvalidArgument, x.to_string());
192191
let res = Response {

src/sync/utils.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
//
55

66
use crate::error::{Error, Result};
7+
#[allow(unused_imports)]
78
use crate::proto::{
89
check_oversize, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE,
910
};
10-
#[cfg(feature = "prost")]
11-
use prost::Message;
12-
#[cfg(not(feature = "prost"))]
13-
use protobuf::Message;
11+
1412
use std::collections::HashMap;
1513

1614
#[cfg(not(feature = "prost"))]
@@ -40,6 +38,27 @@ pub fn response_to_channel(
4038
Ok(())
4139
}
4240

41+
#[cfg(feature = "prost")]
42+
pub fn response_to_channel(
43+
stream_id: u32,
44+
res: Response,
45+
tx: std::sync::mpsc::Sender<(MessageHeader, Vec<u8>)>,
46+
) -> Result<()> {
47+
let mut buffer = Vec::new();
48+
<Response as prost::Message>::encode(&res, &mut buffer).map_err(err_to_others_err!(e, ""))?;
49+
let mh = MessageHeader {
50+
length: buffer.len() as u32,
51+
stream_id,
52+
type_: MESSAGE_TYPE_RESPONSE,
53+
flags: 0,
54+
};
55+
56+
tx.send((mh, buffer)).map_err(err_to_others_err!(e, ""))?;
57+
58+
Ok(())
59+
}
60+
61+
4362
pub fn response_error_to_channel(
4463
stream_id: u32,
4564
e: Error,

0 commit comments

Comments
 (0)