Skip to content

fix: allow missing id and missing params #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "Simple, modern, ergonomic JSON-RPC 2.0 router built with tower an
keywords = ["json-rpc", "jsonrpc", "json"]
categories = ["web-programming::http-server", "web-programming::websocket"]

version = "0.1.1"
version = "0.1.2"
edition = "2021"
rust-version = "1.81"
authors = ["init4", "James Prestwich"]
Expand Down
9 changes: 6 additions & 3 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,12 @@ where
break;
};

let Ok(req) = Request::try_from(item) else {
tracing::warn!("inbound request is malformatted");
continue
let req = match Request::try_from(item) {
Ok(req) => req,
Err(err) => {
tracing::warn!(%err, "inbound request is malformatted");
continue
}
};

let span = debug_span!("ipc request handling", id = req.id(), method = req.method());
Expand Down
56 changes: 38 additions & 18 deletions src/types/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct Request {
///
/// This field is generated by deserializing to a [`RawValue`] and then
/// calculating the offset of the backing slice within the `bytes` field.
id: Range<usize>,
id: Option<Range<usize>>,
/// A range of the `bytes` field that represents the method field of the
/// JSON-RPC request.
///
Expand All @@ -49,7 +49,7 @@ pub struct Request {
///
/// This field is generated by deserializing to a [`RawValue`] and then
/// calculating the offset of the backing slice within the `bytes` field.
params: Range<usize>,
params: Option<Range<usize>>,
}

impl core::fmt::Debug for Request {
Expand All @@ -67,11 +67,11 @@ impl core::fmt::Debug for Request {
#[derive(serde::Deserialize)]
struct DeserHelper<'a> {
#[serde(borrow)]
id: &'a RawValue,
id: Option<&'a RawValue>,
#[serde(borrow)]
method: &'a RawValue,
#[serde(borrow)]
params: &'a RawValue,
params: Option<&'a RawValue>,
}

impl TryFrom<Bytes> for Request {
Expand All @@ -80,12 +80,19 @@ impl TryFrom<Bytes> for Request {
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
let DeserHelper { id, method, params } = serde_json::from_slice(bytes.as_ref())?;

let id = find_range!(bytes, id.get());
// Ensure the id is not too long
let id_len = id.end - id.start;
if id_len > ID_LEN_LIMIT {
return Err(RequestError::IdTooLarge(id_len));
}
let id = if let Some(id) = id {
let id = find_range!(bytes, id.get());

// Ensure the id is not too long
let id_len = id.end - id.start;
if id_len > ID_LEN_LIMIT {
return Err(RequestError::IdTooLarge(id_len));
}

Some(id)
} else {
None
};

// Ensure method is a string, and not too long, and trim the quotes
// from it
Expand All @@ -101,7 +108,7 @@ impl TryFrom<Bytes> for Request {
return Err(RequestError::MethodTooLarge(method_len));
}

let params = find_range!(bytes, params.get());
let params = params.map(|params| find_range!(bytes, params.get()));

Ok(Self {
bytes,
Expand All @@ -122,11 +129,20 @@ impl TryFrom<tokio_tungstenite::tungstenite::Utf8Bytes> for Request {
}

impl Request {
/// Return a reference to the serialized ID field.
/// Return a reference to the serialized ID field. If the ID field is
/// missing, this will return `"null"`, ensuring that response correctly
/// have a null ID, as per [the JSON-RPC spec].
///
/// [the JSON-RPC spec]: https://www.jsonrpc.org/specification#response_object
pub fn id(&self) -> &str {
// SAFETY: `id` is guaranteed to be valid JSON,
// and a valid slice of `bytes`.
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.id.clone())) }
self.id
.as_ref()
.map(|range| {
// SAFETY: `range` is guaranteed to be valid JSON, and a valid
// slice of `bytes`.
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(range.clone())) }
})
.unwrap_or("null")
}

/// Return an owned version of the serialized ID field.
Expand Down Expand Up @@ -161,9 +177,13 @@ impl Request {

/// Return a reference to the serialized params field.
pub fn params(&self) -> &str {
// SAFETY: `params` is guaranteed to be valid JSON,
// and a valid slice of `bytes`.
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.params.clone())) }
if let Some(range) = &self.params {
// SAFETY: `range` is guaranteed to be valid JSON, and a valid
// slice of `bytes`.
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(range.clone())) }
} else {
"null"
}
}

/// Deserialize the params field into a type.
Expand Down
69 changes: 58 additions & 11 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,63 @@ pub fn test_router() -> ajj::Router<()> {

/// Test clients
pub trait TestClient {
async fn send<S: serde::Serialize>(&mut self, method: &str, params: &S);
fn next_id(&mut self) -> usize;

async fn send_raw<S: serde::Serialize>(&mut self, msg: &S);

async fn recv<D: serde::de::DeserializeOwned>(&mut self) -> D;

async fn send<S: serde::Serialize>(&mut self, method: &str, params: &S) -> usize {
let id = self.next_id();
self.send_raw(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}))
.await;
id
}
}

/// basic tests of the test router
pub async fn basic_tests<T: TestClient>(mut client: T) {
client.send("ping", &()).await;
test_ping(&mut client).await;

let next: Value = client.recv().await;
assert_eq!(
next,
serde_json::json!({"id": 0, "jsonrpc": "2.0", "result": "pong"})
);
test_double(&mut client).await;

test_notify(&mut client).await;

test_missing_id(&mut client).await;
}

async fn test_missing_id<T: TestClient>(client: &mut T) {
client
.send_raw(&serde_json::json!(
{"jsonrpc": "2.0", "method": "ping"}
))
.await;

client.send("double", &5).await;
let next: Value = client.recv().await;
assert_eq!(
next,
serde_json::json!({"id": 1, "jsonrpc": "2.0", "result": 10})
serde_json::json!({
"jsonrpc": "2.0",
"result": "pong",
"id": null,
})
);
}

client.send("notify", &()).await;
async fn test_notify<T: TestClient>(client: &mut T) {
let id = client.send("notify", &()).await;

let now = std::time::Instant::now();

let next: Value = client.recv().await;
assert_eq!(
next,
serde_json::json!({"id": 2, "jsonrpc": "2.0", "result": null})
serde_json::json!({"id": id, "jsonrpc": "2.0", "result": null})
);

let next: Value = client.recv().await;
Expand All @@ -66,3 +94,22 @@ pub async fn basic_tests<T: TestClient>(mut client: T) {
serde_json::json!({"method": "notify", "result": "notified"})
);
}

async fn test_double<T: TestClient>(client: &mut T) {
let id = client.send("double", &5).await;
let next: Value = client.recv().await;
assert_eq!(
next,
serde_json::json!({"id": id, "jsonrpc": "2.0", "result": 10})
);
}

async fn test_ping<T: TestClient>(client: &mut T) {
let id = client.send("ping", &()).await;

let next: Value = client.recv().await;
assert_eq!(
next,
serde_json::json!({"id": id, "jsonrpc": "2.0", "result": "pong"})
);
}
15 changes: 4 additions & 11 deletions tests/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,17 @@ impl IpcClient {
async fn recv_inner(&mut self) -> serde_json::Value {
self.recv_half.next().await.unwrap()
}
}

impl TestClient for IpcClient {
fn next_id(&mut self) -> usize {
let id = self.id;
self.id += 1;
id
}
}

impl TestClient for IpcClient {
async fn send<S: serde::Serialize>(&mut self, method: &str, params: &S) {
let id = self.next_id();
self.send_inner(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}))
.await;
async fn send_raw<S: serde::Serialize>(&mut self, msg: &S) {
self.send_inner(msg).await;
}

async fn recv<D: serde::de::DeserializeOwned>(&mut self) -> D {
Expand Down
19 changes: 6 additions & 13 deletions tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn serve_ws() -> ServerShutdown {

struct WsClient {
socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
id: u64,
id: usize,
}

impl WsClient {
Expand All @@ -37,24 +37,17 @@ impl WsClient {
_ => panic!("unexpected message type"),
}
}
}

fn next_id(&mut self) -> u64 {
impl TestClient for WsClient {
fn next_id(&mut self) -> usize {
let id = self.id;
self.id += 1;
id
}
}

impl TestClient for WsClient {
async fn send<S: serde::Serialize>(&mut self, method: &str, params: &S) {
let id = self.next_id();
self.send_inner(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}))
.await;
async fn send_raw<S: serde::Serialize>(&mut self, msg: &S) {
self.send_inner(msg).await;
}

async fn recv<D: serde::de::DeserializeOwned>(&mut self) -> D {
Expand Down
Loading