diff --git a/Cargo.toml b/Cargo.toml index bbcd7aa..2f097f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ description = "Simple, modern, ergonomic JSON-RPC 2.0 router built with tower an keywords = ["json-rpc", "jsonrpc", "json"] categories = ["web-programming::http-server", "web-programming::websocket"] -version = "0.1.0" +version = "0.1.1" edition = "2021" rust-version = "1.81" authors = ["init4", "James Prestwich"] @@ -66,3 +66,7 @@ inherits = "dev" strip = true debug = false incremental = false + +[dev-dependencies] +tempfile = "3.15.0" +tracing-subscriber = "0.3.19" diff --git a/src/lib.rs b/src/lib.rs index 36b516b..c1ef79f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,6 +139,9 @@ pub use primitives::{BorrowedRpcObject, MethodId, RpcBorrow, RpcObject, RpcRecv, #[cfg(feature = "pubsub")] pub mod pubsub; +#[doc(hidden)] // for tests +#[cfg(feature = "ipc")] +pub use pubsub::ReadJsonStream; mod routes; pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route}; diff --git a/src/pubsub/ipc.rs b/src/pubsub/ipc.rs index 525195d..c21cd25 100644 --- a/src/pubsub/ipc.rs +++ b/src/pubsub/ipc.rs @@ -98,7 +98,8 @@ const CAPACITY: usize = 4096; /// A stream of JSON-RPC items, read from an [`AsyncRead`] stream. #[derive(Debug)] #[pin_project::pin_project] -pub(crate) struct ReadJsonStream { +#[doc(hidden)] +pub struct ReadJsonStream { /// The underlying reader. #[pin] reader: T, diff --git a/src/pubsub/mod.rs b/src/pubsub/mod.rs index 24873a8..febe8fb 100644 --- a/src/pubsub/mod.rs +++ b/src/pubsub/mod.rs @@ -90,6 +90,9 @@ #[cfg(feature = "ipc")] mod ipc; +#[cfg(feature = "ipc")] +#[doc(hidden)] +pub use ipc::ReadJsonStream; mod shared; pub use shared::{ConnectionId, ServerShutdown, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT}; diff --git a/src/routes/ctx.rs b/src/routes/ctx.rs index f763f25..88cff3c 100644 --- a/src/routes/ctx.rs +++ b/src/routes/ctx.rs @@ -68,8 +68,7 @@ impl HandlerCtx { /// Notify a client of an event. pub async fn notify(&self, t: &T) -> Result<(), NotifyError> { if let Some(notifications) = self.notifications.as_ref() { - let ser = serde_json::to_string(t)?; - let rv = serde_json::value::to_raw_value(&ser)?; + let rv = serde_json::value::to_raw_value(t)?; notifications.send(rv).await?; } diff --git a/src/types/req.rs b/src/types/req.rs index 5f53c89..6c26a43 100644 --- a/src/types/req.rs +++ b/src/types/req.rs @@ -134,13 +134,31 @@ impl Request { RawValue::from_string(self.id().to_string()).expect("valid json") } - /// Return a reference to the serialized method field. + /// Return a reference to the method str, deserialized. + /// + /// This is the method without the preceding and trailing quotes. E.g. if + /// the method is `foo`, this will return `&"foo"`. pub fn method(&self) -> &str { - // SAFETY: `method` is guaranteed to be valid JSON, + // SAFETY: `method` is guaranteed to be valid UTF-8, // and a valid slice of `bytes`. unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.method.clone())) } } + /// Return a reference to the raw method str, with preceding and trailing + /// quotes. This is effectively the method as a [`RawValue`]. + /// + /// E.g. if the method is `foo`, this will return `&r#""foo""#`. + pub fn raw_method(&self) -> &str { + // SAFETY: `params` is guaranteed to be valid JSON, + // and a valid slice of `bytes`. + unsafe { + core::str::from_utf8_unchecked( + self.bytes + .get_unchecked(self.method.start - 1..self.method.end + 1), + ) + } + } + /// Return a reference to the serialized params field. pub fn params(&self) -> &str { // SAFETY: `params` is guaranteed to be valid JSON, diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..18e12cf --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,68 @@ +use ajj::{HandlerCtx, Router}; +use serde_json::Value; +use tokio::time; + +/// Instantiate a router for testing. +pub fn test_router() -> ajj::Router<()> { + Router::<()>::new() + .route("ping", || async move { Ok::<_, ()>("pong") }) + .route( + "double", + |params: usize| async move { Ok::<_, ()>(params * 2) }, + ) + .route("notify", |ctx: HandlerCtx| async move { + tokio::task::spawn(async move { + time::sleep(time::Duration::from_millis(100)).await; + + let _ = ctx + .notify(&serde_json::json!({ + "method": "notify", + "result": "notified" + })) + .await; + }); + + Ok::<_, ()>(()) + }) +} + +/// Test clients +pub trait TestClient { + async fn send(&mut self, method: &str, params: &S); + async fn recv(&mut self) -> D; +} + +/// basic tests of the test router +pub async fn basic_tests(mut client: T) { + client.send("ping", &()).await; + + let next: Value = client.recv().await; + assert_eq!( + next, + serde_json::json!({"id": 0, "jsonrpc": "2.0", "result": "pong"}) + ); + + client.send("double", &5).await; + let next: Value = client.recv().await; + assert_eq!( + next, + serde_json::json!({"id": 1, "jsonrpc": "2.0", "result": 10}) + ); + + client.send("notify", &()).await; + + let now = std::time::Instant::now(); + + let next: Value = client.recv().await; + assert_eq!( + next, + serde_json::json!({"id": 2, "jsonrpc": "2.0", "result": null}) + ); + + let next: Value = client.recv().await; + assert!(now.elapsed().as_millis() >= 100); + assert_eq!( + next, + serde_json::json!({"method": "notify", "result": "notified"}) + ); +} diff --git a/tests/ipc.rs b/tests/ipc.rs new file mode 100644 index 0000000..6acb228 --- /dev/null +++ b/tests/ipc.rs @@ -0,0 +1,98 @@ +mod common; +use common::{test_router, TestClient}; + +use ajj::pubsub::{Connect, ReadJsonStream, ServerShutdown}; +use futures_util::StreamExt; +use interprocess::local_socket::{ + self as ls, + tokio::{prelude::LocalSocketStream, RecvHalf, SendHalf}, + traits::tokio::Stream, + ListenerOptions, +}; +use serde_json::Value; +use tempfile::{NamedTempFile, TempPath}; +use tokio::io::AsyncWriteExt; + +pub(crate) fn to_name(path: &std::ffi::OsStr) -> std::io::Result> { + if cfg!(windows) && !path.as_encoded_bytes().starts_with(br"\\.\pipe\") { + ls::ToNsName::to_ns_name::(path) + } else { + ls::ToFsName::to_fs_name::(path) + } +} + +async fn serve_ipc() -> (ServerShutdown, TempPath) { + let router = test_router(); + + let temp = NamedTempFile::new().unwrap().into_temp_path(); + let name = to_name(temp.as_os_str()).unwrap(); + + dbg!(&name); + dbg!(std::fs::remove_file(&temp).unwrap()); + + let shutdown = ListenerOptions::new() + .name(name) + .serve(router) + .await + .unwrap(); + (shutdown, temp) +} + +struct IpcClient { + recv_half: ReadJsonStream, + send_half: SendHalf, + id: usize, +} + +impl IpcClient { + async fn new(temp: &TempPath) -> Self { + let name = to_name(temp.as_os_str()).unwrap(); + let (recv_half, send_half) = LocalSocketStream::connect(name).await.unwrap().split(); + Self { + recv_half: recv_half.into(), + send_half, + id: 0, + } + } + + async fn send_inner(&mut self, msg: &S) { + let s = serde_json::to_string(msg).unwrap(); + + self.send_half.write_all(s.as_bytes()).await.unwrap(); + } + + async fn recv_inner(&mut self) -> serde_json::Value { + self.recv_half.next().await.unwrap() + } + + fn next_id(&mut self) -> usize { + let id = self.id; + self.id += 1; + id + } +} + +impl TestClient for IpcClient { + async fn send(&mut self, method: &str, params: &S) { + let id = self.next_id(); + self.send_inner(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + })) + .await; + } + + async fn recv(&mut self) -> D { + serde_json::from_value(self.recv_inner().await).unwrap() + } +} + +#[tokio::test] +async fn basic_ipc() { + let (_server, temp) = serve_ipc().await; + let client = IpcClient::new(&temp).await; + + common::basic_tests(client).await; +} diff --git a/tests/ws.rs b/tests/ws.rs new file mode 100644 index 0000000..83bac25 --- /dev/null +++ b/tests/ws.rs @@ -0,0 +1,77 @@ +mod common; +use common::{test_router, TestClient}; + +use ajj::pubsub::{Connect, ServerShutdown}; +use futures_util::{SinkExt, StreamExt}; +use std::net::{Ipv4Addr, SocketAddr}; +use tokio_tungstenite::{ + tungstenite::{client::IntoClientRequest, Message}, + MaybeTlsStream, WebSocketStream, +}; + +const WS_SOCKET: SocketAddr = + SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3383); +const WS_SOCKET_STR: &str = "ws://127.0.0.1:3383"; + +async fn serve_ws() -> ServerShutdown { + let router = test_router(); + WS_SOCKET.serve(router).await.unwrap() +} + +struct WsClient { + socket: WebSocketStream>, + id: u64, +} + +impl WsClient { + async fn send_inner(&mut self, msg: &S) { + self.socket + .send(Message::Text(serde_json::to_string(msg).unwrap().into())) + .await + .unwrap(); + } + + async fn recv_inner(&mut self) -> D { + match self.socket.next().await.unwrap().unwrap() { + Message::Text(text) => serde_json::from_str(&text).unwrap(), + _ => panic!("unexpected message type"), + } + } + + fn next_id(&mut self) -> u64 { + let id = self.id; + self.id += 1; + id + } +} + +impl TestClient for WsClient { + async fn send(&mut self, method: &str, params: &S) { + let id = self.next_id(); + self.send_inner(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + })) + .await; + } + + async fn recv(&mut self) -> D { + self.recv_inner().await + } +} + +async fn ws_client() -> WsClient { + let request = WS_SOCKET_STR.into_client_request().unwrap(); + let (socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + + WsClient { socket, id: 0 } +} + +#[tokio::test] +async fn basic_ws() { + let _server = serve_ws().await; + let client = ws_client().await; + common::basic_tests(client).await; +}