diff --git a/.editorconfig b/.editorconfig index 28b112186..5e5b724ee 100644 --- a/.editorconfig +++ b/.editorconfig @@ -9,8 +9,16 @@ trim_trailing_whitespace=true max_line_length=120 insert_final_newline=true -[.travis.yml] +[{.travis.yml,appveyor.yml}] indent_style=space indent_size=2 tab_width=8 end_of_line=lf + +[*.stderr] +indent_style=none +indent_size=none +end_of_line=none +charset=none +trim_trailing_whitespace=none +insert_final_newline=none diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 000000000..7cc7c9725 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,65 @@ +stages: + - checkstyle + - test +variables: &default-vars + GIT_STRATEGY: fetch + GIT_DEPTH: 100 + CARGO_INCREMENTAL: 0 + +.test_and_build: &test_and_build + script: + - cargo build --all + - cargo test --all + +.only: &only + only: + - triggers + - tags + - master + - schedules + - web + - /^[0-9]+$/ + +.docker-env: &docker-env + image: paritytech/ci-linux:production + before_script: + - rustup show + - cargo --version + - sccache -s + variables: + <<: *default-vars + CARGO_TARGET_DIR: "/ci-cache/${CI_PROJECT_NAME}/targets/${CI_COMMIT_REF_NAME}/${CI_JOB_NAME}" + retry: + max: 2 + when: + - runner_system_failure + - unknown_failure + - api_failure + interruptible: true + tags: + - linux-docker + +# check style +checkstyle-linux-stable: + stage: checkstyle + <<: *only + <<: *docker-env + script: + - rustup component add rustfmt clippy + - cargo fmt --all -- --check + - cargo clippy + allow_failure: true + +# test rust stable +test-linux-stable: + stage: test + <<: *docker-env + <<: *only + <<: *test_and_build + +test-mac-stable: + stage: test + <<: *test_and_build + <<: *only + tags: + - osx diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 061da8d0e..000000000 --- a/.travis.yml +++ /dev/null @@ -1,43 +0,0 @@ -sudo: false -language: rust -branches: - only: - - master - - /^parity-.*$/ - -cache: cargo - -matrix: - fast_finish: false - include: - - os: linux - rust: stable - - os: linux - rust: beta - - os: linux - rust: nightly - - os: osx - rust: stable - - os: windows - rust: stable - allow_failures: - - rust: nightly - -script: - - cargo build --all - - cargo test --all - -after_success: | - [ $TRAVIS_OS_NAME == 'linux' ] && - [ $TRAVIS_BRANCH = master ] && - [ $TRAVIS_PULL_REQUEST = false ] && - [ $TRAVIS_RUST_VERSION = stable ] && - cargo doc --all --no-deps && - echo '' > target/doc/index.html && - pip install --user ghp-import && - /home/travis/.local/bin/ghp-import -n target/doc && - git push -fq https://${GH_TOKEN}@github.com/${TRAVIS_REPO_SLUG}.git gh-pages - -env: - global: - - secure: "QA4Rw78VSsP1vH2Yve1eAVpjG32HH9DZZ79xrhxZXp34wKoemp+aGqaFN/8uXPfsXshlYxtMCTT6M9OiWTTLvku5tI5kBsDneu8mLut7eBZHVniYSp2SbKpTeqfpGMDHoCR0WD9AlWDn9Elm6txbghXjrxhCMg8gkhhsLGnQt/ARFF1wRHnXT0TjJg8fQtd+/OK0TaRfknx1RptruaznxfUi3DBwzDdzaMMZfd3VjWR1hPFRpDSL0mM+l6OjNrLbCeiR//k3lV4rpIhedsz0ODjfW2Hdk63qCaLJsXCkG1Bcuf/FYbYC+osm5SrHhGA1j2EgazWcLA6Wkzt15KPOR/HirNj+PCiS0YbGKM5Ac5LT6m6q0iYSF/pq1+jDurcSwBwYrTOY6X2FZCZQBfTP/4qnSjWgGPOkzBSMS6BNEBDQZgdc3xCASXadj7waF4Y4UGD0bDPuBtXopI4ppKLqSa7CsvKz6TX2yW0UVgUuQ5/jz/S+fkcz74o016d5x027yjaxAu/Z8fQFLSaBtiFU8sBzA+MDU3apFgjsYXiaGYZ8gDrp7WjbfHNYfBAMEHHKY4toywB5Vi8zJxF+Wn1n4hkvb/kDqSV9giFmWEg321U+pAGNAH4yY25tIJqS8gT89cz4oQJp7aWjA3Ke01e104yqqZU+N+CSyZHEeksdPt8=" diff --git a/Cargo.toml b/Cargo.toml index 4e73179d7..b509827d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,11 @@ [workspace] members = [ "core", + "core-client", + "core-client/transports", "http", "ipc", - "macros", - "minihttp", + "derive", "pubsub", "pubsub/more-examples", "server-utils", diff --git a/README.md b/README.md index 72cc82648..45a018116 100644 --- a/README.md +++ b/README.md @@ -3,33 +3,38 @@ Rust implementation of JSON-RPC 2.0 Specification. Transport-agnostic `core` and transport servers for `http`, `ipc`, `websockets` and `tcp`. -[![Build Status][travis-image]][travis-url] +**New!** Support for [clients](#Client-support). -[travis-image]: https://travis-ci.org/paritytech/jsonrpc.svg?branch=master -[travis-url]: https://travis-ci.org/paritytech/jsonrpc - -[Documentation](http://paritytech.github.io/jsonrpc/) +[Documentation](https://docs.rs/jsonrpc-core/) ## Sub-projects - [jsonrpc-core](./core) [![crates.io][core-image]][core-url] +- [jsonrpc-core-client](./core-client) [![crates.io][core-client-image]][core-client-url] - [jsonrpc-http-server](./http) [![crates.io][http-server-image]][http-server-url] -- [jsonrpc-minihttp-server](./minihttp) -- [jsonrpc-ipc-server](./ipc) +- [jsonrpc-ipc-server](./ipc) [![crates.io][ipc-server-image]][ipc-server-url] - [jsonrpc-tcp-server](./tcp) [![crates.io][tcp-server-image]][tcp-server-url] -- [jsonrpc-ws-server](./ws) -- [jsonrpc-stdio-server](./stdio) -- [jsonrpc-macros](./macros) [![crates.io][macros-image]][macros-url] +- [jsonrpc-ws-server](./ws) [![crates.io][ws-server-image]][ws-server-url] +- [jsonrpc-stdio-server](./stdio) [![crates.io][stdio-server-image]][stdio-server-url] +- [jsonrpc-derive](./derive) [![crates.io][derive-image]][derive-url] - [jsonrpc-server-utils](./server-utils) [![crates.io][server-utils-image]][server-utils-url] - [jsonrpc-pubsub](./pubsub) [![crates.io][pubsub-image]][pubsub-url] [core-image]: https://img.shields.io/crates/v/jsonrpc-core.svg [core-url]: https://crates.io/crates/jsonrpc-core +[core-client-image]: https://img.shields.io/crates/v/jsonrpc-core-client.svg +[core-client-url]: https://crates.io/crates/jsonrpc-core-client [http-server-image]: https://img.shields.io/crates/v/jsonrpc-http-server.svg [http-server-url]: https://crates.io/crates/jsonrpc-http-server +[ipc-server-image]: https://img.shields.io/crates/v/jsonrpc-ipc-server.svg +[ipc-server-url]: https://crates.io/crates/jsonrpc-ipc-server [tcp-server-image]: https://img.shields.io/crates/v/jsonrpc-tcp-server.svg [tcp-server-url]: https://crates.io/crates/jsonrpc-tcp-server -[macros-image]: https://img.shields.io/crates/v/jsonrpc-macros.svg -[macros-url]: https://crates.io/crates/jsonrpc-macros +[ws-server-image]: https://img.shields.io/crates/v/jsonrpc-ws-server.svg +[ws-server-url]: https://crates.io/crates/jsonrpc-ws-server +[stdio-server-image]: https://img.shields.io/crates/v/jsonrpc-stdio-server.svg +[stdio-server-url]: https://crates.io/crates/jsonrpc-stdio-server +[derive-image]: https://img.shields.io/crates/v/jsonrpc-derive.svg +[derive-url]: https://crates.io/crates/jsonrpc-derive [server-utils-image]: https://img.shields.io/crates/v/jsonrpc-server-utils.svg [server-utils-url]: https://crates.io/crates/jsonrpc-server-utils [pubsub-image]: https://img.shields.io/crates/v/jsonrpc-pubsub.svg @@ -38,22 +43,19 @@ Transport-agnostic `core` and transport servers for `http`, `ipc`, `websockets` ## Examples - [core](./core/examples) -- [macros](./macros/examples) +- [derive](./derive/examples) - [pubsub](./pubsub/examples) ### Basic Usage (with HTTP transport) ```rust -extern crate jsonrpc_core; -extern crate jsonrpc_minihttp_server; - -use jsonrpc_core::{IoHandler, Value, Params}; -use jsonrpc_minihttp_server::{ServerBuilder}; +use jsonrpc_http_server::jsonrpc_core::{IoHandler, Value, Params}; +use jsonrpc_http_server::ServerBuilder; fn main() { - let mut io = IoHandler::new(); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) + let mut io = IoHandler::default(); + io.add_method("say_hello", |_params: Params| async { + Ok(Value::String("hello".to_owned())) }); let server = ServerBuilder::new(io) @@ -61,25 +63,21 @@ fn main() { .start_http(&"127.0.0.1:3030".parse().unwrap()) .unwrap(); - server.wait().unwrap(); + server.wait(); } ``` -### Basic usage with macros +### Basic usage with derive ```rust -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; - use jsonrpc_core::Result; +use jsonrpc_derive::rpc; -build_rpc_trait! { - pub trait Rpc { - /// Adds two numbers and returns a result - #[rpc(name = "add")] - fn add(&self, u64, u64) -> Result; - } +#[rpc] +pub trait Rpc { + /// Adds two numbers and returns a result + #[rpc(name = "add")] + fn add(&self, u64, u64) -> Result; } pub struct RpcImpl; @@ -89,8 +87,59 @@ impl Rpc for RpcImpl { } } - fn main() { let mut io = jsonrpc_core::IoHandler::new(); io.extend_with(RpcImpl.to_delegate()) } +``` + +### Client support + +```rust +use jsonrpc_core_client::transports::local; +use jsonrpc_core::{Error, IoHandler, Result}; +use jsonrpc_derive::rpc; + +/// Rpc trait +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; + + /// Performs asynchronous operation + #[rpc(name = "callAsync")] + fn call(&self, a: u64) -> FutureResult; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn call(&self, _: u64) -> FutureResult { + future::ok("OK".to_owned()) + } +} + +fn main() { + let mut io = IoHandler::new(); + io.extend_with(RpcImpl.to_delegate()); + + let fut = { + let (client, server) = local::connect::(io); + client.add(5, 6).map(|res| println!("5 + 6 = {}", res)).join(server) + }; + fut.wait().unwrap(); +} +``` diff --git a/_automate/bump_version.sh b/_automate/bump_version.sh new file mode 100755 index 000000000..a078520fc --- /dev/null +++ b/_automate/bump_version.sh @@ -0,0 +1,21 @@ +#!/bin/sh + +set -xeu + +VERSION=$1 +PREV_DEPS=$2 +NEW_DEPS=$3 + +ack "^version = \"" -l | \ + grep toml | \ + xargs sed -i "s/^version = \".*/version = \"$VERSION\"/" + +ack "{ version = \"$PREV_DEPS" -l | \ + grep toml | \ + xargs sed -i "s/{ version = \"$PREV_DEPS/{ version = \"$NEW_DEPS/" + +ack " = \"$PREV_DEPS" -l | \ + grep md | \ + xargs sed -i "s/ = \"$PREV_DEPS/ = \"$NEW_DEPS/" + +cargo check diff --git a/_automate/publish.sh b/_automate/publish.sh new file mode 100755 index 000000000..ec389b923 --- /dev/null +++ b/_automate/publish.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +set -eu + +ORDER=(core server-utils tcp ws http ipc stdio pubsub core-client/transports core-client derive test) + +function read_toml () { + NAME="" + VERSION="" + NAME=$(grep "^name" ./Cargo.toml | sed -e 's/.*"\(.*\)"/\1/') + VERSION=$(grep "^version" ./Cargo.toml | sed -e 's/.*"\(.*\)"/\1/') +} +function remote_version () { + REMOTE_VERSION="" + REMOTE_VERSION=$(cargo search "$NAME" | grep "^$NAME =" | sed -e 's/.*"\(.*\)".*/\1/') +} + +# First display the plan +for CRATE_DIR in ${ORDER[@]}; do + cd $CRATE_DIR > /dev/null + read_toml + echo "$NAME@$VERSION" + cd - > /dev/null +done + +read -p ">>>> Really publish?. Press [enter] to continue. " + +set -x + +cargo clean + +set +x + +# Then actually perform publishing. +for CRATE_DIR in ${ORDER[@]}; do + cd $CRATE_DIR > /dev/null + read_toml + remote_version + # Seems the latest version matches, skip by default. + if [ "$REMOTE_VERSION" = "$VERSION" ] || [[ "$REMOTE_VERSION" > "$VERSION" ]]; then + RET="" + echo "Seems like $NAME@$REMOTE_VERSION is already published. Continuing in 5s. " + read -t 5 -p ">>>> Type [r][enter] to retry, or [enter] to continue... " RET || true + if [ "$RET" != "r" ]; then + echo "Skipping $NAME@$VERSION" + cd - > /dev/null + continue + fi + fi + + # Attempt to publish (allow retries) + while : ; do + # give the user an opportunity to abort or skip before publishing + echo "🚀 Publishing $NAME@$VERSION..." + sleep 3 + + set +e && set -x + cargo publish $@ + RES=$? + set +x && set -e + # Check if it succeeded + if [ "$RES" != "0" ]; then + CHOICE="" + echo "##### Publishing $NAME failed" + read -p ">>>>> Type [s][enter] to skip, or [enter] to retry.. " CHOICE + if [ "$CHOICE" = "s" ]; then + break + fi + else + break + fi + done + + # Wait again to make sure that the new version is published and available. + echo "Waiting for $NAME@$VERSION to become available at the registry..." + while : ; do + sleep 3 + remote_version + if [ "$REMOTE_VERSION" = "$VERSION" ]; then + echo "🥳 $NAME@$VERSION published succesfully." + sleep 3 + break + else + echo "#### Got $NAME@$REMOTE_VERSION but expected $NAME@$VERSION. Retrying..." + fi + done + cd - > /dev/null +done + +# Make tags in one go +set -x +git fetch --tags +set +x + +for CRATE_DIR in ${ORDER[@]}; do + cd $CRATE_DIR > /dev/null + read_toml + echo "Tagging $NAME@$VERSION" + set -x + git tag -a "$NAME-$VERSION" -m "$NAME $VERSION" || true + set +x + cd - > /dev/null +done + +set -x +sleep 3 +git push --tags +set +x + +cd core > /dev/null +read_toml +cd - > /dev/null +echo "Tagging jsonrpc@$VERSION" +set -x +git tag -a v$VERSION -m "Version $VERSION" +sleep 3 +git push --tags diff --git a/core-client/Cargo.toml b/core-client/Cargo.toml new file mode 100644 index 000000000..e4340540a --- /dev/null +++ b/core-client/Cargo.toml @@ -0,0 +1,33 @@ +[package] +authors = ["Parity Technologies "] +description = "Transport agnostic JSON-RPC 2.0 client implementation." +documentation = "https://docs.rs/jsonrpc-core-client/" +edition = "2018" +homepage = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] +license = "MIT" +name = "jsonrpc-core-client" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" + +categories = [ + "asynchronous", + "network-programming", + "web-programming::http-client", + "web-programming::http-server", + "web-programming::websocket", +] + +[features] +tls = ["jsonrpc-client-transports/tls"] +http = ["jsonrpc-client-transports/http"] +ws = ["jsonrpc-client-transports/ws"] +ipc = ["jsonrpc-client-transports/ipc"] +arbitrary_precision = ["jsonrpc-client-transports/arbitrary_precision"] + +[dependencies] +jsonrpc-client-transports = { version = "18.0.0", path = "./transports", default-features = false } +futures = { version = "0.3", features = [ "compat" ] } + +[badges] +travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/core-client/src/lib.rs b/core-client/src/lib.rs new file mode 100644 index 000000000..0ffd970b7 --- /dev/null +++ b/core-client/src/lib.rs @@ -0,0 +1,11 @@ +//! JSON-RPC client implementation primitives. +//! +//! By default this crate does not implement any transports, +//! use corresponding features (`tls`, `http` or `ws`) to opt-in for them. +//! +//! See documentation of [`jsonrpc-client-transports`](https://docs.rs/jsonrpc-client-transports) for more details. + +#![deny(missing_docs)] + +pub use futures; +pub use jsonrpc_client_transports::*; diff --git a/core-client/transports/Cargo.toml b/core-client/transports/Cargo.toml new file mode 100644 index 000000000..3ca892697 --- /dev/null +++ b/core-client/transports/Cargo.toml @@ -0,0 +1,63 @@ +[package] +authors = ["Parity Technologies "] +description = "Transport agnostic JSON-RPC 2.0 client implementation." +documentation = "https://docs.rs/jsonrpc-client-transports/" +edition = "2018" +homepage = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] +license = "MIT" +name = "jsonrpc-client-transports" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" + +categories = [ + "asynchronous", + "network-programming", + "web-programming::http-client", + "web-programming::http-server", + "web-programming::websocket", +] + +[features] +default = ["http", "tls", "ws"] +tls = ["hyper-tls", "http"] +http = ["hyper", "tokio/full"] +ws = [ + "websocket", + "tokio", + "futures/compat" +] +ipc = [ + "parity-tokio-ipc", + "jsonrpc-server-utils", + "tokio", +] +arbitrary_precision = ["serde_json/arbitrary_precision", "jsonrpc-core/arbitrary_precision"] + +[dependencies] +derive_more = "0.99" +futures = "0.3" +jsonrpc-core = { version = "18.0.0", path = "../../core" } +jsonrpc-pubsub = { version = "18.0.0", path = "../../pubsub" } +log = "0.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +url = "1.7" + +hyper = { version = "0.14", features = ["client", "http1"], optional = true } +hyper-tls = { version = "0.5", optional = true } +jsonrpc-server-utils = { version = "18.0.0", path = "../../server-utils", optional = true } +parity-tokio-ipc = { version = "0.9", optional = true } +tokio = { version = "1", optional = true } +websocket = { version = "0.24", optional = true } +flate2 = "0.2" + +[dev-dependencies] +assert_matches = "1.1" +jsonrpc-http-server = { version = "18.0.0", path = "../../http" } +jsonrpc-ipc-server = { version = "18.0.0", path = "../../ipc" } +lazy_static = "1.0" +env_logger = "0.7" + +[badges] +travis-ci = { repository = "paritytech/jsonrpc", branch = "master" } diff --git a/core-client/transports/src/lib.rs b/core-client/transports/src/lib.rs new file mode 100644 index 000000000..8b1112945 --- /dev/null +++ b/core-client/transports/src/lib.rs @@ -0,0 +1,479 @@ +//! JSON-RPC client implementation. + +#![deny(missing_docs)] + +use jsonrpc_core::futures::channel::{mpsc, oneshot}; +use jsonrpc_core::futures::{ + self, + task::{Context, Poll}, + Future, Stream, StreamExt, +}; +use jsonrpc_core::{Error, Params}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json::Value; +use std::marker::PhantomData; +use std::pin::Pin; + +pub mod transports; + +#[cfg(test)] +mod logger; + +/// The errors returned by the client. +#[derive(Debug, derive_more::Display)] +pub enum RpcError { + /// An error returned by the server. + #[display(fmt = "Server returned rpc error {}", _0)] + JsonRpcError(Error), + /// Failure to parse server response. + #[display(fmt = "Failed to parse server response as {}: {}", _0, _1)] + ParseError(String, Box), + /// Request timed out. + #[display(fmt = "Request timed out")] + Timeout, + /// A general client error. + #[display(fmt = "Client error: {}", _0)] + Client(String), + /// Not rpc specific errors. + #[display(fmt = "{}", _0)] + Other(Box), +} + +impl std::error::Error for RpcError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match *self { + Self::JsonRpcError(ref e) => Some(e), + Self::ParseError(_, ref e) => Some(&**e), + Self::Other(ref e) => Some(&**e), + _ => None, + } + } +} + +impl From for RpcError { + fn from(error: Error) -> Self { + RpcError::JsonRpcError(error) + } +} + +/// A result returned by the client. +pub type RpcResult = Result; + +/// An RPC call message. +struct CallMessage { + /// The RPC method name. + method: String, + /// The RPC method parameters. + params: Params, + /// The oneshot channel to send the result of the rpc + /// call to. + sender: oneshot::Sender>, +} + +/// An RPC notification. +struct NotifyMessage { + /// The RPC method name. + method: String, + /// The RPC method paramters. + params: Params, +} + +/// An RPC subscription. +struct Subscription { + /// The subscribe method name. + subscribe: String, + /// The subscribe method parameters. + subscribe_params: Params, + /// The name of the notification. + notification: String, + /// The unsubscribe method name. + unsubscribe: String, +} + +/// An RPC subscribe message. +struct SubscribeMessage { + /// The subscription to subscribe to. + subscription: Subscription, + /// The channel to send notifications to. + sender: mpsc::UnboundedSender>, +} + +/// A message sent to the `RpcClient`. +enum RpcMessage { + /// Make an RPC call. + Call(CallMessage), + /// Send a notification. + Notify(NotifyMessage), + /// Subscribe to a notification. + Subscribe(SubscribeMessage), +} + +impl From for RpcMessage { + fn from(msg: CallMessage) -> Self { + RpcMessage::Call(msg) + } +} + +impl From for RpcMessage { + fn from(msg: NotifyMessage) -> Self { + RpcMessage::Notify(msg) + } +} + +impl From for RpcMessage { + fn from(msg: SubscribeMessage) -> Self { + RpcMessage::Subscribe(msg) + } +} + +/// A channel to a `RpcClient`. +#[derive(Clone)] +pub struct RpcChannel(mpsc::UnboundedSender); + +impl RpcChannel { + fn send(&self, msg: RpcMessage) -> Result<(), mpsc::TrySendError> { + self.0.unbounded_send(msg) + } +} + +impl From> for RpcChannel { + fn from(sender: mpsc::UnboundedSender) -> Self { + RpcChannel(sender) + } +} + +/// The future returned by the rpc call. +pub type RpcFuture = oneshot::Receiver>; + +/// The stream returned by a subscribe. +pub type SubscriptionStream = mpsc::UnboundedReceiver>; + +/// A typed subscription stream. +pub struct TypedSubscriptionStream { + _marker: PhantomData, + returns: &'static str, + stream: SubscriptionStream, +} + +impl TypedSubscriptionStream { + /// Creates a new `TypedSubscriptionStream`. + pub fn new(stream: SubscriptionStream, returns: &'static str) -> Self { + TypedSubscriptionStream { + _marker: PhantomData, + returns, + stream, + } + } +} + +impl Stream for TypedSubscriptionStream { + type Item = RpcResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let result = futures::ready!(self.stream.poll_next_unpin(cx)); + match result { + Some(Ok(value)) => Some( + serde_json::from_value::(value) + .map_err(|error| RpcError::ParseError(self.returns.into(), Box::new(error))), + ), + None => None, + Some(Err(err)) => Some(Err(err)), + } + .into() + } +} + +/// Client for raw JSON RPC requests +#[derive(Clone)] +pub struct RawClient(RpcChannel); + +impl From for RawClient { + fn from(channel: RpcChannel) -> Self { + RawClient(channel) + } +} + +impl RawClient { + /// Call RPC method with raw JSON. + pub fn call_method(&self, method: &str, params: Params) -> impl Future> { + let (sender, receiver) = oneshot::channel(); + let msg = CallMessage { + method: method.into(), + params, + sender, + }; + let result = self.0.send(msg.into()); + async move { + let () = result.map_err(|e| RpcError::Other(Box::new(e)))?; + + receiver.await.map_err(|e| RpcError::Other(Box::new(e)))? + } + } + + /// Send RPC notification with raw JSON. + pub fn notify(&self, method: &str, params: Params) -> RpcResult<()> { + let msg = NotifyMessage { + method: method.into(), + params, + }; + match self.0.send(msg.into()) { + Ok(()) => Ok(()), + Err(error) => Err(RpcError::Other(Box::new(error))), + } + } + + /// Subscribe to topic with raw JSON. + pub fn subscribe( + &self, + subscribe: &str, + subscribe_params: Params, + notification: &str, + unsubscribe: &str, + ) -> RpcResult { + let (sender, receiver) = mpsc::unbounded(); + let msg = SubscribeMessage { + subscription: Subscription { + subscribe: subscribe.into(), + subscribe_params, + notification: notification.into(), + unsubscribe: unsubscribe.into(), + }, + sender, + }; + + self.0 + .send(msg.into()) + .map(|()| receiver) + .map_err(|e| RpcError::Other(Box::new(e))) + } +} + +/// Client for typed JSON RPC requests +#[derive(Clone)] +pub struct TypedClient(RawClient); + +impl From for TypedClient { + fn from(channel: RpcChannel) -> Self { + TypedClient(channel.into()) + } +} + +impl TypedClient { + /// Create a new `TypedClient`. + pub fn new(raw_cli: RawClient) -> Self { + TypedClient(raw_cli) + } + + /// Call RPC with serialization of request and deserialization of response. + pub fn call_method( + &self, + method: &str, + returns: &str, + args: T, + ) -> impl Future> { + let returns = returns.to_owned(); + let args = + serde_json::to_value(args).expect("Only types with infallible serialisation can be used for JSON-RPC"); + let params = match args { + Value::Array(vec) => Ok(Params::Array(vec)), + Value::Null => Ok(Params::None), + Value::Object(map) => Ok(Params::Map(map)), + _ => Err(RpcError::Client( + "RPC params should serialize to a JSON array, JSON object or null".into(), + )), + }; + let result = params.map(|params| self.0.call_method(method, params)); + + async move { + let value: Value = result?.await?; + + log::debug!("response: {:?}", value); + + serde_json::from_value::(value).map_err(|error| RpcError::ParseError(returns, Box::new(error))) + } + } + + /// Call RPC with serialization of request only. + pub fn notify(&self, method: &str, args: T) -> RpcResult<()> { + let args = + serde_json::to_value(args).expect("Only types with infallible serialisation can be used for JSON-RPC"); + let params = match args { + Value::Array(vec) => Params::Array(vec), + Value::Null => Params::None, + _ => { + return Err(RpcError::Client( + "RPC params should serialize to a JSON array, or null".into(), + )) + } + }; + + self.0.notify(method, params) + } + + /// Subscribe with serialization of request and deserialization of response. + pub fn subscribe( + &self, + subscribe: &str, + subscribe_params: T, + topic: &str, + unsubscribe: &str, + returns: &'static str, + ) -> RpcResult> { + let args = serde_json::to_value(subscribe_params) + .expect("Only types with infallible serialisation can be used for JSON-RPC"); + + let params = match args { + Value::Array(vec) => Params::Array(vec), + Value::Null => Params::None, + _ => { + return Err(RpcError::Client( + "RPC params should serialize to a JSON array, or null".into(), + )) + } + }; + + self.0 + .subscribe(subscribe, params, topic, unsubscribe) + .map(move |stream| TypedSubscriptionStream::new(stream, returns)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transports::local; + use crate::{RpcChannel, TypedClient}; + use jsonrpc_core::futures::{future, FutureExt}; + use jsonrpc_core::{self as core, IoHandler}; + use jsonrpc_pubsub::{PubSubHandler, Subscriber, SubscriptionId}; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + #[derive(Clone)] + struct AddClient(TypedClient); + + impl From for AddClient { + fn from(channel: RpcChannel) -> Self { + AddClient(channel.into()) + } + } + + impl AddClient { + fn add(&self, a: u64, b: u64) -> impl Future> { + self.0.call_method("add", "u64", (a, b)) + } + + fn completed(&self, success: bool) -> RpcResult<()> { + self.0.notify("completed", (success,)) + } + } + + #[test] + fn test_client_terminates() { + crate::logger::init_log(); + let mut handler = IoHandler::new(); + handler.add_sync_method("add", |params: Params| { + let (a, b) = params.parse::<(u64, u64)>()?; + let res = a + b; + Ok(jsonrpc_core::to_value(res).unwrap()) + }); + + let (tx, rx) = std::sync::mpsc::channel(); + let (client, rpc_client) = local::connect::(handler); + let fut = async move { + let res = client.add(3, 4).await?; + let res = client.add(res, 5).await?; + assert_eq!(res, 12); + tx.send(()).unwrap(); + Ok(()) as RpcResult<_> + }; + let pool = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + pool.spawn_ok(rpc_client.map(|x| x.unwrap())); + pool.spawn_ok(fut.map(|x| x.unwrap())); + rx.recv().unwrap() + } + + #[test] + fn should_send_notification() { + crate::logger::init_log(); + let (tx, rx) = std::sync::mpsc::sync_channel(1); + let mut handler = IoHandler::new(); + handler.add_notification("completed", move |params: Params| { + let (success,) = params.parse::<(bool,)>().expect("expected to receive one boolean"); + assert_eq!(success, true); + tx.send(()).unwrap(); + }); + + let (client, rpc_client) = local::connect::(handler); + client.completed(true).unwrap(); + let pool = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + pool.spawn_ok(rpc_client.map(|x| x.unwrap())); + rx.recv().unwrap() + } + + #[test] + fn should_handle_subscription() { + crate::logger::init_log(); + // given + let (finish, finished) = std::sync::mpsc::sync_channel(1); + let mut handler = PubSubHandler::::default(); + let called = Arc::new(AtomicBool::new(false)); + let called2 = called.clone(); + handler.add_subscription( + "hello", + ("subscribe_hello", move |params, _meta, subscriber: Subscriber| { + assert_eq!(params, core::Params::None); + let sink = subscriber + .assign_id(SubscriptionId::Number(5)) + .expect("assigned subscription id"); + let finish = finish.clone(); + std::thread::spawn(move || { + for i in 0..3 { + std::thread::sleep(std::time::Duration::from_millis(100)); + let value = serde_json::json!({ + "subscription": 5, + "result": vec![i], + }); + let _ = sink.notify(serde_json::from_value(value).unwrap()); + } + finish.send(()).unwrap(); + }); + }), + ("unsubscribe_hello", move |id, _meta| { + // Should be called because session is dropped. + called2.store(true, Ordering::SeqCst); + assert_eq!(id, SubscriptionId::Number(5)); + future::ready(Ok(core::Value::Bool(true))) + }), + ); + + // when + let (tx, rx) = std::sync::mpsc::channel(); + let (client, rpc_client) = local::connect_with_pubsub::(handler); + let received = Arc::new(std::sync::Mutex::new(vec![])); + let r2 = received.clone(); + let fut = async move { + let mut stream = + client.subscribe::<_, (u32,)>("subscribe_hello", (), "hello", "unsubscribe_hello", "u32")?; + let result = stream.next().await; + r2.lock().unwrap().push(result.expect("Expected at least one item.")); + tx.send(()).unwrap(); + Ok(()) as RpcResult<_> + }; + + let pool = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + pool.spawn_ok(rpc_client.map(|_| ())); + pool.spawn_ok(fut.map(|x| x.unwrap())); + + rx.recv().unwrap(); + assert!( + !received.lock().unwrap().is_empty(), + "Expected at least one received item." + ); + // The session is being dropped only when another notification is received. + // TODO [ToDr] we should unsubscribe as soon as the stream is dropped instead! + finished.recv().unwrap(); + assert_eq!(called.load(Ordering::SeqCst), true, "Unsubscribe not called."); + } +} diff --git a/core-client/transports/src/logger.rs b/core-client/transports/src/logger.rs new file mode 100644 index 000000000..812a06524 --- /dev/null +++ b/core-client/transports/src/logger.rs @@ -0,0 +1,25 @@ +use env_logger::Builder; +use lazy_static::lazy_static; +use log::LevelFilter; +use std::env; + +lazy_static! { + static ref LOG_DUMMY: bool = { + let mut builder = Builder::new(); + builder.filter(None, LevelFilter::Info); + + if let Ok(log) = env::var("RUST_LOG") { + builder.parse_filters(&log); + } + + if let Ok(_) = builder.try_init() { + println!("logger initialized"); + } + true + }; +} + +/// Intialize log with default settings +pub fn init_log() { + let _ = *LOG_DUMMY; +} diff --git a/core-client/transports/src/transports/duplex.rs b/core-client/transports/src/transports/duplex.rs new file mode 100644 index 000000000..28751f64d --- /dev/null +++ b/core-client/transports/src/transports/duplex.rs @@ -0,0 +1,325 @@ +//! Duplex transport + +use futures::channel::{mpsc, oneshot}; +use futures::{ + task::{Context, Poll}, + Future, Sink, Stream, StreamExt, +}; +use jsonrpc_core::Id; +use jsonrpc_pubsub::SubscriptionId; +use log::debug; +use serde_json::Value; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::pin::Pin; + +use super::RequestBuilder; +use crate::{RpcChannel, RpcError, RpcMessage, RpcResult}; + +struct Subscription { + /// Subscription id received when subscribing. + id: Option, + /// A method name used for notification. + notification: String, + /// Rpc method to unsubscribe. + unsubscribe: String, + /// Where to send messages to. + channel: mpsc::UnboundedSender>, +} + +impl Subscription { + fn new(channel: mpsc::UnboundedSender>, notification: String, unsubscribe: String) -> Self { + Subscription { + id: None, + notification, + unsubscribe, + channel, + } + } +} + +enum PendingRequest { + Call(oneshot::Sender>), + Subscription(Subscription), +} + +/// The Duplex handles sending and receiving asynchronous +/// messages through an underlying transport. +pub struct Duplex { + request_builder: RequestBuilder, + /// Channel from the client. + channel: Option>, + /// Requests that haven't received a response yet. + pending_requests: HashMap, + /// A map from the subscription name to the subscription. + subscriptions: HashMap<(SubscriptionId, String), Subscription>, + /// Incoming messages from the underlying transport. + stream: Pin>, + /// Unprocessed incoming messages. + incoming: VecDeque<(Id, RpcResult, Option, Option)>, + /// Unprocessed outgoing messages. + outgoing: VecDeque, + /// Outgoing messages from the underlying transport. + sink: Pin>, +} + +impl Duplex { + /// Creates a new `Duplex`. + fn new(sink: Pin>, stream: Pin>, channel: mpsc::UnboundedReceiver) -> Self { + log::debug!("open"); + Duplex { + request_builder: RequestBuilder::new(), + channel: Some(channel), + pending_requests: Default::default(), + subscriptions: Default::default(), + stream, + incoming: Default::default(), + outgoing: Default::default(), + sink, + } + } +} + +/// Creates a new `Duplex`, along with a channel to communicate +pub fn duplex(sink: Pin>, stream: Pin>) -> (Duplex, RpcChannel) +where + TSink: Sink, + TStream: Stream, +{ + let (sender, receiver) = mpsc::unbounded(); + let client = Duplex::new(sink, stream, receiver); + (client, sender.into()) +} + +impl Future for Duplex +where + TSink: Sink, + TStream: Stream, +{ + type Output = RpcResult<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // Handle requests from the client. + log::debug!("handle requests from client"); + loop { + // Check that the client channel is open + let channel = match self.channel.as_mut() { + Some(channel) => channel, + None => break, + }; + let msg = match channel.poll_next_unpin(cx) { + Poll::Ready(Some(msg)) => msg, + Poll::Ready(None) => { + // When the channel is dropped we still need to finish + // outstanding requests. + self.channel.take(); + break; + } + Poll::Pending => break, + }; + let request_str = match msg { + RpcMessage::Call(msg) => { + let (id, request_str) = self.request_builder.call_request(&msg); + + if self + .pending_requests + .insert(id.clone(), PendingRequest::Call(msg.sender)) + .is_some() + { + log::error!("reuse of request id {:?}", id); + } + request_str + } + RpcMessage::Subscribe(msg) => { + let crate::Subscription { + subscribe, + subscribe_params, + notification, + unsubscribe, + } = msg.subscription; + let (id, request_str) = self.request_builder.subscribe_request(subscribe, subscribe_params); + log::debug!("subscribing to {}", notification); + let subscription = Subscription::new(msg.sender, notification, unsubscribe); + if self + .pending_requests + .insert(id.clone(), PendingRequest::Subscription(subscription)) + .is_some() + { + log::error!("reuse of request id {:?}", id); + } + request_str + } + RpcMessage::Notify(msg) => self.request_builder.notification(&msg), + }; + log::debug!("outgoing: {}", request_str); + self.outgoing.push_back(request_str); + } + + // Handle stream. + // Reads from stream and queues to incoming queue. + log::debug!("handle stream"); + loop { + let response_str = match self.stream.as_mut().poll_next(cx) { + Poll::Ready(Some(response_str)) => response_str, + Poll::Ready(None) => { + // The websocket connection was closed so the client + // can be shutdown. Reopening closed connections must + // be handled by the transport. + debug!("connection closed"); + return Poll::Ready(Ok(())); + } + Poll::Pending => break, + }; + log::debug!("incoming: {}", response_str); + // we only send one request at the time, so there can only be one response. + let (id, result, method, sid) = super::parse_response(&response_str)?; + log::debug!( + "id: {:?} (sid: {:?}) result: {:?} method: {:?}", + id, + sid, + result, + method + ); + self.incoming.push_back((id, result, method, sid)); + } + + // Handle incoming queue. + log::debug!("handle incoming"); + loop { + match self.incoming.pop_front() { + Some((id, result, method, sid)) => { + let sid_and_method = sid.and_then(|sid| method.map(|method| (sid, method))); + // Handle the response to a pending request. + match self.pending_requests.remove(&id) { + // It's a regular Req-Res call, so just answer. + Some(PendingRequest::Call(tx)) => { + tx.send(result) + .map_err(|_| RpcError::Client("oneshot channel closed".into()))?; + continue; + } + // It was a subscription request, + // turn it into a proper subscription. + Some(PendingRequest::Subscription(mut subscription)) => { + let sid = result.as_ref().ok().and_then(|res| SubscriptionId::parse_value(res)); + let method = subscription.notification.clone(); + + if let Some(sid) = sid { + subscription.id = Some(sid.clone()); + if self + .subscriptions + .insert((sid.clone(), method.clone()), subscription) + .is_some() + { + log::warn!( + "Overwriting existing subscription under {:?} ({:?}). \ + Seems that server returned the same subscription id.", + sid, + method, + ); + } + } else { + let err = RpcError::Client(format!( + "Subscription {:?} ({:?}) rejected: {:?}", + id, method, result, + )); + + if subscription.channel.unbounded_send(result).is_err() { + log::warn!("{}, but the reply channel has closed.", err); + } + } + continue; + } + // It's not a pending request nor a notification + None if sid_and_method.is_none() => { + log::warn!("Got unexpected response with id {:?} ({:?})", id, sid_and_method); + continue; + } + // just fall-through in case it's a notification + None => {} + }; + + let sid_and_method = if let Some(x) = sid_and_method { + x + } else { + continue; + }; + + if let Some(subscription) = self.subscriptions.get_mut(&sid_and_method) { + let res = subscription.channel.unbounded_send(result); + if res.is_err() { + let subscription = self + .subscriptions + .remove(&sid_and_method) + .expect("Subscription was just polled; qed"); + let sid = subscription.id.expect( + "Every subscription that ends up in `self.subscriptions` has id already \ + assigned; assignment happens during response to subscribe request.", + ); + let (_id, request_str) = + self.request_builder.unsubscribe_request(subscription.unsubscribe, sid); + log::debug!("outgoing: {}", request_str); + self.outgoing.push_back(request_str); + log::debug!("unsubscribed from {:?}", sid_and_method); + } + } else { + log::warn!("Received unexpected subscription notification: {:?}", sid_and_method); + } + } + None => break, + } + } + + // Handle outgoing queue. + // Writes queued messages to sink. + log::debug!("handle outgoing"); + loop { + let err = || Err(RpcError::Client("closing".into())); + match self.sink.as_mut().poll_ready(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(_)) => return err().into(), + _ => break, + } + match self.outgoing.pop_front() { + Some(request) => { + if self.sink.as_mut().start_send(request).is_err() { + // the channel is disconnected. + return err().into(); + } + } + None => break, + } + } + log::debug!("handle sink"); + let sink_empty = match self.sink.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => true, + Poll::Ready(Err(_)) => true, + Poll::Pending => false, + }; + + log::debug!("{:?}", self); + // Return ready when the future is complete + if self.channel.is_none() + && self.outgoing.is_empty() + && self.incoming.is_empty() + && self.pending_requests.is_empty() + && self.subscriptions.is_empty() + && sink_empty + { + log::debug!("close"); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +impl std::fmt::Debug for Duplex { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "channel is none: {}", self.channel.is_none())?; + writeln!(f, "outgoing: {}", self.outgoing.len())?; + writeln!(f, "incoming: {}", self.incoming.len())?; + writeln!(f, "pending_requests: {}", self.pending_requests.len())?; + writeln!(f, "subscriptions: {}", self.subscriptions.len())?; + Ok(()) + } +} diff --git a/core-client/transports/src/transports/http.rs b/core-client/transports/src/transports/http.rs new file mode 100644 index 000000000..99def1f39 --- /dev/null +++ b/core-client/transports/src/transports/http.rs @@ -0,0 +1,315 @@ +//! HTTP client +//! +//! HTTPS support is enabled with the `tls` feature. + +use super::RequestBuilder; +use crate::{RpcChannel, RpcError, RpcMessage, RpcResult}; +use futures::{future, Future, FutureExt, StreamExt, TryFutureExt}; +use hyper::{http, Client, Request, Uri}; +use flate2::read::GzDecoder; +use std::io::Read; + +/// Create a HTTP Client +pub async fn connect_with_options(url: &str, allow_gzip: bool) -> RpcResult +where + TClient: From, +{ + let url: Uri = url.parse().map_err(|e| RpcError::Other(Box::new(e)))?; + + let (client_api, client_worker) = do_connect(url, allow_gzip).await; + tokio::spawn(client_worker); + + Ok(TClient::from(client_api)) +} + +/// Create a HTTP Client +pub async fn connect(url: &str) -> RpcResult +where + TClient: From, +{ + connect_with_options(url, false).await +} + +async fn do_connect(url: Uri, allow_gzip: bool) -> (RpcChannel, impl Future) { + let max_parallel = 8; + + #[cfg(feature = "tls")] + let connector = hyper_tls::HttpsConnector::new(); + #[cfg(feature = "tls")] + let client = Client::builder().build::<_, hyper::Body>(connector); + + #[cfg(not(feature = "tls"))] + let client = Client::new(); + // Keep track of internal request IDs when building subsequent requests + let mut request_builder = RequestBuilder::new(); + + let (sender, receiver) = futures::channel::mpsc::unbounded(); + + let fut = receiver + .filter_map(move |msg: RpcMessage| { + future::ready(match msg { + RpcMessage::Call(call) => { + let (_, request) = request_builder.call_request(&call); + Some((request, Some(call.sender))) + } + RpcMessage::Notify(notify) => Some((request_builder.notification(¬ify), None)), + RpcMessage::Subscribe(_) => { + log::warn!("Unsupported `RpcMessage` type `Subscribe`."); + None + } + }) + }) + .map(move |(request, sender)| { + let request = Request::post(&url) + .header( + http::header::CONTENT_TYPE, + http::header::HeaderValue::from_static("application/json"), + ) + .header( + http::header::ACCEPT, + http::header::HeaderValue::from_static("application/json"), + ) + .header( + http::header::ACCEPT_ENCODING, + http::header::HeaderValue::from_static(if allow_gzip { "gzip" } else { "identity" }), + ) + .body(request.into()) + .expect("Uri and request headers are valid; qed"); + + client + .request(request) + .then(|response| async move { (response, sender) }) + }) + .buffer_unordered(max_parallel) + .for_each(|(response, sender)| async { + let result = match response { + Ok(ref res) if !res.status().is_success() => { + log::trace!("http result status {}", res.status()); + Err(RpcError::Client(format!( + "Unexpected response status code: {}", + res.status() + ))) + } + Err(err) => Err(RpcError::Other(Box::new(err))), + Ok(res) => { + let is_gzip_response = res.headers().get(http::header::CONTENT_ENCODING).unwrap_or(&http::header::HeaderValue::from_static("identity")) == "gzip"; + hyper::body::to_bytes(res.into_body()) + .map_err(|e| RpcError::ParseError(e.to_string(), Box::new(e))) + .await + .and_then(|bytes| { + if is_gzip_response { + let mut decoder = GzDecoder::new(bytes.as_ref()) + .map_err(|e| RpcError::ParseError(e.to_string(), Box::new(e)))?; + let mut buf = String::new(); + let _ = decoder.read_to_string(&mut buf); + Ok(buf.into()) + } else { + Ok(bytes) + } + }) + } + }; + + if let Some(sender) = sender { + let response = result + .and_then(|response| { + let response_str = String::from_utf8_lossy(response.as_ref()).into_owned(); + super::parse_response(&response_str) + }) + .and_then(|r| r.1); + if let Err(err) = sender.send(response) { + log::warn!("Error resuming asynchronous request: {:?}", err); + } + } + }); + + (sender.into(), fut) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::*; + use assert_matches::assert_matches; + use jsonrpc_core::{Error, ErrorCode, IoHandler, Params, Value}; + use jsonrpc_http_server::*; + + fn id(t: T) -> T { + t + } + + struct TestServer { + uri: String, + server: Option, + } + + impl TestServer { + fn serve ServerBuilder>(alter: F) -> Self { + let builder = ServerBuilder::new(io()).rest_api(RestApi::Unsecure); + + let server = alter(builder).start_http(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let uri = format!("http://{}", server.address()); + + TestServer { + uri, + server: Some(server), + } + } + + fn stop(&mut self) { + let server = self.server.take(); + if let Some(server) = server { + server.close(); + } + } + } + + fn io() -> IoHandler { + let mut io = IoHandler::default(); + io.add_sync_method("hello", |params: Params| match params.parse::<(String,)>() { + Ok((msg,)) => Ok(Value::String(format!("hello {}", msg))), + _ => Ok(Value::String("world".into())), + }); + io.add_sync_method("fail", |_: Params| Err(Error::new(ErrorCode::ServerError(-34)))); + io.add_notification("notify", |params: Params| { + let (value,) = params.parse::<(u64,)>().expect("expected one u64 as param"); + assert_eq!(value, 12); + }); + + io + } + + #[derive(Clone)] + struct TestClient(TypedClient); + + impl From for TestClient { + fn from(channel: RpcChannel) -> Self { + TestClient(channel.into()) + } + } + + impl TestClient { + fn hello(&self, msg: &'static str) -> impl Future> { + self.0.call_method("hello", "String", (msg,)) + } + fn fail(&self) -> impl Future> { + self.0.call_method("fail", "()", ()) + } + fn notify(&self, value: u64) -> RpcResult<()> { + self.0.notify("notify", (value,)) + } + } + + #[test] + fn should_work() { + crate::logger::init_log(); + + // given + let server = TestServer::serve(id); + + // when + let run = async { + let client: TestClient = connect(&server.uri).await?; + let result = client.hello("http").await?; + + // then + assert_eq!("hello http", result); + Ok(()) as RpcResult<_> + }; + + tokio::runtime::Runtime::new().unwrap().block_on(run).unwrap(); + } + + #[test] + fn should_send_notification() { + crate::logger::init_log(); + + // given + let server = TestServer::serve(id); + + // when + let run = async { + let client: TestClient = connect(&server.uri).await.unwrap(); + client.notify(12).unwrap(); + }; + + tokio::runtime::Runtime::new().unwrap().block_on(run); + // Ensure that server has not been moved into runtime + drop(server); + } + + #[test] + fn handles_invalid_uri() { + crate::logger::init_log(); + + // given + let invalid_uri = "invalid uri"; + + // when + let fut = connect(invalid_uri); + let res: RpcResult = tokio::runtime::Runtime::new().unwrap().block_on(fut); + + // then + assert_matches!( + res.map(|_cli| unreachable!()), Err(RpcError::Other(err)) => { + assert_eq!(format!("{}", err), "invalid uri character"); + } + ); + } + + #[test] + fn handles_server_error() { + crate::logger::init_log(); + + // given + let server = TestServer::serve(id); + + // when + let run = async { + let client: TestClient = connect(&server.uri).await?; + client.fail().await + }; + let res = tokio::runtime::Runtime::new().unwrap().block_on(run); + + // then + if let Err(RpcError::JsonRpcError(err)) = res { + assert_eq!( + err, + Error { + code: ErrorCode::ServerError(-34), + message: "Server error".into(), + data: None + } + ) + } else { + panic!("Expected JsonRpcError. Received {:?}", res) + } + } + + #[test] + fn handles_connection_refused_error() { + // given + let mut server = TestServer::serve(id); + // stop server so that we get a connection refused + server.stop(); + + let run = async { + let client: TestClient = connect(&server.uri).await?; + let res = client.hello("http").await; + + if let Err(RpcError::Other(err)) = res { + if let Some(err) = err.downcast_ref::() { + assert!(err.is_connect(), "Expected Connection Error, got {:?}", err) + } else { + panic!("Expected a hyper::Error") + } + } else { + panic!("Expected JsonRpcError. Received {:?}", res) + } + + Ok(()) as RpcResult<_> + }; + + tokio::runtime::Runtime::new().unwrap().block_on(run).unwrap(); + } +} diff --git a/core-client/transports/src/transports/ipc.rs b/core-client/transports/src/transports/ipc.rs new file mode 100644 index 000000000..d50022597 --- /dev/null +++ b/core-client/transports/src/transports/ipc.rs @@ -0,0 +1,114 @@ +//! JSON-RPC IPC client implementation using Unix Domain Sockets on UNIX-likes +//! and Named Pipes on Windows. + +use crate::transports::duplex::duplex; +use crate::{RpcChannel, RpcError}; +use futures::{SinkExt, StreamExt, TryStreamExt}; +use jsonrpc_server_utils::codecs::StreamCodec; +use jsonrpc_server_utils::tokio; +use jsonrpc_server_utils::tokio_util::codec::Decoder as _; +use parity_tokio_ipc::Endpoint; +use std::path::Path; + +/// Connect to a JSON-RPC IPC server. +pub async fn connect, Client: From>(path: P) -> Result { + let connection = Endpoint::connect(path) + .await + .map_err(|e| RpcError::Other(Box::new(e)))?; + let (sink, stream) = StreamCodec::stream_incoming().framed(connection).split(); + let sink = sink.sink_map_err(|e| RpcError::Other(Box::new(e))); + let stream = stream.map_err(|e| log::error!("IPC stream error: {}", e)); + + let (client, sender) = duplex( + Box::pin(sink), + Box::pin( + stream + .take_while(|x| futures::future::ready(x.is_ok())) + .map(|x| x.expect("Stream is closed upon first error.")), + ), + ); + + tokio::spawn(client); + + Ok(sender.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::*; + use jsonrpc_core::{Error, ErrorCode, IoHandler, Params, Value}; + use jsonrpc_ipc_server::ServerBuilder; + use parity_tokio_ipc::dummy_endpoint; + use serde_json::map::Map; + + #[test] + fn should_call_one() { + let sock_path = dummy_endpoint(); + + let mut io = IoHandler::new(); + io.add_method("greeting", |params| async { + let map_obj = match params { + Params::Map(obj) => obj, + _ => return Err(Error::invalid_params("missing object")), + }; + let name = match map_obj.get("name") { + Some(val) => val.as_str().unwrap(), + None => return Err(Error::invalid_params("no name")), + }; + Ok(Value::String(format!("Hello {}!", name))) + }); + let builder = ServerBuilder::new(io); + let _server = builder.start(&sock_path).expect("Couldn't open socket"); + + let client_fut = async move { + let client: RawClient = connect(sock_path).await.unwrap(); + let mut map = Map::new(); + map.insert("name".to_string(), "Jeffry".into()); + let fut = client.call_method("greeting", Params::Map(map)); + + match fut.await { + Ok(val) => assert_eq!(&val, "Hello Jeffry!"), + Err(err) => panic!("IPC RPC call failed: {}", err), + } + }; + tokio::runtime::Runtime::new().unwrap().block_on(client_fut); + } + + #[test] + fn should_fail_without_server() { + let test_fut = async move { + match connect::<_, RawClient>(dummy_endpoint()).await { + Err(..) => {} + Ok(..) => panic!("Should not be able to connect to an IPC socket that's not open"), + } + }; + + tokio::runtime::Runtime::new().unwrap().block_on(test_fut); + } + + #[test] + fn should_handle_server_error() { + let sock_path = dummy_endpoint(); + + let mut io = IoHandler::new(); + io.add_method("greeting", |_params| async { Err(Error::invalid_params("test error")) }); + let builder = ServerBuilder::new(io); + let _server = builder.start(&sock_path).expect("Couldn't open socket"); + + let client_fut = async move { + let client: RawClient = connect(sock_path).await.unwrap(); + let mut map = Map::new(); + map.insert("name".to_string(), "Jeffry".into()); + let fut = client.call_method("greeting", Params::Map(map)); + + match fut.await { + Err(RpcError::JsonRpcError(err)) => assert_eq!(err.code, ErrorCode::InvalidParams), + Ok(_) => panic!("Expected the call to fail"), + _ => panic!("Unexpected error type"), + } + }; + + tokio::runtime::Runtime::new().unwrap().block_on(client_fut); + } +} diff --git a/core-client/transports/src/transports/local.rs b/core-client/transports/src/transports/local.rs new file mode 100644 index 000000000..da91a554b --- /dev/null +++ b/core-client/transports/src/transports/local.rs @@ -0,0 +1,209 @@ +//! Rpc client implementation for `Deref>`. + +use crate::{RpcChannel, RpcError, RpcResult}; +use futures::channel::mpsc; +use futures::{ + task::{Context, Poll}, + Future, Sink, SinkExt, Stream, StreamExt, +}; +use jsonrpc_core::{BoxFuture, MetaIoHandler, Metadata, Middleware}; +use jsonrpc_pubsub::Session; +use std::ops::Deref; +use std::pin::Pin; +use std::sync::Arc; + +/// Implements a rpc client for `MetaIoHandler`. +pub struct LocalRpc { + handler: THandler, + meta: TMetadata, + buffered: Buffered, + queue: (mpsc::UnboundedSender, mpsc::UnboundedReceiver), +} + +enum Buffered { + Request(BoxFuture>), + Response(String), + None, +} + +impl LocalRpc +where + TMetadata: Metadata, + TMiddleware: Middleware, + THandler: Deref>, +{ + /// Creates a new `LocalRpc` with default metadata. + pub fn new(handler: THandler) -> Self + where + TMetadata: Default, + { + Self::with_metadata(handler, Default::default()) + } + + /// Creates a new `LocalRpc` with given handler and metadata. + pub fn with_metadata(handler: THandler, meta: TMetadata) -> Self { + Self { + handler, + meta, + buffered: Buffered::None, + queue: mpsc::unbounded(), + } + } +} + +impl Stream for LocalRpc +where + TMetadata: Metadata + Unpin, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, +{ + type Item = String; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.queue.1.poll_next_unpin(cx) + } +} + +impl LocalRpc +where + TMetadata: Metadata + Unpin, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, +{ + fn poll_buffered(&mut self, cx: &mut Context) -> Poll> { + let response = match self.buffered { + Buffered::Request(ref mut r) => futures::ready!(r.as_mut().poll(cx)), + _ => None, + }; + if let Some(response) = response { + self.buffered = Buffered::Response(response); + } + + self.send_response().into() + } + + fn send_response(&mut self) -> Result<(), RpcError> { + if let Buffered::Response(r) = std::mem::replace(&mut self.buffered, Buffered::None) { + self.queue.0.start_send(r).map_err(|e| RpcError::Other(Box::new(e)))?; + } + Ok(()) + } +} + +impl Sink for LocalRpc +where + TMetadata: Metadata + Unpin, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, +{ + type Error = RpcError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + futures::ready!(self.poll_buffered(cx))?; + futures::ready!(self.queue.0.poll_ready(cx)) + .map_err(|e| RpcError::Other(Box::new(e))) + .into() + } + + fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> { + let future = self.handler.handle_request(&item, self.meta.clone()); + self.buffered = Buffered::Request(Box::pin(future)); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + futures::ready!(self.poll_buffered(cx))?; + futures::ready!(self.queue.0.poll_flush_unpin(cx)) + .map_err(|e| RpcError::Other(Box::new(e))) + .into() + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + futures::ready!(self.queue.0.poll_close_unpin(cx)) + .map_err(|e| RpcError::Other(Box::new(e))) + .into() + } +} + +/// Connects to a `Deref` specifying a custom middleware implementation. +pub fn connect_with_metadata_and_middleware( + handler: THandler, + meta: TMetadata, +) -> (TClient, impl Future>) +where + TClient: From, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, + TMetadata: Metadata + Unpin, +{ + let (sink, stream) = LocalRpc::with_metadata(handler, meta).split(); + let (rpc_client, sender) = crate::transports::duplex(Box::pin(sink), Box::pin(stream)); + let client = TClient::from(sender); + (client, rpc_client) +} + +/// Connects to a `Deref`. +pub fn connect_with_metadata( + handler: THandler, + meta: TMetadata, +) -> (TClient, impl Future>) +where + TClient: From, + TMetadata: Metadata + Unpin, + THandler: Deref> + Unpin, +{ + connect_with_metadata_and_middleware(handler, meta) +} + +/// Connects to a `Deref` specifying a custom middleware implementation. +pub fn connect_with_middleware( + handler: THandler, +) -> (TClient, impl Future>) +where + TClient: From, + TMetadata: Metadata + Default + Unpin, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, +{ + connect_with_metadata_and_middleware(handler, Default::default()) +} + +/// Connects to a `Deref`. +pub fn connect(handler: THandler) -> (TClient, impl Future>) +where + TClient: From, + TMetadata: Metadata + Default + Unpin, + THandler: Deref> + Unpin, +{ + connect_with_middleware(handler) +} + +/// Metadata for LocalRpc. +pub type LocalMeta = Arc; + +/// Connects with pubsub specifying a custom middleware implementation. +pub fn connect_with_pubsub_and_middleware( + handler: THandler, +) -> (TClient, impl Future>) +where + TClient: From, + TMiddleware: Middleware + Unpin, + THandler: Deref> + Unpin, +{ + let (tx, rx) = mpsc::unbounded(); + let meta = Arc::new(Session::new(tx)); + let (sink, stream) = LocalRpc::with_metadata(handler, meta).split(); + let stream = futures::stream::select(stream, rx); + let (rpc_client, sender) = crate::transports::duplex(Box::pin(sink), Box::pin(stream)); + let client = TClient::from(sender); + (client, rpc_client) +} + +/// Connects with pubsub. +pub fn connect_with_pubsub(handler: THandler) -> (TClient, impl Future>) +where + TClient: From, + THandler: Deref> + Unpin, +{ + connect_with_pubsub_and_middleware(handler) +} diff --git a/core-client/transports/src/transports/mod.rs b/core-client/transports/src/transports/mod.rs new file mode 100644 index 000000000..9df534a0d --- /dev/null +++ b/core-client/transports/src/transports/mod.rs @@ -0,0 +1,215 @@ +//! Client transport implementations + +use jsonrpc_core::{Call, Error, Id, MethodCall, Notification, Params, Version}; +use jsonrpc_pubsub::SubscriptionId; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::{CallMessage, NotifyMessage, RpcError}; + +pub mod duplex; +#[cfg(feature = "http")] +pub mod http; +#[cfg(feature = "ipc")] +pub mod ipc; +pub mod local; +#[cfg(feature = "ws")] +pub mod ws; + +pub use duplex::duplex; + +/// Creates JSON-RPC requests +pub struct RequestBuilder { + id: u64, +} + +impl RequestBuilder { + /// Create a new RequestBuilder + pub fn new() -> Self { + RequestBuilder { id: 0 } + } + + fn next_id(&mut self) -> Id { + let id = self.id; + self.id = id + 1; + Id::Num(id) + } + + /// Build a single request with the next available id + fn single_request(&mut self, method: String, params: Params) -> (Id, String) { + let id = self.next_id(); + let request = jsonrpc_core::Request::Single(Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method, + params, + id: id.clone(), + })); + ( + id, + serde_json::to_string(&request).expect("Request serialization is infallible; qed"), + ) + } + + fn call_request(&mut self, msg: &CallMessage) -> (Id, String) { + self.single_request(msg.method.clone(), msg.params.clone()) + } + + fn subscribe_request(&mut self, subscribe: String, subscribe_params: Params) -> (Id, String) { + self.single_request(subscribe, subscribe_params) + } + + fn unsubscribe_request(&mut self, unsubscribe: String, sid: SubscriptionId) -> (Id, String) { + self.single_request(unsubscribe, Params::Array(vec![Value::from(sid)])) + } + + fn notification(&mut self, msg: &NotifyMessage) -> String { + let request = jsonrpc_core::Request::Single(Call::Notification(Notification { + jsonrpc: Some(Version::V2), + method: msg.method.clone(), + params: msg.params.clone(), + })); + serde_json::to_string(&request).expect("Request serialization is infallible; qed") + } +} + +/// Parse raw string into a single JSON value, together with the request Id. +/// +/// This method will attempt to parse a JSON-RPC response object (either `Failure` or `Success`) +/// and a `Notification` (for Subscriptions). +/// Note that if you have more specific expectations about the returned type and don't want +/// to handle all of them it might be best to deserialize on your own. +pub fn parse_response( + response: &str, +) -> Result<(Id, Result, Option, Option), RpcError> { + jsonrpc_core::serde_from_str::(response) + .map_err(|e| RpcError::ParseError(e.to_string(), Box::new(e))) + .map(|response| { + let id = response.id().unwrap_or(Id::Null); + let sid = response.subscription_id(); + let method = response.method(); + let value: Result = response.into(); + let result = value.map_err(RpcError::JsonRpcError); + (id, result, method, sid) + }) +} + +/// A type representing all possible values sent from the server to the client. +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +#[serde(untagged)] +pub enum ClientResponse { + /// A regular JSON-RPC request output (single response). + Output(jsonrpc_core::Output), + /// A notification. + Notification(jsonrpc_core::Notification), +} + +impl ClientResponse { + /// Get the id of the response (if any). + pub fn id(&self) -> Option { + match *self { + ClientResponse::Output(ref output) => Some(output.id().clone()), + ClientResponse::Notification(_) => None, + } + } + + /// Get the method name if the output is a notification. + pub fn method(&self) -> Option { + match *self { + ClientResponse::Notification(ref n) => Some(n.method.to_owned()), + ClientResponse::Output(_) => None, + } + } + + /// Parses the response into a subscription id. + pub fn subscription_id(&self) -> Option { + match *self { + ClientResponse::Notification(ref n) => match &n.params { + jsonrpc_core::Params::Map(map) => match map.get("subscription") { + Some(value) => SubscriptionId::parse_value(value), + None => None, + }, + _ => None, + }, + _ => None, + } + } +} + +impl From for Result { + fn from(res: ClientResponse) -> Self { + match res { + ClientResponse::Output(output) => output.into(), + ClientResponse::Notification(n) => match &n.params { + Params::Map(map) => { + let subscription = map.get("subscription"); + let result = map.get("result"); + let error = map.get("error"); + + match (subscription, result, error) { + (Some(_), Some(result), _) => Ok(result.to_owned()), + (Some(_), _, Some(error)) => { + let error = serde_json::from_value::(error.to_owned()) + .ok() + .unwrap_or_else(Error::parse_error); + Err(error) + } + _ => Ok(n.params.into()), + } + } + _ => Ok(n.params.into()), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonrpc_core::{Failure, Notification, Output, Params, Success, Value, Version}; + + #[test] + fn notification_deserialize() { + let dsr = r#"{"jsonrpc":"2.0","method":"hello","params":[10]}"#; + + let deserialized: ClientResponse = jsonrpc_core::serde_from_str(dsr).unwrap(); + assert_eq!( + deserialized, + ClientResponse::Notification(Notification { + jsonrpc: Some(Version::V2), + method: "hello".into(), + params: Params::Array(vec![Value::from(10)]), + }) + ); + } + + #[test] + fn success_deserialize() { + let dsr = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; + + let deserialized: ClientResponse = jsonrpc_core::serde_from_str(dsr).unwrap(); + assert_eq!( + deserialized, + ClientResponse::Output(Output::Success(Success { + jsonrpc: Some(Version::V2), + id: Id::Num(1), + result: 1.into(), + })) + ); + } + + #[test] + fn failure_output_deserialize() { + let dfo = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#; + + let deserialized: ClientResponse = jsonrpc_core::serde_from_str(dfo).unwrap(); + assert_eq!( + deserialized, + ClientResponse::Output(Output::Failure(Failure { + jsonrpc: Some(Version::V2), + error: Error::parse_error(), + id: Id::Num(1) + })) + ); + } +} diff --git a/core-client/transports/src/transports/ws.rs b/core-client/transports/src/transports/ws.rs new file mode 100644 index 000000000..976d48748 --- /dev/null +++ b/core-client/transports/src/transports/ws.rs @@ -0,0 +1,200 @@ +//! JSON-RPC websocket client implementation. +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::{RpcChannel, RpcError}; +use websocket::{ClientBuilder, OwnedMessage}; + +/// Connect to a JSON-RPC websocket server. +/// +/// Uses an unbounded channel to queue outgoing rpc messages. +/// +/// Returns `Err` if the `url` is invalid. +pub fn try_connect(url: &str) -> Result>, RpcError> +where + T: From, +{ + let client_builder = ClientBuilder::new(url).map_err(|e| RpcError::Other(Box::new(e)))?; + Ok(do_connect(client_builder)) +} + +/// Connect to a JSON-RPC websocket server. +/// +/// Uses an unbounded channel to queue outgoing rpc messages. +pub fn connect(url: &url::Url) -> impl Future> +where + T: From, +{ + let client_builder = ClientBuilder::from_url(url); + do_connect(client_builder) +} + +fn do_connect(client_builder: ClientBuilder) -> impl Future> +where + T: From, +{ + use futures::compat::{Future01CompatExt, Sink01CompatExt, Stream01CompatExt}; + use futures::{SinkExt, StreamExt, TryFutureExt, TryStreamExt}; + use websocket::futures::Stream; + + client_builder + .async_connect(None) + .compat() + .map_err(|error| RpcError::Other(Box::new(error))) + .map_ok(|(client, _)| { + let (sink, stream) = client.split(); + + let sink = sink.sink_compat().sink_map_err(|e| RpcError::Other(Box::new(e))); + let stream = stream.compat().map_err(|e| RpcError::Other(Box::new(e))); + let (sink, stream) = WebsocketClient::new(sink, stream).split(); + let (sink, stream) = ( + Box::pin(sink), + Box::pin( + stream + .take_while(|x| futures::future::ready(x.is_ok())) + .map(|x| x.expect("Stream is closed upon first error.")), + ), + ); + let (rpc_client, sender) = super::duplex(sink, stream); + let rpc_client = rpc_client.map_err(|error| log::error!("{:?}", error)); + tokio::spawn(rpc_client); + + sender.into() + }) +} + +struct WebsocketClient { + sink: TSink, + stream: TStream, + queue: VecDeque, +} + +impl WebsocketClient +where + TSink: futures::Sink + Unpin, + TStream: futures::Stream> + Unpin, + TError: std::error::Error + Send + 'static, +{ + pub fn new(sink: TSink, stream: TStream) -> Self { + Self { + sink, + stream, + queue: VecDeque::new(), + } + } + + // Drains the internal buffer and attempts to forward as much of the items + // as possible to the underlying sink + fn try_empty_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + match Pin::new(&mut this.sink).poll_ready(cx) { + Poll::Ready(value) => value?, + Poll::Pending => return Poll::Pending, + } + + while let Some(item) = this.queue.pop_front() { + Pin::new(&mut this.sink).start_send(item)?; + + if !this.queue.is_empty() { + match Pin::new(&mut this.sink).poll_ready(cx) { + Poll::Ready(value) => value?, + Poll::Pending => return Poll::Pending, + } + } + } + + Poll::Ready(Ok(())) + } +} + +// This mostly forwards to the underlying sink but also adds an unbounded queue +// for when the underlying sink is incapable of receiving more items. +// See https://docs.rs/futures-util/0.3.8/futures_util/sink/struct.Buffer.html +// for the variant with a fixed-size buffer. +impl futures::Sink for WebsocketClient +where + TSink: futures::Sink + Unpin, + TStream: futures::Stream> + Unpin, +{ + type Error = RpcError; + + fn start_send(mut self: Pin<&mut Self>, request: String) -> Result<(), Self::Error> { + let request = OwnedMessage::Text(request); + + if self.queue.is_empty() { + let this = Pin::into_inner(self); + Pin::new(&mut this.sink).start_send(request) + } else { + self.queue.push_back(request); + Ok(()) + } + } + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + if this.queue.is_empty() { + return Pin::new(&mut this.sink).poll_ready(cx); + } + + let _ = Pin::new(this).try_empty_buffer(cx)?; + + Poll::Ready(Ok(())) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + match Pin::new(&mut *this).try_empty_buffer(cx) { + Poll::Ready(value) => value?, + Poll::Pending => return Poll::Pending, + } + debug_assert!(this.queue.is_empty()); + + Pin::new(&mut this.sink).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + match Pin::new(&mut *this).try_empty_buffer(cx) { + Poll::Ready(value) => value?, + Poll::Pending => return Poll::Pending, + } + debug_assert!(this.queue.is_empty()); + + Pin::new(&mut this.sink).poll_close(cx) + } +} + +impl futures::Stream for WebsocketClient +where + TSink: futures::Sink + Unpin, + TStream: futures::Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + loop { + match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(message))) => match message { + OwnedMessage::Text(data) => return Poll::Ready(Some(Ok(data))), + OwnedMessage::Binary(data) => log::info!("server sent binary data {:?}", data), + OwnedMessage::Ping(p) => this.queue.push_front(OwnedMessage::Pong(p)), + OwnedMessage::Pong(_) => {} + OwnedMessage::Close(c) => this.queue.push_front(OwnedMessage::Close(c)), + }, + Poll::Ready(None) => { + // TODO try to reconnect (#411). + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(RpcError::Other(Box::new(error))))), + } + } + } +} diff --git a/core/Cargo.toml b/core/Cargo.toml index e9471ca02..b1ba659e9 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,13 +1,14 @@ [package] +authors = ["Parity Technologies "] description = "Transport agnostic rust implementation of JSON-RPC 2.0 Specification." +documentation = "https://docs.rs/jsonrpc-core/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" -repository = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] license = "MIT" name = "jsonrpc-core" -version = "9.0.0" -authors = ["debris "] -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_core/index.html" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" categories = [ "asynchronous", @@ -19,10 +20,19 @@ categories = [ [dependencies] log = "0.4" -futures = "~0.1.6" +# FIXME: Currently a lot of jsonrpc-* crates depend on entire `futures` being +# re-exported but it's not strictly required for this crate. Either adapt the +# remaining crates or settle for this re-export to be a single, common dependency +futures = { version = "0.3", optional = true } +futures-util = { version = "0.3", default-features = false, features = ["std"] } +futures-executor = { version = "0.3", optional = true } serde = "1.0" serde_json = "1.0" serde_derive = "1.0" +[features] +default = ["futures-executor", "futures"] +arbitrary_precision = ["serde_json/arbitrary_precision"] + [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/core/README.md b/core/README.md deleted file mode 100644 index 7ec5b9d46..000000000 --- a/core/README.md +++ /dev/null @@ -1,63 +0,0 @@ -# jsonrpc-core -Transport agnostic rust implementation of JSON-RPC 2.0 Specification. - -[Documentation](http://paritytech.github.io/jsonrpc/jsonrpc_core/index.html) - -- [x] - server side -- [x] - client side - -## Example - -`Cargo.toml` - - -``` -[dependencies] -jsonrpc-core = "4.0" -``` - -`main.rs` - -```rust -extern crate jsonrpc_core; - -use jsonrpc_core::*; - -fn main() { - let mut io = IoHandler::default(); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".into())) - }); - - let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; - let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; - - assert_eq!(io.handle_request_sync(request), Some(response.to_owned())); -} -``` - -### Asynchronous responses - -`main.rs` - -```rust -extern crate jsonrpc_core; - -use jsonrpc_core::*; -use jsonrpc_core::futures::Future; - -fn main() { - let io = IoHandler::new(); - io.add_async_method("say_hello", |_params: Params| { - futures::finished(Value::String("hello".into())) - }); - - let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; - let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; - - assert_eq!(io.handle_request(request).wait().unwrap(), Some(response.to_owned())); -} -``` - -### Publish-Subscribe -See examples directory. diff --git a/core/examples/async.rs b/core/examples/async.rs index 29d0b4f9c..b6397b240 100644 --- a/core/examples/async.rs +++ b/core/examples/async.rs @@ -1,17 +1,16 @@ -extern crate jsonrpc_core; - use jsonrpc_core::*; -use jsonrpc_core::futures::Future; fn main() { - let mut io = IoHandler::new(); + futures_executor::block_on(async { + let mut io = IoHandler::new(); - io.add_method("say_hello", |_: Params| { - futures::finished(Value::String("Hello World!".to_owned())) - }); + io.add_method("say_hello", |_: Params| async { + Ok(Value::String("Hello World!".to_owned())) + }); - let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; - let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; + let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; + let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; - assert_eq!(io.handle_request(request).wait().unwrap(), Some(response.to_owned())); + assert_eq!(io.handle_request(request).await, Some(response.to_owned())); + }); } diff --git a/core/examples/basic.rs b/core/examples/basic.rs index 22300d575..f81c24b12 100644 --- a/core/examples/basic.rs +++ b/core/examples/basic.rs @@ -1,13 +1,9 @@ -extern crate jsonrpc_core; - use jsonrpc_core::*; fn main() { let mut io = IoHandler::new(); - io.add_method("say_hello", |_: Params| { - Ok(Value::String("Hello World!".to_owned())) - }); + io.add_sync_method("say_hello", |_: Params| Ok(Value::String("Hello World!".to_owned()))); let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; diff --git a/core/examples/meta.rs b/core/examples/meta.rs index e1d8db59a..abc3c6b55 100644 --- a/core/examples/meta.rs +++ b/core/examples/meta.rs @@ -1,7 +1,4 @@ -extern crate jsonrpc_core; - use jsonrpc_core::*; -use jsonrpc_core::futures::Future; #[derive(Clone, Default)] struct Meta(usize); @@ -10,7 +7,7 @@ impl Metadata for Meta {} pub fn main() { let mut io = MetaIoHandler::default(); - io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| { + io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| async move { Ok(Value::String(format!("Hello World: {}", meta.0))) }); @@ -19,7 +16,7 @@ pub fn main() { let headers = 5; assert_eq!( - io.handle_request(request, Meta(headers)).wait().unwrap(), + io.handle_request_sync(request, Meta(headers)), Some(response.to_owned()) ); } diff --git a/core/examples/middlewares.rs b/core/examples/middlewares.rs index 60a09fde4..6d25352af 100644 --- a/core/examples/middlewares.rs +++ b/core/examples/middlewares.rs @@ -1,10 +1,8 @@ -extern crate jsonrpc_core; - -use std::time::Instant; -use std::sync::atomic::{self, AtomicUsize}; +use jsonrpc_core::futures_util::{future::Either, FutureExt}; use jsonrpc_core::*; -use jsonrpc_core::futures::Future; -use jsonrpc_core::futures::future::Either; +use std::future::Future; +use std::sync::atomic::{self, AtomicUsize}; +use std::time::Instant; #[derive(Clone, Debug)] struct Meta(usize); @@ -16,15 +14,16 @@ impl Middleware for MyMiddleware { type Future = FutureResponse; type CallFuture = middleware::NoopCallFuture; - fn on_request(&self, request: Request, meta: Meta, next: F) -> Either where + fn on_request(&self, request: Request, meta: Meta, next: F) -> Either + where F: FnOnce(Request, Meta) -> X + Send, - X: Future, Error=()> + Send + 'static, + X: Future> + Send + 'static, { let start = Instant::now(); let request_number = self.0.fetch_add(1, atomic::Ordering::SeqCst); println!("Processing request {}: {:?}, {:?}", request_number, request, meta); - Either::A(Box::new(next(request, meta).map(move |res| { + Either::Left(Box::pin(next(request, meta).map(move |res| { println!("Processing took: {:?}", start.elapsed()); res }))) @@ -34,7 +33,7 @@ impl Middleware for MyMiddleware { pub fn main() { let mut io = MetaIoHandler::with_middleware(MyMiddleware::default()); - io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| { + io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| async move { Ok(Value::String(format!("Hello World: {}", meta.0))) }); @@ -43,7 +42,7 @@ pub fn main() { let headers = 5; assert_eq!( - io.handle_request(request, Meta(headers)).wait().unwrap(), + io.handle_request_sync(request, Meta(headers)), Some(response.to_owned()) ); } diff --git a/core/examples/params.rs b/core/examples/params.rs new file mode 100644 index 000000000..353ab770a --- /dev/null +++ b/core/examples/params.rs @@ -0,0 +1,21 @@ +use jsonrpc_core::*; +use serde_derive::Deserialize; + +#[derive(Deserialize)] +struct HelloParams { + name: String, +} + +fn main() { + let mut io = IoHandler::new(); + + io.add_method("say_hello", |params: Params| async move { + let parsed: HelloParams = params.parse().unwrap(); + Ok(Value::String(format!("hello, {}", parsed.name))) + }); + + let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": { "name": "world" }, "id": 1}"#; + let response = r#"{"jsonrpc":"2.0","result":"hello, world","id":1}"#; + + assert_eq!(io.handle_request_sync(request), Some(response.to_owned())); +} diff --git a/core/src/calls.rs b/core/src/calls.rs index 997f7276f..dfc8dcb16 100644 --- a/core/src/calls.rs +++ b/core/src/calls.rs @@ -1,8 +1,8 @@ +use crate::types::{Error, Params, Value}; +use crate::BoxFuture; use std::fmt; +use std::future::Future; use std::sync::Arc; -use types::{Params, Value, Error}; -use futures::{Future, IntoFuture}; -use BoxFuture; /// Metadata trait pub trait Metadata: Clone + Send + 'static {} @@ -11,10 +11,34 @@ impl Metadata for Option {} impl Metadata for Box {} impl Metadata for Arc {} +/// A future-conversion trait. +pub trait WrapFuture { + /// Convert itself into a boxed future. + fn into_future(self) -> BoxFuture>; +} + +impl WrapFuture for Result { + fn into_future(self) -> BoxFuture> { + Box::pin(async { self }) + } +} + +impl WrapFuture for BoxFuture> { + fn into_future(self) -> BoxFuture> { + self + } +} + +/// A synchronous or asynchronous method. +pub trait RpcMethodSync: Send + Sync + 'static { + /// Call method + fn call(&self, params: Params) -> BoxFuture>; +} + /// Asynchronous Method pub trait RpcMethodSimple: Send + Sync + 'static { /// Output future - type Out: Future + Send; + type Out: Future> + Send; /// Call method fn call(&self, params: Params) -> Self::Out; } @@ -22,7 +46,7 @@ pub trait RpcMethodSimple: Send + Sync + 'static { /// Asynchronous Method with Metadata pub trait RpcMethod: Send + Sync + 'static { /// Call method - fn call(&self, params: Params, meta: T) -> BoxFuture; + fn call(&self, params: Params, meta: T) -> BoxFuture>; } /// Notification @@ -41,9 +65,9 @@ pub trait RpcNotification: Send + Sync + 'static { #[derive(Clone)] pub enum RemoteProcedure { /// A method call - Method(Arc>), + Method(Arc>), /// A notification - Notification(Arc>), + Notification(Arc>), /// An alias to other method, Alias(String), } @@ -54,23 +78,34 @@ impl fmt::Debug for RemoteProcedure { match *self { Method(..) => write!(fmt, ""), Notification(..) => write!(fmt, ""), - Alias(ref alias) => write!(fmt, "alias => {:?}", alias) + Alias(ref alias) => write!(fmt, "alias => {:?}", alias), } } } -impl RpcMethodSimple for F where - F: Fn(Params) -> I, - X: Future, - I: IntoFuture, +impl RpcMethodSimple for F +where + F: Fn(Params) -> X, + X: Future>, { type Out = X; fn call(&self, params: Params) -> Self::Out { + self(params) + } +} + +impl RpcMethodSync for F +where + F: Fn(Params) -> X, + X: WrapFuture, +{ + fn call(&self, params: Params) -> BoxFuture> { self(params).into_future() } } -impl RpcNotificationSimple for F where +impl RpcNotificationSimple for F +where F: Fn(Params), { fn execute(&self, params: Params) { @@ -78,18 +113,19 @@ impl RpcNotificationSimple for F where } } -impl RpcMethod for F where +impl RpcMethod for F +where T: Metadata, - F: Fn(Params, T) -> I, - I: IntoFuture, - X: Future, + F: Fn(Params, T) -> X, + X: Future>, { - fn call(&self, params: Params, meta: T) -> BoxFuture { - Box::new(self(params, meta).into_future()) + fn call(&self, params: Params, meta: T) -> BoxFuture> { + Box::pin(self(params, meta)) } } -impl RpcNotification for F where +impl RpcNotification for F +where T: Metadata, F: Fn(Params, T), { diff --git a/core/src/delegates.rs b/core/src/delegates.rs new file mode 100644 index 000000000..15b60e65a --- /dev/null +++ b/core/src/delegates.rs @@ -0,0 +1,198 @@ +//! Delegate rpc calls + +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; + +use crate::calls::{Metadata, RemoteProcedure, RpcMethod, RpcNotification}; +use crate::types::{Error, Params, Value}; +use crate::BoxFuture; + +struct DelegateAsyncMethod { + delegate: Arc, + closure: F, +} + +impl RpcMethod for DelegateAsyncMethod +where + M: Metadata, + F: Fn(&T, Params) -> I, + I: Future> + Send + 'static, + T: Send + Sync + 'static, + F: Send + Sync + 'static, +{ + fn call(&self, params: Params, _meta: M) -> BoxFuture> { + let closure = &self.closure; + Box::pin(closure(&self.delegate, params)) + } +} + +struct DelegateMethodWithMeta { + delegate: Arc, + closure: F, +} + +impl RpcMethod for DelegateMethodWithMeta +where + M: Metadata, + F: Fn(&T, Params, M) -> I, + I: Future> + Send + 'static, + T: Send + Sync + 'static, + F: Send + Sync + 'static, +{ + fn call(&self, params: Params, meta: M) -> BoxFuture> { + let closure = &self.closure; + Box::pin(closure(&self.delegate, params, meta)) + } +} + +struct DelegateNotification { + delegate: Arc, + closure: F, +} + +impl RpcNotification for DelegateNotification +where + M: Metadata, + F: Fn(&T, Params) + 'static, + F: Send + Sync + 'static, + T: Send + Sync + 'static, +{ + fn execute(&self, params: Params, _meta: M) { + let closure = &self.closure; + closure(&self.delegate, params) + } +} + +struct DelegateNotificationWithMeta { + delegate: Arc, + closure: F, +} + +impl RpcNotification for DelegateNotificationWithMeta +where + M: Metadata, + F: Fn(&T, Params, M) + 'static, + F: Send + Sync + 'static, + T: Send + Sync + 'static, +{ + fn execute(&self, params: Params, meta: M) { + let closure = &self.closure; + closure(&self.delegate, params, meta) + } +} + +/// A set of RPC methods and notifications tied to single `delegate` struct. +pub struct IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + delegate: Arc, + methods: HashMap>, +} + +impl IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + /// Creates new `IoDelegate` + pub fn new(delegate: Arc) -> Self { + IoDelegate { + delegate, + methods: HashMap::new(), + } + } + + /// Adds an alias to existing method. + /// NOTE: Aliases are not transitive, i.e. you cannot create alias to an alias. + pub fn add_alias(&mut self, from: &str, to: &str) { + self.methods.insert(from.into(), RemoteProcedure::Alias(to.into())); + } + + /// Adds async method to the delegate. + pub fn add_method(&mut self, name: &str, method: F) + where + F: Fn(&T, Params) -> I, + I: Future> + Send + 'static, + F: Send + Sync + 'static, + { + self.methods.insert( + name.into(), + RemoteProcedure::Method(Arc::new(DelegateAsyncMethod { + delegate: self.delegate.clone(), + closure: method, + })), + ); + } + + /// Adds async method with metadata to the delegate. + pub fn add_method_with_meta(&mut self, name: &str, method: F) + where + F: Fn(&T, Params, M) -> I, + I: Future> + Send + 'static, + F: Send + Sync + 'static, + { + self.methods.insert( + name.into(), + RemoteProcedure::Method(Arc::new(DelegateMethodWithMeta { + delegate: self.delegate.clone(), + closure: method, + })), + ); + } + + /// Adds notification to the delegate. + pub fn add_notification(&mut self, name: &str, notification: F) + where + F: Fn(&T, Params), + F: Send + Sync + 'static, + { + self.methods.insert( + name.into(), + RemoteProcedure::Notification(Arc::new(DelegateNotification { + delegate: self.delegate.clone(), + closure: notification, + })), + ); + } + + /// Adds notification with metadata to the delegate. + pub fn add_notification_with_meta(&mut self, name: &str, notification: F) + where + F: Fn(&T, Params, M), + F: Send + Sync + 'static, + { + self.methods.insert( + name.into(), + RemoteProcedure::Notification(Arc::new(DelegateNotificationWithMeta { + delegate: self.delegate.clone(), + closure: notification, + })), + ); + } +} + +impl crate::io::IoHandlerExtension for IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + fn augment>(self, handler: &mut crate::MetaIoHandler) { + handler.extend_with(self.methods) + } +} + +impl IntoIterator for IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + type Item = (String, RemoteProcedure); + type IntoIter = std::collections::hash_map::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.methods.into_iter() + } +} diff --git a/core/src/io.rs b/core/src/io.rs index 43d8cc4da..c67ff2620 100644 --- a/core/src/io.rs +++ b/core/src/io.rs @@ -1,48 +1,42 @@ -use std::sync::Arc; -use std::collections::HashMap; +use std::collections::{ + hash_map::{IntoIter, Iter}, + HashMap, +}; +use std::future::Future; use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::sync::Arc; -use serde_json; -use futures::{self, future, Future}; +use futures_util::{self, future, FutureExt}; -use calls::{RemoteProcedure, Metadata, RpcMethodSimple, RpcMethod, RpcNotificationSimple, RpcNotification}; -use middleware::{self, Middleware}; -use types::{Error, ErrorCode, Version}; -use types::{Request, Response, Call, Output}; +use crate::calls::{ + Metadata, RemoteProcedure, RpcMethod, RpcMethodSimple, RpcMethodSync, RpcNotification, RpcNotificationSimple, +}; +use crate::middleware::{self, Middleware}; +use crate::types::{Call, Output, Request, Response}; +use crate::types::{Error, ErrorCode, Version}; /// A type representing middleware or RPC response before serialization. -pub type FutureResponse = Box, Error=()> + Send>; +pub type FutureResponse = Pin> + Send>>; /// A type representing middleware or RPC call output. -pub type FutureOutput = Box, Error=()> + Send>; +pub type FutureOutput = Pin> + Send>>; /// A type representing future string response. pub type FutureResult = future::Map< - future::Either, ()>, FutureRpcResult>, + future::Either>, FutureRpcResult>, fn(Option) -> Option, >; /// A type representing a result of a single method call. -pub type FutureRpcOutput = future::Either< - F, - future::Either< - FutureOutput, - future::FutureResult, ()>, - >, ->; +pub type FutureRpcOutput = future::Either>>>; /// A type representing an optional `Response` for RPC `Request`. pub type FutureRpcResult = future::Either< F, future::Either< - future::Map< - FutureRpcOutput, - fn(Option) -> Option, - >, - future::Map< - future::JoinAll>>, - fn(Vec>) -> Option, - >, + future::Map, fn(Option) -> Option>, + future::Map>, fn(Vec>) -> Option>, >, >; @@ -64,17 +58,15 @@ impl Default for Compatibility { } impl Compatibility { - fn is_version_valid(&self, version: Option) -> bool { - match (*self, version) { - (Compatibility::V1, None) | - (Compatibility::V2, Some(Version::V2)) | - (Compatibility::Both, _) => true, - _ => false, - } + fn is_version_valid(self, version: Option) -> bool { + matches!( + (self, version), + (Compatibility::V1, None) | (Compatibility::V2, Some(Version::V2)) | (Compatibility::Both, _) + ) } - fn default_version(&self) -> Option { - match *self { + fn default_version(self) -> Option { + match self { Compatibility::V1 => None, Compatibility::V2 | Compatibility::Both => Some(Version::V2), } @@ -84,7 +76,7 @@ impl Compatibility { /// Request handler /// /// By default compatible only with jsonrpc v2 -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MetaIoHandler = middleware::Noop> { middleware: S, compatibility: Compatibility, @@ -97,24 +89,41 @@ impl Default for MetaIoHandler { } } +impl> IntoIterator for MetaIoHandler { + type Item = (String, RemoteProcedure); + type IntoIter = IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.methods.into_iter() + } +} + +impl<'a, T: Metadata, S: Middleware> IntoIterator for &'a MetaIoHandler { + type Item = (&'a String, &'a RemoteProcedure); + type IntoIter = Iter<'a, String, RemoteProcedure>; + + fn into_iter(self) -> Self::IntoIter { + self.methods.iter() + } +} + impl MetaIoHandler { /// Creates new `MetaIoHandler` compatible with specified protocol version. pub fn with_compatibility(compatibility: Compatibility) -> Self { MetaIoHandler { - compatibility: compatibility, + compatibility, middleware: Default::default(), methods: Default::default(), } } } - impl> MetaIoHandler { /// Creates new `MetaIoHandler` pub fn new(compatibility: Compatibility, middleware: S) -> Self { MetaIoHandler { - compatibility: compatibility, - middleware: middleware, + compatibility, + middleware, methods: Default::default(), } } @@ -123,86 +132,93 @@ impl> MetaIoHandler { pub fn with_middleware(middleware: S) -> Self { MetaIoHandler { compatibility: Default::default(), - middleware: middleware, + middleware, methods: Default::default(), } } /// Adds an alias to a method. pub fn add_alias(&mut self, alias: &str, other: &str) { - self.methods.insert( - alias.into(), - RemoteProcedure::Alias(other.into()), - ); + self.methods.insert(alias.into(), RemoteProcedure::Alias(other.into())); + } + + /// Adds new supported synchronous method. + /// + /// A backward-compatible wrapper. + pub fn add_sync_method(&mut self, name: &str, method: F) + where + F: RpcMethodSync, + { + self.add_method(name, move |params| method.call(params)) } - /// Adds new supported asynchronous method - pub fn add_method(&mut self, name: &str, method: F) where + /// Adds new supported asynchronous method. + pub fn add_method(&mut self, name: &str, method: F) + where F: RpcMethodSimple, { - self.add_method_with_meta(name, move |params, _meta| { - method.call(params) - }) + self.add_method_with_meta(name, move |params, _meta| method.call(params)) } /// Adds new supported notification - pub fn add_notification(&mut self, name: &str, notification: F) where + pub fn add_notification(&mut self, name: &str, notification: F) + where F: RpcNotificationSimple, { self.add_notification_with_meta(name, move |params, _meta| notification.execute(params)) } /// Adds new supported asynchronous method with metadata support. - pub fn add_method_with_meta(&mut self, name: &str, method: F) where + pub fn add_method_with_meta(&mut self, name: &str, method: F) + where F: RpcMethod, { - self.methods.insert( - name.into(), - RemoteProcedure::Method(Arc::new(method)), - ); + self.methods + .insert(name.into(), RemoteProcedure::Method(Arc::new(method))); } /// Adds new supported notification with metadata support. - pub fn add_notification_with_meta(&mut self, name: &str, notification: F) where + pub fn add_notification_with_meta(&mut self, name: &str, notification: F) + where F: RpcNotification, { - self.methods.insert( - name.into(), - RemoteProcedure::Notification(Arc::new(notification)), - ); + self.methods + .insert(name.into(), RemoteProcedure::Notification(Arc::new(notification))); } /// Extend this `MetaIoHandler` with methods defined elsewhere. - pub fn extend_with(&mut self, methods: F) where - F: Into>> + pub fn extend_with(&mut self, methods: F) + where + F: IntoIterator)>, { - self.methods.extend(methods.into()) + self.methods.extend(methods) } /// Handle given request synchronously - will block until response is available. /// If you have any asynchronous methods in your RPC it is much wiser to use /// `handle_request` instead and deal with asynchronous requests in a non-blocking fashion. + #[cfg(feature = "futures-executor")] pub fn handle_request_sync(&self, request: &str, meta: T) -> Option { - self.handle_request(request, meta).wait().expect("Handler calls can never fail.") + futures_executor::block_on(self.handle_request(request, meta)) } /// Handle given request asynchronously. pub fn handle_request(&self, request: &str, meta: T) -> FutureResult { - use self::future::Either::{A, B}; + use self::future::Either::{Left, Right}; fn as_string(response: Option) -> Option { let res = response.map(write_response); - debug!(target: "rpc", "Response: {}.", match res { - Some(ref res) => res, - None => "None", - }); + debug!(target: "rpc", "Response: {}.", res.as_ref().unwrap_or(&"None".to_string())); res } trace!(target: "rpc", "Request: {}.", request); let request = read_request(request); let result = match request { - Err(error) => A(futures::finished(Some(Response::from(error, self.compatibility.default_version())))), - Ok(request) => B(self.handle_rpc_request(request, meta)), + Err(error) => Left(future::ready(Some(Response::from( + error, + self.compatibility.default_version(), + )))), + Ok(request) => Right(self.handle_rpc_request(request, meta)), }; result.map(as_string) @@ -210,14 +226,14 @@ impl> MetaIoHandler { /// Handle deserialized RPC request. pub fn handle_rpc_request(&self, request: Request, meta: T) -> FutureRpcResult { - use self::future::Either::{A, B}; + use self::future::Either::{Left, Right}; fn output_as_response(output: Option) -> Option { output.map(Response::Single) } fn outputs_as_batch(outs: Vec>) -> Option { - let outs: Vec<_> = outs.into_iter().filter_map(|v| v).collect(); + let outs: Vec<_> = outs.into_iter().flatten().collect(); if outs.is_empty() { None } else { @@ -225,22 +241,27 @@ impl> MetaIoHandler { } } - self.middleware.on_request(request, meta, |request, meta| match request { - Request::Single(call) => { - A(self.handle_call(call, meta).map(output_as_response as fn(Option) -> - Option)) - }, - Request::Batch(calls) => { - let futures: Vec<_> = calls.into_iter().map(move |call| self.handle_call(call, meta.clone())).collect(); - B(futures::future::join_all(futures).map(outputs_as_batch as fn(Vec>) -> - Option)) - }, - }) + self.middleware + .on_request(request, meta, |request, meta| match request { + Request::Single(call) => Left( + self.handle_call(call, meta) + .map(output_as_response as fn(Option) -> Option), + ), + Request::Batch(calls) => { + let futures: Vec<_> = calls + .into_iter() + .map(move |call| self.handle_call(call, meta.clone())) + .collect(); + Right( + future::join_all(futures).map(outputs_as_batch as fn(Vec>) -> Option), + ) + } + }) } /// Handle single call asynchronously. pub fn handle_call(&self, call: Call, meta: T) -> FutureRpcOutput { - use self::future::Either::{A, B}; + use self::future::Either::{Left, Right}; self.middleware.on_call(call, meta, |call, meta| match call { Call::MethodCall(method) => { @@ -249,10 +270,7 @@ impl> MetaIoHandler { let jsonrpc = method.jsonrpc; let valid_version = self.compatibility.is_version_valid(jsonrpc); - let call_method = |method: &Arc>| { - let method = method.clone(); - futures::lazy(move || method.call(params, meta)) - }; + let call_method = |method: &Arc>| method.call(params, meta); let result = match (valid_version, self.methods.get(&method.method)) { (false, _) => Err(Error::invalid_version()), @@ -265,44 +283,130 @@ impl> MetaIoHandler { }; match result { - Ok(result) => A(Box::new( - result.then(move |result| futures::finished(Some(Output::from(result, id, jsonrpc)))) + Ok(result) => Left(Box::pin( + result.then(move |result| future::ready(Some(Output::from(result, id, jsonrpc)))), ) as _), - Err(err) => B(futures::finished(Some(Output::from(Err(err), id, jsonrpc)))), + Err(err) => Right(future::ready(Some(Output::from(Err(err), id, jsonrpc)))), } - }, + } Call::Notification(notification) => { let params = notification.params; let jsonrpc = notification.jsonrpc; if !self.compatibility.is_version_valid(jsonrpc) { - return B(futures::finished(None)); + return Right(future::ready(None)); } match self.methods.get(¬ification.method) { Some(&RemoteProcedure::Notification(ref notification)) => { notification.execute(params, meta); - }, + } Some(&RemoteProcedure::Alias(ref alias)) => { if let Some(&RemoteProcedure::Notification(ref notification)) = self.methods.get(alias) { notification.execute(params, meta); } - }, - _ => {}, + } + _ => {} } - B(futures::finished(None)) - }, - Call::Invalid { id } => { - B(futures::finished(Some(Output::invalid_request(id, self.compatibility.default_version())))) - }, + Right(future::ready(None)) + } + Call::Invalid { id } => Right(future::ready(Some(Output::invalid_request( + id, + self.compatibility.default_version(), + )))), }) } + + /// Returns an iterator visiting all methods in arbitrary order. + pub fn iter(&self) -> impl Iterator)> { + self.methods.iter() + } +} + +/// A type that can augment `MetaIoHandler`. +/// +/// This allows your code to accept generic extensions for `IoHandler` +/// and compose them to create the RPC server. +pub trait IoHandlerExtension { + /// Extend given `handler` with additional methods. + fn augment>(self, handler: &mut MetaIoHandler); +} + +macro_rules! impl_io_handler_extension { + ($( $x:ident, )*) => { + impl IoHandlerExtension for ($( $x, )*) where + M: Metadata, + $( + $x: IoHandlerExtension, + )* + { + #[allow(unused)] + fn augment>(self, handler: &mut MetaIoHandler) { + #[allow(non_snake_case)] + let ( + $( $x, )* + ) = self; + $( + $x.augment(handler); + )* + } + } + } +} + +impl_io_handler_extension!(); +impl_io_handler_extension!(A,); +impl_io_handler_extension!(A, B,); +impl_io_handler_extension!(A, B, C,); +impl_io_handler_extension!(A, B, C, D,); +impl_io_handler_extension!(A, B, C, D, E,); +impl_io_handler_extension!(A, B, C, D, E, F,); +impl_io_handler_extension!(A, B, C, D, E, F, G,); +impl_io_handler_extension!(A, B, C, D, E, F, G, H,); +impl_io_handler_extension!(A, B, C, D, E, F, G, H, I,); +impl_io_handler_extension!(A, B, C, D, E, F, G, H, I, J,); +impl_io_handler_extension!(A, B, C, D, E, F, G, H, I, J, K,); +impl_io_handler_extension!(A, B, C, D, E, F, G, H, I, J, K, L,); + +impl IoHandlerExtension for Vec<(String, RemoteProcedure)> { + fn augment>(self, handler: &mut MetaIoHandler) { + handler.methods.extend(self) + } +} + +impl IoHandlerExtension for HashMap> { + fn augment>(self, handler: &mut MetaIoHandler) { + handler.methods.extend(self) + } +} + +impl> IoHandlerExtension for MetaIoHandler { + fn augment>(self, handler: &mut MetaIoHandler) { + handler.methods.extend(self.methods) + } +} + +impl> IoHandlerExtension for Option { + fn augment>(self, handler: &mut MetaIoHandler) { + if let Some(x) = self { + x.augment(handler) + } + } } /// Simplified `IoHandler` with no `Metadata` associated with each request. -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct IoHandler(MetaIoHandler); +impl IntoIterator for IoHandler { + type Item = as IntoIterator>::Item; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + // Type inference helper impl IoHandler { /// Creates new `IoHandler` without any metadata. @@ -335,6 +439,7 @@ impl IoHandler { /// Handle given request synchronously - will block until response is available. /// If you have any asynchronous methods in your RPC it is much wiser to use /// `handle_request` instead and deal with asynchronous requests in a non-blocking fashion. + #[cfg(feature = "futures-executor")] pub fn handle_request_sync(&self, request: &str) -> Option { self.0.handle_request_sync(request, M::default()) } @@ -360,8 +465,14 @@ impl From for MetaIoHandler<()> { } } +impl IoHandlerExtension for IoHandler { + fn augment>(self, handler: &mut MetaIoHandler) { + handler.methods.extend(self.0.methods) + } +} + fn read_request(request_str: &str) -> Result { - serde_json::from_str(request_str).map_err(|_| Error::new(ErrorCode::ParseError)) + crate::serde_from_str(request_str).map_err(|_| Error::new(ErrorCode::ParseError)) } fn write_response(response: Response) -> String { @@ -371,17 +482,14 @@ fn write_response(response: Response) -> String { #[cfg(test)] mod tests { - use futures; - use types::{Value}; - use super::{IoHandler, Compatibility}; + use super::{Compatibility, IoHandler}; + use crate::types::Value; #[test] fn test_io_handler() { let mut io = IoHandler::new(); - io.add_method("say_hello", |_| { - Ok(Value::String("hello".to_string())) - }); + io.add_method("say_hello", |_| async { Ok(Value::String("hello".to_string())) }); let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; @@ -393,9 +501,7 @@ mod tests { fn test_io_handler_1dot0() { let mut io = IoHandler::with_compatibility(Compatibility::Both); - io.add_method("say_hello", |_| { - Ok(Value::String("hello".to_string())) - }); + io.add_method("say_hello", |_| async { Ok(Value::String("hello".to_string())) }); let request = r#"{"method": "say_hello", "params": [42, 23], "id": 1}"#; let response = r#"{"result":"hello","id":1}"#; @@ -407,9 +513,7 @@ mod tests { fn test_async_io_handler() { let mut io = IoHandler::new(); - io.add_method("say_hello", |_| { - futures::finished(Value::String("hello".to_string())) - }); + io.add_method("say_hello", |_| async { Ok(Value::String("hello".to_string())) }); let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; @@ -419,8 +523,8 @@ mod tests { #[test] fn test_notification() { - use std::sync::Arc; use std::sync::atomic; + use std::sync::Arc; let mut io = IoHandler::new(); @@ -448,12 +552,9 @@ mod tests { #[test] fn test_method_alias() { let mut io = IoHandler::new(); - io.add_method("say_hello", |_| { - Ok(Value::String("hello".to_string())) - }); + io.add_method("say_hello", |_| async { Ok(Value::String("hello".to_string())) }); io.add_alias("say_hello_alias", "say_hello"); - let request = r#"{"jsonrpc": "2.0", "method": "say_hello_alias", "params": [42, 23], "id": 1}"#; let response = r#"{"jsonrpc":"2.0","result":"hello","id":1}"#; @@ -462,8 +563,8 @@ mod tests { #[test] fn test_notification_alias() { - use std::sync::Arc; use std::sync::atomic; + use std::sync::Arc; let mut io = IoHandler::new(); @@ -479,10 +580,29 @@ mod tests { assert_eq!(called.load(atomic::Ordering::SeqCst), true); } + #[test] + fn test_batch_notification() { + use std::sync::atomic; + use std::sync::Arc; + + let mut io = IoHandler::new(); + + let called = Arc::new(atomic::AtomicBool::new(false)); + let c = called.clone(); + io.add_notification("say_hello", move |_| { + c.store(true, atomic::Ordering::SeqCst); + }); + + let request = r#"[{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23]}]"#; + assert_eq!(io.handle_request_sync(request), None); + assert_eq!(called.load(atomic::Ordering::SeqCst), true); + } + #[test] fn test_send_sync() { - fn is_send_sync(_obj: T) -> bool where - T: Send + Sync + fn is_send_sync(_obj: T) -> bool + where + T: Send + Sync, { true } @@ -491,4 +611,30 @@ mod tests { assert!(is_send_sync(io)) } + + #[test] + fn test_extending_by_multiple_delegates() { + use super::IoHandlerExtension; + use crate::delegates::IoDelegate; + use std::sync::Arc; + + struct Test; + impl Test { + fn abc(&self, _p: crate::Params) -> crate::BoxFuture> { + Box::pin(async { Ok(5.into()) }) + } + } + + let mut io = IoHandler::new(); + let mut del1 = IoDelegate::new(Arc::new(Test)); + del1.add_method("rpc_test", Test::abc); + let mut del2 = IoDelegate::new(Arc::new(Test)); + del2.add_method("rpc_test", Test::abc); + + fn augment(x: X, io: &mut IoHandler) { + x.augment(io); + } + + augment((del1, del2), &mut io); + } } diff --git a/core/src/lib.rs b/core/src/lib.rs index e7fa46731..1828b806a 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -3,48 +3,77 @@ //! Right now it supports only server side handling requests. //! //! ```rust -//! extern crate jsonrpc_core; +//! use jsonrpc_core::IoHandler; +//! use jsonrpc_core::Value; +//! let mut io = IoHandler::new(); +//! io.add_sync_method("say_hello", |_| { +//! Ok(Value::String("Hello World!".into())) +//! }); //! -//! use jsonrpc_core::*; -//! use jsonrpc_core::futures::Future; +//! let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; +//! let response = r#"{"jsonrpc":"2.0","result":"Hello World!","id":1}"#; //! -//! fn main() { -//! let mut io = IoHandler::new(); -//! io.add_method("say_hello", |_| { -//! Ok(Value::String("Hello World!".into())) -//! }); -//! -//! let request = r#"{"jsonrpc": "2.0", "method": "say_hello", "params": [42, 23], "id": 1}"#; -//! let response = r#"{"jsonrpc":"2.0","result":"Hello World!","id":1}"#; -//! -//! assert_eq!(io.handle_request(request).wait().unwrap(), Some(response.to_string())); -//! } +//! assert_eq!(io.handle_request_sync(request), Some(response.to_string())); //! ``` -#![warn(missing_docs)] +#![deny(missing_docs)] -#[macro_use] extern crate log; -#[macro_use] extern crate serde_derive; -extern crate serde; +use std::pin::Pin; -pub extern crate futures; +#[macro_use] +extern crate log; +#[macro_use] +extern crate serde_derive; +#[cfg(feature = "futures")] +pub use futures; +#[cfg(feature = "futures-executor")] +pub use futures_executor; +pub use futures_util; + +#[doc(hidden)] +pub extern crate serde; #[doc(hidden)] pub extern crate serde_json; mod calls; mod io; +pub mod delegates; pub mod middleware; pub mod types; +/// A Result type. +pub type Result = std::result::Result; + /// A `Future` trait object. -pub type BoxFuture = Box + Send>; +pub type BoxFuture = Pin + Send>>; -/// A Result type. -pub type Result = ::std::result::Result; +pub use crate::calls::{ + Metadata, RemoteProcedure, RpcMethod, RpcMethodSimple, RpcMethodSync, RpcNotification, RpcNotificationSimple, + WrapFuture, +}; +pub use crate::delegates::IoDelegate; +pub use crate::io::{ + Compatibility, FutureOutput, FutureResponse, FutureResult, FutureRpcResult, IoHandler, IoHandlerExtension, + MetaIoHandler, +}; +pub use crate::middleware::{Middleware, Noop as NoopMiddleware}; +pub use crate::types::*; + +use serde_json::Error as SerdeError; -pub use calls::{RemoteProcedure, Metadata, RpcMethodSimple, RpcMethod, RpcNotificationSimple, RpcNotification}; -pub use io::{Compatibility, IoHandler, MetaIoHandler, FutureOutput, FutureResult, FutureResponse, FutureRpcResult}; -pub use middleware::{Middleware, Noop as NoopMiddleware}; -pub use types::*; +/// workaround for https://github.com/serde-rs/json/issues/505 +/// Arbitrary precision confuses serde when deserializing into untagged enums, +/// this is a workaround +pub fn serde_from_str<'a, T>(input: &'a str) -> std::result::Result +where + T: serde::de::Deserialize<'a>, +{ + if cfg!(feature = "arbitrary_precision") { + let val = serde_json::from_str::(input)?; + T::deserialize(val) + } else { + serde_json::from_str::(input) + } +} diff --git a/core/src/middleware.rs b/core/src/middleware.rs index 699e2853e..308a787c4 100644 --- a/core/src/middleware.rs +++ b/core/src/middleware.rs @@ -1,133 +1,142 @@ //! `IoHandler` middlewares -use calls::Metadata; -use types::{Request, Response, Call, Output}; -use futures::{future::Either, Future}; +use crate::calls::Metadata; +use crate::types::{Call, Output, Request, Response}; +use futures_util::future::Either; +use std::future::Future; +use std::pin::Pin; /// RPC middleware pub trait Middleware: Send + Sync + 'static { /// A returned request future. - type Future: Future, Error=()> + Send + 'static; + type Future: Future> + Send + 'static; /// A returned call future. - type CallFuture: Future, Error=()> + Send + 'static; + type CallFuture: Future> + Send + 'static; /// Method invoked on each request. /// Allows you to either respond directly (without executing RPC call) /// or do any additional work before and/or after processing the request. - fn on_request(&self, request: Request, meta: M, next: F) -> Either where - F: FnOnce(Request, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_request(&self, request: Request, meta: M, next: F) -> Either + where + F: Fn(Request, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - Either::B(next(request, meta)) + Either::Right(next(request, meta)) } /// Method invoked on each call inside a request. /// /// Allows you to either handle the call directly (without executing RPC call). - fn on_call(&self, call: Call, meta: M, next: F) -> Either where - F: FnOnce(Call, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_call(&self, call: Call, meta: M, next: F) -> Either + where + F: Fn(Call, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - Either::B(next(call, meta)) + Either::Right(next(call, meta)) } } /// Dummy future used as a noop result of middleware. -pub type NoopFuture = Box, Error=()> + Send>; +pub type NoopFuture = Pin> + Send>>; /// Dummy future used as a noop call result of middleware. -pub type NoopCallFuture = Box, Error=()> + Send>; +pub type NoopCallFuture = Pin> + Send>>; /// No-op middleware implementation -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct Noop; impl Middleware for Noop { type Future = NoopFuture; type CallFuture = NoopCallFuture; } -impl, B: Middleware> - Middleware for (A, B) -{ +impl, B: Middleware> Middleware for (A, B) { type Future = Either; type CallFuture = Either; - fn on_request(&self, request: Request, meta: M, process: F) -> Either where - F: FnOnce(Request, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_request(&self, request: Request, meta: M, process: F) -> Either + where + F: Fn(Request, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_request(request, meta, move |request, meta| { - self.1.on_request(request, meta, process) + repack(self.0.on_request(request, meta, |request, meta| { + self.1.on_request(request, meta, &process) })) } - fn on_call(&self, call: Call, meta: M, process: F) -> Either where - F: FnOnce(Call, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_call(&self, call: Call, meta: M, process: F) -> Either + where + F: Fn(Call, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_call(call, meta, move |call, meta| { - self.1.on_call(call, meta, process) - })) + repack( + self.0 + .on_call(call, meta, |call, meta| self.1.on_call(call, meta, &process)), + ) } } -impl, B: Middleware, C: Middleware> - Middleware for (A, B, C) -{ +impl, B: Middleware, C: Middleware> Middleware for (A, B, C) { type Future = Either>; type CallFuture = Either>; - fn on_request(&self, request: Request, meta: M, process: F) -> Either where - F: FnOnce(Request, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_request(&self, request: Request, meta: M, process: F) -> Either + where + F: Fn(Request, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_request(request, meta, move |request, meta| { - repack(self.1.on_request(request, meta, move |request, meta| { - self.2.on_request(request, meta, process) + repack(self.0.on_request(request, meta, |request, meta| { + repack(self.1.on_request(request, meta, |request, meta| { + self.2.on_request(request, meta, &process) })) })) } - fn on_call(&self, call: Call, meta: M, process: F) -> Either where - F: FnOnce(Call, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_call(&self, call: Call, meta: M, process: F) -> Either + where + F: Fn(Call, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_call(call, meta, move |call, meta| { - repack(self.1.on_call(call, meta, move |call, meta| { - self.2.on_call(call, meta, process) - })) + repack(self.0.on_call(call, meta, |call, meta| { + repack( + self.1 + .on_call(call, meta, |call, meta| self.2.on_call(call, meta, &process)), + ) })) } } -impl, B: Middleware, C: Middleware, D: Middleware> - Middleware for (A, B, C, D) +impl, B: Middleware, C: Middleware, D: Middleware> Middleware + for (A, B, C, D) { type Future = Either>>; type CallFuture = Either>>; - fn on_request(&self, request: Request, meta: M, process: F) -> Either where - F: FnOnce(Request, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_request(&self, request: Request, meta: M, process: F) -> Either + where + F: Fn(Request, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_request(request, meta, move |request, meta| { - repack(self.1.on_request(request, meta, move |request, meta| { - repack(self.2.on_request(request, meta, move |request, meta| { - self.3.on_request(request, meta, process) + repack(self.0.on_request(request, meta, |request, meta| { + repack(self.1.on_request(request, meta, |request, meta| { + repack(self.2.on_request(request, meta, |request, meta| { + self.3.on_request(request, meta, &process) })) })) })) } - fn on_call(&self, call: Call, meta: M, process: F) -> Either where - F: FnOnce(Call, M) -> X + Send, - X: Future, Error=()> + Send + 'static, + fn on_call(&self, call: Call, meta: M, process: F) -> Either + where + F: Fn(Call, M) -> X + Send + Sync, + X: Future> + Send + 'static, { - repack(self.0.on_call(call, meta, move |call, meta| { - repack(self.1.on_call(call, meta, move |call, meta| { - repack(self.2.on_call(call, meta, move |call, meta| { - self.3.on_call(call, meta, process) - })) + repack(self.0.on_call(call, meta, |call, meta| { + repack(self.1.on_call(call, meta, |call, meta| { + repack( + self.2 + .on_call(call, meta, |call, meta| self.3.on_call(call, meta, &process)), + ) })) })) } @@ -136,8 +145,8 @@ impl, B: Middleware, C: Middleware, D: Middl #[inline(always)] fn repack(result: Either>) -> Either, X> { match result { - Either::A(a) => Either::A(Either::A(a)), - Either::B(Either::A(b)) => Either::A(Either::B(b)), - Either::B(Either::B(x)) => Either::B(x), + Either::Left(a) => Either::Left(Either::Left(a)), + Either::Right(Either::Left(b)) => Either::Left(Either::Right(b)), + Either::Right(Either::Right(x)) => Either::Right(x), } } diff --git a/core/src/types/error.rs b/core/src/types/error.rs index 3ee72a1bf..9da0f916b 100644 --- a/core/src/types/error.rs +++ b/core/src/types/error.rs @@ -1,7 +1,8 @@ //! jsonrpc errors +use super::Value; use serde::de::{Deserialize, Deserializer}; use serde::ser::{Serialize, Serializer}; -use super::Value; +use std::fmt; /// JSONRPC error code #[derive(Debug, PartialEq, Clone)] @@ -18,7 +19,7 @@ pub enum ErrorCode { /// Internal JSON-RPC error. InternalError, /// Reserved for implementation-defined server-errors. - ServerError(i64) + ServerError(i64), } impl ErrorCode { @@ -30,7 +31,7 @@ impl ErrorCode { ErrorCode::MethodNotFound => -32601, ErrorCode::InvalidParams => -32602, ErrorCode::InternalError => -32603, - ErrorCode::ServerError(code) => code + ErrorCode::ServerError(code) => code, } } @@ -63,29 +64,34 @@ impl From for ErrorCode { impl<'a> Deserialize<'a> for ErrorCode { fn deserialize(deserializer: D) -> Result - where D: Deserializer<'a> { - let code: i64 = try!(Deserialize::deserialize(deserializer)); + where + D: Deserializer<'a>, + { + let code: i64 = Deserialize::deserialize(deserializer)?; Ok(ErrorCode::from(code)) } } impl Serialize for ErrorCode { fn serialize(&self, serializer: S) -> Result - where S: Serializer { + where + S: Serializer, + { serializer.serialize_i64(self.code()) } } /// Error object as defined in Spec #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct Error { /// Code pub code: ErrorCode, /// Message pub message: String, /// Optional data - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, } impl Error { @@ -93,8 +99,8 @@ impl Error { pub fn new(code: ErrorCode) -> Self { Error { message: code.description(), - code: code, - data: None + code, + data: None, } } @@ -114,7 +120,8 @@ impl Error { } /// Creates new `InvalidParams` - pub fn invalid_params(message: M) -> Self where + pub fn invalid_params(message: M) -> Self + where M: Into, { Error { @@ -124,6 +131,19 @@ impl Error { } } + /// Creates `InvalidParams` for given parameter, with details. + pub fn invalid_params_with_details(message: M, details: T) -> Error + where + M: Into, + T: fmt::Debug, + { + Error { + code: ErrorCode::InvalidParams, + message: format!("Invalid parameters: {}", message.into()), + data: Some(Value::String(format!("{:?}", details))), + } + } + /// Creates new `InternalError` pub fn internal_error() -> Self { Self::new(ErrorCode::InternalError) @@ -138,3 +158,11 @@ impl Error { } } } + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {}", self.code.description(), self.message) + } +} + +impl std::error::Error for Error {} diff --git a/core/src/types/id.rs b/core/src/types/id.rs index b2abd3d1d..79a3e7897 100644 --- a/core/src/types/id.rs +++ b/core/src/types/id.rs @@ -2,6 +2,7 @@ /// Request Id #[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Id { /// No id (notification) @@ -33,12 +34,22 @@ mod tests { let s = r#"[null, 0, 2, "3"]"#; let deserialized: Vec = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, vec![Id::Null, Id::Num(0), Id::Num(2), Id::Str("3".into())]); + assert_eq!( + deserialized, + vec![Id::Null, Id::Num(0), Id::Num(2), Id::Str("3".into())] + ); } #[test] fn id_serialization() { - let d = vec![Id::Null, Id::Num(0), Id::Num(2), Id::Num(3), Id::Str("3".to_owned()), Id::Str("test".to_owned())]; + let d = vec![ + Id::Null, + Id::Num(0), + Id::Num(2), + Id::Num(3), + Id::Str("3".to_owned()), + Id::Str("test".to_owned()), + ]; let serialized = serde_json::to_string(&d).unwrap(); assert_eq!(serialized, r#"[null,0,2,3,"3","test"]"#); } diff --git a/core/src/types/mod.rs b/core/src/types/mod.rs index d50c989fc..cb4a28e95 100644 --- a/core/src/types/mod.rs +++ b/core/src/types/mod.rs @@ -7,13 +7,13 @@ pub mod request; pub mod response; pub mod version; -pub use serde_json::Value; -pub use serde_json::value::to_value; pub use serde_json::to_string; +pub use serde_json::value::to_value; +pub use serde_json::Value; -pub use self::error::{ErrorCode, Error}; +pub use self::error::{Error, ErrorCode}; pub use self::id::Id; pub use self::params::Params; -pub use self::request::{Request, Call, MethodCall, Notification}; -pub use self::response::{Output, Response, Success, Failure}; +pub use self::request::{Call, MethodCall, Notification, Request}; +pub use self::response::{Failure, Output, Response, Success}; pub use self::version::Version; diff --git a/core/src/types/params.rs b/core/src/types/params.rs index 4d92a999f..2eac38606 100644 --- a/core/src/types/params.rs +++ b/core/src/types/params.rs @@ -1,13 +1,13 @@ //! jsonrpc params field -use serde::de::{DeserializeOwned}; -use serde_json; +use serde::de::DeserializeOwned; use serde_json::value::from_value; -use super::{Value, Error}; +use super::{Error, Value}; /// Request parameters #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Params { /// No parameters @@ -20,24 +20,39 @@ pub enum Params { impl Params { /// Parse incoming `Params` into expected types. - pub fn parse(self) -> Result where D: DeserializeOwned { - let value = match self { + pub fn parse(self) -> Result + where + D: DeserializeOwned, + { + let value: Value = self.into(); + from_value(value).map_err(|e| Error::invalid_params(format!("Invalid params: {}.", e))) + } + + /// Check for no params, returns Err if any params + pub fn expect_no_params(self) -> Result<(), Error> { + match self { + Params::None => Ok(()), + Params::Array(ref v) if v.is_empty() => Ok(()), + p => Err(Error::invalid_params_with_details("No parameters were expected", p)), + } + } +} + +impl From for Value { + fn from(params: Params) -> Value { + match params { Params::Array(vec) => Value::Array(vec), Params::Map(map) => Value::Object(map), - Params::None => Value::Null - }; - - from_value(value).map_err(|e| { - Error::invalid_params(format!("Invalid params: {}.", e)) - }) + Params::None => Value::Null, + } } } #[cfg(test)] mod tests { - use serde_json; use super::Params; - use types::{Value, Error, ErrorCode}; + use crate::types::{Error, ErrorCode, Value}; + use serde_json; #[test] fn params_deserialization() { @@ -47,12 +62,20 @@ mod tests { let mut map = serde_json::Map::new(); map.insert("key".to_string(), Value::String("value".to_string())); - assert_eq!(Params::Array(vec![ - Value::Null, Value::Bool(true), Value::from(-1), Value::from(4), - Value::from(2.3), Value::String("hello".to_string()), - Value::Array(vec![Value::from(0)]), Value::Object(map), - Value::Array(vec![]), - ]), deserialized); + assert_eq!( + Params::Array(vec![ + Value::Null, + Value::Bool(true), + Value::from(-1), + Value::from(4), + Value::from(2.3), + Value::String("hello".to_string()), + Value::Array(vec![Value::from(0)]), + Value::Object(map), + Value::Array(vec![]), + ]), + deserialized + ); } #[test] @@ -69,10 +92,22 @@ mod tests { // then assert_eq!(err1.code, ErrorCode::InvalidParams); - assert_eq!(err1.message, "Invalid params: invalid type: boolean `true`, expected a string."); + assert_eq!( + err1.message, + "Invalid params: invalid type: boolean `true`, expected a string." + ); assert_eq!(err1.data, None); assert_eq!(err2.code, ErrorCode::InvalidParams); - assert_eq!(err2.message, "Invalid params: invalid length 2, expected a tuple of size 3."); + assert_eq!( + err2.message, + "Invalid params: invalid length 2, expected a tuple of size 3." + ); assert_eq!(err2.data, None); } + + #[test] + fn single_param_parsed_as_tuple() { + let params: (u64,) = Params::Array(vec![Value::from(1)]).parse().unwrap(); + assert_eq!(params, (1,)); + } } diff --git a/core/src/types/request.rs b/core/src/types/request.rs index 31518415f..adbf0eccd 100644 --- a/core/src/types/request.rs +++ b/core/src/types/request.rs @@ -3,7 +3,7 @@ use super::{Id, Params, Version}; /// Represents jsonrpc request which is a method call. -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct MethodCall { /// A String specifying the version of the JSON-RPC protocol. @@ -21,7 +21,7 @@ pub struct MethodCall { } /// Represents jsonrpc request which is a notification. -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Notification { /// A String specifying the version of the JSON-RPC protocol. @@ -35,7 +35,7 @@ pub struct Notification { } /// Represents single jsonrpc call. -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(untagged)] pub enum Call { /// Call method @@ -71,13 +71,14 @@ impl From for Call { } /// Represents jsonrpc request. -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Request { /// Single request (call) Single(Call), /// Batch of requests (calls) - Batch(Vec) + Batch(Vec), } #[cfg(test)] @@ -94,11 +95,14 @@ mod tests { jsonrpc: Some(Version::V2), method: "update".to_owned(), params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1) + id: Id::Num(1), }; let serialized = serde_json::to_string(&m).unwrap(); - assert_eq!(serialized, r#"{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1}"#); + assert_eq!( + serialized, + r#"{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1}"# + ); } #[test] @@ -109,7 +113,7 @@ mod tests { let n = Notification { jsonrpc: Some(Version::V2), method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]) + params: Params::Array(vec![Value::from(1), Value::from(2)]), }; let serialized = serde_json::to_string(&n).unwrap(); @@ -124,7 +128,7 @@ mod tests { let n = Call::Notification(Notification { jsonrpc: Some(Version::V2), method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) + params: Params::Array(vec![Value::from(1)]), }); let serialized = serde_json::to_string(&n).unwrap(); @@ -140,18 +144,20 @@ mod tests { jsonrpc: Some(Version::V2), method: "update".to_owned(), params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1) + id: Id::Num(1), }), Call::Notification(Notification { jsonrpc: Some(Version::V2), method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) - }) + params: Params::Array(vec![Value::from(1)]), + }), ]); let serialized = serde_json::to_string(&batch).unwrap(); - assert_eq!(serialized, r#"[{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1},{"jsonrpc":"2.0","method":"update","params":[1]}]"#); - + assert_eq!( + serialized, + r#"[{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1},{"jsonrpc":"2.0","method":"update","params":[1]}]"# + ); } #[test] @@ -162,24 +168,30 @@ mod tests { let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2]}"#; let deserialized: Notification = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Notification { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]) - }); + assert_eq!( + deserialized, + Notification { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]) + } + ); let s = r#"{"jsonrpc": "2.0", "method": "foobar"}"#; let deserialized: Notification = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Notification { - jsonrpc: Some(Version::V2), - method: "foobar".to_owned(), - params: Params::None, - }); + assert_eq!( + deserialized, + Notification { + jsonrpc: Some(Version::V2), + method: "foobar".to_owned(), + params: Params::None, + } + ); let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1}"#; let deserialized: Result = serde_json::from_str(s); - assert!(deserialized.is_err()) + assert!(deserialized.is_err()); } #[test] @@ -188,48 +200,62 @@ mod tests { let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1]}"#; let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Call::Notification(Notification { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) - })); + assert_eq!( + deserialized, + Call::Notification(Notification { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]) + }) + ); let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1], "id": 1}"#; let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Call::MethodCall(MethodCall { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]), - id: Id::Num(1) - })); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]), + id: Id::Num(1) + }) + ); let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [], "id": 1}"#; let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Call::MethodCall(MethodCall { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![]), - id: Id::Num(1) - })); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![]), + id: Id::Num(1) + }) + ); let s = r#"{"jsonrpc": "2.0", "method": "update", "params": null, "id": 1}"#; let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Call::MethodCall(MethodCall { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::None, - id: Id::Num(1) - })); - + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::None, + id: Id::Num(1) + }) + ); let s = r#"{"jsonrpc": "2.0", "method": "update", "id": 1}"#; let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Call::MethodCall(MethodCall { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::None, - id: Id::Num(1) - })); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::None, + id: Id::Num(1) + }) + ); } #[test] @@ -238,20 +264,23 @@ mod tests { let s = r#"[{}, {"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1},{"jsonrpc": "2.0", "method": "update", "params": [1]}]"#; let deserialized: Request = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Request::Batch(vec![ - Call::Invalid { id: Id::Null }, - Call::MethodCall(MethodCall { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1) - }), - Call::Notification(Notification { - jsonrpc: Some(Version::V2), - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) - }) - ])) + assert_eq!( + deserialized, + Request::Batch(vec![ + Call::Invalid { id: Id::Null }, + Call::MethodCall(MethodCall { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: Id::Num(1) + }), + Call::Notification(Notification { + jsonrpc: Some(Version::V2), + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]) + }) + ]) + ) } #[test] @@ -260,8 +289,9 @@ mod tests { let s = r#"{"id":120,"method":"my_method","params":["foo", "bar"],"extra_field":[]}"#; let deserialized: Request = serde_json::from_str(s).unwrap(); + match deserialized { - Request::Single(Call::Invalid { id: Id::Num(120) }) => {}, + Request::Single(Call::Invalid { id: Id::Num(120) }) => {} _ => panic!("Request wrongly deserialized: {:?}", deserialized), } } diff --git a/core/src/types/response.rs b/core/src/types/response.rs index 23406834a..5c45e5c04 100644 --- a/core/src/types/response.rs +++ b/core/src/types/response.rs @@ -1,33 +1,36 @@ //! jsonrpc response -use super::{Id, Value, Error, ErrorCode, Version}; -use {Result as CoreResult}; +use super::{Error, ErrorCode, Id, Value, Version}; +use crate::Result as CoreResult; /// Successful response #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct Success { /// Protocol version - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(skip_serializing_if = "Option::is_none")] pub jsonrpc: Option, /// Result pub result: Value, /// Correlation id - pub id: Id + pub id: Id, } /// Unsuccessful response #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct Failure { /// Protocol Version - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(skip_serializing_if = "Option::is_none")] pub jsonrpc: Option, /// Error pub error: Error, /// Correlation id - pub id: Id + pub id: Id, } /// Represents output - failure or success #[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Output { /// Success @@ -40,24 +43,16 @@ impl Output { /// Creates new output given `Result`, `Id` and `Version`. pub fn from(result: CoreResult, id: Id, jsonrpc: Option) -> Self { match result { - Ok(result) => Output::Success(Success { - id: id, - jsonrpc: jsonrpc, - result: result, - }), - Err(error) => Output::Failure(Failure { - id: id, - jsonrpc: jsonrpc, - error: error, - }), + Ok(result) => Output::Success(Success { jsonrpc, result, id }), + Err(error) => Output::Failure(Failure { jsonrpc, error, id }), } } /// Creates new failure output indicating malformed request. pub fn invalid_request(id: Id, jsonrpc: Option) -> Self { Output::Failure(Failure { - id: id, - jsonrpc: jsonrpc, + id, + jsonrpc, error: Error::new(ErrorCode::InvalidRequest), }) } @@ -90,13 +85,14 @@ impl From for CoreResult { } /// Synchronous response -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Response { /// Single response Single(Output), /// Response to batch request (batch of responses) - Batch(Vec) + Batch(Vec), } impl Response { @@ -104,9 +100,21 @@ impl Response { pub fn from(error: Error, jsonrpc: Option) -> Self { Failure { id: Id::Null, - jsonrpc: jsonrpc, - error: error, - }.into() + jsonrpc, + error, + } + .into() + } + + /// Deserialize `Response` from given JSON string. + /// + /// This method will handle an empty string as empty batch response. + pub fn from_json(s: &str) -> Result { + if s.is_empty() { + Ok(Response::Batch(vec![])) + } else { + crate::serde_from_str(s) + } } } @@ -130,7 +138,7 @@ fn success_output_serialize() { let so = Output::Success(Success { jsonrpc: Some(Version::V2), result: Value::from(1), - id: Id::Num(1) + id: Id::Num(1), }); let serialized = serde_json::to_string(&so).unwrap(); @@ -145,11 +153,14 @@ fn success_output_deserialize() { let dso = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; let deserialized: Output = serde_json::from_str(dso).unwrap(); - assert_eq!(deserialized, Output::Success(Success { - jsonrpc: Some(Version::V2), - result: Value::from(1), - id: Id::Num(1) - })); + assert_eq!( + deserialized, + Output::Success(Success { + jsonrpc: Some(Version::V2), + result: Value::from(1), + id: Id::Num(1) + }) + ); } #[test] @@ -159,11 +170,14 @@ fn failure_output_serialize() { let fo = Output::Failure(Failure { jsonrpc: Some(Version::V2), error: Error::parse_error(), - id: Id::Num(1) + id: Id::Num(1), }); let serialized = serde_json::to_string(&fo).unwrap(); - assert_eq!(serialized, r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#); + assert_eq!( + serialized, + r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"# + ); } #[test] @@ -173,11 +187,14 @@ fn failure_output_serialize_jsonrpc_1() { let fo = Output::Failure(Failure { jsonrpc: None, error: Error::parse_error(), - id: Id::Num(1) + id: Id::Num(1), }); let serialized = serde_json::to_string(&fo).unwrap(); - assert_eq!(serialized, r#"{"error":{"code":-32700,"message":"Parse error"},"id":1}"#); + assert_eq!( + serialized, + r#"{"error":{"code":-32700,"message":"Parse error"},"id":1}"# + ); } #[test] @@ -187,11 +204,14 @@ fn failure_output_deserialize() { let dfo = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#; let deserialized: Output = serde_json::from_str(dfo).unwrap(); - assert_eq!(deserialized, Output::Failure(Failure { - jsonrpc: Some(Version::V2), - error: Error::parse_error(), - id: Id::Num(1) - })); + assert_eq!( + deserialized, + Output::Failure(Failure { + jsonrpc: Some(Version::V2), + error: Error::parse_error(), + id: Id::Num(1) + }) + ); } #[test] @@ -202,11 +222,14 @@ fn single_response_deserialize() { let dsr = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; let deserialized: Response = serde_json::from_str(dsr).unwrap(); - assert_eq!(deserialized, Response::Single(Output::Success(Success { - jsonrpc: Some(Version::V2), - result: Value::from(1), - id: Id::Num(1) - }))); + assert_eq!( + deserialized, + Response::Single(Output::Success(Success { + jsonrpc: Some(Version::V2), + result: Value::from(1), + id: Id::Num(1) + })) + ); } #[test] @@ -217,16 +240,57 @@ fn batch_response_deserialize() { let dbr = r#"[{"jsonrpc":"2.0","result":1,"id":1},{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}]"#; let deserialized: Response = serde_json::from_str(dbr).unwrap(); - assert_eq!(deserialized, Response::Batch(vec![ - Output::Success(Success { - jsonrpc: Some(Version::V2), - result: Value::from(1), - id: Id::Num(1) - }), - Output::Failure(Failure { - jsonrpc: Some(Version::V2), - error: Error::parse_error(), - id: Id::Num(1) - }) - ])); + assert_eq!( + deserialized, + Response::Batch(vec![ + Output::Success(Success { + jsonrpc: Some(Version::V2), + result: Value::from(1), + id: Id::Num(1) + }), + Output::Failure(Failure { + jsonrpc: Some(Version::V2), + error: Error::parse_error(), + id: Id::Num(1) + }) + ]) + ); +} + +#[test] +fn handle_incorrect_responses() { + use serde_json; + + let dsr = r#" +{ + "id": 2, + "jsonrpc": "2.0", + "result": "0x62d3776be72cc7fa62cad6fe8ed873d9bc7ca2ee576e400d987419a3f21079d5", + "error": { + "message": "VM Exception while processing transaction: revert", + "code": -32000, + "data": {} + } +}"#; + + let deserialized: Result = serde_json::from_str(dsr); + assert!( + deserialized.is_err(), + "Expected error when deserializing invalid payload." + ); +} + +#[test] +fn should_parse_empty_response_as_batch() { + use serde_json; + + let dsr = r#""#; + + let deserialized1: Result = serde_json::from_str(dsr); + let deserialized2: Result = Response::from_json(dsr); + assert!( + deserialized1.is_err(), + "Empty string is not valid JSON, so we should get an error." + ); + assert_eq!(deserialized2.unwrap(), Response::Batch(vec![])); } diff --git a/core/src/types/version.rs b/core/src/types/version.rs index a3e42fdb0..e3f51b3cb 100644 --- a/core/src/types/version.rs +++ b/core/src/types/version.rs @@ -1,6 +1,6 @@ //! jsonrpc version field -use serde::{Serialize, Serializer, Deserialize, Deserializer}; use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; @@ -8,21 +8,25 @@ use std::fmt; #[derive(Debug, PartialEq, Clone, Copy, Hash, Eq)] pub enum Version { /// JSONRPC 2.0 - V2 + V2, } impl Serialize for Version { fn serialize(&self, serializer: S) -> Result - where S: Serializer { - match self { - &Version::V2 => serializer.serialize_str("2.0") + where + S: Serializer, + { + match *self { + Version::V2 => serializer.serialize_str("2.0"), } } } impl<'a> Deserialize<'a> for Version { fn deserialize(deserializer: D) -> Result - where D: Deserializer<'a> { + where + D: Deserializer<'a>, + { deserializer.deserialize_identifier(VersionVisitor) } } @@ -36,10 +40,13 @@ impl<'a> Visitor<'a> for VersionVisitor { formatter.write_str("a string") } - fn visit_str(self, value: &str) -> Result where E: de::Error { + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { match value { "2.0" => Ok(Version::V2), - _ => Err(de::Error::custom("invalid version")) + _ => Err(de::Error::custom("invalid version")), } } } diff --git a/derive/Cargo.toml b/derive/Cargo.toml new file mode 100644 index 000000000..41ae32c10 --- /dev/null +++ b/derive/Cargo.toml @@ -0,0 +1,29 @@ +[package] +authors = ["Parity Technologies "] +documentation = "https://docs.rs/jsonrpc-derive/" +description = "High level, typed wrapper for `jsonrpc-core`" +edition = "2018" +homepage = "https://github.com/paritytech/jsonrpc" +license = "MIT" +name = "jsonrpc-derive" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0", features = ["full", "extra-traits", "visit", "fold"] } +proc-macro2 = "1.0" +quote = "1.0.6" +proc-macro-crate = "0.1.4" + +[dev-dependencies] +assert_matches = "1.3" +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-core-client = { version = "18.0.0", path = "../core-client" } +jsonrpc-pubsub = { version = "18.0.0", path = "../pubsub" } +jsonrpc-tcp-server = { version = "18.0.0", path = "../tcp" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +trybuild = "1.0" diff --git a/derive/examples/client-local.rs b/derive/examples/client-local.rs new file mode 100644 index 000000000..a0fdf6fe1 --- /dev/null +++ b/derive/examples/client-local.rs @@ -0,0 +1,62 @@ +use jsonrpc_core::{ + futures::{self, FutureExt}, + BoxFuture, IoHandler, Result, +}; +use jsonrpc_core_client::transports::local; +use jsonrpc_derive::rpc; + +/// Rpc trait +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; + + /// Performs asynchronous operation + #[rpc(name = "callAsync")] + fn call(&self, a: u64) -> BoxFuture>; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn call(&self, _: u64) -> BoxFuture> { + Box::pin(futures::future::ready(Ok("OK".to_owned()))) + } +} + +fn main() { + futures::executor::block_on(async { + let mut io = IoHandler::new(); + io.extend_with(RpcImpl.to_delegate()); + println!("Starting local server"); + let (client, server) = local::connect(io); + let client = use_client(client).fuse(); + let server = server.fuse(); + + futures::pin_mut!(client); + futures::pin_mut!(server); + + futures::select! { + _server = server => {}, + _client = client => {}, + } + }); +} + +async fn use_client(client: RpcClient) { + let res = client.add(5, 6).await.unwrap(); + println!("5 + 6 = {}", res); +} diff --git a/derive/examples/generic-trait-bounds.rs b/derive/examples/generic-trait-bounds.rs new file mode 100644 index 000000000..6ac09705c --- /dev/null +++ b/derive/examples/generic-trait-bounds.rs @@ -0,0 +1,60 @@ +use serde::{Deserialize, Serialize}; + +use jsonrpc_core::{futures::future, BoxFuture, IoHandler, IoHandlerExtension, Result}; +use jsonrpc_derive::rpc; + +// One is both parameter and a result so requires both Serialize and DeserializeOwned +// Two is only a parameter so only requires DeserializeOwned +// Three is only a result so only requires Serialize +#[rpc(server)] +pub trait Rpc { + /// Get One type. + #[rpc(name = "getOne")] + fn one(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "setTwo")] + fn set_two(&self, a: Two) -> Result<()>; + + #[rpc(name = "getThree")] + fn get_three(&self) -> Result; + + /// Performs asynchronous operation + #[rpc(name = "beFancy")] + fn call(&self, a: One) -> BoxFuture>; +} + +struct RpcImpl; + +#[derive(Serialize, Deserialize)] +struct InAndOut { + foo: u64, +} +#[derive(Deserialize)] +struct In {} +#[derive(Serialize)] +struct Out {} + +impl Rpc for RpcImpl { + fn one(&self) -> Result { + Ok(InAndOut { foo: 1u64 }) + } + + fn set_two(&self, _x: In) -> Result<()> { + Ok(()) + } + + fn get_three(&self) -> Result { + Ok(Out {}) + } + + fn call(&self, num: InAndOut) -> BoxFuture> { + Box::pin(future::ready(Ok((InAndOut { foo: num.foo + 999 }, num.foo)))) + } +} + +fn main() { + let mut io = IoHandler::new(); + + RpcImpl.to_delegate().augment(&mut io); +} diff --git a/derive/examples/generic-trait.rs b/derive/examples/generic-trait.rs new file mode 100644 index 000000000..8e0972bc8 --- /dev/null +++ b/derive/examples/generic-trait.rs @@ -0,0 +1,43 @@ +use jsonrpc_core; + +use jsonrpc_core::{BoxFuture, IoHandler, Result}; +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Get One type. + #[rpc(name = "getOne")] + fn one(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "setTwo")] + fn set_two(&self, a: Two) -> Result<()>; + + /// Performs asynchronous operation + #[rpc(name = "beFancy")] + fn call(&self, a: One) -> BoxFuture>; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn one(&self) -> Result { + Ok(100) + } + + fn set_two(&self, x: String) -> Result<()> { + println!("{}", x); + Ok(()) + } + + fn call(&self, num: u64) -> BoxFuture> { + Box::pin(jsonrpc_core::futures::future::ready(Ok((num + 999, "hello".into())))) + } +} + +fn main() { + let mut io = IoHandler::new(); + let rpc = RpcImpl; + + io.extend_with(rpc.to_delegate()) +} diff --git a/derive/examples/meta-macros.rs b/derive/examples/meta-macros.rs new file mode 100644 index 000000000..9ea7ea69e --- /dev/null +++ b/derive/examples/meta-macros.rs @@ -0,0 +1,91 @@ +use std::collections::BTreeMap; + +use jsonrpc_core::futures::future; +use jsonrpc_core::{BoxFuture, MetaIoHandler, Metadata, Params, Result, Value}; +use jsonrpc_derive::rpc; + +#[derive(Clone)] +struct Meta(String); +impl Metadata for Meta {} + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Get One type. + #[rpc(name = "getOne")] + fn one(&self) -> Result; + + /// Adds two numbers and returns a result. + #[rpc(name = "add")] + fn add(&self, a: u64, b: u64) -> Result; + + /// Multiplies two numbers. Second number is optional. + #[rpc(name = "mul")] + fn mul(&self, a: u64, b: Option) -> Result; + + /// Retrieves and debug prints the underlying `Params` object. + #[rpc(name = "raw", params = "raw")] + fn raw(&self, params: Params) -> Result; + + /// Performs an asynchronous operation. + #[rpc(name = "callAsync")] + fn call(&self, a: u64) -> BoxFuture>; + + /// Performs an asynchronous operation with meta. + #[rpc(meta, name = "callAsyncMeta", alias("callAsyncMetaAlias"))] + fn call_meta(&self, a: Self::Metadata, b: BTreeMap) -> BoxFuture>; + + /// Handles a notification. + #[rpc(name = "notify")] + fn notify(&self, a: u64); +} + +struct RpcImpl; +impl Rpc for RpcImpl { + type Metadata = Meta; + + fn one(&self) -> Result { + Ok(100) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn mul(&self, a: u64, b: Option) -> Result { + Ok(a * b.unwrap_or(1)) + } + + fn raw(&self, params: Params) -> Result { + Ok(format!("Got: {:?}", params)) + } + + fn call(&self, x: u64) -> BoxFuture> { + Box::pin(future::ready(Ok(format!("OK: {}", x)))) + } + + fn call_meta(&self, meta: Self::Metadata, map: BTreeMap) -> BoxFuture> { + Box::pin(future::ready(Ok(format!("From: {}, got: {:?}", meta.0, map)))) + } + + fn notify(&self, a: u64) { + println!("Received `notify` with value: {}", a); + } +} + +fn main() { + let mut io = MetaIoHandler::default(); + let rpc = RpcImpl; + + io.extend_with(rpc.to_delegate()); + + let server = + jsonrpc_tcp_server::ServerBuilder::with_meta_extractor(io, |context: &jsonrpc_tcp_server::RequestContext| { + Meta(format!("{}", context.peer_addr)) + }) + .start(&"0.0.0.0:3030".parse().unwrap()) + .expect("Server must start with no issues"); + + server.wait() +} diff --git a/derive/examples/pubsub-macros.rs b/derive/examples/pubsub-macros.rs new file mode 100644 index 000000000..ec8479a27 --- /dev/null +++ b/derive/examples/pubsub-macros.rs @@ -0,0 +1,88 @@ +use std::collections::HashMap; +use std::sync::{atomic, Arc, RwLock}; +use std::thread; + +use jsonrpc_core::{Error, ErrorCode, Result}; +use jsonrpc_derive::rpc; +use jsonrpc_pubsub::typed; +use jsonrpc_pubsub::{PubSubHandler, Session, SubscriptionId}; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, meta: Self::Metadata, subscriber: typed::Subscriber, param: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, meta: Option, subscription: SubscriptionId) -> Result; +} + +#[derive(Default)] +struct RpcImpl { + uid: atomic::AtomicUsize, + active: Arc>>>, +} +impl Rpc for RpcImpl { + type Metadata = Arc; + + fn subscribe(&self, _meta: Self::Metadata, subscriber: typed::Subscriber, param: u64) { + if param != 10 { + subscriber + .reject(Error { + code: ErrorCode::InvalidParams, + message: "Rejecting subscription - invalid parameters provided.".into(), + data: None, + }) + .unwrap(); + return; + } + + let id = self.uid.fetch_add(1, atomic::Ordering::SeqCst); + let sub_id = SubscriptionId::Number(id as u64); + let sink = subscriber.assign_id(sub_id.clone()).unwrap(); + self.active.write().unwrap().insert(sub_id, sink); + } + + fn unsubscribe(&self, _meta: Option, id: SubscriptionId) -> Result { + let removed = self.active.write().unwrap().remove(&id); + if removed.is_some() { + Ok(true) + } else { + Err(Error { + code: ErrorCode::InvalidParams, + message: "Invalid subscription.".into(), + data: None, + }) + } + } +} + +fn main() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl::default(); + let active_subscriptions = rpc.active.clone(); + + thread::spawn(move || loop { + { + let subscribers = active_subscriptions.read().unwrap(); + for sink in subscribers.values() { + let _ = sink.notify(Ok("Hello World!".into())); + } + } + thread::sleep(::std::time::Duration::from_secs(1)); + }); + + io.extend_with(rpc.to_delegate()); + + let server = + jsonrpc_tcp_server::ServerBuilder::with_meta_extractor(io, |context: &jsonrpc_tcp_server::RequestContext| { + Arc::new(Session::new(context.sender.clone())) + }) + .start(&"0.0.0.0:3030".parse().unwrap()) + .expect("Server must start with no issues"); + + server.wait() +} diff --git a/derive/examples/std.rs b/derive/examples/std.rs new file mode 100644 index 000000000..3c2c1e640 --- /dev/null +++ b/derive/examples/std.rs @@ -0,0 +1,58 @@ +//! A simple example +#![deny(missing_docs)] +use jsonrpc_core::futures::{self, future, TryFutureExt}; +use jsonrpc_core::{BoxFuture, IoHandler, Result}; +use jsonrpc_core_client::transports::local; +use jsonrpc_derive::rpc; + +/// Rpc trait +#[rpc] +pub trait Rpc { + /// Returns a protocol version. + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result. + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; + + /// Performs asynchronous operation. + #[rpc(name = "callAsync")] + fn call(&self, a: u64) -> BoxFuture>; + + /// Handles a notification. + #[rpc(name = "notify")] + fn notify(&self, a: u64); +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn call(&self, _: u64) -> BoxFuture> { + Box::pin(future::ready(Ok("OK".to_owned()))) + } + + fn notify(&self, a: u64) { + println!("Received `notify` with value: {}", a); + } +} + +fn main() { + let mut io = IoHandler::new(); + io.extend_with(RpcImpl.to_delegate()); + + let (client, server) = local::connect::(io); + let fut = client.add(5, 6).map_ok(|res| println!("5 + 6 = {}", res)); + + futures::executor::block_on(async move { futures::join!(fut, server) }) + .0 + .unwrap(); +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs new file mode 100644 index 000000000..2a8109a69 --- /dev/null +++ b/derive/src/lib.rs @@ -0,0 +1,229 @@ +//! High level, typed wrapper for `jsonrpc_core`. +//! +//! Enables creation of "Service" objects grouping a set of RPC methods together in a typed manner. +//! +//! Example +//! +//! ``` +//! use jsonrpc_core::{IoHandler, Result, BoxFuture}; +//! use jsonrpc_core::futures::future; +//! use jsonrpc_derive::rpc; +//! +//! #[rpc(server)] +//! pub trait Rpc { +//! #[rpc(name = "protocolVersion")] +//! fn protocol_version(&self) -> Result; +//! +//! #[rpc(name = "add")] +//! fn add(&self, a: u64, b: u64) -> Result; +//! +//! #[rpc(name = "callAsync")] +//! fn call(&self, a: u64) -> BoxFuture>; +//! } +//! +//! struct RpcImpl; +//! impl Rpc for RpcImpl { +//! fn protocol_version(&self) -> Result { +//! Ok("version1".into()) +//! } +//! +//! fn add(&self, a: u64, b: u64) -> Result { +//! Ok(a + b) +//! } +//! +//! fn call(&self, _: u64) -> BoxFuture> { +//! Box::pin(future::ready(Ok("OK".to_owned()).into())) +//! } +//! } +//! +//! fn main() { +//! let mut io = IoHandler::new(); +//! let rpc = RpcImpl; +//! +//! io.extend_with(rpc.to_delegate()); +//! } +//! ``` +//! +//! Pub/Sub Example +//! +//! Each subscription must have `subscribe` and `unsubscribe` methods. They can +//! have any name but must be annotated with `subscribe` or `unsubscribe` and +//! have a matching unique subscription name. +//! +//! ``` +//! use std::sync::{atomic, Arc, RwLock}; +//! use std::collections::HashMap; +//! +//! use jsonrpc_core::{Error, ErrorCode, Result}; +//! use jsonrpc_derive::rpc; +//! use jsonrpc_pubsub::{Session, PubSubHandler, SubscriptionId, typed::{Subscriber, Sink}}; +//! +//! #[rpc] +//! pub trait Rpc { +//! type Metadata; +//! +//! /// Hello subscription +//! #[pubsub( +//! subscription = "hello", +//! subscribe, +//! name = "hello_subscribe", +//! alias("hello_sub") +//! )] +//! fn subscribe(&self, _: Self::Metadata, _: Subscriber, param: u64); +//! +//! /// Unsubscribe from hello subscription. +//! #[pubsub( +//! subscription = "hello", +//! unsubscribe, +//! name = "hello_unsubscribe" +//! )] +//! fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +//! } +//! +//! +//! #[derive(Default)] +//! struct RpcImpl { +//! uid: atomic::AtomicUsize, +//! active: Arc>>>, +//! } +//! impl Rpc for RpcImpl { +//! type Metadata = Arc; +//! +//! fn subscribe(&self, _meta: Self::Metadata, subscriber: Subscriber, param: u64) { +//! if param != 10 { +//! subscriber.reject(Error { +//! code: ErrorCode::InvalidParams, +//! message: "Rejecting subscription - invalid parameters provided.".into(), +//! data: None, +//! }).unwrap(); +//! return; +//! } +//! +//! let id = self.uid.fetch_add(1, atomic::Ordering::SeqCst); +//! let sub_id = SubscriptionId::Number(id as u64); +//! let sink = subscriber.assign_id(sub_id.clone()).unwrap(); +//! self.active.write().unwrap().insert(sub_id, sink); +//! } +//! +//! fn unsubscribe(&self, _meta: Option, id: SubscriptionId) -> Result { +//! let removed = self.active.write().unwrap().remove(&id); +//! if removed.is_some() { +//! Ok(true) +//! } else { +//! Err(Error { +//! code: ErrorCode::InvalidParams, +//! message: "Invalid subscription.".into(), +//! data: None, +//! }) +//! } +//! } +//! } +//! +//! fn main() { +//! let mut io = jsonrpc_core::MetaIoHandler::default(); +//! io.extend_with(RpcImpl::default().to_delegate()); +//! +//! let server_builder = jsonrpc_tcp_server::ServerBuilder::with_meta_extractor( +//! io, +//! |request: &jsonrpc_tcp_server::RequestContext| Arc::new(Session::new(request.sender.clone())) +//! ); +//! let server = server_builder +//! .start(&"127.0.0.1:3030".parse().unwrap()) +//! .expect("Unable to start TCP server"); +//! +//! // The server spawns a separate thread. Dropping the `server` handle causes it to close. +//! // Uncomment the line below to keep the server running in your example. +//! // server.wait(); +//! } +//! ``` +//! +//! Client Example +//! +//! ``` +//! use jsonrpc_core_client::transports::local; +//! use jsonrpc_core::futures::{self, future}; +//! use jsonrpc_core::{IoHandler, Result, BoxFuture}; +//! use jsonrpc_derive::rpc; +//! +//! /// Rpc trait +//! #[rpc] +//! pub trait Rpc { +//! /// Returns a protocol version +//! #[rpc(name = "protocolVersion")] +//! fn protocol_version(&self) -> Result; +//! +//! /// Adds two numbers and returns a result +//! #[rpc(name = "add", alias("callAsyncMetaAlias"))] +//! fn add(&self, a: u64, b: u64) -> Result; +//! +//! /// Performs asynchronous operation +//! #[rpc(name = "callAsync")] +//! fn call(&self, a: u64) -> BoxFuture>; +//! } +//! +//! struct RpcImpl; +//! +//! impl Rpc for RpcImpl { +//! fn protocol_version(&self) -> Result { +//! Ok("version1".into()) +//! } +//! +//! fn add(&self, a: u64, b: u64) -> Result { +//! Ok(a + b) +//! } +//! +//! fn call(&self, _: u64) -> BoxFuture> { +//! Box::pin(future::ready(Ok("OK".to_owned()))) +//! } +//! } +//! +//! fn main() { +//! let exec = futures::executor::ThreadPool::new().unwrap(); +//! exec.spawn_ok(run()) +//! } +//! async fn run() { +//! let mut io = IoHandler::new(); +//! io.extend_with(RpcImpl.to_delegate()); +//! +//! let (client, server) = local::connect::(io); +//! let res = client.add(5, 6).await.unwrap(); +//! println!("5 + 6 = {}", res); +//! server.await.unwrap() +//! } +//! +//! ``` + +#![recursion_limit = "256"] +#![warn(missing_docs)] + +extern crate proc_macro; + +use proc_macro::TokenStream; +use syn::parse_macro_input; + +mod options; +mod params_style; +mod rpc_attr; +mod rpc_trait; +mod to_client; +mod to_delegate; + +/// Apply `#[rpc]` to a trait, and a `to_delegate` method is generated which +/// wires up methods decorated with `#[rpc]` or `#[pubsub]` attributes. +/// Attach the delegate to an `IoHandler` and the methods are now callable +/// via JSON-RPC. +#[proc_macro_attribute] +pub fn rpc(args: TokenStream, input: TokenStream) -> TokenStream { + let input_toks = parse_macro_input!(input as syn::Item); + let args = syn::parse_macro_input!(args as syn::AttributeArgs); + + let options = match options::DeriveOptions::try_from(args) { + Ok(options) => options, + Err(error) => return error.to_compile_error().into(), + }; + + match rpc_trait::rpc_impl(input_toks, &options) { + Ok(output) => output.into(), + Err(err) => err.to_compile_error().into(), + } +} diff --git a/derive/src/options.rs b/derive/src/options.rs new file mode 100644 index 000000000..fe52c07dc --- /dev/null +++ b/derive/src/options.rs @@ -0,0 +1,71 @@ +use std::str::FromStr; + +use crate::params_style::ParamStyle; +use crate::rpc_attr::path_eq_str; + +const CLIENT_META_WORD: &str = "client"; +const SERVER_META_WORD: &str = "server"; +const PARAMS_META_KEY: &str = "params"; + +#[derive(Debug)] +pub struct DeriveOptions { + pub enable_client: bool, + pub enable_server: bool, + pub params_style: ParamStyle, +} + +impl DeriveOptions { + pub fn new(enable_client: bool, enable_server: bool, params_style: ParamStyle) -> Self { + DeriveOptions { + enable_client, + enable_server, + params_style, + } + } + + pub fn try_from(args: syn::AttributeArgs) -> Result { + let mut options = DeriveOptions::new(false, false, ParamStyle::default()); + for arg in args { + if let syn::NestedMeta::Meta(meta) = arg { + match meta { + syn::Meta::Path(ref p) => { + match p + .get_ident() + .ok_or(syn::Error::new_spanned( + p, + format!("Expecting identifier `{}` or `{}`", CLIENT_META_WORD, SERVER_META_WORD), + ))? + .to_string() + .as_ref() + { + CLIENT_META_WORD => options.enable_client = true, + SERVER_META_WORD => options.enable_server = true, + _ => {} + }; + } + syn::Meta::NameValue(nv) => { + if path_eq_str(&nv.path, PARAMS_META_KEY) { + if let syn::Lit::Str(ref lit) = nv.lit { + options.params_style = ParamStyle::from_str(&lit.value()) + .map_err(|e| syn::Error::new_spanned(nv.clone(), e))?; + } + } else { + return Err(syn::Error::new_spanned(nv, "Unexpected RPC attribute key")); + } + } + _ => return Err(syn::Error::new_spanned(meta, "Unexpected use of RPC attribute macro")), + } + } + } + if !options.enable_client && !options.enable_server { + // if nothing provided default to both + options.enable_client = true; + options.enable_server = true; + } + if options.enable_server && options.params_style == ParamStyle::Named { + // This is not allowed at this time + panic!("Server code generation only supports `params = \"positional\"` (default) or `params = \"raw\" at this time.") + } + Ok(options) + } +} diff --git a/derive/src/params_style.rs b/derive/src/params_style.rs new file mode 100644 index 000000000..fe8b9dfe7 --- /dev/null +++ b/derive/src/params_style.rs @@ -0,0 +1,34 @@ +use std::str::FromStr; + +const POSITIONAL: &str = "positional"; +const NAMED: &str = "named"; +const RAW: &str = "raw"; + +#[derive(Clone, Debug, PartialEq)] +pub enum ParamStyle { + Positional, + Named, + Raw, +} + +impl Default for ParamStyle { + fn default() -> Self { + Self::Positional + } +} + +impl FromStr for ParamStyle { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + POSITIONAL => Ok(Self::Positional), + NAMED => Ok(Self::Named), + RAW => Ok(Self::Raw), + _ => Err(format!( + "Invalid value for params key. Must be one of [{}, {}, {}]", + POSITIONAL, NAMED, RAW + )), + } + } +} diff --git a/derive/src/rpc_attr.rs b/derive/src/rpc_attr.rs new file mode 100644 index 000000000..b66967efb --- /dev/null +++ b/derive/src/rpc_attr.rs @@ -0,0 +1,307 @@ +use crate::params_style::ParamStyle; +use std::str::FromStr; +use syn::{ + visit::{self, Visit}, + Error, Result, +}; + +#[derive(Clone, Debug)] +pub struct RpcMethodAttribute { + pub attr: syn::Attribute, + pub name: String, + pub aliases: Vec, + pub kind: AttributeKind, + pub params_style: Option, // None means do not override the top level default +} + +#[derive(Clone, Debug)] +pub enum AttributeKind { + Rpc { + has_metadata: bool, + returns: Option, + is_notification: bool, + }, + PubSub { + subscription_name: String, + kind: PubSubMethodKind, + }, +} + +#[derive(Clone, Debug)] +pub enum PubSubMethodKind { + Subscribe, + Unsubscribe, +} + +const RPC_ATTR_NAME: &str = "rpc"; +const RPC_NAME_KEY: &str = "name"; +const SUBSCRIPTION_NAME_KEY: &str = "subscription"; +const ALIASES_KEY: &str = "alias"; +const PUB_SUB_ATTR_NAME: &str = "pubsub"; +const METADATA_META_WORD: &str = "meta"; +const RAW_PARAMS_META_WORD: &str = "raw_params"; // to be deprecated and replaced with `params = "raw"` +const SUBSCRIBE_META_WORD: &str = "subscribe"; +const UNSUBSCRIBE_META_WORD: &str = "unsubscribe"; +const RETURNS_META_WORD: &str = "returns"; +const PARAMS_STYLE_KEY: &str = "params"; + +const MULTIPLE_RPC_ATTRIBUTES_ERR: &str = "Expected only a single rpc attribute per method"; +const INVALID_ATTR_PARAM_NAMES_ERR: &str = "Invalid attribute parameter(s):"; +const MISSING_NAME_ERR: &str = "rpc attribute should have a name e.g. `name = \"method_name\"`"; +const MISSING_SUB_NAME_ERR: &str = "pubsub attribute should have a subscription name"; +const BOTH_SUB_AND_UNSUB_ERR: &str = "pubsub attribute annotated with both subscribe and unsubscribe"; +const NEITHER_SUB_OR_UNSUB_ERR: &str = "pubsub attribute not annotated with either subscribe or unsubscribe"; + +impl RpcMethodAttribute { + pub fn parse_attr(method: &syn::TraitItemMethod) -> Result> { + let output = &method.sig.output; + let attrs = method + .attrs + .iter() + .filter_map(|attr| Self::parse_meta(attr, &output)) + .collect::>>()?; + + if attrs.len() <= 1 { + Ok(attrs.first().cloned()) + } else { + Err(Error::new_spanned(method, MULTIPLE_RPC_ATTRIBUTES_ERR)) + } + } + + fn parse_meta(attr: &syn::Attribute, output: &syn::ReturnType) -> Option> { + match attr.parse_meta().and_then(validate_attribute_meta) { + Ok(ref meta) => { + let attr_kind = match path_to_str(meta.path()).as_deref() { + Some(RPC_ATTR_NAME) => Some(Self::parse_rpc(meta, output)), + Some(PUB_SUB_ATTR_NAME) => Some(Self::parse_pubsub(meta)), + _ => None, + }; + attr_kind.map(|kind| { + kind.and_then(|kind| { + get_meta_list(meta) + .and_then(|ml| get_name_value(RPC_NAME_KEY, ml)) + .map_or(Err(Error::new_spanned(attr, MISSING_NAME_ERR)), |name| { + let aliases = get_meta_list(&meta).map_or(Vec::new(), |ml| get_aliases(ml)); + let raw_params = + get_meta_list(meta).map_or(false, |ml| has_meta_word(RAW_PARAMS_META_WORD, ml)); + let params_style = match raw_params { + true => { + // "`raw_params` will be deprecated in a future release. Use `params = \"raw\" instead`" + Ok(Some(ParamStyle::Raw)) + } + false => get_meta_list(meta).map_or(Ok(None), |ml| get_params_style(ml).map(Some)), + }?; + Ok(RpcMethodAttribute { + attr: attr.clone(), + name, + aliases, + kind, + params_style, + }) + }) + }) + }) + } + Err(err) => Some(Err(err)), + } + } + + fn parse_rpc(meta: &syn::Meta, output: &syn::ReturnType) -> Result { + let has_metadata = get_meta_list(meta).map_or(false, |ml| has_meta_word(METADATA_META_WORD, ml)); + let returns = get_meta_list(meta).and_then(|ml| get_name_value(RETURNS_META_WORD, ml)); + let is_notification = match output { + syn::ReturnType::Default => true, + syn::ReturnType::Type(_, ret) => { + matches!(**ret, syn::Type::Tuple(ref tup) if tup.elems.empty_or_trailing()) + } + }; + + if is_notification && returns.is_some() { + return Err(syn::Error::new_spanned(output, &"Notifications must return ()")); + } + + Ok(AttributeKind::Rpc { + has_metadata, + returns, + is_notification, + }) + } + + fn parse_pubsub(meta: &syn::Meta) -> Result { + let name_and_list = + get_meta_list(meta).and_then(|ml| get_name_value(SUBSCRIPTION_NAME_KEY, ml).map(|name| (name, ml))); + + name_and_list.map_or(Err(Error::new_spanned(meta, MISSING_SUB_NAME_ERR)), |(sub_name, ml)| { + let is_subscribe = has_meta_word(SUBSCRIBE_META_WORD, ml); + let is_unsubscribe = has_meta_word(UNSUBSCRIBE_META_WORD, ml); + let kind = match (is_subscribe, is_unsubscribe) { + (true, false) => Ok(PubSubMethodKind::Subscribe), + (false, true) => Ok(PubSubMethodKind::Unsubscribe), + (true, true) => Err(Error::new_spanned(meta, BOTH_SUB_AND_UNSUB_ERR)), + (false, false) => Err(Error::new_spanned(meta, NEITHER_SUB_OR_UNSUB_ERR)), + }; + kind.map(|kind| AttributeKind::PubSub { + subscription_name: sub_name, + kind, + }) + }) + } + + pub fn is_pubsub(&self) -> bool { + match self.kind { + AttributeKind::PubSub { .. } => true, + AttributeKind::Rpc { .. } => false, + } + } + + pub fn is_notification(&self) -> bool { + match self.kind { + AttributeKind::Rpc { is_notification, .. } => is_notification, + AttributeKind::PubSub { .. } => false, + } + } +} + +fn validate_attribute_meta(meta: syn::Meta) -> Result { + #[derive(Default)] + struct Visitor { + meta_words: Vec, + name_value_names: Vec, + meta_list_names: Vec, + } + impl<'a> Visit<'a> for Visitor { + fn visit_meta(&mut self, meta: &syn::Meta) { + if let Some(ident) = path_to_str(meta.path()) { + match meta { + syn::Meta::Path(_) => self.meta_words.push(ident), + syn::Meta::List(_) => self.meta_list_names.push(ident), + syn::Meta::NameValue(_) => self.name_value_names.push(ident), + } + } + } + } + + let mut visitor = Visitor::default(); + visit::visit_meta(&mut visitor, &meta); + + let ident = path_to_str(meta.path()); + match ident.as_deref() { + Some(RPC_ATTR_NAME) => { + validate_idents(&meta, &visitor.meta_words, &[METADATA_META_WORD, RAW_PARAMS_META_WORD])?; + validate_idents( + &meta, + &visitor.name_value_names, + &[RPC_NAME_KEY, RETURNS_META_WORD, PARAMS_STYLE_KEY], + )?; + validate_idents(&meta, &visitor.meta_list_names, &[ALIASES_KEY]) + } + Some(PUB_SUB_ATTR_NAME) => { + validate_idents( + &meta, + &visitor.meta_words, + &[SUBSCRIBE_META_WORD, UNSUBSCRIBE_META_WORD, RAW_PARAMS_META_WORD], + )?; + validate_idents(&meta, &visitor.name_value_names, &[SUBSCRIPTION_NAME_KEY, RPC_NAME_KEY])?; + validate_idents(&meta, &visitor.meta_list_names, &[ALIASES_KEY]) + } + _ => Ok(meta), // ignore other attributes - compiler will catch unknown ones + } +} + +fn validate_idents(meta: &syn::Meta, attr_idents: &[String], valid: &[&str]) -> Result { + let invalid_meta_words: Vec<_> = attr_idents + .iter() + .filter(|w| !valid.iter().any(|v| v == w)) + .cloned() + .collect(); + if invalid_meta_words.is_empty() { + Ok(meta.clone()) + } else { + let expected = format!("Expected '{}'", valid.join(", ")); + let msg = format!( + "{} '{}'. {}", + INVALID_ATTR_PARAM_NAMES_ERR, + invalid_meta_words.join(", "), + expected + ); + Err(Error::new_spanned(meta, msg)) + } +} + +fn get_meta_list(meta: &syn::Meta) -> Option<&syn::MetaList> { + if let syn::Meta::List(ml) = meta { + Some(ml) + } else { + None + } +} + +fn get_name_value(key: &str, ml: &syn::MetaList) -> Option { + ml.nested.iter().find_map(|nested| { + if let syn::NestedMeta::Meta(syn::Meta::NameValue(mnv)) = nested { + if path_eq_str(&mnv.path, key) { + if let syn::Lit::Str(ref lit) = mnv.lit { + Some(lit.value()) + } else { + None + } + } else { + None + } + } else { + None + } + }) +} + +fn has_meta_word(word: &str, ml: &syn::MetaList) -> bool { + ml.nested.iter().any(|nested| { + if let syn::NestedMeta::Meta(syn::Meta::Path(p)) = nested { + path_eq_str(&p, word) + } else { + false + } + }) +} + +fn get_aliases(ml: &syn::MetaList) -> Vec { + ml.nested + .iter() + .find_map(|nested| { + if let syn::NestedMeta::Meta(syn::Meta::List(list)) = nested { + if path_eq_str(&list.path, ALIASES_KEY) { + Some(list) + } else { + None + } + } else { + None + } + }) + .map_or(Vec::new(), |list| { + list.nested + .iter() + .filter_map(|nm| { + if let syn::NestedMeta::Lit(syn::Lit::Str(alias)) = nm { + Some(alias.value()) + } else { + None + } + }) + .collect() + }) +} + +fn get_params_style(ml: &syn::MetaList) -> Result { + get_name_value(PARAMS_STYLE_KEY, ml).map_or(Ok(ParamStyle::default()), |s| { + ParamStyle::from_str(&s).map_err(|e| Error::new_spanned(ml, e)) + }) +} + +pub fn path_eq_str(path: &syn::Path, s: &str) -> bool { + path.get_ident().map_or(false, |i| i == s) +} + +fn path_to_str(path: &syn::Path) -> Option { + Some(path.get_ident()?.to_string()) +} diff --git a/derive/src/rpc_trait.rs b/derive/src/rpc_trait.rs new file mode 100644 index 000000000..25bc004b6 --- /dev/null +++ b/derive/src/rpc_trait.rs @@ -0,0 +1,285 @@ +use crate::options::DeriveOptions; +use crate::params_style::ParamStyle; +use crate::rpc_attr::{AttributeKind, PubSubMethodKind, RpcMethodAttribute}; +use crate::to_client::generate_client_module; +use crate::to_delegate::{generate_trait_item_method, MethodRegistration, RpcMethod}; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use std::collections::HashMap; +use syn::{ + fold::{self, Fold}, + parse_quote, + punctuated::Punctuated, + Error, Ident, Result, Token, +}; + +const METADATA_TYPE: &str = "Metadata"; + +const MISSING_SUBSCRIBE_METHOD_ERR: &str = "Can't find subscribe method, expected a method annotated with `subscribe` \ + e.g. `#[pubsub(subscription = \"hello\", subscribe, name = \"hello_subscribe\")]`"; + +const MISSING_UNSUBSCRIBE_METHOD_ERR: &str = + "Can't find unsubscribe method, expected a method annotated with `unsubscribe` \ + e.g. `#[pubsub(subscription = \"hello\", unsubscribe, name = \"hello_unsubscribe\")]`"; + +pub const USING_NAMED_PARAMS_WITH_SERVER_ERR: &str = + "`params = \"named\"` can only be used to generate a client (on a trait annotated with #[rpc(client)]). \ + At this time the server does not support named parameters."; + +const RPC_MOD_NAME_PREFIX: &str = "rpc_impl_"; + +struct RpcTrait { + has_pubsub_methods: bool, + methods: Vec, + has_metadata: bool, +} + +impl<'a> Fold for RpcTrait { + fn fold_trait_item_method(&mut self, method: syn::TraitItemMethod) -> syn::TraitItemMethod { + let mut foldable_method = method.clone(); + // strip rpc attributes + foldable_method.attrs.retain(|a| { + let rpc_method = self.methods.iter().find(|m| m.trait_item == method); + rpc_method.map_or(true, |rpc| rpc.attr.attr != *a) + }); + fold::fold_trait_item_method(self, foldable_method) + } + + fn fold_trait_item_type(&mut self, mut ty: syn::TraitItemType) -> syn::TraitItemType { + if ty.ident == METADATA_TYPE { + self.has_metadata = true; + if self.has_pubsub_methods { + ty.bounds.push(parse_quote!(_jsonrpc_pubsub::PubSubMetadata)) + } else { + ty.bounds.push(parse_quote!(_jsonrpc_core::Metadata)) + } + return ty; + } + ty + } +} + +fn compute_method_registrations(item_trait: &syn::ItemTrait) -> Result<(Vec, Vec)> { + let methods_result: Result> = item_trait + .items + .iter() + .filter_map(|trait_item| { + if let syn::TraitItem::Method(method) = trait_item { + match RpcMethodAttribute::parse_attr(method) { + Ok(Some(attr)) => Some(Ok(RpcMethod::new(attr, method.clone()))), + Ok(None) => None, // non rpc annotated trait method + Err(err) => Some(Err(syn::Error::new_spanned(method, err))), + } + } else { + None + } + }) + .collect(); + let methods = methods_result?; + + let mut pubsub_method_pairs: HashMap, Option)> = HashMap::new(); + let mut method_registrations: Vec = Vec::new(); + + for method in methods.iter() { + match &method.attr().kind { + AttributeKind::Rpc { + has_metadata, + is_notification, + .. + } => { + if *is_notification { + method_registrations.push(MethodRegistration::Notification { + method: method.clone(), + has_metadata: *has_metadata, + }) + } else { + method_registrations.push(MethodRegistration::Standard { + method: method.clone(), + has_metadata: *has_metadata, + }) + } + } + AttributeKind::PubSub { + subscription_name, + kind, + } => { + let (ref mut sub, ref mut unsub) = pubsub_method_pairs + .entry(subscription_name.clone()) + .or_insert((vec![], None)); + match kind { + PubSubMethodKind::Subscribe => sub.push(method.clone()), + PubSubMethodKind::Unsubscribe => { + if unsub.is_none() { + *unsub = Some(method.clone()) + } else { + return Err(syn::Error::new_spanned( + &method.trait_item, + format!( + "Subscription '{}' unsubscribe method is already defined", + subscription_name + ), + )); + } + } + } + } + } + } + + for (name, pair) in pubsub_method_pairs { + match pair { + (subscribers, Some(unsubscribe)) => { + if subscribers.is_empty() { + return Err(syn::Error::new_spanned( + &unsubscribe.trait_item, + format!("subscription '{}'. {}", name, MISSING_SUBSCRIBE_METHOD_ERR), + )); + } + + let mut subscriber_args = subscribers.iter().filter_map(|s| s.subscriber_arg()); + if let Some(subscriber_arg) = subscriber_args.next() { + for next_method_arg in subscriber_args { + if next_method_arg != subscriber_arg { + return Err(syn::Error::new_spanned( + &next_method_arg, + format!( + "Inconsistent signature for 'Subscriber' argument: {}, previously defined: {}", + next_method_arg.clone().into_token_stream(), + subscriber_arg.into_token_stream() + ), + )); + } + } + } + + method_registrations.push(MethodRegistration::PubSub { + name: name.clone(), + subscribes: subscribers.clone(), + unsubscribe: unsubscribe.clone(), + }); + } + (_, None) => { + return Err(syn::Error::new_spanned( + &item_trait, + format!("subscription '{}'. {}", name, MISSING_UNSUBSCRIBE_METHOD_ERR), + )); + } + } + } + + Ok((method_registrations, methods)) +} + +fn generate_server_module( + method_registrations: &[MethodRegistration], + item_trait: &syn::ItemTrait, + methods: &[RpcMethod], +) -> Result { + let has_pubsub_methods = methods.iter().any(RpcMethod::is_pubsub); + + let mut rpc_trait = RpcTrait { + methods: methods.to_owned(), + has_pubsub_methods, + has_metadata: false, + }; + let mut rpc_server_trait = fold::fold_item_trait(&mut rpc_trait, item_trait.clone()); + + let to_delegate_method = generate_trait_item_method( + &method_registrations, + &rpc_server_trait, + rpc_trait.has_metadata, + has_pubsub_methods, + )?; + + rpc_server_trait.items.push(syn::TraitItem::Method(to_delegate_method)); + + let trait_bounds: Punctuated = parse_quote!(Sized + Send + Sync + 'static); + rpc_server_trait.supertraits.extend(trait_bounds); + + let optional_pubsub_import = if has_pubsub_methods { + crate_name("jsonrpc-pubsub").map(|pubsub_name| quote!(use #pubsub_name as _jsonrpc_pubsub;)) + } else { + Ok(quote!()) + }?; + + let rpc_server_module = quote! { + /// The generated server module. + pub mod gen_server { + #optional_pubsub_import + use self::_jsonrpc_core::futures as _futures; + use super::*; + + #rpc_server_trait + } + }; + + Ok(rpc_server_module) +} + +fn rpc_wrapper_mod_name(rpc_trait: &syn::ItemTrait) -> syn::Ident { + let name = rpc_trait.ident.clone(); + let mod_name = format!("{}{}", RPC_MOD_NAME_PREFIX, name.to_string()); + syn::Ident::new(&mod_name, proc_macro2::Span::call_site()) +} + +fn has_named_params(methods: &[RpcMethod]) -> bool { + methods + .iter() + .any(|method| method.attr.params_style == Some(ParamStyle::Named)) +} + +pub fn crate_name(name: &str) -> Result { + proc_macro_crate::crate_name(name) + .map(|name| Ident::new(&name, Span::call_site())) + .map_err(|e| Error::new(Span::call_site(), &e)) +} + +pub fn rpc_impl(input: syn::Item, options: &DeriveOptions) -> Result { + let rpc_trait = match input { + syn::Item::Trait(item_trait) => item_trait, + item => { + return Err(syn::Error::new_spanned( + item, + "The #[rpc] custom attribute only works with trait declarations", + )); + } + }; + + let (method_registrations, methods) = compute_method_registrations(&rpc_trait)?; + + let name = rpc_trait.ident.clone(); + let mod_name_ident = rpc_wrapper_mod_name(&rpc_trait); + + let core_name = crate_name("jsonrpc-core")?; + + let mut submodules = Vec::new(); + let mut exports = Vec::new(); + if options.enable_client { + let rpc_client_module = generate_client_module(&method_registrations, &rpc_trait, options)?; + submodules.push(rpc_client_module); + let client_name = syn::Ident::new(&format!("{}Client", name), proc_macro2::Span::call_site()); + exports.push(quote! { + pub use self::#mod_name_ident::gen_client; + pub use self::#mod_name_ident::gen_client::Client as #client_name; + }); + } + if options.enable_server { + if has_named_params(&methods) { + return Err(syn::Error::new_spanned(rpc_trait, USING_NAMED_PARAMS_WITH_SERVER_ERR)); + } + let rpc_server_module = generate_server_module(&method_registrations, &rpc_trait, &methods)?; + submodules.push(rpc_server_module); + exports.push(quote! { + pub use self::#mod_name_ident::gen_server::#name; + }); + } + Ok(quote!( + mod #mod_name_ident { + use #core_name as _jsonrpc_core; + use super::*; + + #(#submodules)* + } + #(#exports)* + )) +} diff --git a/derive/src/to_client.rs b/derive/src/to_client.rs new file mode 100644 index 000000000..6de28a745 --- /dev/null +++ b/derive/src/to_client.rs @@ -0,0 +1,336 @@ +use crate::options::DeriveOptions; +use crate::params_style::ParamStyle; +use crate::rpc_attr::AttributeKind; +use crate::rpc_trait::crate_name; +use crate::to_delegate::{generate_where_clause_serialization_predicates, MethodRegistration}; +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::Result; + +pub fn generate_client_module( + methods: &[MethodRegistration], + item_trait: &syn::ItemTrait, + options: &DeriveOptions, +) -> Result { + let client_methods = generate_client_methods(methods, &options)?; + let generics = &item_trait.generics; + let where_clause = generate_where_clause_serialization_predicates(&item_trait, true); + let where_clause2 = where_clause.clone(); + let markers = generics + .params + .iter() + .filter_map(|param| match param { + syn::GenericParam::Type(syn::TypeParam { ident, .. }) => Some(ident), + _ => None, + }) + .enumerate() + .map(|(i, ty)| { + let field_name = "_".to_string() + &i.to_string(); + let field = Ident::new(&field_name, ty.span()); + (field, ty) + }); + let (markers_decl, markers_impl): (Vec<_>, Vec<_>) = markers + .map(|(field, ty)| { + ( + quote! { + #field: std::marker::PhantomData<#ty> + }, + quote! { + #field: std::marker::PhantomData + }, + ) + }) + .unzip(); + let client_name = crate_name("jsonrpc-core-client")?; + let client = quote! { + /// The generated client module. + pub mod gen_client { + use #client_name as _jsonrpc_core_client; + use super::*; + use _jsonrpc_core::{ + Call, Error, ErrorCode, Id, MethodCall, Params, Request, + Response, Version, + }; + use _jsonrpc_core::serde_json::{self, Value}; + use _jsonrpc_core_client::futures::{Future, FutureExt, channel::{mpsc, oneshot}}; + use _jsonrpc_core_client::{RpcChannel, RpcResult, RpcFuture, TypedClient, TypedSubscriptionStream}; + + /// The Client. + #[derive(Clone)] + pub struct Client#generics { + inner: TypedClient, + #(#markers_decl),* + } + + impl#generics Client#generics + where + #(#where_clause),* + { + /// Creates a new `Client`. + pub fn new(sender: RpcChannel) -> Self { + Client { + inner: sender.into(), + #(#markers_impl),* + } + } + + #(#client_methods)* + } + + impl#generics From for Client#generics + where + #(#where_clause2),* + { + fn from(channel: RpcChannel) -> Self { + Client::new(channel.into()) + } + } + } + }; + + Ok(client) +} + +fn generate_client_methods(methods: &[MethodRegistration], options: &DeriveOptions) -> Result> { + let mut client_methods = vec![]; + for method in methods { + match method { + MethodRegistration::Standard { method, .. } => { + let attrs = get_doc_comments(&method.trait_item.attrs); + let rpc_name = method.name(); + let name = &method.trait_item.sig.ident; + let args = compute_args(&method.trait_item); + let arg_names = compute_arg_identifiers(&args)?; + let returns = match &method.attr.kind { + AttributeKind::Rpc { returns, .. } => compute_returns(&method.trait_item, returns)?, + AttributeKind::PubSub { .. } => continue, + }; + let returns_str = quote!(#returns).to_string(); + + let args_serialized = match method + .attr + .params_style + .clone() + .unwrap_or_else(|| options.params_style.clone()) + { + ParamStyle::Named => { + quote! { // use object style serialization with field names taken from the function param names + serde_json::json!({ + #(stringify!(#arg_names): #arg_names,)* + }) + } + } + ParamStyle::Positional => quote! { // use tuple style serialization + (#(#arg_names,)*) + }, + ParamStyle::Raw => match arg_names.first() { + Some(arg_name) => quote! {#arg_name}, + None => quote! {serde_json::Value::Null}, + }, + }; + + let client_method = syn::parse_quote! { + #(#attrs)* + pub fn #name(&self, #args) -> impl Future> { + let args = #args_serialized; + self.inner.call_method(#rpc_name, #returns_str, args) + } + }; + client_methods.push(client_method); + } + MethodRegistration::PubSub { + name: subscription, + subscribes, + unsubscribe, + } => { + for subscribe in subscribes { + let attrs = get_doc_comments(&subscribe.trait_item.attrs); + let name = &subscribe.trait_item.sig.ident; + let mut args = compute_args(&subscribe.trait_item).into_iter(); + let returns = compute_subscription_type(&args.next().unwrap()); + let returns_str = quote!(#returns).to_string(); + let args = args.collect(); + let arg_names = compute_arg_identifiers(&args)?; + let subscribe = subscribe.name(); + let unsubscribe = unsubscribe.name(); + let client_method = syn::parse_quote!( + #(#attrs)* + pub fn #name(&self, #args) -> RpcResult> { + let args_tuple = (#(#arg_names,)*); + self.inner.subscribe(#subscribe, args_tuple, #subscription, #unsubscribe, #returns_str) + } + ); + client_methods.push(client_method); + } + } + MethodRegistration::Notification { method, .. } => { + let attrs = get_doc_comments(&method.trait_item.attrs); + let rpc_name = method.name(); + let name = &method.trait_item.sig.ident; + let args = compute_args(&method.trait_item); + let arg_names = compute_arg_identifiers(&args)?; + let client_method = syn::parse_quote! { + #(#attrs)* + pub fn #name(&self, #args) -> RpcResult<()> { + let args_tuple = (#(#arg_names,)*); + self.inner.notify(#rpc_name, args_tuple) + } + }; + client_methods.push(client_method); + } + } + } + Ok(client_methods) +} + +fn get_doc_comments(attrs: &[syn::Attribute]) -> Vec { + let mut doc_comments = vec![]; + for attr in attrs { + match attr { + syn::Attribute { + path: syn::Path { segments, .. }, + .. + } => match &segments[0] { + syn::PathSegment { ident, .. } => { + if *ident == "doc" { + doc_comments.push(attr.to_owned()); + } + } + }, + } + } + doc_comments +} + +fn compute_args(method: &syn::TraitItemMethod) -> Punctuated { + let mut args = Punctuated::new(); + for arg in &method.sig.inputs { + let ty = match arg { + syn::FnArg::Typed(syn::PatType { ty, .. }) => ty, + _ => continue, + }; + let segments = match &**ty { + syn::Type::Path(syn::TypePath { + path: syn::Path { ref segments, .. }, + .. + }) => segments, + _ => continue, + }; + let syn::PathSegment { ident, .. } = &segments[0]; + let ident = ident; + if *ident == "Self" { + continue; + } + args.push(arg.to_owned()); + } + args +} + +fn compute_arg_identifiers(args: &Punctuated) -> Result> { + let mut arg_names = vec![]; + for arg in args { + let pat = match arg { + syn::FnArg::Typed(syn::PatType { pat, .. }) => pat, + _ => continue, + }; + let ident = match **pat { + syn::Pat::Ident(syn::PatIdent { ref ident, .. }) => ident, + syn::Pat::Wild(ref wild) => { + let span = wild.underscore_token.spans[0]; + let msg = "No wildcard patterns allowed in rpc trait."; + return Err(syn::Error::new(span, msg)); + } + _ => continue, + }; + arg_names.push(ident); + } + Ok(arg_names) +} + +fn compute_returns(method: &syn::TraitItemMethod, returns: &Option) -> Result { + let returns: Option = match returns { + Some(returns) => Some(syn::parse_str(returns)?), + None => None, + }; + let returns = match returns { + None => try_infer_returns(&method.sig.output), + _ => returns, + }; + let returns = match returns { + Some(returns) => returns, + None => { + let span = method.attrs[0].pound_token.spans[0]; + let msg = "Missing returns attribute."; + return Err(syn::Error::new(span, msg)); + } + }; + Ok(returns) +} + +fn try_infer_returns(output: &syn::ReturnType) -> Option { + let extract_path_segments = |ty: &syn::Type| match ty { + syn::Type::Path(syn::TypePath { + path: syn::Path { segments, .. }, + .. + }) => Some(segments.clone()), + _ => None, + }; + + match output { + syn::ReturnType::Type(_, ty) => { + let segments = extract_path_segments(&**ty)?; + let check_segment = |seg: &syn::PathSegment| match seg { + syn::PathSegment { ident, arguments, .. } => { + let id = ident.to_string(); + let inner = get_first_type_argument(arguments); + if id.ends_with("Result") { + Ok(inner) + } else { + Err(inner) + } + } + }; + // Try out first argument (Result) or nested types like: + // BoxFuture> + match check_segment(&segments[0]) { + Ok(returns) => Some(returns?), + Err(inner) => { + let segments = extract_path_segments(&inner?)?; + check_segment(&segments[0]).ok().flatten() + } + } + } + _ => None, + } +} + +fn get_first_type_argument(args: &syn::PathArguments) -> Option { + match args { + syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }) => { + if !args.is_empty() { + match &args[0] { + syn::GenericArgument::Type(ty) => Some(ty.clone()), + _ => None, + } + } else { + None + } + } + _ => None, + } +} + +fn compute_subscription_type(arg: &syn::FnArg) -> syn::Type { + let ty = match arg { + syn::FnArg::Typed(cap) => match *cap.ty { + syn::Type::Path(ref path) => { + let last = &path.path.segments[&path.path.segments.len() - 1]; + get_first_type_argument(&last.arguments) + } + _ => None, + }, + _ => None, + }; + ty.expect("a subscription needs a return type") +} diff --git a/derive/src/to_delegate.rs b/derive/src/to_delegate.rs new file mode 100644 index 000000000..7e9148a82 --- /dev/null +++ b/derive/src/to_delegate.rs @@ -0,0 +1,510 @@ +use std::collections::HashSet; + +use crate::params_style::ParamStyle; +use crate::rpc_attr::RpcMethodAttribute; +use quote::quote; +use syn::{ + parse_quote, + punctuated::Punctuated, + visit::{self, Visit}, + Result, Token, +}; + +pub enum MethodRegistration { + Standard { + method: RpcMethod, + has_metadata: bool, + }, + PubSub { + name: String, + subscribes: Vec, + unsubscribe: RpcMethod, + }, + Notification { + method: RpcMethod, + has_metadata: bool, + }, +} + +impl MethodRegistration { + fn generate(&self) -> Result { + match self { + MethodRegistration::Standard { method, has_metadata } => { + let rpc_name = &method.name(); + let add_method = if *has_metadata { + quote!(add_method_with_meta) + } else { + quote!(add_method) + }; + let closure = method.generate_delegate_closure(false)?; + let add_aliases = method.generate_add_aliases(); + + Ok(quote! { + del.#add_method(#rpc_name, #closure); + #add_aliases + }) + } + MethodRegistration::PubSub { + name, + subscribes, + unsubscribe, + } => { + let unsub_name = unsubscribe.name(); + let unsub_method_ident = unsubscribe.ident(); + let unsub_closure = quote! { + move |base, id, meta| { + use self::_futures::{FutureExt, TryFutureExt}; + self::_jsonrpc_core::WrapFuture::into_future( + Self::#unsub_method_ident(base, meta, id) + ) + .map_ok(|value| _jsonrpc_core::to_value(value) + .expect("Expected always-serializable type; qed")) + .map_err(Into::into) + } + }; + + let mut add_subscriptions = proc_macro2::TokenStream::new(); + + for subscribe in subscribes.iter() { + let sub_name = subscribe.name(); + let sub_closure = subscribe.generate_delegate_closure(true)?; + let sub_aliases = subscribe.generate_add_aliases(); + + add_subscriptions = quote! { + #add_subscriptions + del.add_subscription( + #name, + (#sub_name, #sub_closure), + (#unsub_name, #unsub_closure), + ); + #sub_aliases + }; + } + + let unsub_aliases = unsubscribe.generate_add_aliases(); + + Ok(quote! { + #add_subscriptions + #unsub_aliases + }) + } + MethodRegistration::Notification { method, has_metadata } => { + let name = &method.name(); + let add_notification = if *has_metadata { + quote!(add_notification_with_meta) + } else { + quote!(add_notification) + }; + let closure = method.generate_delegate_closure(false)?; + let add_aliases = method.generate_add_aliases(); + + Ok(quote! { + del.#add_notification(#name, #closure); + #add_aliases + }) + } + } + } +} + +const SUBSCRIBER_TYPE_IDENT: &str = "Subscriber"; +const METADATA_CLOSURE_ARG: &str = "meta"; +const SUBSCRIBER_CLOSURE_ARG: &str = "subscriber"; + +// tuples are limited to 16 fields: the maximum supported by `serde::Deserialize` +const TUPLE_FIELD_NAMES: [&str; 16] = [ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", +]; + +pub fn generate_trait_item_method( + methods: &[MethodRegistration], + trait_item: &syn::ItemTrait, + has_metadata: bool, + has_pubsub_methods: bool, +) -> Result { + let io_delegate_type = if has_pubsub_methods { + quote!(_jsonrpc_pubsub::IoDelegate) + } else { + quote!(_jsonrpc_core::IoDelegate) + }; + let add_methods = methods + .iter() + .map(MethodRegistration::generate) + .collect::>>()?; + let to_delegate_body = quote! { + let mut del = #io_delegate_type::new(self.into()); + #(#add_methods)* + del + }; + + let method: syn::TraitItemMethod = if has_metadata { + parse_quote! { + /// Create an `IoDelegate`, wiring rpc calls to the trait methods. + fn to_delegate(self) -> #io_delegate_type { + #to_delegate_body + } + } + } else { + parse_quote! { + /// Create an `IoDelegate`, wiring rpc calls to the trait methods. + fn to_delegate(self) -> #io_delegate_type { + #to_delegate_body + } + } + }; + + let predicates = generate_where_clause_serialization_predicates(&trait_item, false); + let mut method = method; + method.sig.generics.make_where_clause().predicates.extend(predicates); + Ok(method) +} + +#[derive(Clone)] +pub struct RpcMethod { + pub attr: RpcMethodAttribute, + pub trait_item: syn::TraitItemMethod, +} + +impl RpcMethod { + pub fn new(attr: RpcMethodAttribute, trait_item: syn::TraitItemMethod) -> RpcMethod { + RpcMethod { attr, trait_item } + } + + pub fn attr(&self) -> &RpcMethodAttribute { + &self.attr + } + + pub fn name(&self) -> &str { + &self.attr.name + } + + pub fn ident(&self) -> &syn::Ident { + &self.trait_item.sig.ident + } + + pub fn is_pubsub(&self) -> bool { + self.attr.is_pubsub() + } + + pub fn subscriber_arg(&self) -> Option { + self.trait_item + .sig + .inputs + .iter() + .filter_map(|arg| match arg { + syn::FnArg::Typed(ty) => Some(*ty.ty.clone()), + _ => None, + }) + .find(|ty| { + if let syn::Type::Path(path) = ty { + if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) { + return true; + } + } + false + }) + } + + fn generate_delegate_closure(&self, is_subscribe: bool) -> Result { + let mut param_types: Vec<_> = self + .trait_item + .sig + .inputs + .iter() + .cloned() + .filter_map(|arg| match arg { + syn::FnArg::Typed(ty) => Some(*ty.ty), + _ => None, + }) + .collect(); + + // special args are those which are not passed directly via rpc params: metadata, subscriber + let special_args = Self::special_args(¶m_types); + param_types.retain(|ty| !special_args.iter().any(|(_, sty)| sty == ty)); + if param_types.len() > TUPLE_FIELD_NAMES.len() { + return Err(syn::Error::new_spanned( + &self.trait_item, + &format!("Maximum supported number of params is {}", TUPLE_FIELD_NAMES.len()), + )); + } + let tuple_fields: &Vec<_> = &(TUPLE_FIELD_NAMES + .iter() + .take(param_types.len()) + .map(|name| ident(name)) + .collect()); + let param_types = ¶m_types; + let parse_params = { + // last arguments that are `Option`-s are optional 'trailing' arguments + let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count(); + + if trailing_args_num != 0 { + self.params_with_trailing(trailing_args_num, param_types, tuple_fields) + } else if param_types.is_empty() { + quote! { let params = params.expect_no_params(); } + } else if self.attr.params_style == Some(ParamStyle::Raw) { + quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); } + } else if self.attr.params_style == Some(ParamStyle::Positional) { + quote! { let params = params.parse::<(#(#param_types, )*)>(); } + } else { + unimplemented!("Server side named parameters are not implemented"); + } + }; + + let method_ident = self.ident(); + let result = &self.trait_item.sig.output; + let extra_closure_args: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.0).collect(); + let extra_method_types: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.1).collect(); + + let closure_args = quote! { base, params, #(#extra_closure_args), * }; + let method_sig = quote! { fn(&Self, #(#extra_method_types, ) * #(#param_types), *) #result }; + let method_call = quote! { (base, #(#extra_closure_args, )* #(#tuple_fields), *) }; + let match_params = if is_subscribe { + quote! { + Ok((#(#tuple_fields, )*)) => { + let subscriber = _jsonrpc_pubsub::typed::Subscriber::new(subscriber); + (method)#method_call + }, + Err(e) => { + let _ = subscriber.reject(e); + return + } + } + } else if self.attr.is_notification() { + quote! { + Ok((#(#tuple_fields, )*)) => { + (method)#method_call + }, + Err(_) => return, + } + } else { + quote! { + Ok((#(#tuple_fields, )*)) => { + use self::_futures::{FutureExt, TryFutureExt}; + let fut = self::_jsonrpc_core::WrapFuture::into_future((method)#method_call) + .map_ok(|value| _jsonrpc_core::to_value(value) + .expect("Expected always-serializable type; qed")) + .map_err(Into::into as fn(_) -> _jsonrpc_core::Error); + _futures::future::Either::Left(fut) + }, + Err(e) => _futures::future::Either::Right(_futures::future::ready(Err(e))), + } + }; + + Ok(quote! { + move |#closure_args| { + let method = &(Self::#method_ident as #method_sig); + #parse_params + match params { + #match_params + } + } + }) + } + + fn special_args(param_types: &[syn::Type]) -> Vec<(syn::Ident, syn::Type)> { + let meta_arg = param_types.first().and_then(|ty| { + if *ty == parse_quote!(Self::Metadata) { + Some(ty.clone()) + } else { + None + } + }); + let subscriber_arg = param_types.get(1).and_then(|ty| { + if let syn::Type::Path(path) = ty { + if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) { + Some(ty.clone()) + } else { + None + } + } else { + None + } + }); + + let mut special_args = Vec::new(); + if let Some(meta) = meta_arg { + special_args.push((ident(METADATA_CLOSURE_ARG), meta)); + } + if let Some(subscriber) = subscriber_arg { + special_args.push((ident(SUBSCRIBER_CLOSURE_ARG), subscriber)); + } + special_args + } + + fn params_with_trailing( + &self, + trailing_args_num: usize, + param_types: &[syn::Type], + tuple_fields: &[syn::Ident], + ) -> proc_macro2::TokenStream { + let total_args_num = param_types.len(); + let required_args_num = total_args_num - trailing_args_num; + + let switch_branches = (0..=trailing_args_num) + .map(|passed_trailing_args_num| { + let passed_args_num = required_args_num + passed_trailing_args_num; + let passed_param_types = ¶m_types[..passed_args_num]; + let passed_tuple_fields = &tuple_fields[..passed_args_num]; + let missed_args_num = total_args_num - passed_args_num; + let missed_params_values = ::std::iter::repeat(quote! { None }) + .take(missed_args_num) + .collect::>(); + + if passed_args_num == 0 { + quote! { + #passed_args_num => params.expect_no_params() + .map(|_| (#(#missed_params_values, ) *)) + .map_err(Into::into) + } + } else { + quote! { + #passed_args_num => params.parse::<(#(#passed_param_types, )*)>() + .map(|(#(#passed_tuple_fields,)*)| + (#(#passed_tuple_fields, )* #(#missed_params_values, )*)) + .map_err(Into::into) + } + } + }) + .collect::>(); + + quote! { + let passed_args_num = match params { + _jsonrpc_core::Params::Array(ref v) => Ok(v.len()), + _jsonrpc_core::Params::None => Ok(0), + _ => Err(_jsonrpc_core::Error::invalid_params("`params` should be an array")) + }; + + let params = passed_args_num.and_then(|passed_args_num| { + match passed_args_num { + _ if passed_args_num < #required_args_num => Err(_jsonrpc_core::Error::invalid_params( + format!("`params` should have at least {} argument(s)", #required_args_num))), + #(#switch_branches),*, + _ => Err(_jsonrpc_core::Error::invalid_params_with_details( + format!("Expected from {} to {} parameters.", #required_args_num, #total_args_num), + format!("Got: {}", passed_args_num))), + } + }); + } + } + + fn generate_add_aliases(&self) -> proc_macro2::TokenStream { + let name = self.name(); + let add_aliases: Vec<_> = self + .attr + .aliases + .iter() + .map(|alias| quote! { del.add_alias(#alias, #name); }) + .collect(); + quote! { #(#add_aliases)* } + } +} + +fn ident(s: &str) -> syn::Ident { + syn::Ident::new(s, proc_macro2::Span::call_site()) +} + +fn is_option_type(ty: &syn::Type) -> bool { + if let syn::Type::Path(path) = ty { + path.path.segments.first().map_or(false, |t| t.ident == "Option") + } else { + false + } +} + +pub fn generate_where_clause_serialization_predicates( + item_trait: &syn::ItemTrait, + client: bool, +) -> Vec { + #[derive(Default)] + struct FindTyParams { + trait_generics: HashSet, + server_to_client_type_params: HashSet, + client_to_server_type_params: HashSet, + visiting_return_type: bool, + visiting_fn_arg: bool, + visiting_subscriber_arg: bool, + } + impl<'ast> Visit<'ast> for FindTyParams { + fn visit_type_param(&mut self, ty_param: &'ast syn::TypeParam) { + self.trait_generics.insert(ty_param.ident.clone()); + } + fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) { + self.visiting_return_type = true; + visit::visit_return_type(self, return_type); + self.visiting_return_type = false + } + fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) { + self.visiting_subscriber_arg = + self.visiting_subscriber_arg || (self.visiting_fn_arg && segment.ident == SUBSCRIBER_TYPE_IDENT); + visit::visit_path_segment(self, segment); + self.visiting_subscriber_arg = self.visiting_subscriber_arg && segment.ident != SUBSCRIBER_TYPE_IDENT; + } + fn visit_ident(&mut self, ident: &'ast syn::Ident) { + if self.trait_generics.contains(&ident) { + if self.visiting_return_type || self.visiting_subscriber_arg { + self.server_to_client_type_params.insert(ident.clone()); + } + if self.visiting_fn_arg && !self.visiting_subscriber_arg { + self.client_to_server_type_params.insert(ident.clone()); + } + } + } + fn visit_fn_arg(&mut self, arg: &'ast syn::FnArg) { + self.visiting_fn_arg = true; + visit::visit_fn_arg(self, arg); + self.visiting_fn_arg = false; + } + } + let mut visitor = FindTyParams::default(); + visitor.visit_item_trait(item_trait); + + let additional_where_clause = item_trait.generics.where_clause.clone(); + + item_trait + .generics + .type_params() + .map(|ty| { + let ty_path = syn::TypePath { + qself: None, + path: ty.ident.clone().into(), + }; + let mut bounds: Punctuated = parse_quote!(Send + Sync + 'static); + // add json serialization trait bounds + if client { + if visitor.server_to_client_type_params.contains(&ty.ident) { + bounds.push(parse_quote!(_jsonrpc_core::serde::de::DeserializeOwned)) + } + if visitor.client_to_server_type_params.contains(&ty.ident) { + bounds.push(parse_quote!(_jsonrpc_core::serde::Serialize)) + } + } else { + if visitor.server_to_client_type_params.contains(&ty.ident) { + bounds.push(parse_quote!(_jsonrpc_core::serde::Serialize)) + } + if visitor.client_to_server_type_params.contains(&ty.ident) { + bounds.push(parse_quote!(_jsonrpc_core::serde::de::DeserializeOwned)) + } + } + + // add the trait bounds specified by the user in where clause. + if let Some(ref where_clause) = additional_where_clause { + for predicate in where_clause.predicates.iter() { + if let syn::WherePredicate::Type(where_ty) = predicate { + if let syn::Type::Path(ref predicate) = where_ty.bounded_ty { + if *predicate == ty_path { + bounds.extend(where_ty.bounds.clone().into_iter()); + } + } + } + } + } + + syn::WherePredicate::Type(syn::PredicateType { + lifetimes: None, + bounded_ty: syn::Type::Path(ty_path), + colon_token: ::default(), + bounds, + }) + }) + .collect() +} diff --git a/derive/tests/client.rs b/derive/tests/client.rs new file mode 100644 index 000000000..8ac820ab0 --- /dev/null +++ b/derive/tests/client.rs @@ -0,0 +1,128 @@ +use assert_matches::assert_matches; +use jsonrpc_core::futures::{self, FutureExt, TryFutureExt}; +use jsonrpc_core::{IoHandler, Result}; +use jsonrpc_core_client::transports::local; +use jsonrpc_derive::rpc; + +mod client_server { + use super::*; + + #[rpc(params = "positional")] + pub trait Rpc { + #[rpc(name = "add")] + fn add(&self, a: u64, b: u64) -> Result; + + #[rpc(name = "notify")] + fn notify(&self, foo: u64); + } + + struct RpcServer; + + impl Rpc for RpcServer { + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn notify(&self, foo: u64) { + println!("received {}", foo); + } + } + + #[test] + fn client_server_roundtrip() { + let mut handler = IoHandler::new(); + handler.extend_with(RpcServer.to_delegate()); + let (client, rpc_client) = local::connect::(handler); + let fut = client + .clone() + .add(3, 4) + .map_ok(move |res| client.notify(res).map(move |_| res)) + .map(|res| { + self::assert_matches!(res, Ok(Ok(7))); + }); + let exec = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + exec.spawn_ok(async move { + futures::join!(fut, rpc_client).1.unwrap(); + }); + } +} + +mod named_params { + use super::*; + use jsonrpc_core::Params; + use serde_json::json; + + #[rpc(client, params = "named")] + pub trait Rpc { + #[rpc(name = "call_with_named")] + fn call_with_named(&self, number: u64, string: String, json: Value) -> Result; + + #[rpc(name = "notify", params = "raw")] + fn notify(&self, payload: Value); + } + + #[test] + fn client_generates_correct_named_params_payload() { + use jsonrpc_core::futures::{FutureExt, TryFutureExt}; + + let expected = json!({ // key names are derived from function parameter names in the trait + "number": 3, + "string": String::from("test string"), + "json": { + "key": ["value"] + } + }); + + let mut handler = IoHandler::new(); + handler.add_sync_method("call_with_named", |params: Params| Ok(params.into())); + + let (client, rpc_client) = local::connect::(handler); + let fut = client + .clone() + .call_with_named(3, String::from("test string"), json!({"key": ["value"]})) + .map_ok(move |res| client.notify(res.clone()).map(move |_| res)) + .map(move |res| { + self::assert_matches!(res, Ok(Ok(x)) if x == expected); + }); + let exec = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + exec.spawn_ok(async move { futures::join!(fut, rpc_client).1.unwrap() }); + } +} + +mod raw_params { + use super::*; + use jsonrpc_core::Params; + use serde_json::json; + + #[rpc(client)] + pub trait Rpc { + #[rpc(name = "call_raw", params = "raw")] + fn call_raw_single_param(&self, params: Value) -> Result; + + #[rpc(name = "notify", params = "raw")] + fn notify(&self, payload: Value); + } + + #[test] + fn client_generates_correct_raw_params_payload() { + let expected = json!({ + "sub_object": { + "key": ["value"] + } + }); + + let mut handler = IoHandler::new(); + handler.add_sync_method("call_raw", |params: Params| Ok(params.into())); + + let (client, rpc_client) = local::connect::(handler); + let fut = client + .clone() + .call_raw_single_param(expected.clone()) + .map_ok(move |res| client.notify(res.clone()).map(move |_| res)) + .map(move |res| { + self::assert_matches!(res, Ok(Ok(x)) if x == expected); + }); + let exec = futures::executor::ThreadPool::builder().pool_size(1).create().unwrap(); + exec.spawn_ok(async move { futures::join!(fut, rpc_client).1.unwrap() }); + } +} diff --git a/derive/tests/macros.rs b/derive/tests/macros.rs new file mode 100644 index 000000000..1f4483672 --- /dev/null +++ b/derive/tests/macros.rs @@ -0,0 +1,256 @@ +use jsonrpc_core::types::params::Params; +use jsonrpc_core::{IoHandler, Response}; +use jsonrpc_derive::rpc; +use serde_json; + +pub enum MyError {} +impl From for jsonrpc_core::Error { + fn from(_e: MyError) -> Self { + unreachable!() + } +} + +type Result = ::std::result::Result; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version. + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Negates number and returns a result. + #[rpc(name = "neg")] + fn neg(&self, a: i64) -> Result; + + /// Adds two numbers and returns a result. + #[rpc(name = "add", alias("add_alias1", "add_alias2"))] + fn add(&self, a: u64, b: u64) -> Result; + + /// Retrieves and debug prints the underlying `Params` object. + #[rpc(name = "raw", params = "raw")] + fn raw(&self, params: Params) -> Result; + + /// Handles a notification. + #[rpc(name = "notify")] + fn notify(&self, a: u64); +} + +#[derive(Default)] +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn neg(&self, a: i64) -> Result { + Ok(-a) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn raw(&self, _params: Params) -> Result { + Ok("OK".into()) + } + + fn notify(&self, a: u64) { + println!("Received `notify` with value: {}", a); + } +} + +#[test] +fn should_accept_empty_array_as_no_params() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion","params":[]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion","params":null}"#; + let req3 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion"}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + let res3 = io.handle_request_sync(req3); + let expected = r#"{ + "jsonrpc": "2.0", + "result": "version1", + "id": 1 + }"#; + let expected: Response = serde_json::from_str(expected).unwrap(); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!(expected, result1); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!(expected, result2); + + let result3: Response = serde_json::from_str(&res3.unwrap()).unwrap(); + assert_eq!(expected, result3); +} + +#[test] +fn should_accept_single_param() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"neg","params":[1]}"#; + + let res1 = io.handle_request_sync(req1); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!( + result1, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": -1, + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_accept_multiple_params() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"add","params":[1, 2]}"#; + + let res1 = io.handle_request_sync(req1); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!( + result1, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_use_method_name_aliases() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"add_alias1","params":[1, 2]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"add_alias2","params":[1, 2]}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!( + result1, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"# + ) + .unwrap() + ); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!( + result2, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_accept_any_raw_params() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"raw","params":[1, 2]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"raw","params":{"foo":"bar"}}"#; + let req3 = r#"{"jsonrpc":"2.0","id":1,"method":"raw","params":null}"#; + let req4 = r#"{"jsonrpc":"2.0","id":1,"method":"raw"}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + let res3 = io.handle_request_sync(req3); + let res4 = io.handle_request_sync(req4); + let expected = r#"{ + "jsonrpc": "2.0", + "result": "OK", + "id": 1 + }"#; + let expected: Response = serde_json::from_str(expected).unwrap(); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!(expected, result1); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!(expected, result2); + + let result3: Response = serde_json::from_str(&res3.unwrap()).unwrap(); + assert_eq!(expected, result3); + + let result4: Response = serde_json::from_str(&res4.unwrap()).unwrap(); + assert_eq!(expected, result4); +} + +#[test] +fn should_accept_only_notifications() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","method":"notify","params":[1]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"notify","params":[1]}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + + // then + assert!(res1.is_none()); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!( + result2, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "error": { + "code": -32601, + "message": "Method not found" + }, + "id":1 + }"# + ) + .unwrap() + ); +} diff --git a/derive/tests/pubsub-macros.rs b/derive/tests/pubsub-macros.rs new file mode 100644 index 000000000..71f7615fd --- /dev/null +++ b/derive/tests/pubsub-macros.rs @@ -0,0 +1,147 @@ +use jsonrpc_core; +use jsonrpc_pubsub; +use serde_json; +#[macro_use] +extern crate jsonrpc_derive; + +use jsonrpc_core::futures::channel::mpsc; +use jsonrpc_pubsub::typed::Subscriber; +use jsonrpc_pubsub::{PubSubHandler, PubSubMetadata, Session, SubscriptionId}; +use std::sync::Arc; + +pub enum MyError {} +impl From for jsonrpc_core::Error { + fn from(_e: MyError) -> Self { + unreachable!() + } +} + +type Result = ::std::result::Result; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription. + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_alias"))] + fn subscribe(&self, a: Self::Metadata, b: Subscriber, c: u32, d: Option); + + /// Hello subscription through different method. + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe_second")] + fn subscribe_second(&self, a: Self::Metadata, b: Subscriber, e: String); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, a: Option, b: SubscriptionId) -> Result; + + /// A regular rpc method alongside pubsub. + #[rpc(name = "add")] + fn add(&self, a: u64, b: u64) -> Result; + + /// A notification alongside pubsub. + #[rpc(name = "notify")] + fn notify(&self, a: u64); +} + +#[derive(Default)] +struct RpcImpl; + +impl Rpc for RpcImpl { + type Metadata = Metadata; + + fn subscribe(&self, _meta: Self::Metadata, subscriber: Subscriber, _pre: u32, _trailing: Option) { + let _sink = subscriber.assign_id(SubscriptionId::Number(5)); + } + + fn subscribe_second(&self, _meta: Self::Metadata, subscriber: Subscriber, _e: String) { + let _sink = subscriber.assign_id(SubscriptionId::Number(6)); + } + + fn unsubscribe(&self, _meta: Option, _id: SubscriptionId) -> Result { + Ok(true) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn notify(&self, a: u64) { + println!("Received `notify` with value: {}", a); + } +} + +#[derive(Clone, Default)] +struct Metadata; +impl jsonrpc_core::Metadata for Metadata {} +impl PubSubMetadata for Metadata { + fn session(&self) -> Option> { + let (tx, _rx) = mpsc::unbounded(); + Some(Arc::new(Session::new(tx))) + } +} + +#[test] +fn test_invalid_trailing_pubsub_params() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let meta = Metadata; + let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_subscribe","params":[]}"#; + let res = io.handle_request_sync(req, meta); + let expected = r#"{ + "jsonrpc": "2.0", + "error": { + "code": -32602, + "message": "`params` should have at least 1 argument(s)" + }, + "id": 1 + }"#; + + let expected: jsonrpc_core::Response = serde_json::from_str(expected).unwrap(); + let result: jsonrpc_core::Response = serde_json::from_str(&res.unwrap()).unwrap(); + assert_eq!(expected, result); +} + +#[test] +fn test_subscribe_with_alias() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let meta = Metadata; + let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_alias","params":[1]}"#; + let res = io.handle_request_sync(req, meta); + let expected = r#"{ + "jsonrpc": "2.0", + "result": 5, + "id": 1 + }"#; + + let expected: jsonrpc_core::Response = serde_json::from_str(expected).unwrap(); + let result: jsonrpc_core::Response = serde_json::from_str(&res.unwrap()).unwrap(); + assert_eq!(expected, result); +} + +#[test] +fn test_subscribe_alternate_method() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let meta = Metadata; + let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_subscribe_second","params":["Data"]}"#; + let res = io.handle_request_sync(req, meta); + let expected = r#"{ + "jsonrpc": "2.0", + "result": 6, + "id": 1 + }"#; + + let expected: jsonrpc_core::Response = serde_json::from_str(expected).unwrap(); + let result: jsonrpc_core::Response = serde_json::from_str(&res.unwrap()).unwrap(); + assert_eq!(expected, result); +} diff --git a/derive/tests/run-pass/client_only.rs b/derive/tests/run-pass/client_only.rs new file mode 100644 index 000000000..25f7712bb --- /dev/null +++ b/derive/tests/run-pass/client_only.rs @@ -0,0 +1,30 @@ +extern crate jsonrpc_core; +extern crate jsonrpc_core_client; +#[macro_use] +extern crate jsonrpc_derive; + +use jsonrpc_core::IoHandler; +use jsonrpc_core::futures::{self, TryFutureExt}; +use jsonrpc_core_client::transports::local; + +#[rpc(client)] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; +} + +fn main() { + let fut = { + let handler = IoHandler::new(); + let (client, _rpc_client) = local::connect::(handler); + client + .add(5, 6) + .map_ok(|res| println!("5 + 6 = {}", res)) + }; + let _ = futures::executor::block_on(fut); +} diff --git a/derive/tests/run-pass/client_with_generic_trait_bounds.rs b/derive/tests/run-pass/client_with_generic_trait_bounds.rs new file mode 100644 index 000000000..b0e780c19 --- /dev/null +++ b/derive/tests/run-pass/client_with_generic_trait_bounds.rs @@ -0,0 +1,39 @@ +use jsonrpc_core::futures::future; +use jsonrpc_core::{IoHandler, Result, BoxFuture}; +use jsonrpc_derive::rpc; +use std::collections::BTreeMap; + +#[rpc] +pub trait Rpc +where + One: Ord, + Two: Ord + Eq, +{ + /// Adds two numbers and returns a result + #[rpc(name = "setTwo")] + fn set_two(&self, a: Two) -> Result>; + + /// Performs asynchronous operation + #[rpc(name = "beFancy")] + fn call(&self, a: One) -> BoxFuture>; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn set_two(&self, x: String) -> Result> { + println!("{}", x); + Ok(Default::default()) + } + + fn call(&self, num: u64) -> BoxFuture> { + Box::pin(future::ready(Ok((num + 999, "hello".into())))) + } +} + +fn main() { + let mut io = IoHandler::new(); + let rpc = RpcImpl; + + io.extend_with(rpc.to_delegate()) +} diff --git a/derive/tests/run-pass/pubsub-dependency-not-required-for-vanilla-rpc.rs b/derive/tests/run-pass/pubsub-dependency-not-required-for-vanilla-rpc.rs new file mode 100644 index 000000000..cae0bac2d --- /dev/null +++ b/derive/tests/run-pass/pubsub-dependency-not-required-for-vanilla-rpc.rs @@ -0,0 +1,32 @@ +use jsonrpc_derive::rpc; +use jsonrpc_core::{Result, IoHandler}; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } +} + +fn main() { + let mut io = IoHandler::new(); + let rpc = RpcImpl; + + io.extend_with(rpc.to_delegate()) +} diff --git a/derive/tests/run-pass/pubsub-subscription-generic-type-with-deserialize.rs b/derive/tests/run-pass/pubsub-subscription-generic-type-with-deserialize.rs new file mode 100644 index 000000000..45b8bedc2 --- /dev/null +++ b/derive/tests/run-pass/pubsub-subscription-generic-type-with-deserialize.rs @@ -0,0 +1,49 @@ +use jsonrpc_derive::rpc; +use serde::{Serialize, Deserialize}; + +use std::sync::Arc; +use jsonrpc_core::Result; +use jsonrpc_pubsub::{typed::Subscriber, SubscriptionId, Session, PubSubHandler}; + +#[derive(Serialize, Deserialize)] +pub struct Wrapper { + inner: T, + inner2: U, +} + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: Subscriber>); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, a: Option, b: SubscriptionId) -> Result; +} + +#[derive(Serialize, Deserialize)] +struct SerializeAndDeserialize { + foo: String, +} + +struct RpcImpl; +impl Rpc for RpcImpl { + type Metadata = Arc; + + fn subscribe(&self, _: Self::Metadata, _: Subscriber>) { + unimplemented!(); + } + + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result { + unimplemented!(); + } +} + +fn main() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl; + io.extend_with(rpc.to_delegate()); +} diff --git a/derive/tests/run-pass/pubsub-subscription-type-with-deserialize.rs b/derive/tests/run-pass/pubsub-subscription-type-with-deserialize.rs new file mode 100644 index 000000000..b3e4f7bde --- /dev/null +++ b/derive/tests/run-pass/pubsub-subscription-type-with-deserialize.rs @@ -0,0 +1,43 @@ +use jsonrpc_derive::rpc; +use serde::{Serialize, Deserialize}; + +use std::sync::Arc; +use jsonrpc_core::Result; +use jsonrpc_pubsub::{typed::Subscriber, SubscriptionId, Session, PubSubHandler}; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: Subscriber); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, a: Option, b: SubscriptionId) -> Result; +} + +#[derive(Serialize, Deserialize)] +struct SerializeAndDeserialize { + foo: String, +} + +struct RpcImpl; +impl Rpc for RpcImpl { + type Metadata = Arc; + + fn subscribe(&self, _: Self::Metadata, _: Subscriber) { + unimplemented!(); + } + + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result { + unimplemented!(); + } +} + +fn main() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl; + io.extend_with(rpc.to_delegate()); +} diff --git a/derive/tests/run-pass/pubsub-subscription-type-without-deserialize.rs b/derive/tests/run-pass/pubsub-subscription-type-without-deserialize.rs new file mode 100644 index 000000000..e0024595e --- /dev/null +++ b/derive/tests/run-pass/pubsub-subscription-type-without-deserialize.rs @@ -0,0 +1,44 @@ +use jsonrpc_derive::rpc; +use serde::Serialize; + +use std::sync::Arc; +use jsonrpc_core::Result; +use jsonrpc_pubsub::{typed::Subscriber, SubscriptionId, Session, PubSubHandler}; + +#[rpc(server)] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: Subscriber); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, a: Option, b: SubscriptionId) -> Result; +} + +// One way serialization +#[derive(Serialize)] +struct SerializeOnly { + foo: String, +} + +struct RpcImpl; +impl Rpc for RpcImpl { + type Metadata = Arc; + + fn subscribe(&self, _: Self::Metadata, _: Subscriber) { + unimplemented!(); + } + + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result { + unimplemented!(); + } +} + +fn main() { + let mut io = PubSubHandler::default(); + let rpc = RpcImpl; + io.extend_with(rpc.to_delegate()); +} diff --git a/derive/tests/run-pass/server_only.rs b/derive/tests/run-pass/server_only.rs new file mode 100644 index 000000000..ffa142ba9 --- /dev/null +++ b/derive/tests/run-pass/server_only.rs @@ -0,0 +1,35 @@ +extern crate jsonrpc_core; +#[macro_use] +extern crate jsonrpc_derive; + +use jsonrpc_core::{Result, IoHandler}; + +#[rpc(server)] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + fn protocol_version(&self) -> Result; + + /// Adds two numbers and returns a result + #[rpc(name = "add", alias("callAsyncMetaAlias"))] + fn add(&self, a: u64, b: u64) -> Result; +} + +struct RpcImpl; + +impl Rpc for RpcImpl { + fn protocol_version(&self) -> Result { + Ok("version1".into()) + } + + fn add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } +} + +fn main() { + let mut io = IoHandler::new(); + let rpc = RpcImpl; + + io.extend_with(rpc.to_delegate()) +} diff --git a/derive/tests/trailing.rs b/derive/tests/trailing.rs new file mode 100644 index 000000000..b391e3062 --- /dev/null +++ b/derive/tests/trailing.rs @@ -0,0 +1,197 @@ +use jsonrpc_core::{IoHandler, Response, Result}; +use jsonrpc_derive::rpc; +use serde_json; + +#[rpc] +pub trait Rpc { + /// Multiplies two numbers. Second number is optional. + #[rpc(name = "mul")] + fn mul(&self, a: u64, b: Option) -> Result; + + /// Echos back the message, example of a single param trailing + #[rpc(name = "echo")] + fn echo(&self, a: Option) -> Result; + + /// Adds up to three numbers and returns a result + #[rpc(name = "add_multi")] + fn add_multi(&self, a: Option, b: Option, c: Option) -> Result; +} + +#[derive(Default)] +struct RpcImpl; + +impl Rpc for RpcImpl { + fn mul(&self, a: u64, b: Option) -> Result { + Ok(a * b.unwrap_or(1)) + } + + fn echo(&self, x: Option) -> Result { + Ok(x.unwrap_or("".into())) + } + + fn add_multi(&self, a: Option, b: Option, c: Option) -> Result { + Ok(a.unwrap_or_default() + b.unwrap_or_default() + c.unwrap_or_default()) + } +} + +#[test] +fn should_accept_trailing_param() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req = r#"{"jsonrpc":"2.0","id":1,"method":"mul","params":[2, 2]}"#; + let res = io.handle_request_sync(req); + + // then + let result: Response = serde_json::from_str(&res.unwrap()).unwrap(); + assert_eq!( + result, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 4, + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_accept_missing_trailing_param() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req = r#"{"jsonrpc":"2.0","id":1,"method":"mul","params":[2]}"#; + let res = io.handle_request_sync(req); + + // then + let result: Response = serde_json::from_str(&res.unwrap()).unwrap(); + assert_eq!( + result, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 2, + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_accept_single_trailing_param() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"echo","params":["hello"]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"echo","params":[]}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!( + result1, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": "hello", + "id": 1 + }"# + ) + .unwrap() + ); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!( + result2, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": "", + "id": 1 + }"# + ) + .unwrap() + ); +} + +#[test] +fn should_accept_multiple_trailing_params() { + let mut io = IoHandler::new(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[]}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1]}"#; + let req3 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1, 2]}"#; + let req4 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1, 2, 3]}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + let res3 = io.handle_request_sync(req3); + let res4 = io.handle_request_sync(req4); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!( + result1, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 0, + "id": 1 + }"# + ) + .unwrap() + ); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!( + result2, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 1, + "id": 1 + }"# + ) + .unwrap() + ); + + let result3: Response = serde_json::from_str(&res3.unwrap()).unwrap(); + assert_eq!( + result3, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"# + ) + .unwrap() + ); + + let result4: Response = serde_json::from_str(&res4.unwrap()).unwrap(); + assert_eq!( + result4, + serde_json::from_str( + r#"{ + "jsonrpc": "2.0", + "result": 6, + "id": 1 + }"# + ) + .unwrap() + ); +} diff --git a/derive/tests/trybuild.rs b/derive/tests/trybuild.rs new file mode 100644 index 000000000..f8c2e2e43 --- /dev/null +++ b/derive/tests/trybuild.rs @@ -0,0 +1,6 @@ +#[test] +fn compile_test() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); + t.pass("tests/run-pass/*.rs"); +} diff --git a/derive/tests/ui/attr-invalid-meta-list-names.rs b/derive/tests/ui/attr-invalid-meta-list-names.rs new file mode 100644 index 000000000..7d8860065 --- /dev/null +++ b/derive/tests/ui/attr-invalid-meta-list-names.rs @@ -0,0 +1,10 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion", Xalias("alias"))] + fn protocol_version(&self) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/attr-invalid-meta-list-names.stderr b/derive/tests/ui/attr-invalid-meta-list-names.stderr new file mode 100644 index 000000000..b1beb38c4 --- /dev/null +++ b/derive/tests/ui/attr-invalid-meta-list-names.stderr @@ -0,0 +1,7 @@ +error: Invalid attribute parameter(s): 'Xalias'. Expected 'alias' + --> $DIR/attr-invalid-meta-list-names.rs:5:2 + | +5 | / /// Returns a protocol version +6 | | #[rpc(name = "protocolVersion", Xalias("alias"))] +7 | | fn protocol_version(&self) -> Result; + | |_________________________________________________^ diff --git a/derive/tests/ui/attr-invalid-meta-words.rs b/derive/tests/ui/attr-invalid-meta-words.rs new file mode 100644 index 000000000..e4eb6876e --- /dev/null +++ b/derive/tests/ui/attr-invalid-meta-words.rs @@ -0,0 +1,10 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion", Xmeta)] + fn protocol_version(&self) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/attr-invalid-meta-words.stderr b/derive/tests/ui/attr-invalid-meta-words.stderr new file mode 100644 index 000000000..b76def12b --- /dev/null +++ b/derive/tests/ui/attr-invalid-meta-words.stderr @@ -0,0 +1,7 @@ +error: Invalid attribute parameter(s): 'Xmeta'. Expected 'meta, raw_params' + --> $DIR/attr-invalid-meta-words.rs:5:2 + | +5 | / /// Returns a protocol version +6 | | #[rpc(name = "protocolVersion", Xmeta)] +7 | | fn protocol_version(&self) -> Result; + | |_________________________________________________^ diff --git a/derive/tests/ui/attr-invalid-name-values.rs b/derive/tests/ui/attr-invalid-name-values.rs new file mode 100644 index 000000000..16cdb4b6c --- /dev/null +++ b/derive/tests/ui/attr-invalid-name-values.rs @@ -0,0 +1,10 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(Xname = "protocolVersion")] + fn protocol_version(&self) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/attr-invalid-name-values.stderr b/derive/tests/ui/attr-invalid-name-values.stderr new file mode 100644 index 000000000..f308254c1 --- /dev/null +++ b/derive/tests/ui/attr-invalid-name-values.stderr @@ -0,0 +1,7 @@ +error: Invalid attribute parameter(s): 'Xname'. Expected 'name, returns, params' + --> $DIR/attr-invalid-name-values.rs:5:2 + | +5 | / /// Returns a protocol version +6 | | #[rpc(Xname = "protocolVersion")] +7 | | fn protocol_version(&self) -> Result; + | |_________________________________________________^ diff --git a/derive/tests/ui/attr-missing-rpc-name.rs b/derive/tests/ui/attr-missing-rpc-name.rs new file mode 100644 index 000000000..044d9ae7a --- /dev/null +++ b/derive/tests/ui/attr-missing-rpc-name.rs @@ -0,0 +1,10 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc] + fn protocol_version(&self) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/attr-missing-rpc-name.stderr b/derive/tests/ui/attr-missing-rpc-name.stderr new file mode 100644 index 000000000..d919e31de --- /dev/null +++ b/derive/tests/ui/attr-missing-rpc-name.stderr @@ -0,0 +1,7 @@ +error: rpc attribute should have a name e.g. `name = "method_name"` + --> $DIR/attr-missing-rpc-name.rs:5:2 + | +5 | / /// Returns a protocol version +6 | | #[rpc] +7 | | fn protocol_version(&self) -> Result; + | |_________________________________________________^ diff --git a/derive/tests/ui/attr-named-params-on-server.rs b/derive/tests/ui/attr-named-params-on-server.rs new file mode 100644 index 000000000..074995642 --- /dev/null +++ b/derive/tests/ui/attr-named-params-on-server.rs @@ -0,0 +1,10 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "add", params = "named")] + fn add(&self, a: u32, b: u32) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/attr-named-params-on-server.stderr b/derive/tests/ui/attr-named-params-on-server.stderr new file mode 100644 index 000000000..41ccc852a --- /dev/null +++ b/derive/tests/ui/attr-named-params-on-server.stderr @@ -0,0 +1,9 @@ +error: `params = "named"` can only be used to generate a client (on a trait annotated with #[rpc(client)]). At this time the server does not support named parameters. + --> $DIR/attr-named-params-on-server.rs:4:1 + | +4 | / pub trait Rpc { +5 | | /// Returns a protocol version +6 | | #[rpc(name = "add", params = "named")] +7 | | fn add(&self, a: u32, b: u32) -> Result; +8 | | } + | |_^ diff --git a/derive/tests/ui/multiple-rpc-attributes.rs b/derive/tests/ui/multiple-rpc-attributes.rs new file mode 100644 index 000000000..a6fde45aa --- /dev/null +++ b/derive/tests/ui/multiple-rpc-attributes.rs @@ -0,0 +1,11 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Returns a protocol version + #[rpc(name = "protocolVersion")] + #[rpc(name = "protocolVersionAgain")] + fn protocol_version(&self) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/multiple-rpc-attributes.stderr b/derive/tests/ui/multiple-rpc-attributes.stderr new file mode 100644 index 000000000..dddb37788 --- /dev/null +++ b/derive/tests/ui/multiple-rpc-attributes.stderr @@ -0,0 +1,8 @@ +error: Expected only a single rpc attribute per method + --> $DIR/multiple-rpc-attributes.rs:5:2 + | +5 | / /// Returns a protocol version +6 | | #[rpc(name = "protocolVersion")] +7 | | #[rpc(name = "protocolVersionAgain")] +8 | | fn protocol_version(&self) -> Result; + | |_________________________________________________^ diff --git a/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.rs b/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.rs new file mode 100644 index 000000000..a8946eb1b --- /dev/null +++ b/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, unsubscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.stderr b/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.stderr new file mode 100644 index 000000000..bab557883 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-both-subscribe-and-unsubscribe.stderr @@ -0,0 +1,11 @@ +error: pubsub attribute annotated with both subscribe and unsubscribe + --> $DIR/attr-both-subscribe-and-unsubscribe.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscription = "hello", subscribe, unsubscribe, name = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/attr-invalid-meta-list-names.rs b/derive/tests/ui/pubsub/attr-invalid-meta-list-names.rs new file mode 100644 index 000000000..75412aaf4 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-meta-list-names.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", Xalias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-invalid-meta-list-names.stderr b/derive/tests/ui/pubsub/attr-invalid-meta-list-names.stderr new file mode 100644 index 000000000..2f7ddd478 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-meta-list-names.stderr @@ -0,0 +1,11 @@ +error: Invalid attribute parameter(s): 'Xalias'. Expected 'alias' + --> $DIR/attr-invalid-meta-list-names.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", Xalias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/attr-invalid-meta-words.rs b/derive/tests/ui/pubsub/attr-invalid-meta-words.rs new file mode 100644 index 000000000..d3a6d8372 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-meta-words.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", Xsubscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-invalid-meta-words.stderr b/derive/tests/ui/pubsub/attr-invalid-meta-words.stderr new file mode 100644 index 000000000..cca937234 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-meta-words.stderr @@ -0,0 +1,11 @@ +error: Invalid attribute parameter(s): 'Xsubscribe'. Expected 'subscribe, unsubscribe' + --> $DIR/attr-invalid-meta-words.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscription = "hello", Xsubscribe, name = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/attr-invalid-name-values.rs b/derive/tests/ui/pubsub/attr-invalid-name-values.rs new file mode 100644 index 000000000..3933ad67d --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-name-values.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(Xsubscription = "hello", subscribe, Xname = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-invalid-name-values.stderr b/derive/tests/ui/pubsub/attr-invalid-name-values.stderr new file mode 100644 index 000000000..fa38e1a82 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-invalid-name-values.stderr @@ -0,0 +1,11 @@ +error: Invalid attribute parameter(s): 'Xsubscription, Xname'. Expected 'subscription, name' + --> $DIR/attr-invalid-name-values.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(Xsubscription = "hello", subscribe, Xname = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/attr-missing-subscription-name.rs b/derive/tests/ui/pubsub/attr-missing-subscription-name.rs new file mode 100644 index 000000000..4139b7155 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-missing-subscription-name.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-missing-subscription-name.stderr b/derive/tests/ui/pubsub/attr-missing-subscription-name.stderr new file mode 100644 index 000000000..ab1638f51 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-missing-subscription-name.stderr @@ -0,0 +1,11 @@ +error: pubsub attribute should have a subscription name + --> $DIR/attr-missing-subscription-name.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscribe, name = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.rs b/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.rs new file mode 100644 index 000000000..c79512c82 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.rs @@ -0,0 +1,19 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.stderr b/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.stderr new file mode 100644 index 000000000..505627965 --- /dev/null +++ b/derive/tests/ui/pubsub/attr-neither-subscribe-or-unsubscribe.stderr @@ -0,0 +1,11 @@ +error: pubsub attribute not annotated with either subscribe or unsubscribe + --> $DIR/attr-neither-subscribe-or-unsubscribe.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscription = "hello", name = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/missing-subscribe.rs b/derive/tests/ui/pubsub/missing-subscribe.rs new file mode 100644 index 000000000..ae91c4050 --- /dev/null +++ b/derive/tests/ui/pubsub/missing-subscribe.rs @@ -0,0 +1,17 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + // note that a subscribe method is missing + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/missing-subscribe.stderr b/derive/tests/ui/pubsub/missing-subscribe.stderr new file mode 100644 index 000000000..cc7d14f23 --- /dev/null +++ b/derive/tests/ui/pubsub/missing-subscribe.stderr @@ -0,0 +1,11 @@ +error: subscription 'hello'. Can't find subscribe method, expected a method annotated with `subscribe` e.g. `#[pubsub(subscription = "hello", subscribe, name = "hello_subscribe")]` + --> $DIR/missing-subscribe.rs:12:2 + | +12 | /// Unsubscribe from hello subscription. + | _____^ +13 | | #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] +14 | | fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; + | |________________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/missing-unsubscribe.rs b/derive/tests/ui/pubsub/missing-unsubscribe.rs new file mode 100644 index 000000000..093751390 --- /dev/null +++ b/derive/tests/ui/pubsub/missing-unsubscribe.rs @@ -0,0 +1,17 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + // note that the unsubscribe method is missing +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/missing-unsubscribe.stderr b/derive/tests/ui/pubsub/missing-unsubscribe.stderr new file mode 100644 index 000000000..56b8b6dfb --- /dev/null +++ b/derive/tests/ui/pubsub/missing-unsubscribe.stderr @@ -0,0 +1,11 @@ +error: subscription 'hello'. Can't find unsubscribe method, expected a method annotated with `unsubscribe` e.g. `#[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")]` + --> $DIR/missing-unsubscribe.rs:10:2 + | +10 | /// Hello subscription + | _____^ +11 | | #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] +12 | | fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | |_________________________________________________________________________________^ + +error: aborting due to previous error + diff --git a/derive/tests/ui/pubsub/mixed-subscriber-signatures.rs b/derive/tests/ui/pubsub/mixed-subscriber-signatures.rs new file mode 100644 index 000000000..25948adbf --- /dev/null +++ b/derive/tests/ui/pubsub/mixed-subscriber-signatures.rs @@ -0,0 +1,24 @@ +#[macro_use] +extern crate jsonrpc_derive; +extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; + +#[rpc] +pub trait Rpc { + type Metadata; + + /// Hello subscription + #[pubsub(subscription = "hello", subscribe, name = "hello_subscribe", alias("hello_sub"))] + fn subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Hello subscription with probably different impl and mismatched Subscriber type + #[pubsub(subscription = "hello", subscribe, name = "hello_anotherSubscribe")] + fn another_subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + + /// Unsubscribe from hello subscription. + #[pubsub(subscription = "hello", unsubscribe, name = "hello_unsubscribe")] + fn unsubscribe(&self, _: Option, _: SubscriptionId) -> Result; + // note that the unsubscribe method is missing +} + +fn main() {} diff --git a/derive/tests/ui/pubsub/mixed-subscriber-signatures.stderr b/derive/tests/ui/pubsub/mixed-subscriber-signatures.stderr new file mode 100644 index 000000000..3a0e93791 --- /dev/null +++ b/derive/tests/ui/pubsub/mixed-subscriber-signatures.stderr @@ -0,0 +1,5 @@ +error: Inconsistent signature for 'Subscriber' argument: typed :: Subscriber < usize >, previously defined: typed :: Subscriber < String > + --> $DIR/mixed-subscriber-signatures.rs:16:52 + | +16 | fn another_subscribe(&self, _: Self::Metadata, _: typed::Subscriber, _: u64); + | ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/derive/tests/ui/too-many-params.rs b/derive/tests/ui/too-many-params.rs new file mode 100644 index 000000000..326c77dab --- /dev/null +++ b/derive/tests/ui/too-many-params.rs @@ -0,0 +1,14 @@ +use jsonrpc_derive::rpc; + +#[rpc] +pub trait Rpc { + /// Has too many params + #[rpc(name = "tooManyParams")] + fn to_many_params( + &self, + a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64, + k: u64, l: u64, m: u64, n: u64, o: u64, p: u64, q: u64, + ) -> Result; +} + +fn main() {} diff --git a/derive/tests/ui/too-many-params.stderr b/derive/tests/ui/too-many-params.stderr new file mode 100644 index 000000000..3c23fdf72 --- /dev/null +++ b/derive/tests/ui/too-many-params.stderr @@ -0,0 +1,11 @@ +error: Maximum supported number of params is 16 + --> $DIR/too-many-params.rs:5:2 + | +5 | / /// Has too many params +6 | | #[rpc(name = "tooManyParams")] +7 | | fn to_many_params( +8 | | &self, +9 | | a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64, +10 | | k: u64, l: u64, m: u64, n: u64, o: u64, p: u64, q: u64, +11 | | ) -> Result; + | |________________________^ diff --git a/derive/tests/ui/trait-attr-named-params-on-server.rs b/derive/tests/ui/trait-attr-named-params-on-server.rs new file mode 100644 index 000000000..302768fcf --- /dev/null +++ b/derive/tests/ui/trait-attr-named-params-on-server.rs @@ -0,0 +1,7 @@ +use jsonrpc_derive::rpc; + +#[rpc(server, params = "named")] +pub trait Rpc { +} + +fn main() {} diff --git a/derive/tests/ui/trait-attr-named-params-on-server.stderr b/derive/tests/ui/trait-attr-named-params-on-server.stderr new file mode 100644 index 000000000..c44d44465 --- /dev/null +++ b/derive/tests/ui/trait-attr-named-params-on-server.stderr @@ -0,0 +1,7 @@ +error: custom attribute panicked + --> $DIR/trait-attr-named-params-on-server.rs:3:1 + | +3 | #[rpc(server, params = "named")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Server code generation only supports `params = "positional"` (default) or `params = "raw" at this time. diff --git a/http/Cargo.toml b/http/Cargo.toml index 2781d7805..276c5f78d 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -1,21 +1,27 @@ [package] +authors = ["Parity Technologies "] description = "Rust http server using JSONRPC 2.0." +documentation = "https://docs.rs/jsonrpc-http-server/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" -repository = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "server"] license = "MIT" name = "jsonrpc-http-server" -version = "9.0.0" -authors = ["debris "] -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "server"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_http_server/index.html" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" [dependencies] -hyper = "0.12" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-server-utils = { version = "9.0", path = "../server-utils" } +futures = "0.3" +hyper = { version = "0.14", features = ["http1", "tcp", "server", "stream"] } +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-server-utils = { version = "18.0.0", path = "../server-utils" } log = "0.4" net2 = "0.2" +parking_lot = "0.11.0" unicase = "2.0" +[dev-dependencies] +env_logger = "0.7" + [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/http/README.md b/http/README.md index 40c62df05..c80a2e65c 100644 --- a/http/README.md +++ b/http/README.md @@ -9,14 +9,12 @@ Rust http server using JSON-RPC 2.0. ``` [dependencies] -jsonrpc-http-server = { git = "https://github.com/paritytech/jsonrpc" } +jsonrpc-http-server = "15.0" ``` `main.rs` ```rust -extern crate jsonrpc_http_server; - use jsonrpc_http_server::*; use jsonrpc_http_server::jsonrpc_core::*; diff --git a/http/examples/http_async.rs b/http/examples/http_async.rs index e2bed681a..c243bb875 100644 --- a/http/examples/http_async.rs +++ b/http/examples/http_async.rs @@ -1,12 +1,10 @@ -extern crate jsonrpc_http_server; - -use jsonrpc_http_server::{ServerBuilder, DomainsValidation, AccessControlAllowOrigin}; use jsonrpc_http_server::jsonrpc_core::*; +use jsonrpc_http_server::{AccessControlAllowOrigin, DomainsValidation, ServerBuilder}; fn main() { let mut io = IoHandler::default(); io.add_method("say_hello", |_params| { - futures::finished(Value::String("hello".to_owned())) + futures::future::ready(Ok(Value::String("hello".to_owned()))) }); let server = ServerBuilder::new(io) @@ -16,4 +14,3 @@ fn main() { server.wait(); } - diff --git a/http/examples/http_meta.rs b/http/examples/http_meta.rs index 5338f1da0..50d412148 100644 --- a/http/examples/http_meta.rs +++ b/http/examples/http_meta.rs @@ -1,8 +1,5 @@ -extern crate jsonrpc_http_server; -extern crate unicase; - -use jsonrpc_http_server::{ServerBuilder, hyper, RestApi, cors::AccessControlAllowHeaders}; use jsonrpc_http_server::jsonrpc_core::*; +use jsonrpc_http_server::{cors::AccessControlAllowHeaders, hyper, RestApi, ServerBuilder}; #[derive(Default, Clone)] struct Meta { @@ -16,23 +13,23 @@ fn main() { io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| { let auth = meta.auth.unwrap_or_else(String::new); - if auth.as_str() == "let-me-in" { + futures::future::ready(if auth.as_str() == "let-me-in" { Ok(Value::String("Hello World!".to_owned())) } else { - Ok(Value::String("Please send a valid Bearer token in Authorization header.".to_owned())) - } + Ok(Value::String( + "Please send a valid Bearer token in Authorization header.".to_owned(), + )) + }) }); let server = ServerBuilder::new(io) - .cors_allow_headers(AccessControlAllowHeaders::Only( - vec![ - "Authorization".to_owned(), - ]) - ) + .cors_allow_headers(AccessControlAllowHeaders::Only(vec!["Authorization".to_owned()])) .rest_api(RestApi::Unsecure) // You can also implement `MetaExtractor` trait and pass a struct here. .meta_extractor(|req: &hyper::Request| { - let auth = req.headers().get(hyper::header::AUTHORIZATION) + let auth = req + .headers() + .get(hyper::header::AUTHORIZATION) .map(|h| h.to_str().unwrap_or("").to_owned()); Meta { auth } @@ -42,4 +39,3 @@ fn main() { server.wait(); } - diff --git a/http/examples/http_middleware.rs b/http/examples/http_middleware.rs index 6caea36a7..68629d78c 100644 --- a/http/examples/http_middleware.rs +++ b/http/examples/http_middleware.rs @@ -1,15 +1,11 @@ -extern crate jsonrpc_http_server; - -use jsonrpc_http_server::{ - hyper, ServerBuilder, DomainsValidation, AccessControlAllowOrigin, Response, RestApi -}; -use jsonrpc_http_server::jsonrpc_core::{IoHandler, Value}; use jsonrpc_http_server::jsonrpc_core::futures; +use jsonrpc_http_server::jsonrpc_core::{IoHandler, Value}; +use jsonrpc_http_server::{hyper, AccessControlAllowOrigin, DomainsValidation, Response, RestApi, ServerBuilder}; fn main() { let mut io = IoHandler::default(); io.add_method("say_hello", |_params| { - futures::finished(Value::String("hello".to_owned())) + futures::future::ready(Ok(Value::String("hello".to_owned()))) }); let server = ServerBuilder::new(io) @@ -27,4 +23,3 @@ fn main() { server.wait(); } - diff --git a/http/examples/server.rs b/http/examples/server.rs index 11c7407b9..a8bdfce5d 100644 --- a/http/examples/server.rs +++ b/http/examples/server.rs @@ -1,13 +1,11 @@ -extern crate jsonrpc_http_server; - -use jsonrpc_http_server::{ServerBuilder, DomainsValidation, AccessControlAllowOrigin, RestApi}; use jsonrpc_http_server::jsonrpc_core::*; +use jsonrpc_http_server::{AccessControlAllowOrigin, DomainsValidation, RestApi, ServerBuilder}; fn main() { + env_logger::init(); + let mut io = IoHandler::default(); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params: Params| Ok(Value::String("hello".to_string()))); let server = ServerBuilder::new(io) .threads(3) @@ -18,4 +16,3 @@ fn main() { server.wait(); } - diff --git a/http/src/handler.rs b/http/src/handler.rs index 63e6631cb..0466ee844 100644 --- a/http/src/handler.rs +++ b/http/src/handler.rs @@ -1,27 +1,29 @@ -use Rpc; +use crate::WeakRpc; -use std::{fmt, mem, str}; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use std::task::{self, Poll}; +use std::{fmt, mem, str}; -use hyper::{self, service::Service, Body, Method}; use hyper::header::{self, HeaderMap, HeaderValue}; +use hyper::{self, service::Service, Body, Method}; -use jsonrpc::{self as core, middleware, FutureResult, Metadata, Middleware}; -use jsonrpc::futures::{Future, Poll, Async, Stream, future}; -use jsonrpc::serde_json; -use response::Response; -use server_utils::cors; +use crate::jsonrpc::serde_json; +use crate::jsonrpc::{self as core, middleware, Metadata, Middleware}; +use crate::response::Response; +use crate::server_utils::cors; -use {utils, RequestMiddleware, RequestMiddlewareAction, CorsDomains, AllowedHosts, RestApi}; +use crate::{utils, AllowedHosts, CorsDomains, RequestMiddleware, RequestMiddlewareAction, RestApi}; /// jsonrpc http request handler. pub struct ServerHandler = middleware::Noop> { - jsonrpc_handler: Rpc, + jsonrpc_handler: WeakRpc, allowed_hosts: AllowedHosts, cors_domains: CorsDomains, cors_max_age: Option, cors_allowed_headers: cors::AccessControlAllowHeaders, - middleware: Arc, + middleware: Arc, rest_api: RestApi, health_api: Option<(String, String)>, max_request_body_size: usize, @@ -31,12 +33,12 @@ pub struct ServerHandler = middleware::Noop> impl> ServerHandler { /// Create new request handler. pub fn new( - jsonrpc_handler: Rpc, + jsonrpc_handler: WeakRpc, cors_domains: CorsDomains, cors_max_age: Option, cors_allowed_headers: cors::AccessControlAllowHeaders, allowed_hosts: AllowedHosts, - middleware: Arc, + middleware: Arc, rest_api: RestApi, health_api: Option<(String, String)>, max_request_body_size: usize, @@ -57,28 +59,38 @@ impl> ServerHandler { } } -impl> Service for ServerHandler { - type ReqBody = Body; - type ResBody = Body; +impl> Service> for ServerHandler +where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ + type Response = hyper::Response; type Error = hyper::Error; type Future = Handler; - fn call(&mut self, request: hyper::Request) -> Self::Future { + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll> { + task::Poll::Ready(Ok(())) + } + + fn call(&mut self, request: hyper::Request) -> Self::Future { let is_host_allowed = utils::is_host_allowed(&request, &self.allowed_hosts); let action = self.middleware.on_request(request); let (should_validate_hosts, should_continue_on_invalid_cors, response) = match action { - RequestMiddlewareAction::Proceed { should_continue_on_invalid_cors, request }=> ( - true, should_continue_on_invalid_cors, Err(request) - ), - RequestMiddlewareAction::Respond { should_validate_hosts, response } => ( - should_validate_hosts, false, Ok(response) - ), + RequestMiddlewareAction::Proceed { + should_continue_on_invalid_cors, + request, + } => (true, should_continue_on_invalid_cors, Err(request)), + RequestMiddlewareAction::Respond { + should_validate_hosts, + response, + } => (should_validate_hosts, false, Ok(response)), }; // Validate host if should_validate_hosts && !is_host_allowed { - return Handler::Error(Some(Response::host_not_allowed())); + return Handler::Err(Some(Response::host_not_allowed())); } // Replace response with the one returned by middleware. @@ -111,38 +123,37 @@ impl> Service for ServerHandler { pub enum Handler> { Rpc(RpcHandler), - Error(Option), - Middleware(Box, Error = hyper::Error> + Send>), + Err(Option), + Middleware(Pin>> + Send>>), } -impl> Future for Handler { - type Item = hyper::Response; - type Error = hyper::Error; - - fn poll(&mut self) -> Poll { - match *self { - Handler::Rpc(ref mut handler) => handler.poll(), - Handler::Middleware(ref mut middleware) => middleware.poll(), - Handler::Error(ref mut response) => Ok(Async::Ready( - response.take().expect("Response always Some initialy. Returning `Ready` so will never be polled again; qed").into() - )), +impl> Future for Handler +where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ + type Output = hyper::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match Pin::into_inner(self) { + Handler::Rpc(ref mut handler) => Pin::new(handler).poll(cx), + Handler::Middleware(ref mut middleware) => Pin::new(middleware).poll(cx), + Handler::Err(ref mut response) => Poll::Ready(Ok(response + .take() + .expect("Response always Some initialy. Returning `Ready` so will never be polled again; qed") + .into())), } } } -enum RpcPollState where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ - Ready(RpcHandlerState), - NotReady(RpcHandlerState), +enum RpcPollState { + Ready(RpcHandlerState), + NotReady(RpcHandlerState), } -impl RpcPollState where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ - fn decompose(self) -> (RpcHandlerState, bool) { +impl RpcPollState { + fn decompose(self) -> (RpcHandlerState, bool) { use self::RpcPollState::*; match self { Ready(handler) => (handler, true), @@ -151,15 +162,7 @@ impl RpcPollState where } } -type FutureResponse = future::Map< - future::Either, ()>, core::FutureRpcResult>, - fn(Option) -> Response, ->; - -enum RpcHandlerState where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ +enum RpcHandlerState { ReadingHeaders { request: hyper::Request, cors_domains: CorsDomains, @@ -182,23 +185,20 @@ enum RpcHandlerState where metadata: M, }, Writing(Response), - Waiting(FutureResult), - WaitingForResponse(FutureResponse), + Waiting(Pin> + Send>>), + WaitingForResponse(Pin + Send>>), Done, } -impl fmt::Debug for RpcHandlerState where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ +impl fmt::Debug for RpcHandlerState { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { use self::RpcHandlerState::*; match *self { - ReadingHeaders {..} => write!(fmt, "ReadingHeaders"), - ReadingBody {..} => write!(fmt, "ReadingBody"), - ProcessRest {..} => write!(fmt, "ProcessRest"), - ProcessHealth {..} => write!(fmt, "ProcessHealth"), + ReadingHeaders { .. } => write!(fmt, "ReadingHeaders"), + ReadingBody { .. } => write!(fmt, "ReadingBody"), + ProcessRest { .. } => write!(fmt, "ProcessRest"), + ProcessHealth { .. } => write!(fmt, "ProcessHealth"), Writing(ref res) => write!(fmt, "Writing({:?})", res), WaitingForResponse(_) => write!(fmt, "WaitingForResponse"), Waiting(_) => write!(fmt, "Waiting"), @@ -208,8 +208,8 @@ impl fmt::Debug for RpcHandlerState where } pub struct RpcHandler> { - jsonrpc_handler: Rpc, - state: RpcHandlerState, + jsonrpc_handler: WeakRpc, + state: RpcHandlerState, is_options: bool, cors_allow_origin: cors::AllowCors, cors_allow_headers: cors::AllowCors>, @@ -220,69 +220,70 @@ pub struct RpcHandler> { keep_alive: bool, } -impl> Future for RpcHandler { - type Item = hyper::Response; - type Error = hyper::Error; +impl> Future for RpcHandler +where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ + type Output = hyper::Result>; - fn poll(&mut self) -> Poll { - let new_state = match mem::replace(&mut self.state, RpcHandlerState::Done) { + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let this = Pin::into_inner(self); + + let new_state = match mem::replace(&mut this.state, RpcHandlerState::Done) { RpcHandlerState::ReadingHeaders { - request, cors_domains, cors_headers, continue_on_invalid_cors, keep_alive, + request, + cors_domains, + cors_headers, + continue_on_invalid_cors, + keep_alive, } => { // Read cors header - self.cors_allow_origin = utils::cors_allow_origin(&request, &cors_domains); - self.cors_allow_headers = utils::cors_allow_headers(&request, &cors_headers); - self.keep_alive = utils::keep_alive(&request, keep_alive); - self.is_options = *request.method() == Method::OPTIONS; + this.cors_allow_origin = utils::cors_allow_origin(&request, &cors_domains); + this.cors_allow_headers = utils::cors_allow_headers(&request, &cors_headers); + this.keep_alive = utils::keep_alive(&request, keep_alive); + this.is_options = *request.method() == Method::OPTIONS; // Read other headers - RpcPollState::Ready(self.read_headers(request, continue_on_invalid_cors)) - }, - RpcHandlerState::ReadingBody { body, request, metadata, uri, } => { - match self.process_body(body, request, uri, metadata) { - Err(BodyError::Utf8(ref e)) => { - let mesg = format!("utf-8 encoding error at byte {} in request body", e.valid_up_to()); - let resp = Response::bad_request(mesg); - RpcPollState::Ready(RpcHandlerState::Writing(resp)) - } - Err(BodyError::TooLarge) => { - let resp = Response::too_large("request body size exceeds allowed maximum"); - RpcPollState::Ready(RpcHandlerState::Writing(resp)) - } - Err(BodyError::Hyper(e)) => return Err(e), - Ok(state) => state, + RpcPollState::Ready(this.read_headers(request, continue_on_invalid_cors)) + } + RpcHandlerState::ReadingBody { + body, + request, + metadata, + uri, + } => match this.process_body(body, request, uri, metadata, cx) { + Err(BodyError::Utf8(ref e)) => { + let mesg = format!("utf-8 encoding error at byte {} in request body", e.valid_up_to()); + let resp = Response::bad_request(mesg); + RpcPollState::Ready(RpcHandlerState::Writing(resp)) } - }, - RpcHandlerState::ProcessRest { uri, metadata } => { - self.process_rest(uri, metadata)? - }, - RpcHandlerState::ProcessHealth { method, metadata } => { - self.process_health(method, metadata)? - }, - RpcHandlerState::WaitingForResponse(mut waiting) => { - match waiting.poll() { - Ok(Async::Ready(response)) => RpcPollState::Ready(RpcHandlerState::Writing(response.into())), - Ok(Async::NotReady) => RpcPollState::NotReady(RpcHandlerState::WaitingForResponse(waiting)), - Err(e) => RpcPollState::Ready(RpcHandlerState::Writing( - Response::internal_error(format!("{:?}", e)) - )), + Err(BodyError::TooLarge) => { + let resp = Response::too_large("request body size exceeds allowed maximum"); + RpcPollState::Ready(RpcHandlerState::Writing(resp)) } + Err(BodyError::Hyper(e)) => return Poll::Ready(Err(e)), + Ok(state) => state, + }, + RpcHandlerState::ProcessRest { uri, metadata } => this.process_rest(uri, metadata)?, + RpcHandlerState::ProcessHealth { method, metadata } => this.process_health(method, metadata)?, + RpcHandlerState::WaitingForResponse(mut waiting) => match Pin::new(&mut waiting).poll(cx) { + Poll::Ready(response) => RpcPollState::Ready(RpcHandlerState::Writing(response)), + Poll::Pending => RpcPollState::NotReady(RpcHandlerState::WaitingForResponse(waiting)), }, RpcHandlerState::Waiting(mut waiting) => { - match waiting.poll() { - Ok(Async::Ready(response)) => { + match Pin::new(&mut waiting).poll(cx) { + Poll::Ready(response) => { RpcPollState::Ready(RpcHandlerState::Writing(match response { // Notification, just return empty response. None => Response::ok(String::new()), // Add new line to have nice output when using CLI clients (curl) Some(result) => Response::ok(format!("{}\n", result)), - }.into())) - }, - Ok(Async::NotReady) => RpcPollState::NotReady(RpcHandlerState::Waiting(waiting)), - Err(e) => RpcPollState::Ready(RpcHandlerState::Writing( - Response::internal_error(format!("{:?}", e)) - )), + })) + } + Poll::Pending => RpcPollState::NotReady(RpcHandlerState::Waiting(waiting)), } - }, + } state => RpcPollState::NotReady(state), }; @@ -290,27 +291,27 @@ impl> Future for RpcHandler { match new_state { RpcHandlerState::Writing(res) => { let mut response: hyper::Response = res.into(); - let cors_allow_origin = mem::replace(&mut self.cors_allow_origin, cors::AllowCors::Invalid); - let cors_allow_headers = mem::replace(&mut self.cors_allow_headers, cors::AllowCors::Invalid); + let cors_allow_origin = mem::replace(&mut this.cors_allow_origin, cors::AllowCors::Invalid); + let cors_allow_headers = mem::replace(&mut this.cors_allow_headers, cors::AllowCors::Invalid); Self::set_response_headers( response.headers_mut(), - self.is_options, - self.cors_max_age, + this.is_options, + this.cors_max_age, cors_allow_origin.into(), cors_allow_headers.into(), - self.keep_alive, + this.keep_alive, ); - Ok(Async::Ready(response)) - }, + Poll::Ready(Ok(response)) + } state => { - self.state = state; + this.state = state; if is_ready { - self.poll() + Pin::new(this).poll(cx) } else { - Ok(Async::NotReady) + Poll::Pending } - }, + } } } } @@ -329,12 +330,12 @@ impl From for BodyError { } } -impl> RpcHandler { - fn read_headers( - &self, - request: hyper::Request, - continue_on_invalid_cors: bool, - ) -> RpcHandlerState { +impl> RpcHandler +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ + fn read_headers(&self, request: hyper::Request, continue_on_invalid_cors: bool) -> RpcHandlerState { if self.cors_allow_origin == cors::AllowCors::Invalid && !continue_on_invalid_cors { return RpcHandlerState::Writing(Response::invalid_allow_origin()); } @@ -344,57 +345,57 @@ impl> RpcHandler { } // Read metadata - let metadata = self.jsonrpc_handler.extractor.read_metadata(&request); + let handler = match self.jsonrpc_handler.upgrade() { + Some(handler) => handler, + None => return RpcHandlerState::Writing(Response::closing()), + }; + let metadata = handler.extractor.read_metadata(&request); // Proceed match *request.method() { // Validate the ContentType header // to prevent Cross-Origin XHRs with text/plain Method::POST if Self::is_json(request.headers().get("content-type")) => { - let uri = if self.rest_api != RestApi::Disabled { Some(request.uri().clone()) } else { None }; + let uri = if self.rest_api != RestApi::Disabled { + Some(request.uri().clone()) + } else { + None + }; RpcHandlerState::ReadingBody { metadata, request: Default::default(), uri, body: request.into_body(), } - }, + } Method::POST if self.rest_api == RestApi::Unsecure && request.uri().path().split('/').count() > 2 => { RpcHandlerState::ProcessRest { metadata, uri: request.uri().clone(), } - }, + } // Just return error for unsupported content type - Method::POST => { - RpcHandlerState::Writing(Response::unsupported_content_type()) - }, + Method::POST => RpcHandlerState::Writing(Response::unsupported_content_type()), // Don't validate content type on options - Method::OPTIONS => { - RpcHandlerState::Writing(Response::empty()) - }, + Method::OPTIONS => RpcHandlerState::Writing(Response::empty()), // Respond to health API request if there is one configured. Method::GET if self.health_api.as_ref().map(|x| &*x.0) == Some(request.uri().path()) => { RpcHandlerState::ProcessHealth { metadata, - method: self.health_api.as_ref() - .map(|x| x.1.clone()) - .expect("Health api is defined since the URI matched."), + method: self + .health_api + .as_ref() + .map(|x| x.1.clone()) + .expect("Health api is defined since the URI matched."), } - }, + } // Disallow other methods. - _ => { - RpcHandlerState::Writing(Response::method_not_allowed()) - }, + _ => RpcHandlerState::Writing(Response::method_not_allowed()), } } - fn process_health( - &self, - method: String, - metadata: M, - ) -> Result, hyper::Error> { - use self::core::types::{Call, MethodCall, Version, Params, Request, Id, Output, Success, Failure}; + fn process_health(&self, method: String, metadata: M) -> Result, hyper::Error> { + use self::core::types::{Call, Failure, Id, MethodCall, Output, Params, Request, Success, Version}; // Create a request let call = Request::Single(Call::MethodCall(MethodCall { @@ -404,32 +405,32 @@ impl> RpcHandler { id: Id::Num(1), })); - return Ok(RpcPollState::Ready(RpcHandlerState::WaitingForResponse( - future::Either::B(self.jsonrpc_handler.handler.handle_rpc_request(call, metadata)) - .map(|res| match res { + let response = match self.jsonrpc_handler.upgrade() { + Some(h) => h.handler.handle_rpc_request(call, metadata), + None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))), + }; + + Ok(RpcPollState::Ready(RpcHandlerState::WaitingForResponse(Box::pin( + async { + match response.await { Some(core::Response::Single(Output::Success(Success { result, .. }))) => { - let result = serde_json::to_string(&result) - .expect("Serialization of result is infallible;qed"); + let result = serde_json::to_string(&result).expect("Serialization of result is infallible;qed"); Response::ok(result) - }, + } Some(core::Response::Single(Output::Failure(Failure { error, .. }))) => { - let result = serde_json::to_string(&error) - .expect("Serialization of error is infallible;qed"); + let result = serde_json::to_string(&error).expect("Serialization of error is infallible;qed"); Response::service_unavailable(result) - }, + } e => Response::internal_error(format!("Invalid response for health request: {:?}", e)), - }) - ))); + } + }, + )))) } - fn process_rest( - &self, - uri: hyper::Uri, - metadata: M, - ) -> Result, hyper::Error> { - use self::core::types::{Call, MethodCall, Version, Params, Request, Id, Value}; + fn process_rest(&self, uri: hyper::Uri, metadata: M) -> Result, hyper::Error> { + use self::core::types::{Call, Id, MethodCall, Params, Request, Value, Version}; // skip the initial / let mut it = uri.path().split('/').skip(1); @@ -452,12 +453,16 @@ impl> RpcHandler { id: Id::Num(1), })); - return Ok(RpcPollState::Ready(RpcHandlerState::Waiting( - future::Either::B(self.jsonrpc_handler.handler.handle_rpc_request(call, metadata)) - .map(|res| res.map(|x| serde_json::to_string(&x) - .expect("Serialization of response is infallible;qed") - )) - ))); + let response = match self.jsonrpc_handler.upgrade() { + Some(h) => h.handler.handle_rpc_request(call, metadata), + None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))), + }; + + Ok(RpcPollState::Ready(RpcHandlerState::Waiting(Box::pin(async { + response + .await + .map(|x| serde_json::to_string(&x).expect("Serialization of response is infallible;qed")) + })))) } fn process_body( @@ -466,21 +471,27 @@ impl> RpcHandler { mut request: Vec, uri: Option, metadata: M, - ) -> Result, BodyError> { + cx: &mut task::Context<'_>, + ) -> Result, BodyError> { + use futures::Stream; + loop { - match body.poll()? { - Async::Ready(Some(chunk)) => { - if request.len().checked_add(chunk.len()).map(|n| n > self.max_request_body_size).unwrap_or(true) { - return Err(BodyError::TooLarge) + let pinned_body = Pin::new(&mut body); + match pinned_body.poll_next(cx)? { + Poll::Ready(Some(chunk)) => { + if request + .len() + .checked_add(chunk.len()) + .map(|n| n > self.max_request_body_size) + .unwrap_or(true) + { + return Err(BodyError::TooLarge); } request.extend_from_slice(&*chunk) - }, - Async::Ready(None) => { + } + Poll::Ready(None) => { if let (Some(uri), true) = (uri, request.is_empty()) { - return Ok(RpcPollState::Ready(RpcHandlerState::ProcessRest { - uri, - metadata, - })); + return Ok(RpcPollState::Ready(RpcHandlerState::ProcessRest { uri, metadata })); } let content = match str::from_utf8(&request) { @@ -488,22 +499,25 @@ impl> RpcHandler { Err(err) => { // Return utf error. return Err(BodyError::Utf8(err)); - }, + } + }; + + let response = match self.jsonrpc_handler.upgrade() { + Some(h) => h.handler.handle_request(content, metadata), + None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))), }; // Content is ready - return Ok(RpcPollState::Ready(RpcHandlerState::Waiting( - self.jsonrpc_handler.handler.handle_request(content, metadata) - ))); - }, - Async::NotReady => { + return Ok(RpcPollState::Ready(RpcHandlerState::Waiting(Box::pin(response)))); + } + Poll::Pending => { return Ok(RpcPollState::NotReady(RpcHandlerState::ReadingBody { body, request, metadata, uri, })); - }, + } } } } @@ -525,7 +539,8 @@ impl> RpcHandler { .cloned() .collect::>(); let max_len = if val.is_empty() { 0 } else { val.len() - 2 }; - HeaderValue::from_bytes(&val[..max_len]).expect("Concatenation of valid headers with `, ` is still valid; qed") + HeaderValue::from_bytes(&val[..max_len]) + .expect("Concatenation of valid headers with `, ` is still valid; qed") }; let allowed = concat(&[as_header(Method::OPTIONS), as_header(Method::POST)]); @@ -543,7 +558,7 @@ impl> RpcHandler { if let Some(cma) = cors_max_age { headers.append( header::ACCESS_CONTROL_MAX_AGE, - HeaderValue::from_str(&cma.to_string()).expect("`u32` will always parse; qed") + HeaderValue::from_str(&cma.to_string()).expect("`u32` will always parse; qed"), ); } @@ -563,9 +578,47 @@ impl> RpcHandler { /// message. fn is_json(content_type: Option<&header::HeaderValue>) -> bool { match content_type.and_then(|val| val.to_str().ok()) { - Some("application/json") => true, - Some("application/json; charset=utf-8") => true, + Some(ref content) + if content.eq_ignore_ascii_case("application/json") + || content.eq_ignore_ascii_case("application/json; charset=utf-8") + || content.eq_ignore_ascii_case("application/json;charset=utf-8") => + { + true + } _ => false, } } } + +#[cfg(test)] +mod test { + use super::{hyper, RpcHandler}; + use jsonrpc_core::middleware::Noop; + + #[test] + fn test_case_insensitive_content_type() { + let request = hyper::Request::builder() + .header("content-type", "Application/Json; charset=UTF-8") + .body(()) + .unwrap(); + + let request2 = hyper::Request::builder() + .header("content-type", "Application/Json;charset=UTF-8") + .body(()) + .unwrap(); + + assert_eq!( + request.headers().get("content-type").unwrap(), + &"Application/Json; charset=UTF-8" + ); + + assert_eq!( + RpcHandler::<(), Noop>::is_json(request.headers().get("content-type")), + true + ); + assert_eq!( + RpcHandler::<(), Noop>::is_json(request2.headers().get("content-type")), + true + ); + } +} diff --git a/http/src/lib.rs b/http/src/lib.rs index 8185afd6a..976c95791 100644 --- a/http/src/lib.rs +++ b/http/src/lib.rs @@ -1,62 +1,62 @@ //! jsonrpc http server. //! //! ```no_run -//! extern crate jsonrpc_core; -//! extern crate jsonrpc_http_server; -//! //! use jsonrpc_core::*; //! use jsonrpc_http_server::*; //! //! fn main() { -//! let mut io = IoHandler::new(); -//! io.add_method("say_hello", |_: Params| { -//! Ok(Value::String("hello".to_string())) -//! }); +//! let mut io = IoHandler::new(); +//! io.add_sync_method("say_hello", |_: Params| { +//! Ok(Value::String("hello".to_string())) +//! }); //! -//! let _server = ServerBuilder::new(io) -//! .start_http(&"127.0.0.1:3030".parse().unwrap()) -//! .expect("Unable to start RPC server"); +//! let _server = ServerBuilder::new(io) +//! .start_http(&"127.0.0.1:3030".parse().unwrap()) +//! .expect("Unable to start RPC server"); //! -//! _server.wait(); +//! _server.wait(); //! } //! ``` -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate unicase; -extern crate jsonrpc_server_utils as server_utils; -extern crate net2; +use jsonrpc_server_utils as server_utils; -pub extern crate jsonrpc_core; -pub extern crate hyper; +pub use hyper; +pub use jsonrpc_core; #[macro_use] extern crate log; mod handler; mod response; -mod utils; #[cfg(test)] mod tests; +mod utils; +use std::convert::Infallible; +use std::future::Future; use std::io; -use std::sync::{mpsc, Arc}; use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::{mpsc, Arc, Weak}; use std::thread; -use hyper::{server, Body}; +use parking_lot::Mutex; + +use crate::jsonrpc::MetaIoHandler; +use crate::server_utils::reactor::{Executor, UninitializedExecutor}; +use futures::{channel::oneshot, future}; +use hyper::Body; use jsonrpc_core as jsonrpc; -use jsonrpc::MetaIoHandler; -use jsonrpc::futures::{self, Future, Stream, future}; -use jsonrpc::futures::sync::oneshot; -use server_utils::reactor::{Executor, UninitializedExecutor}; - -pub use server_utils::hosts::{Host, DomainsValidation}; -pub use server_utils::cors::{self, AccessControlAllowOrigin, Origin, AllowCors}; -pub use server_utils::{tokio, SuspendableStream}; -pub use handler::ServerHandler; -pub use utils::{is_host_allowed, cors_allow_origin, cors_allow_headers}; -pub use response::Response; + +pub use crate::handler::ServerHandler; +pub use crate::response::Response; +pub use crate::server_utils::cors::{self, AccessControlAllowOrigin, AllowCors, Origin}; +pub use crate::server_utils::hosts::{DomainsValidation, Host}; +pub use crate::server_utils::reactor::TaskExecutor; +pub use crate::server_utils::{tokio, SuspendableStream}; +pub use crate::utils::{cors_allow_headers, cors_allow_origin, is_host_allowed}; /// Action undertaken by a middleware. pub enum RequestMiddlewareAction { @@ -73,15 +73,15 @@ pub enum RequestMiddlewareAction { /// Should standard hosts validation be performed? should_validate_hosts: bool, /// a future for server response - response: Box, Error=hyper::Error> + Send>, - } + response: Pin>> + Send>>, + }, } impl From for RequestMiddlewareAction { fn from(o: Response) -> Self { RequestMiddlewareAction::Respond { should_validate_hosts: true, - response: Box::new(futures::future::ok(o.into())), + response: Box::pin(async { Ok(o.into()) }), } } } @@ -90,7 +90,7 @@ impl From> for RequestMiddlewareAction { fn from(response: hyper::Response) -> Self { RequestMiddlewareAction::Respond { should_validate_hosts: true, - response: Box::new(futures::future::ok(response)), + response: Box::pin(async { Ok(response) }), } } } @@ -110,7 +110,8 @@ pub trait RequestMiddleware: Send + Sync + 'static { fn on_request(&self, request: hyper::Request) -> RequestMiddlewareAction; } -impl RequestMiddleware for F where +impl RequestMiddleware for F +where F: Fn(hyper::Request) -> RequestMiddlewareAction + Sync + Send + 'static, { fn on_request(&self, request: hyper::Request) -> RequestMiddlewareAction { @@ -135,7 +136,8 @@ pub trait MetaExtractor: Sync + Send + 'static { fn read_metadata(&self, _: &hyper::Request) -> M; } -impl MetaExtractor for F where +impl MetaExtractor for F +where M: jsonrpc::Metadata, F: Fn(&hyper::Request) -> M + Sync + Send + 'static, { @@ -157,7 +159,7 @@ pub struct Rpc = jsonrpc::m /// RPC Handler pub handler: Arc>, /// Metadata extractor - pub extractor: Arc>, + pub extractor: Arc>, } impl> Clone for Rpc { @@ -169,6 +171,46 @@ impl> Clone for Rpc { } } +impl> Rpc { + /// Downgrade the `Rpc` to `WeakRpc`. + /// + /// Downgrades internal `Arc`s to `Weak` references. + pub fn downgrade(&self) -> WeakRpc { + WeakRpc { + handler: Arc::downgrade(&self.handler), + extractor: Arc::downgrade(&self.extractor), + } + } +} +/// A weak handle to the RPC server. +/// +/// Since request handling futures are spawned directly on the executor, +/// whenever the server is closed we want to make sure that existing +/// tasks are not blocking the server and are dropped as soon as the server stops. +pub struct WeakRpc = jsonrpc::middleware::Noop> { + handler: Weak>, + extractor: Weak>, +} + +impl> Clone for WeakRpc { + fn clone(&self) -> Self { + WeakRpc { + handler: self.handler.clone(), + extractor: self.extractor.clone(), + } + } +} + +impl> WeakRpc { + /// Upgrade the handle to a strong one (`Rpc`) if possible. + pub fn upgrade(&self) -> Option> { + let handler = self.handler.upgrade()?; + let extractor = self.extractor.upgrade()?; + + Some(Rpc { handler, extractor }) + } +} + type AllowedHosts = Option>; type CorsDomains = Option>; @@ -194,8 +236,8 @@ pub enum RestApi { pub struct ServerBuilder = jsonrpc::middleware::Noop> { handler: Arc>, executor: UninitializedExecutor, - meta_extractor: Arc>, - request_middleware: Arc, + meta_extractor: Arc>, + request_middleware: Arc, cors_domains: CorsDomains, cors_max_age: Option, allowed_headers: cors::AccessControlAllowHeaders, @@ -207,26 +249,38 @@ pub struct ServerBuilder = max_request_body_size: usize, } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ /// Creates new `ServerBuilder` for given `IoHandler`. /// /// By default: /// 1. Server is not sending any CORS headers. /// 2. Server is validating `Host` header. - pub fn new(handler: T) -> Self where - T: Into> + pub fn new(handler: T) -> Self + where + T: Into>, { Self::with_meta_extractor(handler, NoopExtractor) } } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ /// Creates new `ServerBuilder` for given `IoHandler`. /// /// By default: /// 1. Server is not sending any CORS headers. /// 2. Server is validating `Host` header. - pub fn with_meta_extractor(handler: T, extractor: E) -> Self where + pub fn with_meta_extractor(handler: T, extractor: E) -> Self + where T: Into>, E: MetaExtractor, { @@ -250,7 +304,7 @@ impl> ServerBuilder { /// Utilize existing event loop executor to poll RPC results. /// /// Applies only to 1 of the threads. Other threads will spawn their own Event Loops. - pub fn event_loop_executor(mut self, executor: tokio::runtime::TaskExecutor) -> Self { + pub fn event_loop_executor(mut self, executor: TaskExecutor) -> Self { self.executor = UninitializedExecutor::Shared(executor); self } @@ -271,7 +325,8 @@ impl> ServerBuilder { /// Error returned from the method will be converted to status `500` response. /// /// Expects a tuple with `(, )`. - pub fn health_api(mut self, health_api: T) -> Self where + pub fn health_api(mut self, health_api: T) -> Self + where T: Into>, A: Into, B: Into, @@ -284,7 +339,7 @@ impl> ServerBuilder { /// /// Default is true. pub fn keep_alive(mut self, val: bool) -> Self { - self.keep_alive = val; + self.keep_alive = val; self } @@ -292,6 +347,7 @@ impl> ServerBuilder { /// /// Panics when set to `0`. #[cfg(not(unix))] + #[allow(unused_mut)] pub fn threads(mut self, _threads: usize) -> Self { warn!("Multi-threaded server is not available on Windows. Falling back to single thread."); self @@ -300,6 +356,12 @@ impl> ServerBuilder { /// Sets number of threads of the server to run. /// /// Panics when set to `0`. + /// The first thread will use provided `Executor` instance + /// and all other threads will use `UninitializedExecutor` to spawn + /// a new runtime for futures. + /// So it's also possible to run a multi-threaded server by + /// passing the default `tokio::runtime` executor to this builder + /// and setting `threads` to 1. #[cfg(unix)] pub fn threads(mut self, threads: usize) -> Self { self.threads = threads; @@ -314,8 +376,7 @@ impl> ServerBuilder { /// Configure CORS `AccessControlMaxAge` header returned. /// - /// Passing `Some(millis)` informs the client that the CORS preflight request is not necessary - /// for at least `millis` ms. + /// Informs the client that the CORS preflight request is not necessary for `cors_max_age` seconds. /// Disabled by default. pub fn cors_max_age>>(mut self, cors_max_age: T) -> Self { self.cors_max_age = cors_max_age.into(); @@ -324,7 +385,7 @@ impl> ServerBuilder { /// Configure the CORS `AccessControlAllowHeaders` header which are allowed. pub fn cors_allow_headers(mut self, allowed_headers: cors::AccessControlAllowHeaders) -> Self { - self.allowed_headers = allowed_headers.into(); + self.allowed_headers = allowed_headers; self } @@ -376,10 +437,12 @@ impl> ServerBuilder { let (local_addr_tx, local_addr_rx) = mpsc::channel(); let (close, shutdown_signal) = oneshot::channel(); + let (done_tx, done_rx) = oneshot::channel(); let eloop = self.executor.init_with_name("http.worker0")?; let req_max_size = self.max_request_body_size; + // The first threads `Executor` is initialised differently from the others serve( - (shutdown_signal, local_addr_tx), + (shutdown_signal, local_addr_tx, done_tx), eloop.executor(), addr.to_owned(), cors_domains.clone(), @@ -394,61 +457,78 @@ impl> ServerBuilder { reuse_port, req_max_size, ); - let handles = (0..self.threads - 1).map(|i| { - let (local_addr_tx, local_addr_rx) = mpsc::channel(); - let (close, shutdown_signal) = oneshot::channel(); - let eloop = UninitializedExecutor::Unspawned.init_with_name(format!("http.worker{}", i + 1))?; - serve( - (shutdown_signal, local_addr_tx), - eloop.executor(), - addr.to_owned(), - cors_domains.clone(), - cors_max_age, - allowed_headers.clone(), - request_middleware.clone(), - allowed_hosts.clone(), - jsonrpc_handler.clone(), - rest_api, - health_api.clone(), - keep_alive, - reuse_port, - req_max_size, - ); - Ok((eloop, close, local_addr_rx)) - }).collect::>>()?; + let handles = (0..self.threads - 1) + .map(|i| { + let (local_addr_tx, local_addr_rx) = mpsc::channel(); + let (close, shutdown_signal) = oneshot::channel(); + let (done_tx, done_rx) = oneshot::channel(); + let eloop = UninitializedExecutor::Unspawned.init_with_name(format!("http.worker{}", i + 1))?; + serve( + (shutdown_signal, local_addr_tx, done_tx), + eloop.executor(), + addr.to_owned(), + cors_domains.clone(), + cors_max_age, + allowed_headers.clone(), + request_middleware.clone(), + allowed_hosts.clone(), + jsonrpc_handler.clone(), + rest_api, + health_api.clone(), + keep_alive, + reuse_port, + req_max_size, + ); + Ok((eloop, close, local_addr_rx, done_rx)) + }) + .collect::>>()?; // Wait for server initialization let local_addr = recv_address(local_addr_rx); // Wait for other threads as well. - let mut handles = handles.into_iter().map(|(eloop, close, local_addr_rx)| { - let _ = recv_address(local_addr_rx)?; - Ok((eloop, close)) - }).collect::)>>()?; - handles.push((eloop, close)); - let (executors, close) = handles.into_iter().unzip(); + let mut handles: Vec<(Executor, oneshot::Sender<()>, oneshot::Receiver<()>)> = handles + .into_iter() + .map(|(eloop, close, local_addr_rx, done_rx)| { + let _ = recv_address(local_addr_rx)?; + Ok((eloop, close, done_rx)) + }) + .collect::>>()?; + handles.push((eloop, close, done_rx)); + + let (executors, done_rxs) = handles + .into_iter() + .fold((vec![], vec![]), |mut acc, (eloop, closer, done_rx)| { + acc.0.push((eloop, closer)); + acc.1.push(done_rx); + acc + }); Ok(Server { address: local_addr?, - executor: Some(executors), - close: Some(close), + executors: Arc::new(Mutex::new(Some(executors))), + done: Some(done_rxs), }) } } fn recv_address(local_addr_rx: mpsc::Receiver>) -> io::Result { - local_addr_rx.recv().map_err(|_| { - io::Error::new(io::ErrorKind::Interrupted, "") - })? + local_addr_rx + .recv() + .map_err(|_| io::Error::new(io::ErrorKind::Interrupted, ""))? } fn serve>( - signals: (oneshot::Receiver<()>, mpsc::Sender>), - executor: tokio::runtime::TaskExecutor, + signals: ( + oneshot::Receiver<()>, + mpsc::Sender>, + oneshot::Sender<()>, + ), + executor: TaskExecutor, addr: SocketAddr, cors_domains: CorsDomains, cors_max_age: Option, allowed_headers: cors::AccessControlAllowHeaders, - request_middleware: Arc, + request_middleware: Arc, allowed_hosts: AllowedHosts, jsonrpc_handler: Rpc, rest_api: RestApi, @@ -456,11 +536,13 @@ fn serve>( keep_alive: bool, reuse_port: bool, max_request_body_size: usize, -) { - let (shutdown_signal, local_addr_tx) = signals; - executor.spawn(future::lazy(move || { - let handle = tokio::reactor::Handle::current(); - +) where + S::Future: Unpin, + S::CallFuture: Unpin, + M: Unpin, +{ + let (shutdown_signal, local_addr_tx, done_tx) = signals; + executor.spawn(async move { let bind = move || { let listener = match addr { SocketAddr::V4(_) => net2::TcpBuilder::new_v4()?, @@ -470,69 +552,94 @@ fn serve>( listener.reuse_address(true)?; listener.bind(&addr)?; let listener = listener.listen(1024)?; - let listener = tokio::net::TcpListener::from_std(listener, &handle)?; + let local_addr = listener.local_addr()?; + + // NOTE: Future-proof by explicitly setting the listener socket to + // non-blocking mode of operation (future Tokio/Hyper versions + // require for the callers to do that manually) + listener.set_nonblocking(true)?; + // HACK: See below. + #[cfg(windows)] + let raw_socket = std::os::windows::io::AsRawSocket::as_raw_socket(&listener); + #[cfg(not(windows))] + let raw_socket = (); + + let server_builder = + hyper::Server::from_tcp(listener).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; // Add current host to allowed headers. // NOTE: we need to use `l.local_addr()` instead of `addr` // it might be different! - let local_addr = listener.local_addr()?; - - Ok((listener, local_addr)) + Ok((server_builder, local_addr, raw_socket)) }; let bind_result = match bind() { - Ok((listener, local_addr)) => { + Ok((server_builder, local_addr, raw_socket)) => { // Send local address match local_addr_tx.send(Ok(local_addr)) { - Ok(_) => futures::future::ok((listener, local_addr)), + Ok(_) => Ok((server_builder, local_addr, raw_socket)), Err(_) => { - warn!("Thread {:?} unable to reach receiver, closing server", thread::current().name()); - futures::future::err(()) - }, + warn!( + "Thread {:?} unable to reach receiver, closing server", + thread::current().name() + ); + Err(()) + } } - }, + } Err(err) => { // Send error let _send_result = local_addr_tx.send(Err(err)); - futures::future::err(()) + Err(()) } }; - bind_result.and_then(move |(listener, local_addr)| { - let allowed_hosts = server_utils::hosts::update(allowed_hosts, &local_addr); - - let mut http = server::conn::Http::new(); - http.keep_alive(keep_alive); - let tcp_stream = SuspendableStream::new(listener.incoming()); - - tcp_stream - .for_each(move |socket| { - let service = ServerHandler::new( - jsonrpc_handler.clone(), - cors_domains.clone(), - cors_max_age, - allowed_headers.clone(), - allowed_hosts.clone(), - request_middleware.clone(), - rest_api, - health_api.clone(), - max_request_body_size, - keep_alive, - ); - tokio::spawn(http.serve_connection(socket, service) - .map_err(|e| error!("Error serving connection: {:?}", e))); - Ok(()) - }) - .map_err(|e| { - warn!("Incoming streams error, closing sever: {:?}", e); - }) - .select(shutdown_signal.map_err(|e| { - debug!("Shutdown signaller dropped, closing server: {:?}", e); - })) - .map(|_| ()) - .map_err(|_| ()) - }) - })); + let (server_builder, local_addr, _raw_socket) = bind_result?; + + let allowed_hosts = server_utils::hosts::update(allowed_hosts, &local_addr); + + let server_builder = server_builder + .http1_keepalive(keep_alive) + .tcp_nodelay(true) + // Explicitly attempt to recover from accept errors (e.g. too many + // files opened) instead of erroring out the entire server. + .tcp_sleep_on_accept_errors(true); + + let service_fn = hyper::service::make_service_fn(move |_addr_stream| { + let service = ServerHandler::new( + jsonrpc_handler.downgrade(), + cors_domains.clone(), + cors_max_age, + allowed_headers.clone(), + allowed_hosts.clone(), + request_middleware.clone(), + rest_api, + health_api.clone(), + max_request_body_size, + keep_alive, + ); + async { Ok::<_, Infallible>(service) } + }); + + let server = server_builder.serve(service_fn).with_graceful_shutdown(async { + if let Err(err) = shutdown_signal.await { + debug!("Shutdown signaller dropped, closing server: {:?}", err); + } + }); + + if let Err(err) = server.await { + error!("Error running HTTP server: {:?}", err); + } + + // FIXME: Work around TCP listener socket not being properly closed + // in mio v0.6. This runs the std::net::TcpListener's destructor, + // which closes the underlying OS socket. + // Remove this once we migrate to Tokio 1.0. + #[cfg(windows)] + let _: std::net::TcpListener = unsafe { std::os::windows::io::FromRawSocket::from_raw_socket(_raw_socket) }; + + done_tx.send(()) + }); } #[cfg(unix)] @@ -540,7 +647,7 @@ fn configure_port(reuse: bool, tcp: &net2::TcpBuilder) -> io::Result<()> { use net2::unix::*; if reuse { - try!(tcp.reuse_port(true)); + tcp.reuse_port(true)?; } Ok(()) @@ -548,17 +655,36 @@ fn configure_port(reuse: bool, tcp: &net2::TcpBuilder) -> io::Result<()> { #[cfg(not(unix))] fn configure_port(_reuse: bool, _tcp: &net2::TcpBuilder) -> io::Result<()> { - Ok(()) + Ok(()) +} + +/// Handle used to close the server. Can be cloned and passed around to different threads and be used +/// to close a server that is `wait()`ing. + +#[derive(Clone)] +pub struct CloseHandle(Arc)>>>>); + +impl CloseHandle { + /// Shutdown a running server + pub fn close(self) { + if let Some(executors) = self.0.lock().take() { + for (executor, closer) in executors { + // First send shutdown signal so we can proceed with underlying select + let _ = closer.send(()); + executor.close(); + } + } + } } +type Executors = Arc)>>>>; /// jsonrpc http server instance pub struct Server { address: SocketAddr, - executor: Option>, - close: Option>>, + executors: Executors, + done: Option>>, } -const PROOF: &'static str = "Server is always Some until self is consumed."; impl Server { /// Returns address of this server pub fn address(&self) -> &SocketAddr { @@ -566,28 +692,33 @@ impl Server { } /// Closes the server. - pub fn close(mut self) { - for close in self.close.take().expect(PROOF) { - let _ = close.send(()); - } - - for executor in self.executor.take().expect(PROOF) { - executor.close(); - } + pub fn close(self) { + self.close_handle().close() } /// Will block, waiting for the server to finish. pub fn wait(mut self) { - for executor in self.executor.take().expect(PROOF) { - executor.wait(); + self.wait_internal(); + } + + /// Get a handle that allows us to close the server from a different thread and/or while the + /// server is `wait()`ing. + pub fn close_handle(&self) -> CloseHandle { + CloseHandle(self.executors.clone()) + } + + fn wait_internal(&mut self) { + if let Some(receivers) = self.done.take() { + // NOTE: Gracefully handle the case where we may wait on a *nested* + // local task pool (for now, wait on a dedicated, spawned thread) + let _ = std::thread::spawn(move || futures::executor::block_on(future::try_join_all(receivers))).join(); } } } impl Drop for Server { fn drop(&mut self) { - self.executor.take().map(|executors| { - for executor in executors { executor.close(); } - }); + self.close_handle().close(); + self.wait_internal(); } } diff --git a/http/src/response.rs b/http/src/response.rs index 50814b68a..a66fec9ca 100644 --- a/http/src/response.rs +++ b/http/src/response.rs @@ -1,6 +1,6 @@ //! Basic Request/Response structures used internally. -pub use hyper::{self, Method, Body, StatusCode, header::HeaderValue}; +pub use hyper::{self, header::HeaderValue, Body, Method, StatusCode}; /// Simple server response structure #[derive(Debug)] @@ -42,7 +42,7 @@ impl Response { Response { code: StatusCode::SERVICE_UNAVAILABLE, content_type: HeaderValue::from_static("application/json; charset=utf-8"), - content: format!("{}", msg.into()), + content: msg.into(), } } @@ -96,7 +96,7 @@ impl Response { Response { code: StatusCode::BAD_REQUEST, content_type: plain_text(), - content: msg.into() + content: msg.into(), } } @@ -105,7 +105,16 @@ impl Response { Response { code: StatusCode::PAYLOAD_TOO_LARGE, content_type: plain_text(), - content: msg.into() + content: msg.into(), + } + } + + /// Create a 500 response when server is closing. + pub(crate) fn closing() -> Self { + Response { + code: StatusCode::SERVICE_UNAVAILABLE, + content_type: plain_text(), + content: "Server is closing.".into(), } } } diff --git a/http/src/tests.rs b/http/src/tests.rs index 3f063d22d..302642599 100644 --- a/http/src/tests.rs +++ b/http/src/tests.rs @@ -1,22 +1,27 @@ -extern crate jsonrpc_core; +use jsonrpc_core; -use std::str::Lines; -use std::net::TcpStream; +use self::jsonrpc_core::{Error, ErrorCode, IoHandler, Params, Value}; use std::io::{Read, Write}; -use self::jsonrpc_core::{IoHandler, Params, Value, Error, ErrorCode}; +use std::net::TcpStream; +use std::str::Lines; +use std::time::Duration; -use self::jsonrpc_core::futures::{self, Future}; +use self::jsonrpc_core::futures; use super::*; fn serve_hosts(hosts: Vec) -> Server { ServerBuilder::new(IoHandler::default()) - .cors(DomainsValidation::AllowOnly(vec![AccessControlAllowOrigin::Value("parity.io".into())])) + .cors(DomainsValidation::AllowOnly(vec![AccessControlAllowOrigin::Value( + "parity.io".into(), + )])) .allowed_hosts(DomainsValidation::AllowOnly(hosts)) .start_http(&"127.0.0.1:0".parse().unwrap()) .unwrap() } -fn id(t: T) -> T { t } +fn id(t: T) -> T { + t +} fn serve ServerBuilder>(alter: F) -> Server { let builder = ServerBuilder::new(io()) @@ -28,49 +33,40 @@ fn serve ServerBuilder>(alter: F) -> Server { .rest_api(RestApi::Secure) .health_api(("/health", "hello_async")); - alter(builder) - .start_http(&"127.0.0.1:0".parse().unwrap()) - .unwrap() + alter(builder).start_http(&"127.0.0.1:0".parse().unwrap()).unwrap() } fn serve_allow_headers(cors_allow_headers: cors::AccessControlAllowHeaders) -> Server { let mut io = IoHandler::default(); - io.add_method("hello", |params: Params| { - match params.parse::<(u64, )>() { - Ok((num, )) => Ok(Value::String(format!("world: {}", num))), - _ => Ok(Value::String("world".into())), - } + io.add_sync_method("hello", |params: Params| match params.parse::<(u64,)>() { + Ok((num,)) => Ok(Value::String(format!("world: {}", num))), + _ => Ok(Value::String("world".into())), }); ServerBuilder::new(io) - .cors( - DomainsValidation::AllowOnly(vec![ - AccessControlAllowOrigin::Value("parity.io".into()), - AccessControlAllowOrigin::Null, - ]) - ) + .cors(DomainsValidation::AllowOnly(vec![ + AccessControlAllowOrigin::Value("parity.io".into()), + AccessControlAllowOrigin::Null, + ])) .cors_allow_headers(cors_allow_headers) .start_http(&"127.0.0.1:0".parse().unwrap()) .unwrap() } fn io() -> IoHandler { - use std::{thread, time}; - let mut io = IoHandler::default(); - io.add_method("hello", |params: Params| { - match params.parse::<(u64, )>() { - Ok((num, )) => Ok(Value::String(format!("world: {}", num))), - _ => Ok(Value::String("world".into())), - } + io.add_sync_method("hello", |params: Params| match params.parse::<(u64,)>() { + Ok((num,)) => Ok(Value::String(format!("world: {}", num))), + _ => Ok(Value::String("world".into())), }); - io.add_method("fail", |_: Params| Err(Error::new(ErrorCode::ServerError(-34)))); + io.add_sync_method("fail", |_: Params| Err(Error::new(ErrorCode::ServerError(-34)))); io.add_method("hello_async", |_params: Params| { - futures::finished(Value::String("world".into())) + futures::future::ready(Ok(Value::String("world".into()))) }); io.add_method("hello_async2", |_params: Params| { - let (c, p) = futures::oneshot(); + use futures::TryFutureExt; + let (c, p) = futures::channel::oneshot::channel(); thread::spawn(move || { - thread::sleep(time::Duration::from_millis(10)); + thread::sleep(Duration::from_millis(10)); c.send(Value::String("world".into())).unwrap(); }); p.map_err(|_| Error::invalid_request()) @@ -94,7 +90,7 @@ fn read_block(lines: &mut Lines) -> String { Some(v) => { block.push_str(v); block.push_str("\n"); - }, + } } } block @@ -112,11 +108,7 @@ fn request(server: Server, request: &str) -> Response { let headers = read_block(&mut lines); let body = read_block(&mut lines); - Response { - status: status, - headers: headers, - body: body, - } + Response { status, headers, body } } #[test] @@ -125,19 +117,23 @@ fn should_return_method_not_allowed_for_get() { let server = serve(id); // when - let response = request(server, + let response = request( + server, "\ - GET / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - " + GET / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ", ); // then assert_eq!(response.status, "HTTP/1.1 405 Method Not Allowed".to_owned()); - assert_eq!(response.body, "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned()); + assert_eq!( + response.body, + "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned() + ); } #[test] @@ -146,14 +142,15 @@ fn should_handle_health_endpoint() { let server = serve(id); // when - let response = request(server, + let response = request( + server, "\ - GET /health HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - " + GET /health HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ", ); // then @@ -167,14 +164,15 @@ fn should_handle_health_endpoint_failure() { let server = serve(|builder| builder.health_api(("/api/health", "fail"))); // when - let response = request(server, + let response = request( + server, "\ - GET /api/health HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - " + GET /api/health HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ", ); // then @@ -188,19 +186,23 @@ fn should_return_unsupported_media_type_if_not_json() { let server = serve(id); // when - let response = request(server, + let response = request( + server, "\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - \r\n\ - {}\r\n\ - " + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + \r\n\ + {}\r\n\ + ", ); // then assert_eq!(response.status, "HTTP/1.1 415 Unsupported Media Type".to_owned()); - assert_eq!(response.body, "Supplied content type is not allowed. Content-Type: application/json is required\n".to_owned()); + assert_eq!( + response.body, + "Supplied content type is not allowed. Content-Type: application/json is required\n".to_owned() + ); } #[test] @@ -210,16 +212,21 @@ fn should_return_error_for_malformed_request() { // when let req = r#"{"jsonrpc":"3.0","method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -234,16 +241,21 @@ fn should_return_error_for_malformed_request2() { // when let req = r#"{"jsonrpc":"2.0","metho1d":""}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -258,16 +270,21 @@ fn should_return_empty_response_for_notification() { // when let req = r#"{"jsonrpc":"2.0","method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -275,7 +292,6 @@ fn should_return_empty_response_for_notification() { assert_eq!(response.body, "".to_owned()); } - #[test] fn should_return_method_not_found() { // given @@ -283,16 +299,21 @@ fn should_return_method_not_found() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -307,23 +328,34 @@ fn should_add_cors_allow_origins() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, method_not_found()); - assert!(response.headers.contains("access-control-allow-origin: http://parity.io"), "Headers missing in {}", response.headers); + assert!( + response + .headers + .contains("access-control-allow-origin: http://parity.io"), + "Headers missing in {}", + response.headers + ); } #[test] @@ -333,24 +365,39 @@ fn should_add_cors_max_age_headers() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, method_not_found()); - assert!(response.headers.contains("access-control-allow-origin: http://parity.io"), "Headers missing in {}", response.headers); - assert!(response.headers.contains("access-control-max-age: 1000"), "Headers missing in {}", response.headers); + assert!( + response + .headers + .contains("access-control-allow-origin: http://parity.io"), + "Headers missing in {}", + response.headers + ); + assert!( + response.headers.contains("access-control-max-age: 1000"), + "Headers missing in {}", + response.headers + ); } #[test] @@ -360,17 +407,22 @@ fn should_not_add_cors_allow_origins() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: fake.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: fake.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -378,8 +430,6 @@ fn should_not_add_cors_allow_origins() { assert_eq!(response.body, cors_invalid_allow_origin()); } - - #[test] fn should_not_process_the_request_in_case_of_invalid_allow_origin() { // given @@ -387,17 +437,22 @@ fn should_not_process_the_request_in_case_of_invalid_allow_origin() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello"}"#; - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: fake.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: fake.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -405,27 +460,35 @@ fn should_not_process_the_request_in_case_of_invalid_allow_origin() { assert_eq!(response.body, cors_invalid_allow_origin()); } - #[test] fn should_return_proper_headers_on_options() { // given let server = serve(id); // when - let response = request(server, + let response = request( + server, "\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Length: 0\r\n\ - \r\n\ - " + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Length: 0\r\n\ + \r\n\ + ", ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); - assert!(response.headers.contains("allow: OPTIONS, POST"), "Headers missing in {}", response.headers); - assert!(response.headers.contains("accept: application/json"), "Headers missing in {}", response.headers); + assert!( + response.headers.contains("allow: OPTIONS, POST"), + "Headers missing in {}", + response.headers + ); + assert!( + response.headers.contains("accept: application/json"), + "Headers missing in {}", + response.headers + ); assert_eq!(response.body, ""); } @@ -436,23 +499,32 @@ fn should_add_cors_allow_origin_for_null_origin() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: null\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: null\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, method_not_found()); - assert!(response.headers.contains("access-control-allow-origin: null"), "Headers missing in {}", response.headers); + assert!( + response.headers.contains("access-control-allow-origin: null"), + "Headers missing in {}", + response.headers + ); } #[test] @@ -462,23 +534,32 @@ fn should_add_cors_allow_origin_for_null_origin_when_all() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: null\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: null\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, method_not_found()); - assert!(response.headers.contains("access-control-allow-origin: null"), "Headers missing in {}", response.headers); + assert!( + response.headers.contains("access-control-allow-origin: null"), + "Headers missing in {}", + response.headers + ); } #[test] @@ -488,16 +569,17 @@ fn should_not_allow_request_larger_than_max() { .start_http(&"127.0.0.1:0".parse().unwrap()) .unwrap(); - let response = request(server, + let response = request( + server, "\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Length: 8\r\n\ - Content-Type: application/json\r\n\ - \r\n\ - 12345678\r\n\ - " + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Length: 8\r\n\ + Content-Type: application/json\r\n\ + \r\n\ + 12345678\r\n\ + ", ); assert_eq!(response.status, "HTTP/1.1 413 Payload Too Large".to_owned()); } @@ -509,16 +591,21 @@ fn should_reject_invalid_hosts() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -533,15 +620,20 @@ fn should_reject_missing_host() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -556,16 +648,21 @@ fn should_allow_if_host_is_valid() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: parity.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: parity.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -576,31 +673,36 @@ fn should_allow_if_host_is_valid() { #[test] fn should_respond_configured_allowed_hosts_to_options() { // given - let allowed = vec![ - "X-Allowed".to_owned(), - "X-AlsoAllowed".to_owned(), - ]; + let allowed = vec!["X-Allowed".to_owned(), "X-AlsoAllowed".to_owned()]; let custom = cors::AccessControlAllowHeaders::Only(allowed.clone()); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Access-Control-Request-Headers: {}\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - \r\n\ - ", &allowed.join(", ")) + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Access-Control-Request-Headers: {}\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + \r\n\ + ", + &allowed.join(", ") + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); let expected = format!("access-control-allow-headers: {}", &allowed.join(", ")); - assert!(response.headers.contains(&expected), "Headers missing in {}", response.headers); + assert!( + response.headers.contains(&expected), + "Headers missing in {}", + response.headers + ); } #[test] @@ -609,22 +711,28 @@ fn should_not_contain_default_cors_allow_headers() { let server = serve(id); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: 0\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); - assert!(!response.headers.contains("access-control-allow-headers:"), - "Header should not be in {}", response.headers); + assert!( + !response.headers.contains("access-control-allow-headers:"), + "Header should not be in {}", + response.headers + ); } #[test] @@ -633,106 +741,123 @@ fn should_respond_valid_to_default_allowed_headers() { let server = serve(id); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: Accept, Content-Type, Origin\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: Accept, Content-Type, Origin\r\n\ + \r\n\ + " + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); let expected = "access-control-allow-headers: Accept, Content-Type, Origin"; - assert!(response.headers.contains(expected), "Headers missing in {}", response.headers); + assert!( + response.headers.contains(expected), + "Headers missing in {}", + response.headers + ); } #[test] fn should_by_default_respond_valid_to_any_request_headers() { // given - let allowed = vec![ - "X-Abc".to_owned(), - "X-123".to_owned(), - ]; + let allowed = vec!["X-Abc".to_owned(), "X-123".to_owned()]; let custom = cors::AccessControlAllowHeaders::Only(allowed.clone()); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: {}\r\n\ - \r\n\ - ", &allowed.join(", ")) + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: {}\r\n\ + \r\n\ + ", + &allowed.join(", ") + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); let expected = format!("access-control-allow-headers: {}", &allowed.join(", ")); - assert!(response.headers.contains(&expected), "Headers missing in {}", response.headers); + assert!( + response.headers.contains(&expected), + "Headers missing in {}", + response.headers + ); } #[test] fn should_respond_valid_to_configured_allow_headers() { // given - let allowed = vec![ - "X-Allowed".to_owned(), - "X-AlsoAllowed".to_owned(), - ]; + let allowed = vec!["X-Allowed".to_owned(), "X-AlsoAllowed".to_owned()]; let custom = cors::AccessControlAllowHeaders::Only(allowed.clone()); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: {}\r\n\ - \r\n\ - ", &allowed.join(", ")) + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: {}\r\n\ + \r\n\ + ", + &allowed.join(", ") + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); let expected = format!("access-control-allow-headers: {}", &allowed.join(", ")); - assert!(response.headers.contains(&expected), "Headers missing in {}", response.headers); + assert!( + response.headers.contains(&expected), + "Headers missing in {}", + response.headers + ); } #[test] fn should_respond_invalid_if_non_allowed_header_used() { // given - let custom = cors::AccessControlAllowHeaders::Only( - vec![ - "X-Allowed".to_owned(), - ]); + let custom = cors::AccessControlAllowHeaders::Only(vec!["X-Allowed".to_owned()]); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - X-Not-Allowed: not allowed\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + X-Not-Allowed: not allowed\r\n\ + \r\n\ + " + ), ); // then @@ -743,26 +868,29 @@ fn should_respond_invalid_if_non_allowed_header_used() { #[test] fn should_respond_valid_if_allowed_header_used() { // given - let custom = cors::AccessControlAllowHeaders::Only( - vec![ - "X-Allowed".to_owned(), - ]); + let custom = cors::AccessControlAllowHeaders::Only(vec!["X-Allowed".to_owned()]); let server = serve_allow_headers(custom); let addr = server.address().clone(); // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - X-Allowed: Foobar\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + X-Allowed: Foobar\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -773,26 +901,29 @@ fn should_respond_valid_if_allowed_header_used() { #[test] fn should_respond_valid_if_case_insensitive_allowed_header_used() { // given - let custom = cors::AccessControlAllowHeaders::Only( - vec![ - "X-Allowed".to_owned(), - ]); + let custom = cors::AccessControlAllowHeaders::Only(vec!["X-Allowed".to_owned()]); let server = serve_allow_headers(custom); let addr = server.address().clone(); // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - X-AlLoWed: Foobar\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + X-AlLoWed: Foobar\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -803,32 +934,32 @@ fn should_respond_valid_if_case_insensitive_allowed_header_used() { #[test] fn should_respond_valid_on_case_mismatches_in_allowed_headers() { // given - let allowed = vec![ - "X-Allowed".to_owned(), - "X-AlsoAllowed".to_owned(), - ]; + let allowed = vec!["X-Allowed".to_owned(), "X-AlsoAllowed".to_owned()]; let custom = cors::AccessControlAllowHeaders::Only(allowed.clone()); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: x-ALLoweD, x-alSOaLloWeD\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: x-ALLoweD, x-alSOaLloWeD\r\n\ + \r\n\ + " + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); - let contained = response.headers.contains( - "access-control-allow-headers: x-ALLoweD, x-alSOaLloWeD" - ); + let contained = response + .headers + .contains("access-control-allow-headers: x-ALLoweD, x-alSOaLloWeD"); assert!(contained, "Headers missing in {}", response.headers); } @@ -840,46 +971,54 @@ fn should_respond_valid_to_any_requested_header() { let headers = "Something, Anything, Xyz, 123, _?"; // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: {}\r\n\ - \r\n\ - ", headers) + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: {}\r\n\ + \r\n\ + ", + headers + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); let expected = format!("access-control-allow-headers: {}", headers); - assert!(response.headers.contains(&expected), "Headers missing in {}", response.headers); + assert!( + response.headers.contains(&expected), + "Headers missing in {}", + response.headers + ); } #[test] fn should_forbid_invalid_request_headers() { // given - let custom = cors::AccessControlAllowHeaders::Only( - vec![ - "X-Allowed".to_owned(), - ]); + let custom = cors::AccessControlAllowHeaders::Only(vec!["X-Allowed".to_owned()]); let server = serve_allow_headers(custom); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: *\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: *\r\n\ + \r\n\ + " + ), ); // then @@ -896,23 +1035,29 @@ fn should_respond_valid_to_wildcard_if_any_header_allowed() { let server = serve_allow_headers(cors::AccessControlAllowHeaders::Any); // when - let response = request(server, - &format!("\ - OPTIONS / HTTP/1.1\r\n\ - Host: 127.0.0.1:8080\r\n\ - Origin: http://parity.io\r\n\ - Content-Length: 0\r\n\ - Content-Type: application/json\r\n\ - Connection: close\r\n\ - Access-Control-Request-Headers: *\r\n\ - \r\n\ - ") + let response = request( + server, + &format!( + "\ + OPTIONS / HTTP/1.1\r\n\ + Host: 127.0.0.1:8080\r\n\ + Origin: http://parity.io\r\n\ + Content-Length: 0\r\n\ + Content-Type: application/json\r\n\ + Connection: close\r\n\ + Access-Control-Request-Headers: *\r\n\ + \r\n\ + " + ), ); // then assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); - assert!(response.headers.contains("access-control-allow-headers: *"), - "Headers missing in {}", response.headers); + assert!( + response.headers.contains("access-control-allow-headers: *"), + "Headers missing in {}", + response.headers + ); } #[test] @@ -922,16 +1067,21 @@ fn should_allow_application_json_utf8() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: parity.io\r\n\ - Connection: close\r\n\ - Content-Type: application/json; charset=utf-8\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: parity.io\r\n\ + Connection: close\r\n\ + Content-Type: application/json; charset=utf-8\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + req.as_bytes().len(), + req + ), ); // then @@ -947,16 +1097,22 @@ fn should_always_allow_the_bind_address() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: {}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr, req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: {}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr, + req.as_bytes().len(), + req + ), ); // then @@ -972,16 +1128,22 @@ fn should_always_allow_the_bind_address_as_localhost() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -997,16 +1159,22 @@ fn should_handle_sync_requests_correctly() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -1022,16 +1190,22 @@ fn should_handle_async_requests_with_immediate_response_correctly() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_async"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -1047,16 +1221,22 @@ fn should_handle_async_requests_correctly() { // when let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_async2"}"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -1072,16 +1252,22 @@ fn should_handle_sync_batch_requests_correctly() { // when let req = r#"[{"jsonrpc":"2.0","id":1,"method":"hello"}]"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -1097,16 +1283,53 @@ fn should_handle_rest_request_with_params() { // when let req = ""; - let response = request(server, - &format!("\ - POST /hello/5 HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST /hello/5 HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), + ); + + // then + assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); + assert_eq!(response.body, world_5()); +} + +#[test] +fn should_handle_rest_request_with_case_insensitive_content_type() { + // given + let server = serve(id); + let addr = server.address().clone(); + + // when + let req = ""; + let response = request( + server, + &format!( + "\ + POST /hello/5 HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: Application/JSON; charset=UTF-8\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then @@ -1122,20 +1345,29 @@ fn should_return_error_in_case_of_unsecure_rest_and_no_method() { // when let req = ""; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then assert_eq!(response.status, "HTTP/1.1 415 Unsupported Media Type".to_owned()); - assert_eq!(&response.body, "Supplied content type is not allowed. Content-Type: application/json is required\n"); + assert_eq!( + &response.body, + "Supplied content type is not allowed. Content-Type: application/json is required\n" + ); } #[test] @@ -1146,25 +1378,53 @@ fn should_return_connection_header() { // when let req = r#"[{"jsonrpc":"2.0","id":1,"method":"hello"}]"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: close\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: close\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then - assert!(response.headers.contains("connection: close"), - "Headers missing in {}", response.headers); + assert!( + response.headers.contains("connection: close"), + "Headers missing in {}", + response.headers + ); assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, world_batch()); } +#[test] +fn close_handle_makes_wait_return() { + let server = serve(id); + let close_handle = server.close_handle(); + + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + tx.send(server.wait()).unwrap(); + }); + + thread::sleep(Duration::from_secs(3)); + + close_handle.close(); + + rx.recv_timeout(Duration::from_secs(10)) + .expect("Expected server to close"); +} + #[test] fn should_close_connection_without_keep_alive() { // given @@ -1173,20 +1433,29 @@ fn should_close_connection_without_keep_alive() { // when let req = r#"[{"jsonrpc":"2.0","id":1,"method":"hello"}]"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then - assert!(response.headers.contains("connection: close"), - "Header missing in {}", response.headers); + assert!( + response.headers.contains("connection: close"), + "Header missing in {}", + response.headers + ); assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, world_batch()); } @@ -1199,26 +1468,104 @@ fn should_respond_with_close_even_if_client_wants_to_keep_alive() { // when let req = r#"[{"jsonrpc":"2.0","id":1,"method":"hello"}]"#; - let response = request(server, - &format!("\ - POST / HTTP/1.1\r\n\ - Host: localhost:{}\r\n\ - Connection: keep-alive\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}\r\n\ - ", addr.port(), req.as_bytes().len(), req) + let response = request( + server, + &format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Connection: keep-alive\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ), ); // then - assert!(response.headers.contains("connection: close"), - "Headers missing in {}", response.headers); + assert!( + response.headers.contains("connection: close"), + "Headers missing in {}", + response.headers + ); assert_eq!(response.status, "HTTP/1.1 200 OK".to_owned()); assert_eq!(response.body, world_batch()); } +#[test] +fn should_drop_io_handler_when_server_is_closed() { + use std::sync::{Arc, Mutex}; + // given + let (weak, _req) = { + let my_ref = Arc::new(Mutex::new(5)); + let weak = Arc::downgrade(&my_ref); + let mut io = IoHandler::default(); + io.add_sync_method("hello", move |_| { + Ok(Value::String(format!("{}", my_ref.lock().unwrap()))) + }); + let server = ServerBuilder::new(io) + .start_http(&"127.0.0.1:0".parse().unwrap()) + .unwrap(); + + let addr = server.address().clone(); + + // when + let req = TcpStream::connect(addr).unwrap(); + server.close(); + (weak, req) + }; + + // then + for _ in 1..1000 { + if weak.upgrade().is_none() { + return; + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + panic!("expected server to be closed and io handler to be dropped") +} + +#[test] +fn should_not_close_server_when_serving_errors() { + // given + let server = serve(|builder| builder.keep_alive(false)); + let addr = server.address().clone(); + // when + let req = "{}"; + let request = format!( + "\ + POST / HTTP/1.1\r\n\ + Host: localhost:{}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + 😈: 😈\r\n\ + \r\n\ + {}\r\n\ + ", + addr.port(), + req.as_bytes().len(), + req + ); + + let mut req = TcpStream::connect(addr).unwrap(); + req.write_all(request.as_bytes()).unwrap(); + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + assert!(!response.is_empty(), "Response should not be empty: {}", response); + + // then make a second request and it must not fail. + let mut req = TcpStream::connect(addr).unwrap(); + req.write_all(request.as_bytes()).unwrap(); + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + assert!(!response.is_empty(), "Response should not be empty: {}", response); +} fn invalid_host() -> String { "Provided Host header is not whitelisted.\n".into() @@ -1233,18 +1580,18 @@ fn cors_invalid_allow_headers() -> String { } fn method_not_found() -> String { - "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32601,\"message\":\"Method not found\"},\"id\":1}\n".into() + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32601,\"message\":\"Method not found\"},\"id\":1}\n".into() } fn invalid_request() -> String { - "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"Invalid request\"},\"id\":null}\n".into() + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"Invalid request\"},\"id\":null}\n".into() } fn world() -> String { - "{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}\n".into() + "{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}\n".into() } fn world_5() -> String { - "{\"jsonrpc\":\"2.0\",\"result\":\"world: 5\",\"id\":1}\n".into() + "{\"jsonrpc\":\"2.0\",\"result\":\"world: 5\",\"id\":1}\n".into() } fn world_batch() -> String { - "[{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}]\n".into() + "[{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}]\n".into() } diff --git a/http/src/utils.rs b/http/src/utils.rs index 0640a4076..6140b508b 100644 --- a/http/src/utils.rs +++ b/http/src/utils.rs @@ -1,6 +1,6 @@ use hyper::{self, header}; -use server_utils::{cors, hosts}; +use crate::server_utils::{cors, hosts}; /// Extracts string value of a single header in request. fn read_header<'a>(req: &'a hyper::Request, header_name: &str) -> Option<&'a str> { @@ -8,22 +8,26 @@ fn read_header<'a>(req: &'a hyper::Request, header_name: &str) -> O } /// Returns `true` if Host header in request matches a list of allowed hosts. -pub fn is_host_allowed( - request: &hyper::Request, - allowed_hosts: &Option>, -) -> bool { +pub fn is_host_allowed(request: &hyper::Request, allowed_hosts: &Option>) -> bool { hosts::is_host_valid(read_header(request, "host"), allowed_hosts) } /// Returns a CORS AllowOrigin header that should be returned with that request. pub fn cors_allow_origin( request: &hyper::Request, - cors_domains: &Option> + cors_domains: &Option>, ) -> cors::AllowCors { - cors::get_cors_allow_origin(read_header(request, "origin"), read_header(request, "host"), cors_domains).map(|origin| { + cors::get_cors_allow_origin( + read_header(request, "origin"), + read_header(request, "host"), + cors_domains, + ) + .map(|origin| { use self::cors::AccessControlAllowOrigin::*; match origin { - Value(ref val) => header::HeaderValue::from_str(val).unwrap_or(header::HeaderValue::from_static("null")), + Value(ref val) => { + header::HeaderValue::from_str(val).unwrap_or_else(|_| header::HeaderValue::from_static("null")) + } Null => header::HeaderValue::from_static("null"), Any => header::HeaderValue::from_static("*"), } @@ -33,41 +37,29 @@ pub fn cors_allow_origin( /// Returns the CORS AllowHeaders header that should be returned with that request. pub fn cors_allow_headers( request: &hyper::Request, - cors_allow_headers: &cors::AccessControlAllowHeaders + cors_allow_headers: &cors::AccessControlAllowHeaders, ) -> cors::AllowCors> { - let headers = request.headers().keys() - .map(|name| name.as_str()); - let requested_headers = request.headers() + let headers = request.headers().keys().map(|name| name.as_str()); + let requested_headers = request + .headers() .get_all("access-control-request-headers") .iter() .filter_map(|val| val.to_str().ok()) .flat_map(|val| val.split(", ")) - .flat_map(|val| val.split(",")); + .flat_map(|val| val.split(',')); - cors::get_cors_allow_headers( - headers, - requested_headers, - cors_allow_headers.into(), - |name| header::HeaderValue::from_str(name) - .unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) - ) + cors::get_cors_allow_headers(headers, requested_headers, cors_allow_headers, |name| { + header::HeaderValue::from_str(name).unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) + }) } /// Returns an optional value of `Connection` header that should be included in the response. /// The second parameter defines if server is configured with keep-alive option. /// Return value of `true` indicates that no `Connection` header should be returned, /// `false` indicates `Connection: close`. -pub fn keep_alive( - request: &hyper::Request, - keep_alive: bool, -) -> bool { +pub fn keep_alive(request: &hyper::Request, keep_alive: bool) -> bool { read_header(request, "connection") - .map(|val| match (keep_alive, val) { - // indicate that connection should be closed - (false, _) | (_, "close") => false, - // don't include any headers otherwise - _ => true, - }) + .map(|val| !matches!((keep_alive, val), (false, _) | (_, "close"))) // if the client header is not present, close connection if we don't keep_alive .unwrap_or(keep_alive) } diff --git a/ipc/Cargo.toml b/ipc/Cargo.toml index 16b1e228e..1bb5d66ba 100644 --- a/ipc/Cargo.toml +++ b/ipc/Cargo.toml @@ -1,27 +1,29 @@ [package] -name = "jsonrpc-ipc-server" +authors = ["Parity Technologies "] description = "IPC server for JSON-RPC" -version = "9.0.0" -authors = ["Nikolay Volf "] -license = "MIT" +documentation = "https://docs.rs/jsonrpc-ipc-server/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" +license = "MIT" +name = "jsonrpc-ipc-server" repository = "https://github.com/paritytech/jsonrpc" -documentation = "https://paritytech.github.io/jsonrpc/json_ipc_server/index.html" +version = "18.0.0" [dependencies] +futures = "0.3" log = "0.4" -tokio-service = "0.1" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-server-utils = { version = "9.0", path = "../server-utils" } -parity-tokio-ipc = "0.1" -parking_lot = "0.6" +tower-service = "0.3" +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-server-utils = { version = "18.0.0", path = "../server-utils", default-features = false } +parity-tokio-ipc = "0.9" +parking_lot = "0.11.0" [dev-dependencies] -env_logger = "0.6" +env_logger = "0.7" lazy_static = "1.0" [target.'cfg(not(windows))'.dev-dependencies] -tokio-uds = "0.2" +tokio = { version = "1", default-features = false, features = ["net", "time", "rt-multi-thread"] } [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/ipc/README.md b/ipc/README.md index 14552d40e..b709f8071 100644 --- a/ipc/README.md +++ b/ipc/README.md @@ -9,7 +9,7 @@ IPC server (Windows & Linux) for JSON-RPC 2.0. ``` [dependencies] -jsonrpc-ipc-server = { git = "https://github.com/paritytech/jsonrpc" } +jsonrpc-ipc-server = "15.0" ``` `main.rs` @@ -17,7 +17,7 @@ jsonrpc-ipc-server = { git = "https://github.com/paritytech/jsonrpc" } ```rust extern crate jsonrpc_ipc_server; -use jsonrpc_ipc_server::Server; +use jsonrpc_ipc_server::ServerBuilder; use jsonrpc_ipc_server::jsonrpc_core::*; fn main() { @@ -26,8 +26,9 @@ fn main() { Ok(Value::String("hello".into())) }); - let server = Server::new("/tmp/json-ipc-test.ipc", io).unwrap(); - ::std::thread::spawn(move || server.run()); + let builder = ServerBuilder::new(io); + let server = builder.start("/tmp/json-ipc-test.ipc").expect("Couldn't open socket"); + server.wait(); } ``` diff --git a/ipc/examples/ipc.rs b/ipc/examples/ipc.rs index 89a45c0ce..483794887 100644 --- a/ipc/examples/ipc.rs +++ b/ipc/examples/ipc.rs @@ -1,12 +1,11 @@ -extern crate jsonrpc_ipc_server; +use jsonrpc_ipc_server; use jsonrpc_ipc_server::jsonrpc_core::*; fn main() { let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_string()))); let _server = jsonrpc_ipc_server::ServerBuilder::new(io) - .start("/tmp/parity-example.ipc").expect("Server should start ok"); + .start("/tmp/parity-example.ipc") + .expect("Server should start ok"); } diff --git a/ipc/src/lib.rs b/ipc/src/lib.rs index deff37bf2..ab9e73130 100644 --- a/ipc/src/lib.rs +++ b/ipc/src/lib.rs @@ -1,28 +1,29 @@ //! Cross-platform JSON-RPC IPC transport. -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate jsonrpc_server_utils as server_utils; -extern crate parity_tokio_ipc; -extern crate tokio_service; -extern crate parking_lot; +use jsonrpc_server_utils as server_utils; -pub extern crate jsonrpc_core; +pub use jsonrpc_core; -#[macro_use] extern crate log; +#[macro_use] +extern crate log; -#[cfg(test)] #[macro_use] extern crate lazy_static; -#[cfg(test)] extern crate env_logger; -#[cfg(test)] mod logger; +#[cfg(test)] +#[macro_use] +extern crate lazy_static; + +#[cfg(test)] +mod logger; -mod server; -mod select_with_weak; mod meta; +mod select_with_weak; +mod server; use jsonrpc_core as jsonrpc; -pub use meta::{MetaExtractor, NoopExtractor, RequestContext}; -pub use server::{Server, ServerBuilder, CloseHandle,SecurityAttributes}; +pub use crate::meta::{MetaExtractor, NoopExtractor, RequestContext}; +pub use crate::server::{CloseHandle, SecurityAttributes, Server, ServerBuilder}; -pub use self::server_utils::{tokio, codecs::Separator}; -pub use self::server_utils::session::{SessionStats, SessionId}; +pub use self::server_utils::session::{SessionId, SessionStats}; +pub use self::server_utils::{codecs::Separator, tokio}; diff --git a/ipc/src/logger.rs b/ipc/src/logger.rs index 02d5a8ce2..9b885a72d 100644 --- a/ipc/src/logger.rs +++ b/ipc/src/logger.rs @@ -1,8 +1,8 @@ #![allow(dead_code)] -use std::env; -use log::LevelFilter; use env_logger::Builder; +use log::LevelFilter; +use std::env; lazy_static! { static ref LOG_DUMMY: bool = { @@ -10,7 +10,7 @@ lazy_static! { builder.filter(None, LevelFilter::Info); if let Ok(log) = env::var("RUST_LOG") { - builder.parse(&log); + builder.parse_filters(&log); } if let Ok(_) = builder.try_init() { diff --git a/ipc/src/meta.rs b/ipc/src/meta.rs index 60e173f4a..497eaf086 100644 --- a/ipc/src/meta.rs +++ b/ipc/src/meta.rs @@ -1,24 +1,27 @@ -use jsonrpc::futures::sync::mpsc; -use jsonrpc::Metadata; -use server_utils::session; +use std::path::Path; + +use crate::jsonrpc::futures::channel::mpsc; +use crate::jsonrpc::Metadata; +use crate::server_utils::session; /// Request context pub struct RequestContext<'a> { /// Session ID pub session_id: session::SessionId, /// Remote UDS endpoint - pub endpoint_addr: &'a ::parity_tokio_ipc::RemoteId, + pub endpoint_addr: &'a Path, /// Direct pipe sender - pub sender: mpsc::Sender, + pub sender: mpsc::UnboundedSender, } /// Metadata extractor (per session) -pub trait MetaExtractor : Send + Sync + 'static { +pub trait MetaExtractor: Send + Sync + 'static { /// Extracts metadata from request context fn extract(&self, context: &RequestContext) -> M; } -impl MetaExtractor for F where +impl MetaExtractor for F +where M: Metadata, F: Fn(&RequestContext) -> M + Send + Sync + 'static, { @@ -30,5 +33,7 @@ impl MetaExtractor for F where /// Noop-extractor pub struct NoopExtractor; impl MetaExtractor for NoopExtractor { - fn extract(&self, _context: &RequestContext) -> M { M::default() } + fn extract(&self, _context: &RequestContext) -> M { + M::default() + } } diff --git a/ipc/src/select_with_weak.rs b/ipc/src/select_with_weak.rs index 204059c27..43c90b783 100644 --- a/ipc/src/select_with_weak.rs +++ b/ipc/src/select_with_weak.rs @@ -1,14 +1,25 @@ -use jsonrpc::futures::{Poll, Async}; -use jsonrpc::futures::stream::{Stream, Fuse}; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use futures::stream::{Fuse, Stream}; pub trait SelectWithWeakExt: Stream { fn select_with_weak(self, other: S) -> SelectWithWeak - where S: Stream, Self: Sized; + where + S: Stream, + Self: Sized; } -impl SelectWithWeakExt for T where T: Stream { +impl SelectWithWeakExt for T +where + T: Stream, +{ fn select_with_weak(self, other: S) -> SelectWithWeak - where S: Stream, Self: Sized { + where + S: Stream, + Self: Sized, + { new(self, other) } } @@ -29,9 +40,11 @@ pub struct SelectWithWeak { } fn new(stream1: S1, stream2: S2) -> SelectWithWeak - where S1: Stream, - S2: Stream +where + S1: Stream, + S2: Stream, { + use futures::StreamExt; SelectWithWeak { strong: stream1.fuse(), weak: stream2.fuse(), @@ -40,36 +53,37 @@ fn new(stream1: S1, stream2: S2) -> SelectWithWeak } impl Stream for SelectWithWeak - where S1: Stream, - S2: Stream +where + S1: Stream + Unpin, + S2: Stream + Unpin, { type Item = S1::Item; - type Error = S1::Error; - fn poll(&mut self) -> Poll, S1::Error> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = Pin::into_inner(self); let mut checked_strong = false; loop { - if self.use_strong { - match self.strong.poll()? { - Async::Ready(Some(item)) => { - self.use_strong = false; - return Ok(Some(item).into()) - }, - Async::Ready(None) => return Ok(None.into()), - Async::NotReady => { + if this.use_strong { + match Pin::new(&mut this.strong).poll_next(cx) { + Poll::Ready(Some(item)) => { + this.use_strong = false; + return Poll::Ready(Some(item)); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { if !checked_strong { - self.use_strong = false; + this.use_strong = false; } else { - return Ok(Async::NotReady) + return Poll::Pending; } } } checked_strong = true; } else { - self.use_strong = true; - match self.weak.poll()? { - Async::Ready(Some(item)) => return Ok(Some(item).into()), - Async::Ready(None) | Async::NotReady => (), + this.use_strong = true; + match Pin::new(&mut this.weak).poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) | Poll::Pending => (), } } } diff --git a/ipc/src/server.rs b/ipc/src/server.rs index 568805249..1f1411a47 100644 --- a/ipc/src/server.rs +++ b/ipc/src/server.rs @@ -1,23 +1,20 @@ -#![allow(deprecated)] - -use std; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; - -use tokio_service::{self, Service as TokioService}; -use jsonrpc::futures::{future, Future, Stream, Sink}; -use jsonrpc::futures::sync::{mpsc, oneshot}; -use jsonrpc::{middleware, FutureResult, Metadata, MetaIoHandler, Middleware}; - -use server_utils::{ - tokio_codec::Framed, - tokio::{self, runtime::TaskExecutor, reactor::Handle}, - reactor, session, codecs, -}; +use std::task::{Context, Poll}; + +use crate::jsonrpc::futures::channel::mpsc; +use crate::jsonrpc::{middleware, MetaIoHandler, Metadata, Middleware}; +use crate::meta::{MetaExtractor, NoopExtractor, RequestContext}; +use crate::select_with_weak::SelectWithWeakExt; +use futures::channel::oneshot; +use futures::StreamExt; +use parity_tokio_ipc::Endpoint; use parking_lot::Mutex; +use tower_service::Service as _; + +use crate::server_utils::{codecs, reactor, reactor::TaskExecutor, session, tokio_util}; -use meta::{MetaExtractor, NoopExtractor, RequestContext}; -use select_with_weak::SelectWithWeakExt; -use parity_tokio_ipc::Endpoint; pub use parity_tokio_ipc::SecurityAttributes; /// IPC server session @@ -29,29 +26,36 @@ pub struct Service = middleware::Noop> { impl> Service { /// Create new IPC server session with given handler and metadata. pub fn new(handler: Arc>, meta: M) -> Self { - Service { handler: handler, meta: meta } + Service { handler, meta } } } -impl> tokio_service::Service for Service { - type Request = String; +impl> tower_service::Service for Service +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ type Response = Option; - type Error = (); - type Future = FutureResult; + type Future = Pin> + Send>>; - fn call(&self, req: Self::Request) -> Self::Future { + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: String) -> Self::Future { + use futures::FutureExt; trace!(target: "ipc", "Received request: {}", req); - self.handler.handle_request(&req, self.meta.clone()) + Box::pin(self.handler.handle_request(&req, self.meta.clone()).map(Ok)) } } /// IPC server builder pub struct ServerBuilder = middleware::Noop> { handler: Arc>, - meta_extractor: Arc>, - session_stats: Option>, + meta_extractor: Arc>, + session_stats: Option>, executor: reactor::UninitializedExecutor, incoming_separator: codecs::Separator, outgoing_separator: codecs::Separator, @@ -59,18 +63,28 @@ pub struct ServerBuilder = middleware::Noop> client_buffer_size: usize, } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new IPC server build given the `IoHandler`. - pub fn new(io_handler: T) -> ServerBuilder where + pub fn new(io_handler: T) -> ServerBuilder + where T: Into>, { Self::with_meta_extractor(io_handler, NoopExtractor) } } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new IPC server build given the `IoHandler` and metadata extractor. - pub fn with_meta_extractor(io_handler: T, extractor: E) -> ServerBuilder where + pub fn with_meta_extractor(io_handler: T, extractor: E) -> ServerBuilder + where T: Into>, E: MetaExtractor, { @@ -93,7 +107,8 @@ impl> ServerBuilder { } /// Sets session metadata extractor. - pub fn session_meta_extractor(mut self, meta_extractor: X) -> Self where + pub fn session_meta_extractor(mut self, meta_extractor: X) -> Self + where X: MetaExtractor, { self.meta_extractor = Arc::new(meta_extractor); @@ -135,12 +150,13 @@ impl> ServerBuilder { let incoming_separator = self.incoming_separator; let outgoing_separator = self.outgoing_separator; let (stop_signal, stop_receiver) = oneshot::channel(); - let (start_signal, start_receiver) = oneshot::channel(); - let (wait_signal, wait_receiver) = oneshot::channel(); + // NOTE: These channels are only waited upon in synchronous fashion + let (start_signal, start_receiver) = std::sync::mpsc::channel(); + let (wait_signal, wait_receiver) = std::sync::mpsc::channel(); let security_attributes = self.security_attributes; let client_buffer_size = self.client_buffer_size; - executor.spawn(future::lazy(move || { + let fut = async move { let mut endpoint = Endpoint::new(endpoint_addr); endpoint.set_security_attributes(security_attributes); @@ -151,82 +167,81 @@ impl> ServerBuilder { } } - let endpoint_handle = Handle::current(); - let connections = match endpoint.incoming(&endpoint_handle) { + let endpoint_addr = endpoint.path().to_owned(); + let connections = match endpoint.incoming() { Ok(connections) => connections, Err(e) => { - start_signal.send(Err(e)).expect("Cannot fail since receiver never dropped before receiving"); - return future::Either::A(future::ok(())); + start_signal + .send(Err(e)) + .expect("Cannot fail since receiver never dropped before receiving"); + return; } }; let mut id = 0u64; - let server = connections.for_each(move |(io_stream, remote_id)| { + use futures::TryStreamExt; + let server = connections.map_ok(move |io_stream| { id = id.wrapping_add(1); let session_id = id; let session_stats = session_stats.clone(); trace!(target: "ipc", "Accepted incoming IPC connection: {}", session_id); - session_stats.as_ref().map(|stats| stats.open_session(session_id)); + if let Some(stats) = session_stats.as_ref() { + stats.open_session(session_id) + } - let (sender, receiver) = mpsc::channel(16); + let (sender, receiver) = mpsc::unbounded(); let meta = meta_extractor.extract(&RequestContext { - endpoint_addr: &remote_id, + endpoint_addr: endpoint_addr.as_ref(), session_id, sender, }); - let service = Service::new(rpc_handler.clone(), meta); - let (writer, reader) = Framed::new( - io_stream, - codecs::StreamCodec::new( - incoming_separator.clone(), - outgoing_separator.clone(), - ), - ).split(); + let mut service = Service::new(rpc_handler.clone(), meta); + let codec = codecs::StreamCodec::new(incoming_separator.clone(), outgoing_separator.clone()); + let framed = tokio_util::codec::Decoder::framed(codec, io_stream); + let (writer, reader) = futures::StreamExt::split(framed); + let responses = reader - .map(move |req| { - service.call(req) - .then(|result| { - match result { - Err(_) => { - future::ok(None) - } - Ok(some_result) => future::ok(some_result), - } - }) - .map_err(|_:()| std::io::ErrorKind::Other.into()) + .map_ok(move |req| { + service + .call(req) + // Ignore service errors + .map(|x| Ok(x.ok().flatten())) }) - .buffer_unordered(client_buffer_size) - .filter_map(|x| x) + .try_buffer_unordered(client_buffer_size) + // Filter out previously ignored service errors as `None`s + .try_filter_map(futures::future::ok) // we use `select_with_weak` here, instead of `select`, to close the stream // as soon as the ipc pipe is closed - .select_with_weak(receiver.map_err(|e| { - warn!(target: "ipc", "Notification error: {:?}", e); - std::io::ErrorKind::Other.into() - })); + .select_with_weak(receiver.map(Ok)); - let writer = writer.send_all(responses).then(move |_| { + responses.forward(writer).then(move |_| { trace!(target: "ipc", "Peer: service finished"); - session_stats.as_ref().map(|stats| stats.close_session(session_id)); - Ok(()) - }); - - tokio::spawn(writer); + if let Some(stats) = session_stats.as_ref() { + stats.close_session(session_id) + } - Ok(()) + async { Ok(()) } + }) }); - start_signal.send(Ok(())).expect("Cannot fail since receiver never dropped before receiving"); - - let stop = stop_receiver.map_err(|_| std::io::ErrorKind::Interrupted.into()); - future::Either::B( - server.select(stop) - .map(|_| { - let _ = wait_signal.send(()); - () - }) - .map_err(|_| ()) - ) - })); + start_signal + .send(Ok(())) + .expect("Cannot fail since receiver never dropped before receiving"); + let stop = stop_receiver.map_err(|_| std::io::ErrorKind::Interrupted); + let stop = Box::pin(stop); + + let server = server.try_buffer_unordered(1024).for_each(|_| async {}); + + let result = futures::future::select(Box::pin(server), stop).await; + // We drop the server first to prevent a situation where main thread terminates + // before the server is properly dropped (see #504 for more details) + drop(result); + let _ = wait_signal.send(()); + }; + + use futures::FutureExt; + let fut = Box::pin(fut.map(drop)); + executor.executor().spawn(fut); let handle = InnerHandles { executor: Some(executor), @@ -234,22 +249,22 @@ impl> ServerBuilder { path: path.to_owned(), }; - match start_receiver.wait().expect("Message should always be sent") { + use futures::TryFutureExt; + match start_receiver.recv().expect("Message should always be sent") { Ok(()) => Ok(Server { handles: Arc::new(Mutex::new(handle)), wait_handle: Some(wait_receiver), }), - Err(e) => Err(e) + Err(e) => Err(e), } } } - /// IPC Server handle #[derive(Debug)] pub struct Server { handles: Arc>, - wait_handle: Option>, + wait_handle: Option>, } impl Server { @@ -267,12 +282,12 @@ impl Server { /// Wait for the server to finish pub fn wait(mut self) { - self.wait_handle.take().map(|wait_receiver| wait_receiver.wait()); + if let Some(wait_receiver) = self.wait_handle.take() { + let _ = wait_receiver.recv(); + } } - } - #[derive(Debug)] struct InnerHandles { executor: Option, @@ -283,7 +298,9 @@ struct InnerHandles { impl InnerHandles { pub fn close(&mut self) { let _ = self.stop.take().map(|stop| stop.send(())); - self.executor.take().map(|executor| executor.close()); + if let Some(executor) = self.executor.take() { + executor.close() + } let _ = ::std::fs::remove_file(&self.path); // ignore error, file could have been gone somewhere } } @@ -309,31 +326,16 @@ impl CloseHandle { #[cfg(test)] #[cfg(not(windows))] mod tests { - extern crate tokio_uds; + use super::*; + use jsonrpc_core::Value; + use std::os::unix::net::UnixStream; use std::thread; - use std::sync::Arc; - use std::time; - use std::time::{Instant, Duration}; - use super::{ServerBuilder, Server}; - use jsonrpc::{MetaIoHandler, Value}; - use jsonrpc::futures::{Future, future, Stream, Sink}; - use jsonrpc::futures::sync::{mpsc, oneshot}; - use self::tokio_uds::UnixStream; - use parking_lot::Mutex; - use server_utils::{ - tokio_codec::Decoder, - tokio::{self, timer::Delay} - }; - use server_utils::codecs; - use meta::{MetaExtractor, RequestContext, NoopExtractor}; - use super::SecurityAttributes; + use std::time::{self, Duration}; fn server_builder() -> ServerBuilder { let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_string()))); ServerBuilder::new(io) } @@ -344,126 +346,137 @@ mod tests { } fn dummy_request_str(path: &str, data: &str) -> String { - let stream_future = UnixStream::connect(path); - let reply = stream_future.and_then(|stream| { - let stream = codecs::StreamCodec::stream_incoming() - .framed(stream); - let reply = stream - .send(data.to_owned()) - .and_then(move |stream| { - stream.into_future().map_err(|(err, _)| err) - }) - .and_then(|(reply, _)| { - future::ok(reply.expect("there should be one reply")) - }); - reply - }); + use futures::SinkExt; + + let reply = async move { + use tokio::net::UnixStream; - reply.wait().expect("wait for reply") + let stream: UnixStream = UnixStream::connect(path).await?; + let codec = codecs::StreamCodec::stream_incoming(); + let mut stream = tokio_util::codec::Decoder::framed(codec, stream); + stream.send(data.to_owned()).await?; + let (reply, _) = stream.into_future().await; + + reply.expect("there should be one reply") + }; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(reply).expect("wait for reply") } #[test] fn start() { - ::logger::init_log(); + crate::logger::init_log(); let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_string()))); let server = ServerBuilder::new(io); - let _server = server.start("/tmp/test-ipc-20000") + let _server = server + .start("/tmp/test-ipc-20000") .expect("Server must run with no issues"); } #[test] fn connect() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-30000"; let _server = run(path); - UnixStream::connect(path).wait().expect("Socket should connect"); + UnixStream::connect(path).expect("Socket should connect"); } #[test] fn request() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-40000"; let server = run(path); - let (stop_signal, stop_receiver) = oneshot::channel(); + let (stop_signal, stop_receiver) = std::sync::mpsc::channel(); let t = thread::spawn(move || { let result = dummy_request_str( path, "{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}", - ); + ); stop_signal.send(result).unwrap(); }); t.join().unwrap(); - let _ = stop_receiver.map(move |result: String| { - assert_eq!( - result, - "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}", - "Response does not exactly match the expected response", - ); - server.close(); - }).wait(); + let result = stop_receiver.recv().unwrap(); + + assert_eq!( + result, "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}", + "Response does not exactly match the expected response", + ); + server.close(); } #[test] fn req_parallel() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-45000"; let server = run(path); - let (stop_signal, stop_receiver) = mpsc::channel(400); + let (stop_signal, stop_receiver) = futures::channel::mpsc::channel(400); let mut handles = Vec::new(); for _ in 0..4 { let path = path.clone(); let mut stop_signal = stop_signal.clone(); - handles.push( - thread::spawn(move || { - for _ in 0..100 { - let result = dummy_request_str( - &path, - "{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}", - ); - stop_signal.try_send(result).unwrap(); - } - }) - ); + handles.push(thread::spawn(move || { + for _ in 0..100 { + let result = dummy_request_str( + &path, + "{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}", + ); + stop_signal.try_send(result).unwrap(); + } + })); } for handle in handles.drain(..) { handle.join().unwrap(); } - let _ = stop_receiver.map(|result| { - assert_eq!( - result, - "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}", - "Response does not exactly match the expected response", - ); - }).take(400).collect().wait(); + thread::spawn(move || { + let fut = stop_receiver + .map(|result| { + assert_eq!( + result, "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}", + "Response does not exactly match the expected response", + ); + }) + .take(400) + .for_each(|_| async {}); + futures::executor::block_on(fut); + }) + .join() + .unwrap(); server.close(); } #[test] fn close() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-50000"; let server = run(path); server.close(); - assert!(::std::fs::metadata(path).is_err(), "There should be no socket file left"); - assert!(UnixStream::connect(path).wait().is_err(), "Connection to the closed socket should fail"); + assert!( + ::std::fs::metadata(path).is_err(), + "There should be no socket file left" + ); + assert!( + UnixStream::connect(path).is_err(), + "Connection to the closed socket should fail" + ); } fn huge_response_test_str() -> String { let mut result = String::from("begin_hello"); result.push_str("begin_hello"); - for _ in 0..16384 { result.push(' '); } + for _ in 0..16384 { + result.push(' '); + } result.push_str("end_hello"); result } @@ -478,13 +491,11 @@ mod tests { #[test] fn test_huge_response() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-60000"; let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_huge_hello", |_params| { - Ok(Value::String(huge_response_test_str())) - }); + io.add_sync_method("say_huge_hello", |_params| Ok(Value::String(huge_response_test_str()))); let builder = ServerBuilder::new(io); let server = builder.start(path).expect("Server must run with no issues"); @@ -500,14 +511,19 @@ mod tests { }); t.join().unwrap(); - let _ = stop_receiver.map(move |result: String| { - assert_eq!( - result, - huge_response_test_json(), - "Response does not exactly match the expected response", - ); - server.close(); - }).wait(); + thread::spawn(move || { + futures::executor::block_on(async move { + let result = stop_receiver.await.unwrap(); + assert_eq!( + result, + huge_response_test_json(), + "Response does not exactly match the expected response", + ); + server.close(); + }); + }) + .join() + .unwrap(); } #[test] @@ -524,7 +540,7 @@ mod tests { } struct SessionEndExtractor { - drop_receivers: Arc>>>, + drop_receivers: Arc>>>, } impl MetaExtractor> for SessionEndExtractor { @@ -538,40 +554,47 @@ mod tests { } } - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-30009"; - let (signal, receiver) = mpsc::channel(16); + let (signal, receiver) = futures::channel::mpsc::channel(16); let session_metadata_extractor = SessionEndExtractor { - drop_receivers: Arc::new(Mutex::new(signal)) + drop_receivers: Arc::new(Mutex::new(signal)), }; let io = MetaIoHandler::>::default(); let builder = ServerBuilder::with_meta_extractor(io, session_metadata_extractor); let server = builder.start(path).expect("Server must run with no issues"); { - let _ = UnixStream::connect(path).wait().expect("Socket should connect"); + let _ = UnixStream::connect(path).expect("Socket should connect"); } - receiver.into_future() - .map_err(|_| ()) - .and_then(|drop_receiver| drop_receiver.0.unwrap().map_err(|_| ())) - .wait().unwrap(); + thread::spawn(move || { + futures::executor::block_on(async move { + let (drop_receiver, ..) = receiver.into_future().await; + drop_receiver.unwrap().await.unwrap(); + }); + }) + .join() + .unwrap(); server.close(); } #[test] fn close_handle() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-90000"; let server = run(path); let handle = server.close_handle(); handle.close(); - assert!(UnixStream::connect(path).wait().is_err(), "Connection to the closed socket should fail"); + assert!( + UnixStream::connect(path).is_err(), + "Connection to the closed socket should fail" + ); } #[test] fn close_when_waiting() { - ::logger::init_log(); + crate::logger::init_log(); let path = "/tmp/test-ipc-70000"; let server = run(path); let close_handle = server.close_handle(); @@ -586,26 +609,25 @@ mod tests { tx.send(true).expect("failed to report that the server has stopped"); }); - let delay = Delay::new(Instant::now() + Duration::from_millis(500)) - .map(|_| false) - .map_err(|err| panic!("{:?}", err)); - - let result_fut = rx - .map_err(|_| ()) - .select(delay) - .then(move |result| { - match result { - Ok((result, _)) => { - assert_eq!(result, true, "Wait timeout exceeded"); - assert!(UnixStream::connect(path).wait().is_err(), - "Connection to the closed socket should fail"); - Ok(()) - }, - Err(_) => Err(()), + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + let timeout = tokio::time::sleep(Duration::from_millis(500)); + futures::pin_mut!(timeout); + + match futures::future::select(rx, timeout).await { + futures::future::Either::Left((result, _)) => { + assert!(result.is_ok(), "Rx failed"); + assert_eq!(result, Ok(true), "Wait timeout exceeded"); + assert!( + UnixStream::connect(path).is_err(), + "Connection to the closed socket should fail" + ); + Ok(()) } - }); - - tokio::run(result_fut); + futures::future::Either::Right(_) => Err("timed out"), + } + }) + .unwrap(); } #[test] diff --git a/macros/Cargo.toml b/macros/Cargo.toml deleted file mode 100644 index f45e6b543..000000000 --- a/macros/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -description = "Helper macros for jsonrpc-core" -homepage = "https://github.com/paritytech/jsonrpc" -repository = "https://github.com/paritytech/jsonrpc" -license = "MIT" -name = "jsonrpc-macros" -version = "9.0.0" -authors = ["rphmeier "] -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "macros"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_macros/index.html" - -[dependencies] -serde = "1.0" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-pubsub = { version = "9.0", path = "../pubsub" } - -[dev-dependencies] -serde_json = "1.0" -jsonrpc-tcp-server = { version = "9.0", path = "../tcp" } - -[badges] -travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/macros/examples/generic-trait-bounds.rs b/macros/examples/generic-trait-bounds.rs deleted file mode 100644 index 101a6f92a..000000000 --- a/macros/examples/generic-trait-bounds.rs +++ /dev/null @@ -1,69 +0,0 @@ -extern crate serde; -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; - -use serde::de::DeserializeOwned; -use jsonrpc_core::{IoHandler, Error, Result}; -use jsonrpc_core::futures::future::{self, FutureResult}; - -// Two only requires DeserializeOwned -build_rpc_trait! { - pub trait Rpc where - Two: DeserializeOwned, - { - /// Get One type. - #[rpc(name = "getOne")] - fn one(&self) -> Result; - - /// Adds two numbers and returns a result - #[rpc(name = "setTwo")] - fn set_two(&self, Two) -> Result<()>; - - /// Performs asynchronous operation - #[rpc(name = "beFancy")] - fn call(&self, One) -> FutureResult<(One, u64), Error>; - } -} - -build_rpc_trait! { - pub trait Rpc2<> where - Two: DeserializeOwned, - { - /// Adds two numbers and returns a result - #[rpc(name = "setTwo")] - fn set_two(&self, Two) -> Result<()>; - } -} - -struct RpcImpl; - -impl Rpc for RpcImpl { - fn one(&self) -> Result { - Ok(100) - } - - fn set_two(&self, x: String) -> Result<()> { - println!("{}", x); - Ok(()) - } - - fn call(&self, num: u64) -> FutureResult<(u64, u64), Error> { - ::future::finished((num + 999, num)) - } -} - -impl Rpc2 for RpcImpl { - fn set_two(&self, _: String) -> Result<()> { - unimplemented!() - } -} - - -fn main() { - let mut io = IoHandler::new(); - - io.extend_with(Rpc::to_delegate(RpcImpl)); - io.extend_with(Rpc2::to_delegate(RpcImpl)); -} - diff --git a/macros/examples/generic-trait.rs b/macros/examples/generic-trait.rs deleted file mode 100644 index a8f6362cd..000000000 --- a/macros/examples/generic-trait.rs +++ /dev/null @@ -1,47 +0,0 @@ -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; - -use jsonrpc_core::{IoHandler, Error, Result}; -use jsonrpc_core::futures::future::{self, FutureResult}; - -build_rpc_trait! { - pub trait Rpc { - /// Get One type. - #[rpc(name = "getOne")] - fn one(&self) -> Result; - - /// Adds two numbers and returns a result - #[rpc(name = "setTwo")] - fn set_two(&self, Two) -> Result<()>; - - /// Performs asynchronous operation - #[rpc(name = "beFancy")] - fn call(&self, One) -> FutureResult<(One, Two), Error>; - } -} - -struct RpcImpl; - -impl Rpc for RpcImpl { - fn one(&self) -> Result { - Ok(100) - } - - fn set_two(&self, x: String) -> Result<()> { - println!("{}", x); - Ok(()) - } - - fn call(&self, num: u64) -> FutureResult<(u64, String), Error> { - ::future::finished((num + 999, "hello".into())) - } -} - - -fn main() { - let mut io = IoHandler::new(); - let rpc = RpcImpl; - - io.extend_with(rpc.to_delegate()) -} diff --git a/macros/examples/meta-macros.rs b/macros/examples/meta-macros.rs deleted file mode 100644 index 446018104..000000000 --- a/macros/examples/meta-macros.rs +++ /dev/null @@ -1,73 +0,0 @@ -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; -extern crate jsonrpc_tcp_server; - -use std::collections::BTreeMap; - -use jsonrpc_core::{futures, MetaIoHandler, Metadata, Error, Value, Result}; -use jsonrpc_core::futures::future::FutureResult; - -#[derive(Clone)] -struct Meta(String); -impl Metadata for Meta {} - -build_rpc_trait! { - pub trait Rpc { - type Metadata; - - /// Adds two numbers and returns a result - #[rpc(name = "add")] - fn add(&self, u64, u64) -> Result; - - /// Multiplies two numbers. Second number is optional. - #[rpc(name = "mul")] - fn mul(&self, u64, jsonrpc_macros::Trailing) -> Result; - - /// Performs asynchronous operation - #[rpc(name = "callAsync")] - fn call(&self, u64) -> FutureResult; - - /// Performs asynchronous operation with meta - #[rpc(meta, name = "callAsyncMeta", alias = [ "callAsyncMetaAlias", ])] - fn call_meta(&self, Self::Metadata, BTreeMap) -> FutureResult; - } -} - -struct RpcImpl; -impl Rpc for RpcImpl { - type Metadata = Meta; - - fn add(&self, a: u64, b: u64) -> Result { - Ok(a + b) - } - - fn mul(&self, a: u64, b: jsonrpc_macros::Trailing) -> Result { - Ok(a * b.unwrap_or(1)) - } - - fn call(&self, x: u64) -> FutureResult { - futures::finished(format!("OK: {}", x)) - } - - fn call_meta(&self, meta: Self::Metadata, map: BTreeMap) -> FutureResult { - futures::finished(format!("From: {}, got: {:?}", meta.0, map)) - } -} - - -fn main() { - let mut io = MetaIoHandler::default(); - let rpc = RpcImpl; - - io.extend_with(rpc.to_delegate()); - - let server = jsonrpc_tcp_server::ServerBuilder - ::with_meta_extractor(io, |context: &jsonrpc_tcp_server::RequestContext| { - Meta(format!("{}", context.peer_addr)) - }) - .start(&"0.0.0.0:3030".parse().unwrap()) - .expect("Server must start with no issues"); - - server.wait() -} diff --git a/macros/examples/pubsub-macros.rs b/macros/examples/pubsub-macros.rs deleted file mode 100644 index 12d4a834d..000000000 --- a/macros/examples/pubsub-macros.rs +++ /dev/null @@ -1,105 +0,0 @@ -extern crate jsonrpc_core; -extern crate jsonrpc_pubsub; -#[macro_use] -extern crate jsonrpc_macros; -extern crate jsonrpc_tcp_server; - -use std::thread; -use std::sync::{atomic, Arc, RwLock}; -use std::collections::HashMap; - -use jsonrpc_core::{Error, ErrorCode, Result}; -use jsonrpc_core::futures::Future; -use jsonrpc_pubsub::{Session, PubSubHandler, SubscriptionId}; - -use jsonrpc_macros::pubsub; - -build_rpc_trait! { - pub trait Rpc { - type Metadata; - - /// Adds two numbers and returns a result - #[rpc(name = "add")] - fn add(&self, u64, u64) -> Result; - - #[pubsub(name = "hello")] { - /// Hello subscription - #[rpc(name = "hello_subscribe", alias = ["hello_sub", ])] - fn subscribe(&self, Self::Metadata, pubsub::Subscriber, u64); - - /// Unsubscribe from hello subscription. - #[rpc(name = "hello_unsubscribe")] - fn unsubscribe(&self, SubscriptionId) -> Result; - } - } -} - -#[derive(Default)] -struct RpcImpl { - uid: atomic::AtomicUsize, - active: Arc>>>, -} -impl Rpc for RpcImpl { - type Metadata = Arc; - - fn add(&self, a: u64, b: u64) -> Result { - Ok(a + b) - } - - fn subscribe(&self, _meta: Self::Metadata, subscriber: pubsub::Subscriber, param: u64) { - if param != 10 { - subscriber.reject(Error { - code: ErrorCode::InvalidParams, - message: "Rejecting subscription - invalid parameters provided.".into(), - data: None, - }).unwrap(); - return; - } - - let id = self.uid.fetch_add(1, atomic::Ordering::SeqCst); - let sub_id = SubscriptionId::Number(id as u64); - let sink = subscriber.assign_id(sub_id.clone()).unwrap(); - self.active.write().unwrap().insert(sub_id, sink); - } - - fn unsubscribe(&self, id: SubscriptionId) -> Result { - let removed = self.active.write().unwrap().remove(&id); - if removed.is_some() { - Ok(true) - } else { - Err(Error { - code: ErrorCode::InvalidParams, - message: "Invalid subscription.".into(), - data: None, - }) - } - } -} - - -fn main() { - let mut io = PubSubHandler::default(); - let rpc = RpcImpl::default(); - let active_subscriptions = rpc.active.clone(); - - thread::spawn(move || { - loop { - { - let subscribers = active_subscriptions.read().unwrap(); - for sink in subscribers.values() { - let _ = sink.notify(Ok("Hello World!".into())).wait(); - } - } - thread::sleep(::std::time::Duration::from_secs(1)); - } - }); - - io.extend_with(rpc.to_delegate()); - - let server = jsonrpc_tcp_server::ServerBuilder - ::with_meta_extractor(io, |context: &jsonrpc_tcp_server::RequestContext| Arc::new(Session::new(context.sender.clone()))) - .start(&"0.0.0.0:3030".parse().unwrap()) - .expect("Server must start with no issues"); - - server.wait() -} diff --git a/macros/examples/std.rs b/macros/examples/std.rs deleted file mode 100644 index dbf5d97c3..000000000 --- a/macros/examples/std.rs +++ /dev/null @@ -1,46 +0,0 @@ -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; - -use jsonrpc_core::{IoHandler, Error, Result}; -use jsonrpc_core::futures::future::{self, FutureResult}; - -build_rpc_trait! { - pub trait Rpc { - /// Returns a protocol version - #[rpc(name = "protocolVersion")] - fn protocol_version(&self) -> Result; - - /// Adds two numbers and returns a result - #[rpc(name = "add")] - fn add(&self, u64, u64) -> Result; - - /// Performs asynchronous operation - #[rpc(name = "callAsync")] - fn call(&self, u64) -> FutureResult; - } -} - -struct RpcImpl; - -impl Rpc for RpcImpl { - fn protocol_version(&self) -> Result { - Ok("version1".into()) - } - - fn add(&self, a: u64, b: u64) -> Result { - Ok(a + b) - } - - fn call(&self, _: u64) -> FutureResult { - future::ok("OK".to_owned()) - } -} - - -fn main() { - let mut io = IoHandler::new(); - let rpc = RpcImpl; - - io.extend_with(rpc.to_delegate()) -} diff --git a/macros/src/auto_args.rs b/macros/src/auto_args.rs deleted file mode 100644 index f9af4df49..000000000 --- a/macros/src/auto_args.rs +++ /dev/null @@ -1,703 +0,0 @@ -// because we reuse the type names as idents in the macros as a dirty hack to -// work around `concat_idents!` being unstable. -#![allow(non_snake_case)] - -///! Automatically serialize and deserialize parameters around a strongly-typed function. - -use jsonrpc_core::{Error, Params, Value, Metadata, Result}; -use jsonrpc_core::futures::{self, Future, IntoFuture}; -use jsonrpc_core::futures::future::{self, Either}; -use jsonrpc_pubsub::{PubSubMetadata, Subscriber}; -use pubsub; -use serde::Serialize; -use serde::de::DeserializeOwned; -use util::{invalid_params, expect_no_params, to_value}; - -/// Auto-generates an RPC trait from trait definition. -/// -/// This just copies out all the methods, docs, and adds another -/// function `to_delegate` which will automatically wrap each strongly-typed -/// function in a wrapper which handles parameter and output type serialization. -/// -/// RPC functions may come in a couple forms: synchronous, async and async with metadata. -/// These are parsed with the custom `#[rpc]` attribute, which must follow -/// documentation. -/// -/// ## The #[rpc] attribute -/// -/// Valid forms: -/// - `#[rpc(name = "name_here")]` (an async rpc function which should be bound to the given name) -/// - `#[rpc(meta, name = "name_here")]` (an async rpc function with metadata which should be bound to the given name) -/// -/// Synchronous function format: -/// `fn foo(&self, Param1, Param2, Param3) -> Result`. -/// -/// Asynchronous RPC functions must come in this form: -/// `fn foo(&self, Param1, Param2, Param3) -> BoxFuture; -/// -/// Asynchronous RPC functions with metadata must come in this form: -/// `fn foo(&self, Self::Metadata, Param1, Param2, Param3) -> BoxFuture; -/// -/// Anything else will be rejected by the code generator. -/// -/// ## The #[pubsub] attribute -/// -/// Valid form: -/// ```rust,ignore -/// #[pubsub(name = "hello")] { -/// #[rpc(name = "hello_subscribe")] -/// fn subscribe(&self, Self::Metadata, pubsub::Subscriber, u64); -/// #[rpc(name = "hello_unsubscribe")] -/// fn unsubscribe(&self, SubscriptionId) -> Result; -/// } -/// ``` -/// -/// The attribute is used to create a new pair of subscription methods -/// (if underlying transport supports that.) - - -#[macro_export] -macro_rules! metadata { - () => { - /// Requests metadata - type Metadata: $crate::jsonrpc_core::Metadata; - }; - ( - $( $sub_name: ident )+ - ) => { - /// Requests metadata - type Metadata: $crate::jsonrpc_pubsub::PubSubMetadata; - }; -} - -#[macro_export] -macro_rules! build_rpc_trait { - ( - $(#[$t_attr: meta])* - pub trait $name:ident $(<$( $generics:ident ),*> - $( - where - $( $generics2:ident : $bounds:tt $( + $morebounds:tt )* ,)+ - )* )* - { - $( $rest: tt )+ - } - ) => { - build_rpc_trait! { - @WITH_BOUNDS - $(#[$t_attr])* - pub trait $name $(< - // first generic parameters with both bounds - $( $generics ,)* - @BOUNDS - // then specialised ones - $( $( $generics2 : $bounds $( + $morebounds )* )* )* - > )* { - $( $rest )+ - } - } - }; - - // entry-point. todo: make another for traits w/ bounds. - ( - @WITH_BOUNDS - $(#[$t_attr: meta])* - pub trait $name:ident $(< - $( $simple_generics:ident ,)* - @BOUNDS - $( $generics:ident $(: $bounds:tt $( + $morebounds:tt )* )* ),* - >)* { - $( - $( #[doc=$m_doc:expr] )* - #[ rpc( $($t:tt)* ) ] - fn $m_name: ident ( $( $p: tt )* ) -> $result: tt <$out: ty $(, $error: ty)* >; - )* - } - ) => { - $(#[$t_attr])* - pub trait $name $(<$( $simple_generics ,)* $( $generics , )*>)* : Sized + Send + Sync + 'static { - build_rpc_trait!( - GENERATE_FUNCTIONS - $( - $(#[doc=$m_doc])* - fn $m_name ( $( $p )* ) -> $result <$out $(, $error) *>; - )* - ); - - /// Transform this into an `IoDelegate`, automatically wrapping - /// the parameters. - fn to_delegate(self) -> $crate::IoDelegate - where $( - $($simple_generics: Send + Sync + 'static + $crate::Serialize + $crate::DeserializeOwned ,)* - $($generics: Send + Sync + 'static $( + $bounds $( + $morebounds )* )* ),* - )* - { - let mut del = $crate::IoDelegate::new(self.into()); - $( - build_rpc_trait!(WRAP del => - ( $($t)* ) - fn $m_name ( $( $p )* ) -> $result <$out $(, $error)* > - ); - )* - del - } - } - }; - - // entry-point for trait with metadata methods - ( - @WITH_BOUNDS - $(#[$t_attr: meta])* - pub trait $name: ident $(< - $( $simple_generics:ident ,)* - @BOUNDS - $($generics:ident $( : $bounds:tt $( + $morebounds:tt )* )* ),* - >)* { - type Metadata; - - $( - $( #[ doc=$m_doc:expr ] )* - #[ rpc( $($t:tt)* ) ] - fn $m_name: ident ( $( $p: tt )* ) -> $result: tt <$out: ty $(, $error_std: ty) *>; - )* - - $( - #[ pubsub( $($pubsub_t:tt)+ ) ] { - $( #[ doc= $sub_doc:expr ] )* - #[ rpc( $($sub_t:tt)* ) ] - fn $sub_name: ident ( $($sub_p: tt)* ); - $( #[ doc= $unsub_doc:expr ] )* - #[ rpc( $($unsub_t:tt)* ) ] - fn $unsub_name: ident ( $($unsub_p: tt)* ) -> $sub_result: tt <$sub_out: ty $(, $error_unsub: ty)* >; - } - )* - - } - ) => { - $(#[$t_attr])* - pub trait $name $(<$( $simple_generics ,)* $( $generics , )* >)* : Sized + Send + Sync + 'static { - // Metadata bound differs for traits with subscription methods. - metadata! ( - $( $sub_name )* - ); - - build_rpc_trait!(GENERATE_FUNCTIONS - $( - $(#[doc=$m_doc])* - fn $m_name ( $( $p )* ) -> $result <$out $(, $error_std) *>; - )* - ); - - $( - $(#[doc=$sub_doc])* - fn $sub_name ( $($sub_p)* ); - $(#[doc=$unsub_doc])* - fn $unsub_name ( $($unsub_p)* ) -> $sub_result <$sub_out $(, $error_unsub)* >; - )* - - /// Transform this into an `IoDelegate`, automatically wrapping - /// the parameters. - fn to_delegate(self) -> $crate::IoDelegate - where $( - $($simple_generics: Send + Sync + 'static + $crate::Serialize + $crate::DeserializeOwned ),* - $($generics: Send + Sync + 'static $( + $bounds $( + $morebounds )* )* ),* - )* - { - let mut del = $crate::IoDelegate::new(self.into()); - $( - build_rpc_trait!(WRAP del => - ( $($t)* ) - fn $m_name ( $( $p )* ) -> $result <$out $(, $error_std)* > - ); - )* - $( - build_rpc_trait!(WRAP del => - pubsub: ( $($pubsub_t)* ) - subscribe: ( $($sub_t)* ) - fn $sub_name ( $($sub_p)* ); - unsubscribe: ( $($unsub_t)* ) - fn $unsub_name ( $($unsub_p)* ) -> $sub_result <$sub_out $(, $error_unsub)* >; - ); - )* - del - } - } - }; - - (GENERATE_FUNCTIONS - $( - $( #[doc=$m_doc:expr] )* - fn $m_name: ident (&self $(, $p: ty)* ) -> $result: ty; - )* - ) => { - $( - $(#[doc=$m_doc])* - fn $m_name (&self $(, _: $p )* ) -> $result; - )* - }; - - ( WRAP $del: expr => - (meta, name = $name: expr $(, alias = [ $( $alias: expr, )+ ])*) - fn $method: ident (&self, Self::Metadata $(, $param: ty)*) -> $result: tt <$out: ty $(, $error: ty)* > - ) => { - $del.add_method_with_meta($name, move |base, params, meta| { - $crate::WrapMeta::wrap_rpc(&(Self::$method as fn(&_, Self::Metadata $(, $param)*) -> $result <$out $(, $error)* >), base, params, meta) - }); - $( - $( - $del.add_alias($alias, $name); - )+ - )* - }; - - ( WRAP $del: expr => - pubsub: (name = $name: expr) - subscribe: (name = $subscribe: expr $(, alias = [ $( $sub_alias: expr, )+ ])*) - fn $sub_method: ident (&self, Self::Metadata $(, $sub_p: ty)+); - unsubscribe: (name = $unsubscribe: expr $(, alias = [ $( $unsub_alias: expr, )+ ])*) - fn $unsub_method: ident (&self $(, $unsub_p: ty)+) -> $result: tt <$out: ty $(, $error_unsub: ty)* >; - ) => { - $del.add_subscription( - $name, - ($subscribe, move |base, params, meta, subscriber| { - $crate::WrapSubscribe::wrap_rpc( - &(Self::$sub_method as fn(&_, Self::Metadata $(, $sub_p)*)), - base, - params, - meta, - subscriber, - ) - }), - ($unsubscribe, move |base, id| { - use $crate::jsonrpc_core::futures::{IntoFuture, Future}; - Self::$unsub_method(base, id).into_future() - .map($crate::to_value) - .map_err(Into::into) - }), - ); - - $( - $( - $del.add_alias($sub_alias, $subscribe); - )* - )* - $( - $( - $del.add_alias($unsub_alias, $unsubscribe); - )* - )* - }; - - ( WRAP $del: expr => - (name = $name: expr $(, alias = [ $( $alias: expr, )+ ])*) - fn $method: ident (&self $(, $param: ty)*) -> $result: tt <$out: ty $(, $error: ty)* > - ) => { - $del.add_method($name, move |base, params| { - $crate::WrapAsync::wrap_rpc(&(Self::$method as fn(&_ $(, $param)*) -> $result <$out $(, $error)*>), base, params) - }); - $( - $( - $del.add_alias($alias, $name); - )+ - )* - }; -} - -/// A wrapper type without an implementation of `Deserialize` -/// which allows a special implementation of `Wrap` for functions -/// that take a trailing default parameter. -pub struct Trailing(Option); - -impl Into> for Trailing { - fn into(self) -> Option { - self.0 - } -} - -impl From> for Trailing { - fn from(o: Option) -> Self { - Trailing(o) - } -} - -impl Trailing { - /// Returns a underlying value if present or provided value. - pub fn unwrap_or(self, other: T) -> T { - self.0.unwrap_or(other) - } - - /// Returns an underlying value or computes it if not present. - pub fn unwrap_or_else T>(self, f: F) -> T { - self.0.unwrap_or_else(f) - } -} - -impl Trailing { - /// Returns an underlying value or the default value. - pub fn unwrap_or_default(self) -> T { - self.0.unwrap_or_default() - } -} - -type WrappedFuture = future::MapErr< - future::Map Value>, - fn(E) -> Error ->; -type WrapResult = Either< - WrappedFuture, - future::FutureResult, ->; - -fn as_future(el: I) -> WrappedFuture where - OUT: Serialize, - E: Into, - F: Future, - I: IntoFuture -{ - el.into_future() - .map(to_value as fn(OUT) -> Value) - .map_err(Into::into as fn(E) -> Error) -} - -/// Wrapper trait for asynchronous RPC functions. -pub trait WrapAsync { - /// Output type. - type Out: IntoFuture; - - /// Invokes asynchronous RPC method. - fn wrap_rpc(&self, base: &B, params: Params) -> Self::Out; -} - -/// Wrapper trait for meta RPC functions. -pub trait WrapMeta { - /// Output type. - type Out: IntoFuture; - /// Invokes asynchronous RPC method with Metadata. - fn wrap_rpc(&self, base: &B, params: Params, meta: M) -> Self::Out; -} - -/// Wrapper trait for subscribe RPC functions. -pub trait WrapSubscribe { - /// Invokes subscription. - fn wrap_rpc(&self, base: &B, params: Params, meta: M, subscriber: Subscriber); -} - -// special impl for no parameters. -impl WrapAsync for fn(&B) -> I where - B: Send + Sync + 'static, - OUT: Serialize + 'static, - E: Into + 'static, - F: Future + Send + 'static, - I: IntoFuture, -{ - type Out = WrapResult; - - fn wrap_rpc(&self, base: &B, params: Params) -> Self::Out { - match expect_no_params(params) { - Ok(()) => Either::A(as_future((self)(base))), - Err(e) => Either::B(futures::failed(e)), - } - } -} - -impl WrapMeta for fn(&B, M) -> I where - M: Metadata, - B: Send + Sync + 'static, - OUT: Serialize + 'static, - E: Into + 'static, - F: Future + Send + 'static, - I: IntoFuture, -{ - type Out = WrapResult; - - fn wrap_rpc(&self, base: &B, params: Params, meta: M) -> Self::Out { - match expect_no_params(params) { - Ok(()) => Either::A(as_future((self)(base, meta))), - Err(e) => Either::B(futures::failed(e)), - } - } -} - -impl WrapSubscribe for fn(&B, M, pubsub::Subscriber) where - M: PubSubMetadata, - B: Send + Sync + 'static, - OUT: Serialize, -{ - fn wrap_rpc(&self, base: &B, params: Params, meta: M, subscriber: Subscriber) { - match expect_no_params(params) { - Ok(()) => (self)(base, meta, pubsub::Subscriber::new(subscriber)), - Err(e) => { - let _ = subscriber.reject(e); - }, - } - } -} - -// creates a wrapper implementation which deserializes the parameters, -// calls the function with concrete type, and serializes the output. -macro_rules! wrap { - ($($x: ident),+) => { - - // asynchronous implementation - impl < - BASE: Send + Sync + 'static, - OUT: Serialize + 'static, - $($x: DeserializeOwned,)+ - ERR: Into + 'static, - X: Future + Send + 'static, - Z: IntoFuture, - > WrapAsync for fn(&BASE, $($x,)+ ) -> Z { - type Out = WrapResult; - fn wrap_rpc(&self, base: &BASE, params: Params) -> Self::Out { - match params.parse::<($($x,)+)>() { - Ok(($($x,)+)) => Either::A(as_future((self)(base, $($x,)+))), - Err(e) => Either::B(futures::failed(e)), - } - } - } - - // asynchronous implementation with meta - impl < - BASE: Send + Sync + 'static, - META: Metadata, - OUT: Serialize + 'static, - $($x: DeserializeOwned,)+ - ERR: Into + 'static, - X: Future + Send + 'static, - Z: IntoFuture, - > WrapMeta for fn(&BASE, META, $($x,)+) -> Z { - type Out = WrapResult; - fn wrap_rpc(&self, base: &BASE, params: Params, meta: META) -> Self::Out { - match params.parse::<($($x,)+)>() { - Ok(($($x,)+)) => Either::A(as_future((self)(base, meta, $($x,)+))), - Err(e) => Either::B(futures::failed(e)), - } - } - } - - // subscribe implementation - impl < - BASE: Send + Sync + 'static, - META: PubSubMetadata, - OUT: Serialize, - $($x: DeserializeOwned,)+ - > WrapSubscribe for fn(&BASE, META, pubsub::Subscriber, $($x,)+) { - fn wrap_rpc(&self, base: &BASE, params: Params, meta: META, subscriber: Subscriber) { - match params.parse::<($($x,)+)>() { - Ok(($($x,)+)) => (self)(base, meta, pubsub::Subscriber::new(subscriber), $($x,)+), - Err(e) => { - let _ = subscriber.reject(e); - }, - } - } - } - } -} - -fn params_len(params: &Params) -> Result { - match *params { - Params::Array(ref v) => Ok(v.len()), - Params::None => Ok(0), - _ => Err(invalid_params("`params` should be an array", "")), - } -} - -fn require_len(params: &Params, required: usize) -> Result { - let len = params_len(params)?; - if len < required { - return Err(invalid_params(&format!("`params` should have at least {} argument(s)", required), "")); - } - Ok(len) -} - -fn parse_trailing_param(params: Params) -> Result<(Option, )> { - let len = try!(params_len(¶ms)); - let id = match len { - 0 => Ok((None,)), - 1 => params.parse::<(T,)>().map(|(x, )| (Some(x), )), - _ => Err(invalid_params("Expecting only one optional parameter.", "")), - }; - - id -} - -// special impl for no parameters other than block parameter. -impl WrapAsync for fn(&B, Trailing) -> I where - B: Send + Sync + 'static, - OUT: Serialize + 'static, - T: DeserializeOwned, - E: Into + 'static, - F: Future + Send + 'static, - I: IntoFuture, -{ - type Out = WrapResult; - fn wrap_rpc(&self, base: &B, params: Params) -> Self::Out { - let id = parse_trailing_param(params); - - match id { - Ok((id,)) => Either::A(as_future((self)(base, Trailing(id)))), - Err(e) => Either::B(futures::failed(e)), - } - } -} - -impl WrapMeta for fn(&B, M, Trailing) -> I where - M: Metadata, - B: Send + Sync + 'static, - OUT: Serialize + 'static, - T: DeserializeOwned, - E: Into + 'static, - F: Future + Send + 'static, - I: IntoFuture, -{ - type Out = WrapResult; - fn wrap_rpc(&self, base: &B, params: Params, meta: M) -> Self::Out { - let id = parse_trailing_param(params); - - match id { - Ok((id,)) => Either::A(as_future((self)(base, meta, Trailing(id)))), - Err(e) => Either::B(futures::failed(e)), - } - } -} - -impl WrapSubscribe for fn(&B, M, pubsub::Subscriber, Trailing) where - M: PubSubMetadata, - B: Send + Sync + 'static, - OUT: Serialize, - T: DeserializeOwned, -{ - fn wrap_rpc(&self, base: &B, params: Params, meta: M, subscriber: Subscriber) { - let id = parse_trailing_param(params); - - match id { - Ok((id,)) => (self)(base, meta, pubsub::Subscriber::new(subscriber), Trailing(id)), - Err(e) => { - let _ = subscriber.reject(e); - }, - } - } -} - -// similar to `wrap!`, but handles a single default trailing parameter -// accepts an additional argument indicating the number of non-trailing parameters. -macro_rules! wrap_with_trailing { - ($num: expr, $($x: ident),+) => { - // asynchronous implementation - impl < - BASE: Send + Sync + 'static, - OUT: Serialize + 'static, - $($x: DeserializeOwned,)+ - TRAILING: DeserializeOwned, - ERR: Into + 'static, - X: Future + Send + 'static, - Z: IntoFuture, - > WrapAsync for fn(&BASE, $($x,)+ Trailing) -> Z { - type Out = WrapResult; - fn wrap_rpc(&self, base: &BASE, params: Params) -> Self::Out { - let len = match require_len(¶ms, $num) { - Ok(len) => len, - Err(e) => return Either::B(futures::failed(e)), - }; - - let params = match len - $num { - 0 => params.parse::<($($x,)+)>() - .map(|($($x,)+)| ($($x,)+ None)).map_err(Into::into), - 1 => params.parse::<($($x,)+ TRAILING)>() - .map(|($($x,)+ id)| ($($x,)+ Some(id))).map_err(Into::into), - _ => Err(invalid_params(&format!("Expected {} or {} parameters.", $num, $num + 1), format!("Got: {}", len))), - }; - - match params { - Ok(($($x,)+ id)) => Either::A(as_future((self)(base, $($x,)+ Trailing(id)))), - Err(e) => Either::B(futures::failed(e)), - } - } - } - - // asynchronous implementation with meta - impl < - BASE: Send + Sync + 'static, - META: Metadata, - OUT: Serialize + 'static, - $($x: DeserializeOwned,)+ - TRAILING: DeserializeOwned, - ERR: Into + 'static, - X: Future + Send + 'static, - Z: IntoFuture, - > WrapMeta for fn(&BASE, META, $($x,)+ Trailing) -> Z { - type Out = WrapResult; - fn wrap_rpc(&self, base: &BASE, params: Params, meta: META) -> Self::Out { - let len = match require_len(¶ms, $num) { - Ok(len) => len, - Err(e) => return Either::B(futures::failed(e)), - }; - - let params = match len - $num { - 0 => params.parse::<($($x,)+)>() - .map(|($($x,)+)| ($($x,)+ None)).map_err(Into::into), - 1 => params.parse::<($($x,)+ TRAILING)>() - .map(|($($x,)+ id)| ($($x,)+ Some(id))).map_err(Into::into), - _ => Err(invalid_params(&format!("Expected {} or {} parameters.", $num, $num + 1), format!("Got: {}", len))), - }; - - match params { - Ok(($($x,)+ id)) => Either::A(as_future((self)(base, meta, $($x,)+ Trailing(id)))), - Err(e) => Either::B(futures::failed(e)), - } - } - } - - // subscribe implementation - impl < - BASE: Send + Sync + 'static, - META: PubSubMetadata, - OUT: Serialize, - $($x: DeserializeOwned,)+ - TRAILING: DeserializeOwned, - > WrapSubscribe for fn(&BASE, META, pubsub::Subscriber, $($x,)+ Trailing) { - fn wrap_rpc(&self, base: &BASE, params: Params, meta: META, subscriber: Subscriber) { - let len = match require_len(¶ms, $num) { - Ok(len) => len, - Err(e) => { - let _ = subscriber.reject(e); - return; - }, - }; - - let params = match len - $num { - 0 => params.parse::<($($x,)+)>() - .map(|($($x,)+)| ($($x,)+ None)), - 1 => params.parse::<($($x,)+ TRAILING)>() - .map(|($($x,)+ id)| ($($x,)+ Some(id))), - _ => { - let _ = subscriber.reject(invalid_params(&format!("Expected {} or {} parameters.", $num, $num + 1), format!("Got: {}", len))); - return; - }, - }; - - match params { - Ok(($($x,)+ id)) => (self)(base, meta, pubsub::Subscriber::new(subscriber), $($x,)+ Trailing(id)), - Err(e) => { - let _ = subscriber.reject(e); - return; - }, - } - } - } - } -} - -wrap!(A, B, C, D, E, F); -wrap!(A, B, C, D, E); -wrap!(A, B, C, D); -wrap!(A, B, C); -wrap!(A, B); -wrap!(A); - -wrap_with_trailing!(6, A, B, C, D, E, F); -wrap_with_trailing!(5, A, B, C, D, E); -wrap_with_trailing!(4, A, B, C, D); -wrap_with_trailing!(3, A, B, C); -wrap_with_trailing!(2, A, B); -wrap_with_trailing!(1, A); diff --git a/macros/src/delegates.rs b/macros/src/delegates.rs deleted file mode 100644 index 245536333..000000000 --- a/macros/src/delegates.rs +++ /dev/null @@ -1,213 +0,0 @@ -use std::sync::Arc; -use std::collections::HashMap; - -use jsonrpc_core::{Params, Value, Error}; -use jsonrpc_core::{BoxFuture, Metadata, RemoteProcedure, RpcMethod, RpcNotification}; -use jsonrpc_core::futures::IntoFuture; - -use jsonrpc_pubsub::{self, SubscriptionId, Subscriber, PubSubMetadata}; - -struct DelegateAsyncMethod { - delegate: Arc, - closure: F, -} - -impl RpcMethod for DelegateAsyncMethod where - M: Metadata, - F: Fn(&T, Params) -> I, - I: IntoFuture, - T: Send + Sync + 'static, - F: Send + Sync + 'static, - I::Future: Send + 'static, -{ - fn call(&self, params: Params, _meta: M) -> BoxFuture { - let closure = &self.closure; - Box::new(closure(&self.delegate, params).into_future()) - } -} - -struct DelegateMethodWithMeta { - delegate: Arc, - closure: F, -} - -impl RpcMethod for DelegateMethodWithMeta where - M: Metadata, - F: Fn(&T, Params, M) -> I, - I: IntoFuture, - T: Send + Sync + 'static, - F: Send + Sync + 'static, - I::Future: Send + 'static, -{ - fn call(&self, params: Params, meta: M) -> BoxFuture { - let closure = &self.closure; - Box::new(closure(&self.delegate, params, meta).into_future()) - } -} - -struct DelegateNotification { - delegate: Arc, - closure: F, -} - -impl RpcNotification for DelegateNotification where - F: Fn(&T, Params) + 'static, - F: Send + Sync + 'static, - T: Send + Sync + 'static, - M: Metadata, -{ - fn execute(&self, params: Params, _meta: M) { - let closure = &self.closure; - closure(&self.delegate, params) - } -} - -struct DelegateSubscribe { - delegate: Arc, - closure: F, -} - -impl jsonrpc_pubsub::SubscribeRpcMethod for DelegateSubscribe where - M: PubSubMetadata, - F: Fn(&T, Params, M, Subscriber), - T: Send + Sync + 'static, - F: Send + Sync + 'static, -{ - fn call(&self, params: Params, meta: M, subscriber: Subscriber) { - let closure = &self.closure; - closure(&self.delegate, params, meta, subscriber) - } -} - -struct DelegateUnsubscribe { - delegate: Arc, - closure: F, -} - -impl jsonrpc_pubsub::UnsubscribeRpcMethod for DelegateUnsubscribe where - F: Fn(&T, SubscriptionId) -> I, - I: IntoFuture, - T: Send + Sync + 'static, - F: Send + Sync + 'static, - I::Future: Send + 'static, -{ - type Out = I::Future; - fn call(&self, id: SubscriptionId) -> Self::Out { - let closure = &self.closure; - closure(&self.delegate, id).into_future() - } -} - -/// A set of RPC methods and notifications tied to single `delegate` struct. -pub struct IoDelegate where - T: Send + Sync + 'static, - M: Metadata, -{ - delegate: Arc, - methods: HashMap>, -} - -impl IoDelegate where - T: Send + Sync + 'static, - M: Metadata, -{ - /// Creates new `IoDelegate` - pub fn new(delegate: Arc) -> Self { - IoDelegate { - delegate: delegate, - methods: HashMap::new(), - } - } - - /// Adds an alias to existing method. - /// NOTE: Aliases are not transitive, i.e. you cannot create alias to an alias. - pub fn add_alias(&mut self, from: &str, to: &str) { - self.methods.insert(from.into(), RemoteProcedure::Alias(to.into())); - } - - /// Adds async method to the delegate. - pub fn add_method(&mut self, name: &str, method: F) where - F: Fn(&T, Params) -> I, - I: IntoFuture, - F: Send + Sync + 'static, - I::Future: Send + 'static, - { - self.methods.insert(name.into(), RemoteProcedure::Method(Arc::new( - DelegateAsyncMethod { - delegate: self.delegate.clone(), - closure: method, - } - ))); - } - - /// Adds async method with metadata to the delegate. - pub fn add_method_with_meta(&mut self, name: &str, method: F) where - F: Fn(&T, Params, M) -> I, - I: IntoFuture, - F: Send + Sync + 'static, - I::Future: Send + 'static, - { - self.methods.insert(name.into(), RemoteProcedure::Method(Arc::new( - DelegateMethodWithMeta { - delegate: self.delegate.clone(), - closure: method, - } - ))); - } - - /// Adds notification to the delegate. - pub fn add_notification(&mut self, name: &str, notification: F) where - F: Fn(&T, Params), - F: Send + Sync + 'static, - { - self.methods.insert(name.into(), RemoteProcedure::Notification(Arc::new( - DelegateNotification { - delegate: self.delegate.clone(), - closure: notification, - } - ))); - } -} - -impl IoDelegate where - T: Send + Sync + 'static, - M: PubSubMetadata, -{ - /// Adds subscription to the delegate. - pub fn add_subscription( - &mut self, - name: &str, - subscribe: (&str, Sub), - unsubscribe: (&str, Unsub), - ) where - Sub: Fn(&T, Params, M, Subscriber), - Sub: Send + Sync + 'static, - Unsub: Fn(&T, SubscriptionId) -> I, - I: IntoFuture, - Unsub: Send + Sync + 'static, - I::Future: Send + 'static, - { - let (sub, unsub) = jsonrpc_pubsub::new_subscription( - name, - DelegateSubscribe { - delegate: self.delegate.clone(), - closure: subscribe.1, - }, - DelegateUnsubscribe { - delegate: self.delegate.clone(), - closure: unsubscribe.1, - } - ); - self.add_method_with_meta(subscribe.0, move |_, params, meta| sub.call(params, meta)); - self.add_method_with_meta(unsubscribe.0, move |_, params, meta| unsub.call(params, meta)); - } -} - -impl Into>> for IoDelegate where - T: Send + Sync + 'static, - M: Metadata, -{ - fn into(self) -> HashMap> { - self.methods - } -} diff --git a/macros/src/lib.rs b/macros/src/lib.rs deleted file mode 100644 index 6d6989c03..000000000 --- a/macros/src/lib.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! High level, typed wrapper for `jsonrpc_core`. -//! -//! Enables creation of "Service" objects grouping a set of RPC methods together in a typed manner. -//! -//! Example -//! -//! ``` -//! extern crate jsonrpc_core; -//! #[macro_use] extern crate jsonrpc_macros; -//! use jsonrpc_core::{IoHandler, Error, Result}; -//! use jsonrpc_core::futures::future::{self, FutureResult}; -//! build_rpc_trait! { -//! pub trait Rpc { -//! /// Returns a protocol version -//! #[rpc(name = "protocolVersion")] -//! fn protocol_version(&self) -> Result; -//! -//! /// Adds two numbers and returns a result -//! #[rpc(name = "add")] -//! fn add(&self, u64, u64) -> Result; -//! -//! /// Performs asynchronous operation -//! #[rpc(name = "callAsync")] -//! fn call(&self, u64) -> FutureResult; -//! } -//! } -//! struct RpcImpl; -//! impl Rpc for RpcImpl { -//! fn protocol_version(&self) -> Result { -//! Ok("version1".into()) -//! } -//! -//! fn add(&self, a: u64, b: u64) -> Result { -//! Ok(a + b) -//! } -//! -//! fn call(&self, _: u64) -> FutureResult { -//! future::ok("OK".to_owned()).into() -//! } -//! } -//! -//! fn main() { -//! let mut io = IoHandler::new(); -//! let rpc = RpcImpl; -//! -//! io.extend_with(rpc.to_delegate()); -//! } -//! ``` - -#![warn(missing_docs)] - -pub extern crate jsonrpc_core; -pub extern crate jsonrpc_pubsub; -extern crate serde; - -mod auto_args; -mod delegates; -mod util; - -pub mod pubsub; - -#[doc(hidden)] -pub use auto_args::{WrapAsync, WrapMeta, WrapSubscribe}; - -#[doc(hidden)] -pub use serde::{de::DeserializeOwned, Serialize}; - -pub use auto_args::Trailing; -pub use delegates::IoDelegate; -pub use util::to_value; diff --git a/macros/src/pubsub.rs b/macros/src/pubsub.rs deleted file mode 100644 index 5007feaab..000000000 --- a/macros/src/pubsub.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! PUB-SUB auto-serializing structures. - -use std::marker::PhantomData; - -use jsonrpc_core as core; -use jsonrpc_pubsub as pubsub; -use serde; -use util::to_value; - -use self::core::futures::{self, Sink as FuturesSink, sync}; - -pub use self::pubsub::SubscriptionId; - -/// New PUB-SUB subcriber. -#[derive(Debug)] -pub struct Subscriber { - subscriber: pubsub::Subscriber, - _data: PhantomData<(T, E)>, -} - -impl Subscriber { - /// Wrap non-typed subscriber. - pub fn new(subscriber: pubsub::Subscriber) -> Self { - Subscriber { - subscriber: subscriber, - _data: PhantomData, - } - } - - /// Create new subscriber for tests. - pub fn new_test>(method: M) -> ( - Self, - sync::oneshot::Receiver>, - sync::mpsc::Receiver, - ) { - let (subscriber, id, subscription) = pubsub::Subscriber::new_test(method); - (Subscriber::new(subscriber), id, subscription) - } - - /// Reject subscription with given error. - pub fn reject(self, error: core::Error) -> Result<(), ()> { - self.subscriber.reject(error) - } - - /// Assign id to this subscriber. - /// This method consumes `Subscriber` and returns `Sink` - /// if the connection is still open or error otherwise. - pub fn assign_id(self, id: SubscriptionId) -> Result, ()> { - let sink = self.subscriber.assign_id(id.clone())?; - Ok(Sink { - id: id, - sink: sink, - buffered: None, - _data: PhantomData, - }) - } -} - -/// Subscriber sink. -#[derive(Debug, Clone)] -pub struct Sink { - sink: pubsub::Sink, - id: SubscriptionId, - buffered: Option, - _data: PhantomData<(T, E)>, -} - -impl Sink { - /// Sends a notification to the subscriber. - pub fn notify(&self, val: Result) -> pubsub::SinkResult { - self.sink.notify(self.val_to_params(val)) - } - - fn val_to_params(&self, val: Result) -> core::Params { - let id = self.id.clone().into(); - let val = val.map(to_value).map_err(to_value); - - core::Params::Map(vec![ - ("subscription".to_owned(), id), - match val { - Ok(val) => ("result".to_owned(), val), - Err(err) => ("error".to_owned(), err), - }, - ].into_iter().collect()) - } - - fn poll(&mut self) -> futures::Poll<(), pubsub::TransportError> { - if let Some(item) = self.buffered.take() { - let result = self.sink.start_send(item)?; - if let futures::AsyncSink::NotReady(item) = result { - self.buffered = Some(item); - } - } - - if self.buffered.is_some() { - Ok(futures::Async::NotReady) - } else { - Ok(futures::Async::Ready(())) - } - } -} - -impl futures::sink::Sink for Sink { - type SinkItem = Result; - type SinkError = pubsub::TransportError; - - fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend { - // Make sure to always try to process the buffered entry. - // Since we're just a proxy to real `Sink` we don't need - // to schedule a `Task` wakeup. It will be done downstream. - if self.poll()?.is_not_ready() { - return Ok(futures::AsyncSink::NotReady(item)); - } - - let val = self.val_to_params(item); - self.buffered = Some(val); - self.poll()?; - - Ok(futures::AsyncSink::Ready) - } - - fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> { - self.poll()?; - self.sink.poll_complete() - } - - fn close(&mut self) -> futures::Poll<(), Self::SinkError> { - self.poll()?; - self.sink.close() - } -} diff --git a/macros/src/util.rs b/macros/src/util.rs deleted file mode 100644 index 4b06f5afe..000000000 --- a/macros/src/util.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Param & Value utilities - -use std::fmt; -use jsonrpc_core::{self as core, Error, Params, ErrorCode, Value}; -use serde; - -/// Returns an `InvalidParams` for given parameter. -pub fn invalid_params(param: &str, details: T) -> Error where T: fmt::Debug { - Error { - code: ErrorCode::InvalidParams, - message: format!("Couldn't parse parameters: {}", param), - data: Some(Value::String(format!("{:?}", details))), - } -} - -/// Validates if the method was invoked without any params. -pub fn expect_no_params(params: Params) -> core::Result<()> { - match params { - Params::None => Ok(()), - Params::Array(ref v) if v.is_empty() => Ok(()), - p => Err(invalid_params("No parameters were expected", p)), - } -} - -/// Converts a serializable value into `Value`. -pub fn to_value(value: T) -> Value where T: serde::Serialize { - core::to_value(value).expect("Expected always-serializable type.") -} diff --git a/macros/tests/macros.rs b/macros/tests/macros.rs deleted file mode 100644 index b155127a7..000000000 --- a/macros/tests/macros.rs +++ /dev/null @@ -1,73 +0,0 @@ - -extern crate serde_json; -extern crate jsonrpc_core; -#[macro_use] -extern crate jsonrpc_macros; - -use jsonrpc_core::{IoHandler, Response}; - -pub enum MyError {} -impl From for jsonrpc_core::Error { - fn from(_e: MyError) -> Self { - unreachable!() - } -} - -type Result = ::std::result::Result; - -build_rpc_trait! { - pub trait Rpc { - /// Returns a protocol version - #[rpc(name = "protocolVersion")] - fn protocol_version(&self) -> Result; - - /// Adds two numbers and returns a result - #[rpc(name = "add")] - fn add(&self, u64, u64) -> Result; - } -} - -#[derive(Default)] -struct RpcImpl; - -impl Rpc for RpcImpl { - fn protocol_version(&self) -> Result { - Ok("version1".into()) - } - - fn add(&self, a: u64, b: u64) -> Result { - Ok(a + b) - } -} - -#[test] -fn should_accept_empty_array_as_no_params() { - let mut io = IoHandler::new(); - let rpc = RpcImpl::default(); - io.extend_with(rpc.to_delegate()); - - // when - let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion","params":[]}"#; - let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion","params":null}"#; - let req3 = r#"{"jsonrpc":"2.0","id":1,"method":"protocolVersion"}"#; - - let res1 = io.handle_request_sync(req1); - let res2 = io.handle_request_sync(req2); - let res3 = io.handle_request_sync(req3); - let expected = r#"{ - "jsonrpc": "2.0", - "result": "version1", - "id": 1 - }"#; - let expected: Response = serde_json::from_str(expected).unwrap(); - - // then - let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); - assert_eq!(expected, result1); - - let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); - assert_eq!(expected, result2); - - let result3: Response = serde_json::from_str(&res3.unwrap()).unwrap(); - assert_eq!(expected, result3); -} diff --git a/macros/tests/pubsub-macros.rs b/macros/tests/pubsub-macros.rs deleted file mode 100644 index 609321486..000000000 --- a/macros/tests/pubsub-macros.rs +++ /dev/null @@ -1,85 +0,0 @@ -extern crate serde_json; -extern crate jsonrpc_core; -extern crate jsonrpc_pubsub; -#[macro_use] -extern crate jsonrpc_macros; - -use std::sync::Arc; -use jsonrpc_core::futures::sync::mpsc; -use jsonrpc_pubsub::{PubSubHandler, SubscriptionId, Session, PubSubMetadata}; -use jsonrpc_macros::{pubsub, Trailing}; - -pub enum MyError {} -impl From for jsonrpc_core::Error { - fn from(_e: MyError) -> Self { - unreachable!() - } -} - -type Result = ::std::result::Result; - -build_rpc_trait! { - pub trait Rpc { - type Metadata; - - #[pubsub(name = "hello")] { - /// Hello subscription - #[rpc(name = "hello_subscribe")] - fn subscribe(&self, Self::Metadata, pubsub::Subscriber, u32, Trailing); - - /// Unsubscribe from hello subscription. - #[rpc(name = "hello_unsubscribe")] - fn unsubscribe(&self, SubscriptionId) -> Result; - } - } -} - -#[derive(Default)] -struct RpcImpl; - -impl Rpc for RpcImpl { - type Metadata = Metadata; - - fn subscribe(&self, _meta: Self::Metadata, subscriber: pubsub::Subscriber, _pre: u32, _trailing: Trailing) { - let _sink = subscriber.assign_id(SubscriptionId::Number(5)); - } - - fn unsubscribe(&self, _id: SubscriptionId) -> Result { - Ok(true) - } -} - -#[derive(Clone, Default)] -struct Metadata; -impl jsonrpc_core::Metadata for Metadata {} -impl PubSubMetadata for Metadata { - fn session(&self) -> Option> { - let (tx, _rx) = mpsc::channel(1); - Some(Arc::new(Session::new(tx))) - } -} - -#[test] -fn test_invalid_trailing_pubsub_params() { - let mut io = PubSubHandler::default(); - let rpc = RpcImpl::default(); - io.extend_with(rpc.to_delegate()); - - // when - let meta = Metadata; - let req = r#"{"jsonrpc":"2.0","id":1,"method":"hello_subscribe","params":[]}"#; - let res = io.handle_request_sync(req, meta); - let expected = r#"{ - "jsonrpc": "2.0", - "error": { - "code": -32602, - "message": "Couldn't parse parameters: `params` should have at least 1 argument(s)", - "data": "\"\"" - }, - "id": 1 - }"#; - - let expected: jsonrpc_core::Response = serde_json::from_str(expected).unwrap(); - let result: jsonrpc_core::Response = serde_json::from_str(&res.unwrap()).unwrap(); - assert_eq!(expected, result); -} diff --git a/minihttp/Cargo.toml b/minihttp/Cargo.toml deleted file mode 100644 index 84d52a64c..000000000 --- a/minihttp/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -description = "Blazing fast http server for JSON-RPC 2.0." -homepage = "https://github.com/paritytech/jsonrpc" -repository = "https://github.com/paritytech/jsonrpc" -license = "MIT" -name = "jsonrpc-minihttp-server" -version = "9.0.0" -authors = ["tomusdrw "] -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "server"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_minihttp_server/index.html" - -[dependencies] -bytes = "0.4" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-server-utils = { version = "9.0", path = "../server-utils" } -log = "0.4" -parking_lot = "0.6" -tokio-minihttp = { git = "https://github.com/tomusdrw/tokio-minihttp" } -tokio-proto = { git = "https://github.com/tomusdrw/tokio-proto" } -tokio-service = "0.1" - -[dev-dependencies] -env_logger = "0.6" -reqwest = "0.6" - -[badges] -travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/minihttp/README.md b/minihttp/README.md deleted file mode 100644 index fa692b078..000000000 --- a/minihttp/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# jsonrpc-minihttp-server -Blazing fast HTTP server for JSON-RPC 2.0. - -[Documentation](http://paritytech.github.io/jsonrpc/jsonrpc_http_server/index.html) - -## Example - -`Cargo.toml` - -``` -[dependencies] -jsonrpc-minihttp-server = { git = "https://github.com/paritytech/jsonrpc" } -``` - -`main.rs` - -```rust -extern crate jsonrpc_minihttp_server; - -use jsonrpc_minihttp_server::*; -use jsonrpc_minihttp_server::jsonrpc_core::*; -use jsonrpc_minihttp_server::cors::AccessControlAllowOrigin; - -fn main() { - let mut io = IoHandler::default(); - io.add_method("say_hello", |_| { - Ok(Value::String("hello".into())) - }); - - let server = ServerBuilder::new(io) - .cors(DomainsValidation::AllowOnly(vec![AccessControlAllowOrigin::Null])) - .start_http(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); - - server.wait().unwrap(); -} -``` diff --git a/minihttp/examples/http_async.rs b/minihttp/examples/http_async.rs deleted file mode 100644 index 1fbe52e48..000000000 --- a/minihttp/examples/http_async.rs +++ /dev/null @@ -1,19 +0,0 @@ -extern crate jsonrpc_minihttp_server; - -use jsonrpc_minihttp_server::{cors, ServerBuilder, DomainsValidation}; -use jsonrpc_minihttp_server::jsonrpc_core::*; - -fn main() { - let mut io = IoHandler::default(); - io.add_method("say_hello", |_params| { - futures::finished(Value::String("hello".to_owned())) - }); - - let server = ServerBuilder::new(io) - .cors(DomainsValidation::AllowOnly(vec![cors::AccessControlAllowOrigin::Null])) - .start_http(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); - - server.wait().unwrap(); -} - diff --git a/minihttp/examples/http_meta.rs b/minihttp/examples/http_meta.rs deleted file mode 100644 index 630c9fb11..000000000 --- a/minihttp/examples/http_meta.rs +++ /dev/null @@ -1,26 +0,0 @@ -extern crate jsonrpc_minihttp_server; - -use jsonrpc_minihttp_server::{cors, ServerBuilder, DomainsValidation, Req}; -use jsonrpc_minihttp_server::jsonrpc_core::*; - -#[derive(Clone, Default)] -struct Meta(usize); -impl Metadata for Meta {} - -fn main() { - let mut io = MetaIoHandler::default(); - io.add_method_with_meta("say_hello", |_params: Params, meta: Meta| { - futures::finished(Value::String(format!("hello: {}", meta.0))) - }); - - let server = ServerBuilder::new(io) - .meta_extractor(|req: &Req| { - Meta(req.header("Origin").map(|v| v.len()).unwrap_or_default()) - }) - .cors(DomainsValidation::AllowOnly(vec![cors::AccessControlAllowOrigin::Null])) - .start_http(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); - - server.wait().unwrap(); -} - diff --git a/minihttp/examples/server.rs b/minihttp/examples/server.rs deleted file mode 100644 index 63a3e6065..000000000 --- a/minihttp/examples/server.rs +++ /dev/null @@ -1,20 +0,0 @@ -extern crate jsonrpc_minihttp_server; - -use jsonrpc_minihttp_server::{cors, ServerBuilder, DomainsValidation}; -use jsonrpc_minihttp_server::jsonrpc_core::*; - -fn main() { - let mut io = IoHandler::default(); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); - - let server = ServerBuilder::new(io) - .threads(3) - .cors(DomainsValidation::AllowOnly(vec![cors::AccessControlAllowOrigin::Null])) - .start_http(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); - - server.wait().unwrap(); -} - diff --git a/minihttp/src/lib.rs b/minihttp/src/lib.rs deleted file mode 100644 index 9f7920fd6..000000000 --- a/minihttp/src/lib.rs +++ /dev/null @@ -1,358 +0,0 @@ -//! jsonrpc http server. -//! -//! ```no_run -//! extern crate jsonrpc_core; -//! extern crate jsonrpc_minihttp_server; -//! -//! use jsonrpc_core::*; -//! use jsonrpc_minihttp_server::*; -//! -//! fn main() { -//! let mut io = IoHandler::new(); -//! io.add_method("say_hello", |_: Params| { -//! Ok(Value::String("hello".to_string())) -//! }); -//! -//! let _server = ServerBuilder::new(io).start_http(&"127.0.0.1:3030".parse().unwrap()); -//! } -//! ``` - -#![warn(missing_docs)] - -extern crate bytes; -extern crate jsonrpc_server_utils; -extern crate parking_lot; -extern crate tokio_proto; -extern crate tokio_service; -extern crate tokio_minihttp; - -pub extern crate jsonrpc_core; - -#[macro_use] -extern crate log; - -mod req; -mod res; -#[cfg(test)] -mod tests; - -use std::io; -use std::sync::{Arc, mpsc}; -use std::net::SocketAddr; -use std::thread; -use parking_lot::RwLock; -use jsonrpc_core as jsonrpc; -use jsonrpc::futures::{self, future, Future}; -use jsonrpc::{FutureResult, MetaIoHandler, Response, Output}; -use jsonrpc_server_utils::hosts; - -pub use jsonrpc_server_utils::cors; -pub use jsonrpc_server_utils::hosts::{Host, DomainsValidation}; -pub use req::Req; - -/// Extracts metadata from the HTTP request. -pub trait MetaExtractor: Sync + Send + 'static { - /// Read the metadata from the request - fn read_metadata(&self, _: &req::Req) -> M; -} - -impl MetaExtractor for F where - M: jsonrpc::Metadata, - F: Fn(&req::Req) -> M + Sync + Send + 'static, -{ - fn read_metadata(&self, req: &req::Req) -> M { - (*self)(req) - } -} - -#[derive(Default)] -struct NoopExtractor; -impl MetaExtractor for NoopExtractor { - fn read_metadata(&self, _: &req::Req) -> M { - M::default() - } -} - -/// Convenient JSON-RPC HTTP Server builder. -pub struct ServerBuilder = jsonrpc::middleware::Noop> { - jsonrpc_handler: Arc>, - meta_extractor: Arc>, - cors_domains: Option>, - allowed_hosts: Option>, - threads: usize, -} - -const SENDER_PROOF: &'static str = "Server initialization awaits local address."; - -impl> ServerBuilder { - /// Creates new `ServerBuilder` for given `IoHandler`. - /// - /// By default: - /// 1. Server is not sending any CORS headers. - /// 2. Server is validating `Host` header. - pub fn new(handler: T) -> Self where T: Into> { - Self::with_meta_extractor(handler, NoopExtractor) - } -} - -impl> ServerBuilder { - /// Creates new `ServerBuilder` for given `IoHandler` and meta extractor. - /// - /// By default: - /// 1. Server is not sending any CORS headers. - /// 2. Server is validating `Host` header. - pub fn with_meta_extractor(handler: T, extractor: E) -> Self where - T: Into>, - E: MetaExtractor, - { - ServerBuilder { - jsonrpc_handler: Arc::new(handler.into()), - meta_extractor: Arc::new(extractor), - cors_domains: None, - allowed_hosts: None, - threads: 1, - } - } - - /// Sets number of threads of the server to run. (not available for windows) - /// Panics when set to `0`. - pub fn threads(mut self, threads: usize) -> Self { - assert!(threads > 0); - self.threads = threads; - self - } - - /// Configures a list of allowed CORS origins. - pub fn cors(mut self, cors_domains: DomainsValidation) -> Self { - self.cors_domains = cors_domains.into(); - self - } - - /// Configures metadata extractor - pub fn meta_extractor>(mut self, extractor: T) -> Self { - self.meta_extractor = Arc::new(extractor); - self - } - - /// Allow connections only with `Host` header set to binding address. - pub fn allow_only_bind_host(mut self) -> Self { - self.allowed_hosts = Some(Vec::new()); - self - } - - /// Specify a list of valid `Host` headers. Binding address is allowed automatically. - pub fn allowed_hosts(mut self, allowed_hosts: DomainsValidation) -> Self { - self.allowed_hosts = allowed_hosts.into(); - self - } - - /// Start this JSON-RPC HTTP server trying to bind to specified `SocketAddr`. - pub fn start_http(self, addr: &SocketAddr) -> io::Result { - let cors_domains = self.cors_domains; - let allowed_hosts = self.allowed_hosts; - let handler = self.jsonrpc_handler; - let meta_extractor = self.meta_extractor; - let threads = self.threads; - - let (local_addr_tx, local_addr_rx) = mpsc::channel(); - let (close, shutdown_signal) = futures::sync::oneshot::channel(); - let addr = addr.to_owned(); - let handle = thread::spawn(move || { - let run = move || { - let hosts = Arc::new(RwLock::new(allowed_hosts.clone())); - let hosts2 = hosts.clone(); - let mut hosts_setter = hosts2.write(); - - let mut server = tokio_proto::TcpServer::new(tokio_minihttp::Http, addr); - server.threads(threads); - let server = server.bind(move || Ok(RpcService { - handler: handler.clone(), - meta_extractor: meta_extractor.clone(), - hosts: hosts.read().clone(), - cors_domains: cors_domains.clone(), - }))?; - - let local_addr = server.local_addr()?; - // Add current host to allowed headers. - // NOTE: we need to use `local_address` instead of `addr` - // it might be different! - *hosts_setter = hosts::update(allowed_hosts, &local_addr); - - Ok((server, local_addr)) - }; - - match run() { - Ok((server, local_addr)) => { - // Send local address - local_addr_tx.send(Ok(local_addr)).expect(SENDER_PROOF); - - // Start the server and wait for shutdown signal - server.run_until(shutdown_signal.map_err(|_| { - warn!("Shutdown signaller dropped, closing server."); - })).expect("Expected clean shutdown.") - }, - Err(err) => { - // Send error - local_addr_tx.send(Err(err)).expect(SENDER_PROOF); - } - } - }); - - // Wait for server initialization - let local_addr: io::Result = local_addr_rx.recv().map_err(|_| { - io::Error::new(io::ErrorKind::Interrupted, "") - })?; - - Ok(Server { - address: local_addr?, - handle: Some(handle), - close: Some(close), - }) - } -} - -/// Tokio-proto JSON-RPC HTTP Service -pub struct RpcService> { - handler: Arc>, - meta_extractor: Arc>, - hosts: Option>, - cors_domains: Option>, -} - -fn is_json(content_type: Option<&str>) -> bool { - match content_type { - None => false, - Some(ref content_type) => { - let json = "application/json"; - content_type.eq_ignore_ascii_case(json) - } - } -} - -impl> tokio_service::Service for RpcService { - type Request = tokio_minihttp::Request; - type Response = tokio_minihttp::Response; - type Error = io::Error; - type Future = future::Either< - future::FutureResult, - RpcResponse, - >; - - fn call(&self, request: Self::Request) -> Self::Future { - use self::future::Either; - - let request = req::Req::new(request); - // Validate HTTP Method - let is_options = request.method() == req::Method::Options; - if !is_options && request.method() != req::Method::Post { - return Either::A(future::ok( - res::method_not_allowed() - )); - } - - // Validate allowed hosts - let host = request.header("Host"); - if !hosts::is_host_valid(host.clone(), &self.hosts) { - return Either::A(future::ok( - res::invalid_host() - )); - } - - // Extract CORS headers - let origin = request.header("Origin"); - let cors = cors::get_cors_allow_origin(origin, host, &self.cors_domains); - - // Validate cors header - if let cors::AllowCors::Invalid = cors { - return Either::A(future::ok( - res::invalid_allow_origin() - )); - } - - // Don't process data if it's OPTIONS - if is_options { - return Either::A(future::ok( - res::options(cors.into()) - )); - } - - // Validate content type - let content_type = request.header("Content-type"); - if !is_json(content_type) { - return Either::A(future::ok( - res::invalid_content_type() - )); - } - - // Extract metadata - let metadata = self.meta_extractor.read_metadata(&request); - - // Read & handle request - let data = request.body(); - let future = self.handler.handle_request(data, metadata); - Either::B(RpcResponse { - future: future, - cors: cors.into(), - }) - } -} - -/// RPC response wrapper -pub struct RpcResponse where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ - future: FutureResult, - cors: Option, -} - -impl Future for RpcResponse where - F: Future, Error = ()>, - G: Future, Error = ()>, -{ - type Item = tokio_minihttp::Response; - type Error = io::Error; - - fn poll(&mut self) -> futures::Poll { - use self::futures::Async::*; - - match self.future.poll() { - Err(_) => Ok(Ready(res::internal_error())), - Ok(NotReady) => Ok(NotReady), - Ok(Ready(result)) => { - let result = format!("{}\n", result.unwrap_or_default()); - Ok(Ready(res::new(&result, self.cors.take()))) - }, - } - } -} - -/// jsonrpc http server instance -pub struct Server { - address: SocketAddr, - handle: Option>, - close: Option>, -} - -impl Server { - /// Returns addresses of this server - pub fn address(&self) -> &SocketAddr { - &self.address - } - - /// Closes the server. - pub fn close(mut self) { - let _ = self.close.take().expect("Close is always set before self is consumed.").send(()); - } - - /// Will block, waiting for the server to finish. - pub fn wait(mut self) -> thread::Result<()> { - self.handle.take().expect("Handle is always set before set is consumed.").join() - } -} - -impl Drop for Server { - fn drop(&mut self) { - self.close.take().map(|close| close.send(())); - } -} diff --git a/minihttp/src/req.rs b/minihttp/src/req.rs deleted file mode 100644 index 5405226f2..000000000 --- a/minihttp/src/req.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! Convenient Request wrapper used internally. - -use tokio_minihttp; -use bytes::Bytes; - -/// HTTP Method used -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Method { - /// POST - Post, - /// OPTIONS - Options, - /// Other method - Other -} - -/// Request -pub struct Req { - request: tokio_minihttp::Request, - body: Bytes, -} - -impl Req { - /// Creates new `Req` object - pub fn new(request: tokio_minihttp::Request) -> Self { - let body = request.body(); - Req { - request: request, - body: body, - } - } - - /// Returns request method - pub fn method(&self) -> Method { - // RFC 2616: The method is case-sensitive - match self.request.method() { - "OPTIONS" => Method::Options, - "POST" => Method::Post, - _ => Method::Other, - } - } - - /// Returns value of first header with given name. - /// `None` if header is not found or value is not utf-8 encoded - pub fn header(&self, name: &str) -> Option<&str> { - self.request.headers() - .find(|header| header.0.eq_ignore_ascii_case(name)) - .and_then(|header| ::std::str::from_utf8(header.1).ok()) - } - - /// Returns body of the request as a string - pub fn body(&self) -> &str { - ::std::str::from_utf8(&self.body).unwrap_or("") - } -} diff --git a/minihttp/src/res.rs b/minihttp/src/res.rs deleted file mode 100644 index 6eb61052c..000000000 --- a/minihttp/src/res.rs +++ /dev/null @@ -1,85 +0,0 @@ -//! Convenient Response utils used internally - -use tokio_minihttp::Response; -use jsonrpc_server_utils::cors; - -const SERVER: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); - -pub fn options(cors: Option) -> Response { - let mut response = new("", cors); - response - .header("Allow", "OPTIONS, POST") - .header("Accept", "application/json"); - response -} - -pub fn method_not_allowed() -> Response { - let mut response = Response::new(); - response - .status_code(405, "Method Not Allowed") - .server(SERVER) - .header("Content-Type", "text/plain") - .body("Used HTTP Method is not allowed. POST or OPTIONS is required.\n"); - response -} - -pub fn invalid_host() -> Response { - let mut response = Response::new(); - response - .status_code(403, "Forbidden") - .server(SERVER) - .header("Content-Type", "text/plain") - .body("Provided Host header is not whitelisted.\n"); - response -} - -pub fn internal_error() -> Response { - let mut response = Response::new(); - response - .status_code(500, "Internal Error") - .server(SERVER) - .header("Content-Type", "text/plain") - .body("Interal Server Error has occured."); - response -} - -pub fn invalid_allow_origin() -> Response { - let mut response = Response::new(); - response - .status_code(403, "Forbidden") - .server(SERVER) - .header("Content-Type", "text/plain") - .body("Origin of the request is not whitelisted. CORS headers would not be sent and any side-effects were cancelled as well.\n"); - response -} - -pub fn invalid_content_type() -> Response { - let mut response = Response::new(); - response - .status_code(415, "Unsupported Media Type") - .server(SERVER) - .header("Content-Type", "text/plain") - .body("Supplied content type is not allowed. Content-Type: application/json is required.\n"); - response -} - -pub fn new(body: &str, cors: Option) -> Response { - let mut response = Response::new(); - response - .header("Content-Type", "application/json") - .server(SERVER) - .body(body); - - if let Some(cors) = cors { - response - .header("Access-Control-Allow-Methods", "OPTIONS, POST") - .header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept") - .header("Access-Control-Allow-Origin", match cors { - cors::AccessControlAllowOrigin::Null => "null", - cors::AccessControlAllowOrigin::Any => "*", - cors::AccessControlAllowOrigin::Value(ref val) => val, - }) - .header("Vary", "Origin"); - } - response -} diff --git a/minihttp/src/tests.rs b/minihttp/src/tests.rs deleted file mode 100644 index 31f954396..000000000 --- a/minihttp/src/tests.rs +++ /dev/null @@ -1,497 +0,0 @@ -extern crate jsonrpc_core; -extern crate env_logger; -extern crate reqwest; - -use std::io::Read; -use self::reqwest::{StatusCode, Method}; -use self::reqwest::header::{self, Headers}; -use self::jsonrpc_core::{IoHandler, Params, Value, Error}; -use self::jsonrpc_core::futures::{self, Future}; -use super::{ServerBuilder, Server, cors, hosts}; - -fn serve_hosts(hosts: Vec) -> Server { - let _ = env_logger::try_init(); - - ServerBuilder::new(IoHandler::default()) - .cors(hosts::DomainsValidation::AllowOnly(vec![cors::AccessControlAllowOrigin::Value("http://parity.io".into())])) - .allowed_hosts(hosts::DomainsValidation::AllowOnly(hosts)) - .start_http(&"127.0.0.1:0".parse().unwrap()) - .unwrap() -} - -fn serve() -> Server { - use std::thread; - - let _ = env_logger::try_init(); - let mut io = IoHandler::default(); - io.add_method("hello", |_params: Params| Ok(Value::String("world".into()))); - io.add_method("hello_async", |_params: Params| { - futures::finished(Value::String("world".into())) - }); - io.add_method("hello_async2", |_params: Params| { - let (c, p) = futures::oneshot(); - thread::spawn(move || { - thread::sleep(::std::time::Duration::from_millis(10)); - c.send(Value::String("world".into())).unwrap(); - }); - p.map_err(|_| Error::invalid_request()) - }); - - ServerBuilder::new(io) - .cors(hosts::DomainsValidation::AllowOnly(vec![ - cors::AccessControlAllowOrigin::Value("http://parity.io".into()), - cors::AccessControlAllowOrigin::Null, - ])) - .start_http(&"127.0.0.1:0".parse().unwrap()) - .unwrap() -} - -struct Response { - pub status: reqwest::StatusCode, - pub body: String, - pub headers: Headers, -} - -fn request(server: Server, method: Method, headers: Headers, body: &'static str) -> Response { - let client = reqwest::Client::new().unwrap(); - let mut res = client.request(method, &format!("http://{}", server.address())) - .headers(headers) - .body(body) - .send() - .unwrap(); - - let mut body = String::new(); - res.read_to_string(&mut body).unwrap(); - - Response { - status: res.status().clone(), - body: body, - headers: res.headers().clone(), - } -} - -#[test] -fn should_return_method_not_allowed_for_get() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Get, - Headers::new(), - "I shouldn't be read.", - ); - - // then - assert_eq!(response.status, StatusCode::MethodNotAllowed); - assert_eq!(response.body, "Used HTTP Method is not allowed. POST or OPTIONS is required.\n".to_owned()); -} - -#[test] -fn should_ignore_media_type_if_options() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Options, - Headers::new(), - "", - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!( - response.headers.get::(), - Some(&reqwest::header::Allow(vec![Method::Options, Method::Post])) - ); - assert!(response.headers.get::().is_some()); - assert_eq!(response.body, ""); -} - - -#[test] -fn should_return_403_for_options_if_origin_is_invalid() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.append_raw("origin", b"http://invalid.io".to_vec()); - headers - }, - "" - ); - - // then - assert_eq!(response.status, StatusCode::Forbidden); - assert_eq!(response.body, cors_invalid_allow_origin()); -} - -#[test] -fn should_return_unsupported_media_type_if_not_json() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - Headers::new(), - "{}", - ); - - // then - assert_eq!(response.status, StatusCode::UnsupportedMediaType); - assert_eq!(response.body, "Supplied content type is not allowed. Content-Type: application/json is required.\n".to_owned()); -} - -fn content_type_json() -> Headers { - let mut headers = Headers::new(); - headers.set_raw("content-type", vec![b"application/json".to_vec()]); - headers -} - -#[test] -fn should_return_error_for_malformed_request() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"3.0","method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, invalid_request()); -} - -#[test] -fn should_return_error_for_malformed_request2() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","metho1d":""}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, invalid_request()); -} - -#[test] -fn should_return_empty_response_for_notification() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, "\n".to_owned()); -} - - -#[test] -fn should_return_method_not_found() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); -} - -#[test] -fn should_add_cors_allow_origins() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set(header::Origin::new("http", "parity.io", None)); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); - assert_eq!( - response.headers.get::(), - Some(&reqwest::header::AccessControlAllowOrigin::Value("http://parity.io".into())) - ); -} - -#[test] -fn should_add_cors_allow_origins_for_options() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Options, - { - let mut headers = content_type_json(); - headers.set(header::Origin::new("http", "parity.io", None)); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, "".to_owned()); - println!("{:?}", response.headers); - assert_eq!( - response.headers.get::(), - Some(&reqwest::header::AccessControlAllowOrigin::Value("http://parity.io".into())) - ); -} - -#[test] -fn should_not_process_request_with_invalid_allow_origin() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set(header::Origin::new("http", "fake.io", None)); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Forbidden); - assert_eq!(response.body, cors_invalid_allow_origin()); -} - -#[test] -fn should_add_cors_allow_origin_for_null_origin() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.append_raw("origin", b"null".to_vec()); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); - assert_eq!( - response.headers.get::().cloned(), - Some(reqwest::header::AccessControlAllowOrigin::Null) - ); -} - -#[test] -fn should_reject_invalid_hosts() { - // given - let server = serve_hosts(vec!["parity.io".into()]); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set_raw("Host", vec![b"127.0.0.1:8080".to_vec()]); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Forbidden); - assert_eq!(response.body, invalid_host()); -} - -#[test] -fn should_allow_if_host_is_valid() { - // given - let server = serve_hosts(vec!["parity.io".into()]); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set_raw("Host", vec![b"parity.io".to_vec()]); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); -} - -#[test] -fn should_always_allow_the_bind_address() { - // given - let server = serve_hosts(vec!["parity.io".into()]); - let addr = server.address().clone(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set_raw("Host", vec![format!("{}", addr).as_bytes().to_vec()]); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); -} - -#[test] -fn should_always_allow_the_bind_address_as_localhost() { - // given - let server = serve_hosts(vec![]); - let addr = server.address().clone(); - - // when - let response = request(server, - Method::Post, - { - let mut headers = content_type_json(); - headers.set_raw("Host", vec![format!("localhost:{}", addr.port()).as_bytes().to_vec()]); - headers - }, - r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, method_not_found()); -} - -#[test] -fn should_handle_sync_requests_correctly() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","id":1,"method":"hello"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, world()); -} - -#[test] -fn should_handle_async_requests_with_immediate_response_correctly() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","id":1,"method":"hello_async"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, world()); -} - -#[test] -fn should_handle_async_requests_correctly() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"{"jsonrpc":"2.0","id":1,"method":"hello_async2"}"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, world()); -} - -#[test] -fn should_handle_sync_batch_requests_correctly() { - // given - let server = serve(); - - // when - let response = request(server, - Method::Post, - content_type_json(), - r#"[{"jsonrpc":"2.0","id":1,"method":"hello"}]"#, - ); - - // then - assert_eq!(response.status, StatusCode::Ok); - assert_eq!(response.body, world_batch()); -} - -fn invalid_host() -> String { - "Provided Host header is not whitelisted.\n".into() -} - -fn cors_invalid_allow_origin() -> String { - "Origin of the request is not whitelisted. CORS headers would not be sent and any side-effects were cancelled as well.\n".into() -} - -fn method_not_found() -> String { - "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32601,\"message\":\"Method not found\"},\"id\":1}\n".into() -} - -fn invalid_request() -> String { - "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"Invalid request\"},\"id\":null}\n".into() -} -fn world() -> String { - "{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}\n".into() -} -fn world_batch() -> String { - "[{\"jsonrpc\":\"2.0\",\"result\":\"world\",\"id\":1}]\n".into() -} diff --git a/pubsub/Cargo.toml b/pubsub/Cargo.toml index 52372970a..89c462655 100644 --- a/pubsub/Cargo.toml +++ b/pubsub/Cargo.toml @@ -1,21 +1,27 @@ [package] +authors = ["Parity Technologies "] description = "Publish-Subscribe extension for jsonrpc." +documentation = "https://docs.rs/jsonrpc-pubsub/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" -repository = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "macros"] license = "MIT" name = "jsonrpc-pubsub" -version = "9.0.0" -authors = ["tomusdrw "] -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "macros"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_pubsub/index.html" +repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" [dependencies] +futures = { version = "0.3", features = ["thread-pool"] } +jsonrpc-core = { version = "18.0.0", path = "../core" } +lazy_static = "1.4" log = "0.4" -parking_lot = "0.6" -jsonrpc-core = { version = "9.0", path = "../core" } +parking_lot = "0.11.0" +rand = "0.7" +serde = "1.0" [dev-dependencies] -jsonrpc-tcp-server = { version = "9.0", path = "../tcp" } +jsonrpc-tcp-server = { version = "18.0.0", path = "../tcp" } +serde_json = "1.0" [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/pubsub/examples/pubsub.rs b/pubsub/examples/pubsub.rs index 758baa2af..60d17cde6 100644 --- a/pubsub/examples/pubsub.rs +++ b/pubsub/examples/pubsub.rs @@ -1,15 +1,9 @@ -extern crate jsonrpc_core; -extern crate jsonrpc_pubsub; -extern crate jsonrpc_tcp_server; - -use std::{time, thread}; -use std::sync::Arc; +use std::sync::{atomic, Arc}; +use std::{thread, time}; use jsonrpc_core::*; use jsonrpc_pubsub::{PubSubHandler, Session, Subscriber, SubscriptionId}; -use jsonrpc_tcp_server::{ServerBuilder, RequestContext}; - -use jsonrpc_core::futures::Future; +use jsonrpc_tcp_server::{RequestContext, ServerBuilder}; /// To test the server: /// @@ -20,30 +14,38 @@ use jsonrpc_core::futures::Future; /// ``` fn main() { let mut io = PubSubHandler::new(MetaIoHandler::default()); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params: Params| Ok(Value::String("hello".to_string()))); + let is_done = Arc::new(atomic::AtomicBool::default()); + let is_done2 = is_done.clone(); io.add_subscription( "hello", - ("subscribe_hello", |params: Params, _, subscriber: Subscriber| { + ("subscribe_hello", move |params: Params, _, subscriber: Subscriber| { if params != Params::None { - subscriber.reject(Error { - code: ErrorCode::ParseError, - message: "Invalid parameters. Subscription rejected.".into(), - data: None, - }).unwrap(); + subscriber + .reject(Error { + code: ErrorCode::ParseError, + message: "Invalid parameters. Subscription rejected.".into(), + data: None, + }) + .unwrap(); return; } - let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); - // or subscriber.reject(Error {} ); - // or drop(subscriber) + let is_done = is_done.clone(); thread::spawn(move || { + let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); + // or subscriber.reject(Error {} ); + // or drop(subscriber) + loop { + if is_done.load(atomic::Ordering::SeqCst) { + return; + } + thread::sleep(time::Duration::from_millis(100)); - match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() { - Ok(_) => {}, + match sink.notify(Params::Array(vec![Value::Number(10.into())])) { + Ok(_) => {} Err(_) => { println!("Subscription has ended, finishing."); break; @@ -52,8 +54,9 @@ fn main() { } }); }), - ("remove_hello", |_id: SubscriptionId| { + ("remove_hello", move |_id: SubscriptionId, _| { println!("Closing subscription"); + is_done2.store(true, atomic::Ordering::SeqCst); futures::future::ok(Value::Bool(true)) }), ); @@ -65,4 +68,3 @@ fn main() { server.wait(); } - diff --git a/pubsub/examples/pubsub_simple.rs b/pubsub/examples/pubsub_simple.rs index fe85f3f44..bfd5bae1d 100644 --- a/pubsub/examples/pubsub_simple.rs +++ b/pubsub/examples/pubsub_simple.rs @@ -1,49 +1,46 @@ -extern crate jsonrpc_core; -extern crate jsonrpc_pubsub; -extern crate jsonrpc_tcp_server; - -use std::{time, thread}; use std::sync::Arc; +use std::{thread, time}; use jsonrpc_core::*; use jsonrpc_pubsub::{PubSubHandler, Session, Subscriber, SubscriptionId}; -use jsonrpc_tcp_server::{ServerBuilder, RequestContext}; - -use jsonrpc_core::futures::Future; +use jsonrpc_tcp_server::{RequestContext, ServerBuilder}; /// To test the server: /// /// ```bash -/// $ netcat localhost 3030 - -/// {"id":1,"jsonrpc":"2.0","method":"hello_subscribe","params":[10]} +/// $ netcat localhost 3030 +/// > {"id":1,"jsonrpc":"2.0","method":"subscribe_hello","params":null} +/// < {"id":1,"jsonrpc":"2.0","result":5,"id":1} +/// < {"jsonrpc":"2.0","method":"hello","params":[10]} /// /// ``` fn main() { let mut io = PubSubHandler::new(MetaIoHandler::default()); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params: Params| Ok(Value::String("hello".to_string()))); io.add_subscription( "hello", ("subscribe_hello", |params: Params, _, subscriber: Subscriber| { if params != Params::None { - subscriber.reject(Error { - code: ErrorCode::ParseError, - message: "Invalid parameters. Subscription rejected.".into(), - data: None, - }).unwrap(); + subscriber + .reject(Error { + code: ErrorCode::ParseError, + message: "Invalid parameters. Subscription rejected.".into(), + data: None, + }) + .unwrap(); return; } - let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); - // or subscriber.reject(Error {} ); - // or drop(subscriber) thread::spawn(move || { + let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); + // or subscriber.reject(Error {} ); + // or drop(subscriber) + loop { thread::sleep(time::Duration::from_millis(100)); - match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() { - Ok(_) => {}, + match sink.notify(Params::Array(vec![Value::Number(10.into())])) { + Ok(_) => {} Err(_) => { println!("Subscription has ended, finishing."); break; @@ -52,16 +49,17 @@ fn main() { } }); }), - ("remove_hello", |_id: SubscriptionId| { + ("remove_hello", |_id: SubscriptionId, _| { println!("Closing subscription"); futures::future::ok(Value::Bool(true)) }), ); - let server = ServerBuilder::with_meta_extractor(io, |context: &RequestContext| Arc::new(Session::new(context.sender.clone()))) - .start(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); + let server = ServerBuilder::with_meta_extractor(io, |context: &RequestContext| { + Arc::new(Session::new(context.sender.clone())) + }) + .start(&"127.0.0.1:3030".parse().unwrap()) + .expect("Unable to start RPC server"); server.wait(); } - diff --git a/pubsub/more-examples/Cargo.toml b/pubsub/more-examples/Cargo.toml index ce9c3c0f4..f999583ae 100644 --- a/pubsub/more-examples/Cargo.toml +++ b/pubsub/more-examples/Cargo.toml @@ -3,12 +3,12 @@ name = "jsonrpc-pubsub-examples" description = "Examples of Publish-Subscribe extension for jsonrpc." homepage = "https://github.com/paritytech/jsonrpc" repository = "https://github.com/paritytech/jsonrpc" -version = "9.0.0" +version = "18.0.0" authors = ["tomusdrw "] license = "MIT" [dependencies] -jsonrpc-core = { version = "9.0", path = "../../core" } -jsonrpc-pubsub = { version = "9.0", path = "../" } -jsonrpc-ws-server = { version = "9.0", path = "../../ws" } -jsonrpc-ipc-server = { version = "9.0", path = "../../ipc" } +jsonrpc-core = { version = "18.0.0", path = "../../core" } +jsonrpc-pubsub = { version = "18.0.0", path = "../" } +jsonrpc-ws-server = { version = "18.0.0", path = "../../ws" } +jsonrpc-ipc-server = { version = "18.0.0", path = "../../ipc" } diff --git a/pubsub/more-examples/examples/pubsub_ipc.rs b/pubsub/more-examples/examples/pubsub_ipc.rs index 8ffb1e3ba..d8798c971 100644 --- a/pubsub/more-examples/examples/pubsub_ipc.rs +++ b/pubsub/more-examples/examples/pubsub_ipc.rs @@ -1,15 +1,13 @@ extern crate jsonrpc_core; -extern crate jsonrpc_pubsub; extern crate jsonrpc_ipc_server; +extern crate jsonrpc_pubsub; -use std::{time, thread}; use std::sync::Arc; +use std::{thread, time}; use jsonrpc_core::*; +use jsonrpc_ipc_server::{RequestContext, ServerBuilder, SessionId, SessionStats}; use jsonrpc_pubsub::{PubSubHandler, Session, Subscriber, SubscriptionId}; -use jsonrpc_ipc_server::{ServerBuilder, RequestContext, SessionStats, SessionId}; - -use jsonrpc_core::futures::Future; /// To test the server: /// @@ -20,30 +18,31 @@ use jsonrpc_core::futures::Future; /// ``` fn main() { let mut io = PubSubHandler::new(MetaIoHandler::default()); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params: Params| Ok(Value::String("hello".to_string()))); io.add_subscription( "hello", ("subscribe_hello", |params: Params, _, subscriber: Subscriber| { if params != Params::None { - subscriber.reject(Error { - code: ErrorCode::ParseError, - message: "Invalid parameters. Subscription rejected.".into(), - data: None, - }).unwrap(); + subscriber + .reject(Error { + code: ErrorCode::ParseError, + message: "Invalid parameters. Subscription rejected.".into(), + data: None, + }) + .unwrap(); return; } - let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); - // or subscriber.reject(Error {} ); - // or drop(subscriber) thread::spawn(move || { + let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); + // or subscriber.reject(Error {} ); + // or drop(subscriber) + loop { thread::sleep(time::Duration::from_millis(100)); - match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() { - Ok(_) => {}, + match sink.notify(Params::Array(vec![Value::Number(10.into())])) { + Ok(_) => {} Err(_) => { println!("Subscription has ended, finishing."); break; @@ -52,16 +51,18 @@ fn main() { } }); }), - ("remove_hello", |_id: SubscriptionId| -> Result { + ("remove_hello", |_id: SubscriptionId, _meta| { println!("Closing subscription"); - Ok(Value::Bool(true)) + futures::future::ready(Ok(Value::Bool(true))) }), ); - let server = ServerBuilder::with_meta_extractor(io, |context: &RequestContext| Arc::new(Session::new(context.sender.clone()))) - .session_stats(Stats) - .start("./test.ipc") - .expect("Unable to start RPC server"); + let server = ServerBuilder::with_meta_extractor(io, |context: &RequestContext| { + Arc::new(Session::new(context.sender.clone())) + }) + .session_stats(Stats) + .start("./test.ipc") + .expect("Unable to start RPC server"); server.wait(); } @@ -76,4 +77,3 @@ impl SessionStats for Stats { println!("Closing session: {}", id); } } - diff --git a/pubsub/more-examples/examples/pubsub_ws.rs b/pubsub/more-examples/examples/pubsub_ws.rs index bbec658f5..463bb1182 100644 --- a/pubsub/more-examples/examples/pubsub_ws.rs +++ b/pubsub/more-examples/examples/pubsub_ws.rs @@ -2,14 +2,12 @@ extern crate jsonrpc_core; extern crate jsonrpc_pubsub; extern crate jsonrpc_ws_server; -use std::{time, thread}; use std::sync::Arc; +use std::{thread, time}; use jsonrpc_core::*; use jsonrpc_pubsub::{PubSubHandler, Session, Subscriber, SubscriptionId}; -use jsonrpc_ws_server::{ServerBuilder, RequestContext}; - -use jsonrpc_core::futures::Future; +use jsonrpc_ws_server::{RequestContext, ServerBuilder}; /// Use following node.js code to test: /// @@ -36,30 +34,31 @@ use jsonrpc_core::futures::Future; /// ``` fn main() { let mut io = PubSubHandler::new(MetaIoHandler::default()); - io.add_method("say_hello", |_params: Params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params: Params| Ok(Value::String("hello".to_string()))); io.add_subscription( "hello", ("subscribe_hello", |params: Params, _, subscriber: Subscriber| { if params != Params::None { - subscriber.reject(Error { - code: ErrorCode::ParseError, - message: "Invalid parameters. Subscription rejected.".into(), - data: None, - }).unwrap(); + subscriber + .reject(Error { + code: ErrorCode::ParseError, + message: "Invalid parameters. Subscription rejected.".into(), + data: None, + }) + .unwrap(); return; } - let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); - // or subscriber.reject(Error {} ); - // or drop(subscriber) thread::spawn(move || { + let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); + // or subscriber.reject(Error {} ); + // or drop(subscriber) + loop { thread::sleep(time::Duration::from_millis(1000)); - match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() { - Ok(_) => {}, + match sink.notify(Params::Array(vec![Value::Number(10.into())])) { + Ok(_) => {} Err(_) => { println!("Subscription has ended, finishing."); break; @@ -68,15 +67,19 @@ fn main() { } }); }), - ("remove_hello", |_id: SubscriptionId| -> BoxFuture { - println!("Closing subscription"); - Box::new(futures::future::ok(Value::Bool(true))) - }), + ( + "remove_hello", + |_id: SubscriptionId, _meta| -> BoxFuture> { + println!("Closing subscription"); + Box::pin(futures::future::ready(Ok(Value::Bool(true)))) + }, + ), ); - let server = ServerBuilder::with_meta_extractor(io, |context: &RequestContext| Arc::new(Session::new(context.sender()))) - .start(&"127.0.0.1:3030".parse().unwrap()) - .expect("Unable to start RPC server"); + let server = + ServerBuilder::with_meta_extractor(io, |context: &RequestContext| Arc::new(Session::new(context.sender()))) + .start(&"127.0.0.1:3030".parse().unwrap()) + .expect("Unable to start RPC server"); let _ = server.wait(); } diff --git a/pubsub/src/delegates.rs b/pubsub/src/delegates.rs new file mode 100644 index 000000000..7df77e079 --- /dev/null +++ b/pubsub/src/delegates.rs @@ -0,0 +1,151 @@ +use std::marker::PhantomData; +use std::sync::Arc; + +use crate::core::futures::Future; +use crate::core::{self, Metadata, Params, RemoteProcedure, RpcMethod, Value}; +use crate::handler::{SubscribeRpcMethod, UnsubscribeRpcMethod}; +use crate::subscription::{new_subscription, Subscriber}; +use crate::types::{PubSubMetadata, SubscriptionId}; + +struct DelegateSubscription { + delegate: Arc, + closure: F, +} + +impl SubscribeRpcMethod for DelegateSubscription +where + M: PubSubMetadata, + F: Fn(&T, Params, M, Subscriber), + T: Send + Sync + 'static, + F: Send + Sync + 'static, +{ + fn call(&self, params: Params, meta: M, subscriber: Subscriber) { + let closure = &self.closure; + closure(&self.delegate, params, meta, subscriber) + } +} + +impl UnsubscribeRpcMethod for DelegateSubscription +where + M: PubSubMetadata, + F: Fn(&T, SubscriptionId, Option) -> I, + I: Future> + Send + 'static, + T: Send + Sync + 'static, + F: Send + Sync + 'static, +{ + type Out = I; + fn call(&self, id: SubscriptionId, meta: Option) -> Self::Out { + let closure = &self.closure; + closure(&self.delegate, id, meta) + } +} + +/// Wire up rpc subscriptions to `delegate` struct +pub struct IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + inner: core::IoDelegate, + delegate: Arc, + _data: PhantomData, +} + +impl IoDelegate +where + T: Send + Sync + 'static, + M: PubSubMetadata, +{ + /// Creates new `PubSubIoDelegate`, wrapping the core IoDelegate + pub fn new(delegate: Arc) -> Self { + IoDelegate { + inner: core::IoDelegate::new(delegate.clone()), + delegate, + _data: PhantomData, + } + } + + /// Adds subscription to the delegate. + pub fn add_subscription(&mut self, name: &str, subscribe: (&str, Sub), unsubscribe: (&str, Unsub)) + where + Sub: Fn(&T, Params, M, Subscriber), + Sub: Send + Sync + 'static, + Unsub: Fn(&T, SubscriptionId, Option) -> I, + I: Future> + Send + 'static, + Unsub: Send + Sync + 'static, + { + let (sub, unsub) = new_subscription( + name, + DelegateSubscription { + delegate: self.delegate.clone(), + closure: subscribe.1, + }, + DelegateSubscription { + delegate: self.delegate.clone(), + closure: unsubscribe.1, + }, + ); + self.inner + .add_method_with_meta(subscribe.0, move |_, params, meta| sub.call(params, meta)); + self.inner + .add_method_with_meta(unsubscribe.0, move |_, params, meta| unsub.call(params, meta)); + } + + /// Adds an alias to existing method. + pub fn add_alias(&mut self, from: &str, to: &str) { + self.inner.add_alias(from, to) + } + + // TODO [ToDr] Consider sync? + /// Adds async method to the delegate. + pub fn add_method(&mut self, name: &str, method: F) + where + F: Fn(&T, Params) -> I, + I: Future> + Send + 'static, + F: Send + Sync + 'static, + { + self.inner.add_method(name, method) + } + + /// Adds async method with metadata to the delegate. + pub fn add_method_with_meta(&mut self, name: &str, method: F) + where + F: Fn(&T, Params, M) -> I, + I: Future> + Send + 'static, + F: Send + Sync + 'static, + { + self.inner.add_method_with_meta(name, method) + } + + /// Adds notification to the delegate. + pub fn add_notification(&mut self, name: &str, notification: F) + where + F: Fn(&T, Params), + F: Send + Sync + 'static, + { + self.inner.add_notification(name, notification) + } +} + +impl core::IoHandlerExtension for IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + fn augment>(self, handler: &mut core::MetaIoHandler) { + handler.extend_with(self.inner) + } +} + +impl IntoIterator for IoDelegate +where + T: Send + Sync + 'static, + M: Metadata, +{ + type Item = (String, RemoteProcedure); + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} diff --git a/pubsub/src/handler.rs b/pubsub/src/handler.rs index eb7b0b243..912f7701f 100644 --- a/pubsub/src/handler.rs +++ b/pubsub/src/handler.rs @@ -1,8 +1,8 @@ -use core; -use core::futures::{Future, IntoFuture}; +use crate::core; +use crate::core::futures::Future; -use types::{PubSubMetadata, SubscriptionId}; -use subscription::{Subscriber, new_subscription}; +use crate::subscription::{new_subscription, Subscriber}; +use crate::types::{PubSubMetadata, SubscriptionId}; /// Subscribe handler pub trait SubscribeRpcMethod: Send + Sync + 'static { @@ -10,7 +10,8 @@ pub trait SubscribeRpcMethod: Send + Sync + 'static { fn call(&self, params: core::Params, meta: M, subscriber: Subscriber); } -impl SubscribeRpcMethod for F where +impl SubscribeRpcMethod for F +where F: Fn(core::Params, M, Subscriber) + Send + Sync + 'static, M: PubSubMetadata, { @@ -20,21 +21,23 @@ impl SubscribeRpcMethod for F where } /// Unsubscribe handler -pub trait UnsubscribeRpcMethod: Send + Sync + 'static { +pub trait UnsubscribeRpcMethod: Send + Sync + 'static { /// Output type - type Out: Future + Send + 'static; + type Out: Future> + Send + 'static; /// Called when client is requesting to cancel existing subscription. - fn call(&self, id: SubscriptionId) -> Self::Out; + /// + /// Metadata is not available if the session was closed without unsubscribing. + fn call(&self, id: SubscriptionId, meta: Option) -> Self::Out; } -impl UnsubscribeRpcMethod for F where - F: Fn(SubscriptionId) -> I + Send + Sync + 'static, - I: IntoFuture, - I::Future: Send + 'static, +impl UnsubscribeRpcMethod for F +where + F: Fn(SubscriptionId, Option) -> I + Send + Sync + 'static, + I: Future> + Send + 'static, { - type Out = I::Future; - fn call(&self, id: SubscriptionId) -> Self::Out { - (*self)(id).into_future() + type Out = I; + fn call(&self, id: SubscriptionId, meta: Option) -> Self::Out { + (*self)(id, meta) } } @@ -54,20 +57,14 @@ impl Default for PubSubHandler { impl> PubSubHandler { /// Creates new `PubSubHandler` pub fn new(handler: core::MetaIoHandler) -> Self { - PubSubHandler { - handler: handler, - } + PubSubHandler { handler } } /// Adds new subscription. - pub fn add_subscription( - &mut self, - notification: &str, - subscribe: (&str, F), - unsubscribe: (&str, G), - ) where + pub fn add_subscription(&mut self, notification: &str, subscribe: (&str, F), unsubscribe: (&str, G)) + where F: SubscribeRpcMethod, - G: UnsubscribeRpcMethod, + G: UnsubscribeRpcMethod, { let (sub, unsub) = new_subscription(notification, subscribe.1, unsubscribe.1); self.handler.add_method_with_meta(subscribe.0, sub); @@ -97,24 +94,23 @@ impl> Into> #[cfg(test)] mod tests { - use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; - use core; - use core::futures::future; - use core::futures::sync::mpsc; - use subscription::{Session, Subscriber}; - use types::{PubSubMetadata, SubscriptionId}; + use crate::core; + use crate::core::futures::channel::mpsc; + use crate::core::futures::future; + use crate::subscription::{Session, Subscriber}; + use crate::types::{PubSubMetadata, SubscriptionId}; use super::PubSubHandler; - #[derive(Clone, Default)] - struct Metadata; + #[derive(Clone)] + struct Metadata(Arc); impl core::Metadata for Metadata {} impl PubSubMetadata for Metadata { fn session(&self) -> Option> { - let (tx, _rx) = mpsc::channel(1); - Some(Arc::new(Session::new(tx))) + Some(self.0.clone()) } } @@ -130,7 +126,7 @@ mod tests { assert_eq!(params, core::Params::None); let _sink = subscriber.assign_id(SubscriptionId::Number(5)); }), - ("unsubscribe_hello", move |id| { + ("unsubscribe_hello", move |id, _meta| { // Should be called because session is dropped. called2.store(true, Ordering::SeqCst); assert_eq!(id, SubscriptionId::Number(5)); @@ -139,7 +135,8 @@ mod tests { ); // when - let meta = Metadata; + let (tx, _rx) = mpsc::unbounded(); + let meta = Metadata(Arc::new(Session::new(tx))); let req = r#"{"jsonrpc":"2.0","id":1,"method":"subscribe_hello","params":null}"#; let res = handler.handle_request_sync(req, meta); @@ -148,5 +145,4 @@ mod tests { assert_eq!(res, Some(response.into())); assert_eq!(called.load(Ordering::SeqCst), true); } - } diff --git a/pubsub/src/lib.rs b/pubsub/src/lib.rs index 9e773df76..145304486 100644 --- a/pubsub/src/lib.rs +++ b/pubsub/src/lib.rs @@ -1,17 +1,21 @@ //! Publish-Subscribe extension for JSON-RPC -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate jsonrpc_core as core; -extern crate parking_lot; +use jsonrpc_core as core; #[macro_use] extern crate log; +mod delegates; mod handler; +pub mod manager; +pub mod oneshot; mod subscription; +pub mod typed; mod types; +pub use self::delegates::IoDelegate; pub use self::handler::{PubSubHandler, SubscribeRpcMethod, UnsubscribeRpcMethod}; -pub use self::subscription::{Session, Sink, Subscriber, new_subscription}; -pub use self::types::{PubSubMetadata, SubscriptionId, TransportError, SinkResult}; +pub use self::subscription::{new_subscription, Session, Sink, Subscriber}; +pub use self::types::{PubSubMetadata, SinkResult, SubscriptionId, TransportError}; diff --git a/pubsub/src/manager.rs b/pubsub/src/manager.rs new file mode 100644 index 000000000..1948dde10 --- /dev/null +++ b/pubsub/src/manager.rs @@ -0,0 +1,370 @@ +//! The SubscriptionManager used to manage subscription based RPCs. +//! +//! The manager provides four main things in terms of functionality: +//! +//! 1. The ability to create unique subscription IDs through the +//! use of the `IdProvider` trait. Two implementations are availble +//! out of the box, a `NumericIdProvider` and a `RandomStringIdProvider`. +//! +//! 2. An executor with which to drive `Future`s to completion. +//! +//! 3. A way to add new subscriptions. Subscriptions should come in the form +//! of a `Stream`. These subscriptions will be transformed into notifications +//! by the manager, which can be consumed by the client. +//! +//! 4. A way to cancel any currently active subscription. + +use std::collections::HashMap; +use std::iter; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use crate::core::futures::channel::oneshot; +use crate::core::futures::{self, task, Future, FutureExt, TryFutureExt}; +use crate::{ + typed::{Sink, Subscriber}, + SubscriptionId, +}; + +use log::{error, warn}; +use parking_lot::Mutex; +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; + +/// Cloneable `Spawn` handle. +pub type TaskExecutor = Arc; + +type ActiveSubscriptions = Arc>>>; + +/// Trait used to provide unique subscription IDs. +pub trait IdProvider { + /// A unique ID used to identify a subscription. + type Id: Default + Into; + + /// Returns the next ID for the subscription. + fn next_id(&self) -> Self::Id; +} + +/// Provides a thread-safe incrementing integer which +/// can be used as a subscription ID. +#[derive(Clone, Debug)] +pub struct NumericIdProvider { + current_id: Arc, +} + +impl NumericIdProvider { + /// Create a new NumericIdProvider. + pub fn new() -> Self { + Default::default() + } + + /// Create a new NumericIdProvider starting from + /// the given ID. + pub fn with_id(id: AtomicUsize) -> Self { + Self { + current_id: Arc::new(id), + } + } +} + +impl IdProvider for NumericIdProvider { + type Id = u64; + + fn next_id(&self) -> Self::Id { + self.current_id.fetch_add(1, Ordering::AcqRel) as u64 + } +} + +impl Default for NumericIdProvider { + fn default() -> Self { + NumericIdProvider { + current_id: Arc::new(AtomicUsize::new(1)), + } + } +} + +/// Used to generate random strings for use as +/// subscription IDs. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub struct RandomStringIdProvider { + len: usize, +} + +impl RandomStringIdProvider { + /// Create a new RandomStringIdProvider. + pub fn new() -> Self { + Default::default() + } + + /// Create a new RandomStringIdProvider, which will generate + /// random id strings of the given length. + pub fn with_len(len: usize) -> Self { + Self { len } + } +} + +impl IdProvider for RandomStringIdProvider { + type Id = String; + + fn next_id(&self) -> Self::Id { + let mut rng = thread_rng(); + let id: String = iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .take(self.len) + .collect(); + id + } +} + +impl Default for RandomStringIdProvider { + fn default() -> Self { + Self { len: 16 } + } +} + +/// Subscriptions manager. +/// +/// Takes care of assigning unique subscription ids and +/// driving the sinks into completion. +#[derive(Clone)] +pub struct SubscriptionManager { + id_provider: I, + active_subscriptions: ActiveSubscriptions, + executor: TaskExecutor, +} + +impl SubscriptionManager { + /// Creates a new SubscriptionManager. + /// + /// Uses `RandomStringIdProvider` as the ID provider. + pub fn new(executor: TaskExecutor) -> Self { + Self { + id_provider: RandomStringIdProvider::default(), + active_subscriptions: Default::default(), + executor, + } + } +} + +impl SubscriptionManager { + /// Creates a new SubscriptionManager with the specified + /// ID provider. + pub fn with_id_provider(id_provider: I, executor: TaskExecutor) -> Self { + Self { + id_provider, + active_subscriptions: Default::default(), + executor, + } + } + + /// Borrows the internal task executor. + /// + /// This can be used to spawn additional tasks on the underlying event loop. + pub fn executor(&self) -> &TaskExecutor { + &self.executor + } + + /// Creates new subscription for given subscriber. + /// + /// Second parameter is a function that converts Subscriber Sink into a Future. + /// This future will be driven to completion by the underlying event loop + pub fn add(&self, subscriber: Subscriber, into_future: G) -> SubscriptionId + where + G: FnOnce(Sink) -> F, + F: Future + Send + 'static, + { + let id = self.id_provider.next_id(); + let subscription_id: SubscriptionId = id.into(); + if let Ok(sink) = subscriber.assign_id(subscription_id.clone()) { + let (tx, rx) = oneshot::channel(); + let f = into_future(sink).fuse(); + let rx = rx.map_err(|e| warn!("Error timing out: {:?}", e)).fuse(); + let future = async move { + futures::pin_mut!(f); + futures::pin_mut!(rx); + futures::select! { + a = f => a, + _ = rx => (), + } + }; + + self.active_subscriptions.lock().insert(subscription_id.clone(), tx); + if self.executor.spawn_obj(task::FutureObj::new(Box::pin(future))).is_err() { + error!("Failed to spawn RPC subscription task"); + } + } + + subscription_id + } + + /// Cancel subscription. + /// + /// Returns true if subscription existed or false otherwise. + pub fn cancel(&self, id: SubscriptionId) -> bool { + if let Some(tx) = self.active_subscriptions.lock().remove(&id) { + let _ = tx.send(()); + return true; + } + + false + } +} + +impl SubscriptionManager { + /// Creates a new SubscriptionManager. + pub fn with_executor(executor: TaskExecutor) -> Self { + Self { + id_provider: Default::default(), + active_subscriptions: Default::default(), + executor, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed::Subscriber; + use futures::{executor, stream}; + use futures::{FutureExt, StreamExt}; + + // Executor shared by all tests. + // + // This shared executor is used to prevent `Too many open files` errors + // on systems with a lot of cores. + lazy_static::lazy_static! { + static ref EXECUTOR: executor::ThreadPool = executor::ThreadPool::new() + .expect("Failed to create thread pool executor for tests"); + } + + pub struct TestTaskExecutor; + impl task::Spawn for TestTaskExecutor { + fn spawn_obj(&self, future: task::FutureObj<'static, ()>) -> Result<(), task::SpawnError> { + EXECUTOR.spawn_obj(future) + } + + fn status(&self) -> Result<(), task::SpawnError> { + EXECUTOR.status() + } + } + + #[test] + fn making_a_numeric_id_provider_works() { + let provider = NumericIdProvider::new(); + let expected_id = 1; + let actual_id = provider.next_id(); + + assert_eq!(actual_id, expected_id); + } + + #[test] + fn default_numeric_id_provider_works() { + let provider: NumericIdProvider = Default::default(); + let expected_id = 1; + let actual_id = provider.next_id(); + + assert_eq!(actual_id, expected_id); + } + + #[test] + fn numeric_id_provider_with_id_works() { + let provider = NumericIdProvider::with_id(AtomicUsize::new(5)); + let expected_id = 5; + let actual_id = provider.next_id(); + + assert_eq!(actual_id, expected_id); + } + + #[test] + fn random_string_provider_returns_id_with_correct_default_len() { + let provider = RandomStringIdProvider::new(); + let expected_len = 16; + let actual_len = provider.next_id().len(); + + assert_eq!(actual_len, expected_len); + } + + #[test] + fn random_string_provider_returns_id_with_correct_user_given_len() { + let expected_len = 10; + let provider = RandomStringIdProvider::with_len(expected_len); + let actual_len = provider.next_id().len(); + + assert_eq!(actual_len, expected_len); + } + + #[test] + fn new_subscription_manager_defaults_to_random_string_provider() { + let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor)); + let subscriber = Subscriber::::new_test("test_subTest").0; + let stream = stream::iter(vec![Ok(Ok(1))]); + + let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ())); + + assert!(matches!(id, SubscriptionId::String(_))) + } + + #[test] + fn new_subscription_manager_works_with_numeric_id_provider() { + let id_provider = NumericIdProvider::default(); + let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor)); + + let subscriber = Subscriber::::new_test("test_subTest").0; + let stream = stream::iter(vec![Ok(Ok(1))]); + + let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ())); + + assert!(matches!(id, SubscriptionId::Number(_))) + } + + #[test] + fn new_subscription_manager_works_with_random_string_provider() { + let id_provider = RandomStringIdProvider::default(); + let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor)); + + let subscriber = Subscriber::::new_test("test_subTest").0; + let stream = stream::iter(vec![Ok(Ok(1))]); + + let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ())); + + assert!(matches!(id, SubscriptionId::String(_))) + } + + #[test] + fn subscription_is_canceled_if_it_existed() { + let manager = SubscriptionManager::::with_executor(Arc::new(TestTaskExecutor)); + // Need to bind receiver here (unlike the other tests) or else the subscriber + // will think the client has disconnected and not update `active_subscriptions` + let (subscriber, _recv, _) = Subscriber::::new_test("test_subTest"); + + let (mut tx, rx) = futures::channel::mpsc::channel(8); + tx.start_send(1).unwrap(); + let id = manager.add(subscriber, move |sink| { + let rx = rx.map(|v| Ok(Ok(v))); + rx.forward(sink).map(|_| ()) + }); + + let is_cancelled = manager.cancel(id); + assert!(is_cancelled); + } + + #[test] + fn subscription_is_not_canceled_because_it_didnt_exist() { + let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor)); + + let id: SubscriptionId = 23u32.into(); + let is_cancelled = manager.cancel(id); + let is_not_cancelled = !is_cancelled; + + assert!(is_not_cancelled); + } + + #[test] + fn is_send_sync() { + fn send_sync() {} + + send_sync::(); + } +} diff --git a/pubsub/src/oneshot.rs b/pubsub/src/oneshot.rs new file mode 100644 index 000000000..6e6f3cdf7 --- /dev/null +++ b/pubsub/src/oneshot.rs @@ -0,0 +1,100 @@ +//! A futures oneshot channel that can be used for rendezvous. + +use crate::core::futures::{self, channel::oneshot, future, Future, FutureExt, TryFutureExt}; +use std::ops::{Deref, DerefMut}; + +/// Create a new future-base rendezvous channel. +/// +/// The returned `Sender` and `Receiver` objects are wrapping +/// the regular `futures::channel::oneshot` counterparts and have the same functionality. +/// Additionaly `Sender::send_and_wait` allows you to send a message to the channel +/// and get a future that resolves when the message is consumed. +pub fn channel() -> (Sender, Receiver) { + let (sender, receiver) = oneshot::channel(); + let (receipt_tx, receipt_rx) = oneshot::channel(); + + ( + Sender { + sender, + receipt: receipt_rx, + }, + Receiver { + receiver, + receipt: Some(receipt_tx), + }, + ) +} + +/// A sender part of the channel. +#[derive(Debug)] +pub struct Sender { + sender: oneshot::Sender, + receipt: oneshot::Receiver<()>, +} + +impl Sender { + /// Consume the sender and queue up an item to send. + /// + /// This method returns right away and never blocks, + /// there is no guarantee though that the message is received + /// by the other end. + pub fn send(self, t: T) -> Result<(), T> { + self.sender.send(t) + } + + /// Consume the sender and send an item. + /// + /// The returned future will resolve when the message is received + /// on the other end. Note that polling the future is actually not required + /// to send the message as that happens synchronously. + /// The future resolves to error in case the receiving end was dropped before + /// being able to process the message. + pub fn send_and_wait(self, t: T) -> impl Future> { + let Self { sender, receipt } = self; + + if sender.send(t).is_err() { + return future::Either::Left(future::ready(Err(()))); + } + + future::Either::Right(receipt.map_err(|_| ())) + } +} + +impl Deref for Sender { + type Target = oneshot::Sender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for Sender { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } +} + +/// Receiving end of the channel. +/// +/// When this object is `polled` and the result is `Ready` +/// the other end (`Sender`) is also notified about the fact +/// that the item has been consumed and the future returned +/// by `send_and_wait` resolves. +#[must_use = "futures do nothing unless polled"] +#[derive(Debug)] +pub struct Receiver { + receiver: oneshot::Receiver, + receipt: Option>, +} + +impl Future for Receiver { + type Output = as Future>::Output; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut futures::task::Context) -> futures::task::Poll { + let r = futures::ready!(self.receiver.poll_unpin(cx))?; + if let Some(receipt) = self.receipt.take() { + let _ = receipt.send(()); + } + Ok(r).into() + } +} diff --git a/pubsub/src/subscription.rs b/pubsub/src/subscription.rs index 7e34778a3..14e6ddcbb 100644 --- a/pubsub/src/subscription.rs +++ b/pubsub/src/subscription.rs @@ -1,23 +1,33 @@ //! Subscription primitives. -use std::fmt; +use parking_lot::Mutex; use std::collections::HashMap; +use std::fmt; +use std::pin::Pin; use std::sync::Arc; -use parking_lot::Mutex; -use core::{self, BoxFuture}; -use core::futures::{self, future, Sink as FuturesSink, Future}; -use core::futures::sync::{mpsc, oneshot}; +use crate::core::futures::channel::mpsc; +use crate::core::futures::{ + self, future, + task::{Context, Poll}, + Future, Sink as FuturesSink, TryFutureExt, +}; +use crate::core::{self, BoxFuture}; + +use crate::handler::{SubscribeRpcMethod, UnsubscribeRpcMethod}; +use crate::types::{PubSubMetadata, SinkResult, SubscriptionId, TransportError, TransportSender}; -use handler::{SubscribeRpcMethod, UnsubscribeRpcMethod}; -use types::{PubSubMetadata, SubscriptionId, TransportSender, TransportError, SinkResult}; +lazy_static::lazy_static! { + static ref UNSUBSCRIBE_POOL: futures::executor::ThreadPool = futures::executor::ThreadPool::new() + .expect("Unable to spawn background pool for unsubscribe tasks."); +} /// RPC client session /// Keeps track of active subscriptions and unsubscribes from them upon dropping. pub struct Session { - active_subscriptions: Mutex>>, + active_subscriptions: Mutex>>, transport: TransportSender, - on_drop: Mutex>>, + on_drop: Mutex>>, } impl fmt::Debug for Session { @@ -56,10 +66,14 @@ impl Session { } /// Adds new active subscription - fn add_subscription(&self, name: &str, id: &SubscriptionId, remove: F) where + fn add_subscription(&self, name: &str, id: &SubscriptionId, remove: F) + where F: Fn(SubscriptionId) + Send + 'static, { - let ret = self.active_subscriptions.lock().insert((id.clone(), name.into()), Box::new(remove)); + let ret = self + .active_subscriptions + .lock() + .insert((id.clone(), name.into()), Box::new(remove)); if let Some(remove) = ret { warn!("SubscriptionId collision. Unsubscribing previous client."); remove(id.clone()); @@ -67,8 +81,11 @@ impl Session { } /// Removes existing subscription. - fn remove_subscription(&self, name: &str, id: &SubscriptionId) { - self.active_subscriptions.lock().remove(&(id.clone(), name.into())); + fn remove_subscription(&self, name: &str, id: &SubscriptionId) -> bool { + self.active_subscriptions + .lock() + .remove(&(id.clone(), name.into())) + .is_some() } } @@ -97,40 +114,37 @@ impl Sink { /// Sends a notification to a client. pub fn notify(&self, val: core::Params) -> SinkResult { let val = self.params_to_string(val); - self.transport.clone().send(val.0) + self.transport.clone().unbounded_send(val) } - fn params_to_string(&self, val: core::Params) -> (String, core::Params) { + fn params_to_string(&self, val: core::Params) -> String { let notification = core::Notification { jsonrpc: Some(core::Version::V2), method: self.notification.clone(), params: val, }; - ( - core::to_string(¬ification).expect("Notification serialization never fails."), - notification.params, - ) + core::to_string(¬ification).expect("Notification serialization never fails.") } } -impl FuturesSink for Sink { - type SinkItem = core::Params; - type SinkError = TransportError; +impl FuturesSink for Sink { + type Error = TransportError; - fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend { - let (val, params) = self.params_to_string(item); - self.transport.start_send(val).map(|result| match result { - futures::AsyncSink::Ready => futures::AsyncSink::Ready, - futures::AsyncSink::NotReady(_) => futures::AsyncSink::NotReady(params), - }) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.transport).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: core::Params) -> Result<(), Self::Error> { + let val = self.params_to_string(item); + Pin::new(&mut self.transport).start_send(val) } - fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> { - self.transport.poll_complete() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.transport).poll_flush(cx) } - fn close(&mut self) -> futures::Poll<(), Self::SinkError> { - self.transport.close() + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.transport).poll_close(cx) } } @@ -140,20 +154,22 @@ impl FuturesSink for Sink { pub struct Subscriber { notification: String, transport: TransportSender, - sender: oneshot::Sender>, + sender: crate::oneshot::Sender>, } impl Subscriber { /// Creates new subscriber. /// /// Should only be used for tests. - pub fn new_test>(method: T) -> ( + pub fn new_test>( + method: T, + ) -> ( Self, - oneshot::Receiver>, - mpsc::Receiver, + crate::oneshot::Receiver>, + mpsc::UnboundedReceiver, ) { - let (sender, id_receiver) = oneshot::channel(); - let (transport, transport_receiver) = mpsc::channel(1); + let (sender, id_receiver) = crate::oneshot::channel(); + let (transport, transport_receiver) = mpsc::unbounded(); let subscriber = Subscriber { notification: method.into(), @@ -165,41 +181,72 @@ impl Subscriber { } /// Consumes `Subscriber` and assigns unique id to a requestor. + /// /// Returns `Err` if request has already terminated. pub fn assign_id(self, id: SubscriptionId) -> Result { - self.sender.send(Ok(id)).map_err(|_| ())?; + let Self { + notification, + transport, + sender, + } = self; + sender + .send(Ok(id)) + .map(|_| Sink { + notification, + transport, + }) + .map_err(|_| ()) + } - Ok(Sink { - notification: self.notification, - transport: self.transport, + /// Consumes `Subscriber` and assigns unique id to a requestor. + /// + /// The returned `Future` resolves when the subscriber receives subscription id. + /// Resolves to `Err` if request has already terminated. + pub fn assign_id_async(self, id: SubscriptionId) -> impl Future> { + let Self { + notification, + transport, + sender, + } = self; + sender.send_and_wait(Ok(id)).map_ok(|_| Sink { + notification, + transport, }) } /// Rejects this subscription request with given error. + /// /// Returns `Err` if request has already terminated. pub fn reject(self, error: core::Error) -> Result<(), ()> { - self.sender.send(Err(error)).map_err(|_| ())?; - Ok(()) + self.sender.send(Err(error)).map_err(|_| ()) } -} + /// Rejects this subscription request with given error. + /// + /// The returned `Future` resolves when the rejection is sent to the client. + /// Resolves to `Err` if request has already terminated. + pub fn reject_async(self, error: core::Error) -> impl Future> { + self.sender.send_and_wait(Err(error)).map_ok(|_| ()).map_err(|_| ()) + } +} /// Creates new subscribe and unsubscribe RPC methods -pub fn new_subscription(notification: &str, subscribe: F, unsubscribe: G) -> (Subscribe, Unsubscribe) where +pub fn new_subscription(notification: &str, subscribe: F, unsubscribe: G) -> (Subscribe, Unsubscribe) +where M: PubSubMetadata, F: SubscribeRpcMethod, - G: UnsubscribeRpcMethod, + G: UnsubscribeRpcMethod, { let unsubscribe = Arc::new(unsubscribe); let subscribe = Subscribe { notification: notification.to_owned(), - subscribe: subscribe, unsubscribe: unsubscribe.clone(), + subscribe, }; let unsubscribe = Unsubscribe { notification: notification.into(), - unsubscribe: unsubscribe, + unsubscribe, }; (subscribe, unsubscribe) @@ -228,15 +275,16 @@ pub struct Subscribe { unsubscribe: Arc, } -impl core::RpcMethod for Subscribe where +impl core::RpcMethod for Subscribe +where M: PubSubMetadata, F: SubscribeRpcMethod, - G: UnsubscribeRpcMethod, + G: UnsubscribeRpcMethod, { - fn call(&self, params: core::Params, meta: M) -> BoxFuture { + fn call(&self, params: core::Params, meta: M) -> BoxFuture> { match meta.session() { Some(session) => { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crate::oneshot::channel(); // Register the subscription let subscriber = Subscriber { @@ -248,22 +296,26 @@ impl core::RpcMethod for Subscribe where let unsub = self.unsubscribe.clone(); let notification = self.notification.clone(); - let subscribe_future = rx - .map_err(|_| subscription_rejected()) - .and_then(move |result| { - futures::done(match result { - Ok(id) => { - session.add_subscription(¬ification, &id, move |id| { - let _ = unsub.call(id).wait(); + let subscribe_future = rx.map_err(|_| subscription_rejected()).and_then(move |result| { + futures::future::ready(match result { + Ok(id) => { + session.add_subscription(¬ification, &id, move |id| { + // TODO [#570] [ToDr] We currently run unsubscribe tasks on a shared thread pool. + // In the future we should use some kind of `::spawn` method + // that spawns a task on an existing executor or pass the spawner handle here. + let f = unsub.call(id, None); + UNSUBSCRIBE_POOL.spawn_ok(async move { + let _ = f.await; }); - Ok(id.into()) - }, - Err(e) => Err(e), - }) - }); - Box::new(subscribe_future) - }, - None => Box::new(future::err(subscriptions_unavailable())), + }); + Ok(id.into()) + } + Err(e) => Err(e), + }) + }); + Box::pin(subscribe_future) + } + None => Box::pin(future::err(subscriptions_unavailable())), } } } @@ -274,42 +326,43 @@ pub struct Unsubscribe { unsubscribe: Arc, } -impl core::RpcMethod for Unsubscribe where +impl core::RpcMethod for Unsubscribe +where M: PubSubMetadata, - G: UnsubscribeRpcMethod, + G: UnsubscribeRpcMethod, { - fn call(&self, params: core::Params, meta: M) -> BoxFuture { + fn call(&self, params: core::Params, meta: M) -> BoxFuture> { let id = match params { - core::Params::Array(ref vec) if vec.len() == 1 => { - SubscriptionId::parse_value(&vec[0]) - }, + core::Params::Array(ref vec) if vec.len() == 1 => SubscriptionId::parse_value(&vec[0]), _ => None, }; match (meta.session(), id) { (Some(session), Some(id)) => { - session.remove_subscription(&self.notification, &id); - Box::new(self.unsubscribe.call(id)) - }, - (Some(_), None) => Box::new(future::err(core::Error::invalid_params("Expected subscription id."))), - _ => Box::new(future::err(subscriptions_unavailable())), + if session.remove_subscription(&self.notification, &id) { + Box::pin(self.unsubscribe.call(id, Some(meta))) + } else { + Box::pin(future::err(core::Error::invalid_params("Invalid subscription id."))) + } + } + (Some(_), None) => Box::pin(future::err(core::Error::invalid_params("Expected subscription id."))), + _ => Box::pin(future::err(subscriptions_unavailable())), } } } #[cfg(test)] mod tests { - use std::sync::Arc; + use crate::core; + use crate::core::futures::channel::mpsc; + use crate::core::RpcMethod; + use crate::types::{PubSubMetadata, SubscriptionId}; use std::sync::atomic::{AtomicBool, Ordering}; - use core; - use core::RpcMethod; - use core::futures::{Async, Future, Stream}; - use core::futures::sync::{mpsc, oneshot}; - use types::{SubscriptionId, PubSubMetadata}; + use std::sync::Arc; - use super::{Session, Sink, Subscriber, new_subscription}; + use super::{new_subscription, Session, Sink, Subscriber}; - fn session() -> (Session, mpsc::Receiver) { - let (tx, rx) = mpsc::channel(1); + fn session() -> (Session, mpsc::UnboundedReceiver) { + let (tx, rx) = mpsc::unbounded(); (Session::new(tx), rx) } @@ -345,13 +398,36 @@ mod tests { }); // when - session.remove_subscription("test", &id); + let removed = session.remove_subscription("test", &id); drop(session); // then + assert_eq!(removed, true); assert_eq!(called.load(Ordering::SeqCst), false); } + #[test] + fn should_not_remove_subscription_if_invalid() { + // given + let id = SubscriptionId::Number(1); + let called = Arc::new(AtomicBool::new(false)); + let called2 = called.clone(); + let other_session = session().0; + let session = session().0; + session.add_subscription("test", &id, move |id| { + assert_eq!(id, SubscriptionId::Number(1)); + called2.store(true, Ordering::SeqCst); + }); + + // when + let removed = other_session.remove_subscription("test", &id); + drop(session); + + // then + assert_eq!(removed, false); + assert_eq!(called.load(Ordering::SeqCst), true); + } + #[test] fn should_unregister_in_case_of_collision() { // given @@ -374,52 +450,52 @@ mod tests { #[test] fn should_send_notification_to_the_transport() { // given - let (tx, mut rx) = mpsc::channel(1); + let (tx, mut rx) = mpsc::unbounded(); let sink = Sink { notification: "test".into(), transport: tx, }; // when - sink.notify(core::Params::Array(vec![core::Value::Number(10.into())])).wait().unwrap(); + sink.notify(core::Params::Array(vec![core::Value::Number(10.into())])) + .unwrap(); + let val = rx.try_next().unwrap(); // then - assert_eq!( - rx.poll().unwrap(), - Async::Ready(Some(r#"{"jsonrpc":"2.0","method":"test","params":[10]}"#.into())) - ); + assert_eq!(val, Some(r#"{"jsonrpc":"2.0","method":"test","params":[10]}"#.into())); } #[test] fn should_assign_id() { // given - let (transport, _) = mpsc::channel(1); - let (tx, mut rx) = oneshot::channel(); + let (transport, _) = mpsc::unbounded(); + let (tx, rx) = crate::oneshot::channel(); let subscriber = Subscriber { notification: "test".into(), - transport: transport, + transport, sender: tx, }; // when - let sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); + let sink = subscriber.assign_id_async(SubscriptionId::Number(5)); // then - assert_eq!( - rx.poll().unwrap(), - Async::Ready(Ok(SubscriptionId::Number(5))) - ); - assert_eq!(sink.notification, "test".to_owned()); + futures::executor::block_on(async move { + let id = rx.await; + assert_eq!(id, Ok(Ok(SubscriptionId::Number(5)))); + let sink = sink.await.unwrap(); + assert_eq!(sink.notification, "test".to_owned()); + }) } #[test] fn should_reject() { // given - let (transport, _) = mpsc::channel(1); - let (tx, mut rx) = oneshot::channel(); + let (transport, _) = mpsc::unbounded(); + let (tx, rx) = crate::oneshot::channel(); let subscriber = Subscriber { notification: "test".into(), - transport: transport, + transport, sender: tx, }; let error = core::Error { @@ -429,48 +505,97 @@ mod tests { }; // when - subscriber.reject(error.clone()).unwrap(); + let reject = subscriber.reject_async(error.clone()); // then - assert_eq!( - rx.poll().unwrap(), - Async::Ready(Err(error)) - ); + futures::executor::block_on(async move { + assert_eq!(rx.await.unwrap(), Err(error)); + reject.await.unwrap(); + }); } - #[derive(Clone, Default)] - struct Metadata; + #[derive(Clone)] + struct Metadata(Arc); impl core::Metadata for Metadata {} impl PubSubMetadata for Metadata { fn session(&self) -> Option> { - Some(Arc::new(session().0)) + Some(self.0.clone()) + } + } + impl Default for Metadata { + fn default() -> Self { + Self(Arc::new(session().0)) } } #[test] fn should_subscribe() { // given - let called = Arc::new(AtomicBool::new(false)); - let called2 = called.clone(); let (subscribe, _) = new_subscription( "test".into(), - move |params, _meta, _subscriber| { + move |params, _meta, subscriber: Subscriber| { assert_eq!(params, core::Params::None); - called2.store(true, Ordering::SeqCst); + let _sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap(); }, - |_id| Ok(core::Value::Bool(true)), + |_id, _meta| async { Ok(core::Value::Bool(true)) }, ); - let meta = Metadata; // when + let meta = Metadata::default(); let result = subscribe.call(core::Params::None, meta); // then - assert_eq!(called.load(Ordering::SeqCst), true); - assert_eq!(result.wait(), Err(core::Error { - code: core::ErrorCode::ServerError(-32091), - message: "Subscription rejected".into(), - data: None, - })); + assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(5))); + } + + #[test] + fn should_unsubscribe() { + // given + const SUB_ID: u64 = 5; + let (subscribe, unsubscribe) = new_subscription( + "test".into(), + move |params, _meta, subscriber: Subscriber| { + assert_eq!(params, core::Params::None); + let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap(); + }, + |_id, _meta| async { Ok(core::Value::Bool(true)) }, + ); + + // when + let meta = Metadata::default(); + futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap(); + let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID)]), meta); + + // then + assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(true))); + } + + #[test] + fn should_not_unsubscribe_if_invalid() { + // given + const SUB_ID: u64 = 5; + let (subscribe, unsubscribe) = new_subscription( + "test".into(), + move |params, _meta, subscriber: Subscriber| { + assert_eq!(params, core::Params::None); + let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap(); + }, + |_id, _meta| async { Ok(core::Value::Bool(true)) }, + ); + + // when + let meta = Metadata::default(); + futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap(); + let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID + 1)]), meta); + + // then + assert_eq!( + futures::executor::block_on(result), + Err(core::Error { + code: core::ErrorCode::InvalidParams, + message: "Invalid subscription id.".into(), + data: None, + }) + ); } } diff --git a/pubsub/src/typed.rs b/pubsub/src/typed.rs new file mode 100644 index 000000000..4830d5342 --- /dev/null +++ b/pubsub/src/typed.rs @@ -0,0 +1,136 @@ +//! PUB-SUB auto-serializing structures. + +use std::marker::PhantomData; +use std::pin::Pin; + +use crate::subscription; +use crate::types::{SinkResult, SubscriptionId, TransportError}; + +use crate::core::futures::task::{Context, Poll}; +use crate::core::futures::{self, channel}; +use crate::core::{self, Error, Params, Value}; + +/// New PUB-SUB subscriber. +#[derive(Debug)] +pub struct Subscriber { + subscriber: subscription::Subscriber, + _data: PhantomData<(T, E)>, +} + +impl Subscriber { + /// Wrap non-typed subscriber. + pub fn new(subscriber: subscription::Subscriber) -> Self { + Subscriber { + subscriber, + _data: PhantomData, + } + } + + /// Create new subscriber for tests. + pub fn new_test>( + method: M, + ) -> ( + Self, + crate::oneshot::Receiver>, + channel::mpsc::UnboundedReceiver, + ) { + let (subscriber, id, subscription) = subscription::Subscriber::new_test(method); + (Subscriber::new(subscriber), id, subscription) + } + + /// Reject subscription with given error. + pub fn reject(self, error: Error) -> Result<(), ()> { + self.subscriber.reject(error) + } + + /// Reject subscription with given error. + /// + /// The returned future will resolve when the response is sent to the client. + pub async fn reject_async(self, error: Error) -> Result<(), ()> { + self.subscriber.reject_async(error).await + } + + /// Assign id to this subscriber. + /// This method consumes `Subscriber` and returns `Sink` + /// if the connection is still open or error otherwise. + pub fn assign_id(self, id: SubscriptionId) -> Result, ()> { + let sink = self.subscriber.assign_id(id.clone())?; + Ok(Sink { + id, + sink, + _data: PhantomData, + }) + } + + /// Assign id to this subscriber. + /// This method consumes `Subscriber` and resolves to `Sink` + /// if the connection is still open and the id has been sent or to error otherwise. + pub async fn assign_id_async(self, id: SubscriptionId) -> Result, ()> { + let sink = self.subscriber.assign_id_async(id.clone()).await?; + Ok(Sink { + id, + sink, + _data: PhantomData, + }) + } +} + +/// Subscriber sink. +#[derive(Debug, Clone)] +pub struct Sink { + sink: subscription::Sink, + id: SubscriptionId, + _data: PhantomData<(T, E)>, +} + +impl Sink { + /// Sends a notification to the subscriber. + pub fn notify(&self, val: Result) -> SinkResult { + self.sink.notify(self.val_to_params(val)) + } + + fn to_value(value: V) -> Value + where + V: serde::Serialize, + { + core::to_value(value).expect("Expected always-serializable type.") + } + + fn val_to_params(&self, val: Result) -> Params { + let id = self.id.clone().into(); + let val = val.map(Self::to_value).map_err(Self::to_value); + + Params::Map( + vec![ + ("subscription".to_owned(), id), + match val { + Ok(val) => ("result".to_owned(), val), + Err(err) => ("error".to_owned(), err), + }, + ] + .into_iter() + .collect(), + ) + } +} + +impl futures::sink::Sink> for Sink { + type Error = TransportError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.sink).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Result) -> Result<(), Self::Error> { + let val = self.val_to_params(item); + Pin::new(&mut self.sink).start_send(val) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.sink).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.sink).poll_close(cx) + } +} diff --git a/pubsub/src/types.rs b/pubsub/src/types.rs index 4c58a771d..f6606801c 100644 --- a/pubsub/src/types.rs +++ b/pubsub/src/types.rs @@ -1,17 +1,21 @@ +use crate::core; +use crate::core::futures::channel::mpsc; use std::sync::Arc; -use core; -use core::futures::sync::mpsc; -use subscription::Session; +use crate::subscription::Session; /// Raw transport sink for specific client. -pub type TransportSender = mpsc::Sender; +pub type TransportSender = mpsc::UnboundedSender; /// Raw transport error. -pub type TransportError = mpsc::SendError; +pub type TransportError = mpsc::SendError; /// Subscription send result. -pub type SinkResult = core::futures::sink::Send; +pub type SinkResult = Result<(), mpsc::TrySendError>; /// Metadata extension for pub-sub method handling. +/// +/// NOTE storing `PubSubMetadata` (or rather storing `Arc`) in +/// any other place outside of the handler will prevent `unsubscribe` methods +/// to be called in case the `Session` is dropped (i.e. transport connection is closed). pub trait PubSubMetadata: core::Metadata { /// Returns session object associated with given request/client. /// `None` indicates that sessions are not supported on the used transport. @@ -31,12 +35,13 @@ impl PubSubMetadata for Option { } /// Unique subscription id. +/// /// NOTE Assigning same id to different requests will cause the previous request to be unsubscribed. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum SubscriptionId { - /// U64 number + /// A numerical ID, represented by a `u64`. Number(u64), - /// String + /// A non-numerical ID, for example a hash. String(String), } @@ -57,12 +62,6 @@ impl From for SubscriptionId { } } -impl From for SubscriptionId { - fn from(other: u64) -> Self { - SubscriptionId::Number(other) - } -} - impl From for core::Value { fn from(sub: SubscriptionId) -> Self { match sub { @@ -72,30 +71,83 @@ impl From for core::Value { } } +macro_rules! impl_from_num { + ($num:ty) => { + impl From<$num> for SubscriptionId { + fn from(other: $num) -> Self { + SubscriptionId::Number(other.into()) + } + } + }; +} + +impl_from_num!(u8); +impl_from_num!(u16); +impl_from_num!(u32); +impl_from_num!(u64); + #[cfg(test)] mod tests { - use core::Value; use super::SubscriptionId; + use crate::core::Value; + + #[test] + fn should_convert_between_number_value_and_subscription_id() { + let val = Value::Number(5.into()); + let res = SubscriptionId::parse_value(&val); + + assert_eq!(res, Some(SubscriptionId::Number(5))); + assert_eq!(Value::from(res.unwrap()), val); + } + + #[test] + fn should_convert_between_string_value_and_subscription_id() { + let val = Value::String("asdf".into()); + let res = SubscriptionId::parse_value(&val); + + assert_eq!(res, Some(SubscriptionId::String("asdf".into()))); + assert_eq!(Value::from(res.unwrap()), val); + } + + #[test] + fn should_convert_between_null_value_and_subscription_id() { + let val = Value::Null; + let res = SubscriptionId::parse_value(&val); + assert_eq!(res, None); + } + + #[test] + fn should_convert_from_u8_to_subscription_id() { + let val = 5u8; + let res: SubscriptionId = val.into(); + assert_eq!(res, SubscriptionId::Number(5)); + } + + #[test] + fn should_convert_from_u16_to_subscription_id() { + let val = 5u16; + let res: SubscriptionId = val.into(); + assert_eq!(res, SubscriptionId::Number(5)); + } + + #[test] + fn should_convert_from_u32_to_subscription_id() { + let val = 5u32; + let res: SubscriptionId = val.into(); + assert_eq!(res, SubscriptionId::Number(5)); + } + + #[test] + fn should_convert_from_u64_to_subscription_id() { + let val = 5u64; + let res: SubscriptionId = val.into(); + assert_eq!(res, SubscriptionId::Number(5)); + } #[test] - fn should_convert_between_value_and_subscription_id() { - // given - let val1 = Value::Number(5.into()); - let val2 = Value::String("asdf".into()); - let val3 = Value::Null; - - // when - let res1 = SubscriptionId::parse_value(&val1); - let res2 = SubscriptionId::parse_value(&val2); - let res3 = SubscriptionId::parse_value(&val3); - - // then - assert_eq!(res1, Some(SubscriptionId::Number(5))); - assert_eq!(res2, Some(SubscriptionId::String("asdf".into()))); - assert_eq!(res3, None); - - // and back - assert_eq!(Value::from(res1.unwrap()), val1); - assert_eq!(Value::from(res2.unwrap()), val2); + fn should_convert_from_string_to_subscription_id() { + let val = "String".to_string(); + let res: SubscriptionId = val.into(); + assert_eq!(res, SubscriptionId::String("String".to_string())); } } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..86fe606e0 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +hard_tabs = true \ No newline at end of file diff --git a/server-utils/Cargo.toml b/server-utils/Cargo.toml index 194cc28a9..c03b74f06 100644 --- a/server-utils/Cargo.toml +++ b/server-utils/Cargo.toml @@ -1,23 +1,26 @@ [package] +authors = ["Parity Technologies "] description = "Server utils for jsonrpc-core crate." -name = "jsonrpc-server-utils" -version = "9.0.0" -authors = ["tomusdrw "] -license = "MIT" -keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_core/index.html" +documentation = "https://docs.rs/jsonrpc-server-utils/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" +keywords = ["jsonrpc", "json-rpc", "json", "rpc", "serde"] +license = "MIT" +name = "jsonrpc-server-utils" repository = "https://github.com/paritytech/jsonrpc" +version = "18.0.0" [dependencies] -bytes = "0.4" +bytes = "1.0" +futures = "0.3" globset = "0.4" -jsonrpc-core = { version = "9.0", path = "../core" } +jsonrpc-core = { version = "18.0.0", path = "../core" } lazy_static = "1.1.0" log = "0.4" -num_cpus = "1.8" -tokio = { version = "0.1" } -tokio-codec = { version = "0.1" } +tokio = { version = "1", features = ["rt-multi-thread", "io-util", "time", "net"] } +tokio-util = { version = "0.6", features = ["codec"] } +tokio-stream = { version = "0.1", features = ["net"] } + unicase = "2.0" [badges] diff --git a/server-utils/src/cors.rs b/server-utils/src/cors.rs index 78b8157d6..78eab604c 100644 --- a/server-utils/src/cors.rs +++ b/server-utils/src/cors.rs @@ -1,11 +1,9 @@ //! CORS handling utility functions -extern crate unicase; - -use std::{fmt, ops}; -use hosts::{Host, Port}; -use matcher::{Matcher, Pattern}; +use crate::hosts::{Host, Port}; +use crate::matcher::{Matcher, Pattern}; use std::collections::HashSet; -pub use self::unicase::Ascii; +use std::{fmt, ops}; +pub use unicase::Ascii; /// Origin Protocol #[derive(Clone, Hash, Debug, PartialEq, Eq)] @@ -39,10 +37,10 @@ impl Origin { let matcher = Matcher::new(&string); Origin { - protocol: protocol, - host: host, + protocol, + host, as_string: string, - matcher: matcher, + matcher, } } @@ -116,11 +114,15 @@ pub enum AccessControlAllowOrigin { impl fmt::Display for AccessControlAllowOrigin { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", match *self { - AccessControlAllowOrigin::Any => "*", - AccessControlAllowOrigin::Null => "null", - AccessControlAllowOrigin::Value(ref val) => val, - }) + write!( + f, + "{}", + match *self { + AccessControlAllowOrigin::Any => "*", + AccessControlAllowOrigin::Null => "null", + AccessControlAllowOrigin::Value(ref val) => val, + } + ) } } @@ -156,7 +158,8 @@ pub enum AllowCors { impl AllowCors { /// Maps `Ok` variant of `AllowCors`. - pub fn map(self, f: F) -> AllowCors where + pub fn map(self, f: F) -> AllowCors + where F: FnOnce(T) -> O, { use self::AllowCors::*; @@ -184,7 +187,7 @@ impl Into> for AllowCors { pub fn get_cors_allow_origin( origin: Option<&str>, host: Option<&str>, - allowed: &Option> + allowed: &Option>, ) -> AllowCors { match origin { None => AllowCors::NotRequired, @@ -203,44 +206,40 @@ pub fn get_cors_allow_origin( match allowed.as_ref() { None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null), None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))), - Some(ref allowed) if *origin == "null" => { - allowed.iter().find(|cors| **cors == AccessControlAllowOrigin::Null).cloned() - .map(AllowCors::Ok) - .unwrap_or(AllowCors::Invalid) - }, - Some(ref allowed) => { - allowed.iter().find(|cors| { - match **cors { - AccessControlAllowOrigin::Any => true, - AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => - { - true - }, - _ => false - } + Some(ref allowed) if *origin == "null" => allowed + .iter() + .find(|cors| **cors == AccessControlAllowOrigin::Null) + .cloned() + .map(AllowCors::Ok) + .unwrap_or(AllowCors::Invalid), + Some(ref allowed) => allowed + .iter() + .find(|cors| match **cors { + AccessControlAllowOrigin::Any => true, + AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true, + _ => false, }) .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin))) - .map(AllowCors::Ok).unwrap_or(AllowCors::Invalid) - }, + .map(AllowCors::Ok) + .unwrap_or(AllowCors::Invalid), } - }, + } } } /// Validates if the `AccessControlAllowedHeaders` in the request are allowed. pub fn get_cors_allow_headers, O, F: Fn(T) -> O>( - mut headers: impl Iterator, - requested_headers: impl Iterator, + mut headers: impl Iterator, + requested_headers: impl Iterator, cors_allow_headers: &AccessControlAllowHeaders, - to_result: F + to_result: F, ) -> AllowCors> { // Check if the header fields which were sent in the request are allowed if let AccessControlAllowHeaders::Only(only) = cors_allow_headers { - let are_all_allowed = headers - .all(|header| { - let name = &Ascii::new(header.as_ref()); - only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) - }); + let are_all_allowed = headers.all(|header| { + let name = &Ascii::new(header.as_ref()); + only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) + }); if !are_all_allowed { return AllowCors::Invalid; @@ -252,20 +251,20 @@ pub fn get_cors_allow_headers, O, F: Fn(T) -> O>( AccessControlAllowHeaders::Any => { let headers = requested_headers.map(to_result).collect(); (false, headers) - }, + } AccessControlAllowHeaders::Only(only) => { let mut filtered = false; let headers: Vec<_> = requested_headers .filter(|header| { let name = &Ascii::new(header.as_ref()); filtered = true; - only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) + only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) }) .map(to_result) .collect(); (filtered, headers) - }, + } }; if headers.is_empty() { @@ -279,8 +278,8 @@ pub fn get_cors_allow_headers, O, F: Fn(T) -> O>( } } -/// Returns headers which are always allowed. lazy_static! { + /// Returns headers which are always allowed. static ref ALWAYS_ALLOWED_HEADERS: HashSet> = { let mut hs = HashSet::new(); hs.insert(Ascii::new("Accept")); @@ -303,17 +302,29 @@ mod tests { use std::iter; use super::*; - use hosts::Host; + use crate::hosts::Host; #[test] fn should_parse_origin() { use self::OriginProtocol::*; assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None)); - assert_eq!(Origin::parse("https://parity.io:8443"), Origin::new(Https, "parity.io", Some(8443))); - assert_eq!(Origin::parse("chrome-extension://124.0.0.1"), Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)); - assert_eq!(Origin::parse("parity.io/somepath"), Origin::new(Http, "parity.io", None)); - assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545))); + assert_eq!( + Origin::parse("https://parity.io:8443"), + Origin::new(Https, "parity.io", Some(8443)) + ); + assert_eq!( + Origin::parse("chrome-extension://124.0.0.1"), + Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None) + ); + assert_eq!( + Origin::parse("parity.io/somepath"), + Origin::new(Http, "parity.io", None) + ); + assert_eq!( + Origin::parse("127.0.0.1:8545/somepath"), + Origin::new(Http, "127.0.0.1", Some(8545)) + ); } #[test] @@ -435,7 +446,10 @@ mod tests { let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any])); // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!( + res, + AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) + ); } #[test] @@ -445,11 +459,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Null]), - ); + let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); // then assert_eq!(res, AllowCors::NotRequired); @@ -462,11 +472,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Null]), - ); + let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); // then assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null)); @@ -482,11 +488,17 @@ mod tests { let res = get_cors_allow_origin( origin, host, - &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into()), AccessControlAllowOrigin::Value("http://parity.io".into())]), + &Some(vec![ + AccessControlAllowOrigin::Value("http://ethereum.org".into()), + AccessControlAllowOrigin::Value("http://parity.io".into()), + ]), ); // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!( + res, + AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) + ); } #[test] @@ -497,8 +509,8 @@ mod tests { let origin3 = Some("chrome-extension://test".into()); let host = None; let allowed = Some(vec![ - AccessControlAllowOrigin::Value("http://*.io".into()), - AccessControlAllowOrigin::Value("chrome-extension://*".into()) + AccessControlAllowOrigin::Value("http://*.io".into()), + AccessControlAllowOrigin::Value("chrome-extension://*".into()), ]); // when @@ -507,17 +519,21 @@ mod tests { let res3 = get_cors_allow_origin(origin3, host, &allowed); // then - assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!( + res1, + AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) + ); assert_eq!(res2, AllowCors::Invalid); - assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))); + assert_eq!( + res3, + AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())) + ); } #[test] fn should_return_invalid_if_header_not_allowed() { // given - let cors_allow_headers = AccessControlAllowHeaders::Only(vec![ - "x-allowed".to_owned(), - ]); + let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]); let headers = vec!["Access-Control-Request-Headers"]; let requested = vec!["x-not-allowed"]; @@ -531,29 +547,25 @@ mod tests { #[test] fn should_return_valid_if_header_allowed() { // given - let allowed = vec![ - "x-allowed".to_owned(), - ]; + let allowed = vec!["x-allowed".to_owned()]; let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); let headers = vec!["Access-Control-Request-Headers"]; let requested = vec!["x-allowed"]; // when - let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| (*x).to_owned()); + let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| { + (*x).to_owned() + }); // then - let allowed = vec![ - "x-allowed".to_owned(), - ]; + let allowed = vec!["x-allowed".to_owned()]; assert_eq!(res, AllowCors::Ok(allowed)); } #[test] fn should_return_no_allowed_headers_if_none_in_request() { // given - let allowed = vec![ - "x-allowed".to_owned(), - ]; + let allowed = vec!["x-allowed".to_owned()]; let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); let headers: Vec = vec![]; @@ -576,5 +588,4 @@ mod tests { // then assert_eq!(res, AllowCors::NotRequired); } - } diff --git a/server-utils/src/hosts.rs b/server-utils/src/hosts.rs index 50196fccc..ab2270ea6 100644 --- a/server-utils/src/hosts.rs +++ b/server-utils/src/hosts.rs @@ -1,10 +1,10 @@ //! Host header validation. +use crate::matcher::{Matcher, Pattern}; use std::collections::HashSet; use std::net::SocketAddr; -use matcher::{Matcher, Pattern}; -const SPLIT_PROOF: &'static str = "split always returns non-empty iterator."; +const SPLIT_PROOF: &str = "split always returns non-empty iterator."; /// Port pattern #[derive(Clone, Hash, PartialEq, Eq, Debug)] @@ -14,7 +14,7 @@ pub enum Port { /// Port specified as a wildcard pattern Pattern(String), /// Fixed numeric port - Fixed(u16) + Fixed(u16), } impl From> for Port { @@ -56,10 +56,10 @@ impl Host { let matcher = Matcher::new(&string); Host { - hostname: hostname, - port: port, + hostname, + port, as_string: string, - matcher: matcher, + matcher, } } @@ -71,10 +71,10 @@ impl Host { let host = hostname.next().expect(SPLIT_PROOF); let port = match hostname.next() { None => Port::None, - Some(port) => match port.clone().parse::().ok() { + Some(port) => match port.parse::().ok() { Some(num) => Port::Fixed(num), None => Port::Pattern(port.into()), - } + }, }; Host::new(host, port) @@ -153,35 +153,51 @@ pub fn is_host_valid(host: Option<&str>, allowed_hosts: &Option>) -> b None => true, Some(ref allowed_hosts) => match host { None => false, - Some(ref host) => { - allowed_hosts.iter().any(|h| h.matches(host)) - } - } + Some(ref host) => allowed_hosts.iter().any(|h| h.matches(host)), + }, } } /// Updates given list of hosts with the address. pub fn update(hosts: Option>, address: &SocketAddr) -> Option> { + use std::net::{IpAddr, Ipv4Addr}; + hosts.map(|current_hosts| { let mut new_hosts = current_hosts.into_iter().collect::>(); - let address = address.to_string(); - new_hosts.insert(address.clone().into()); - new_hosts.insert(address.replace("127.0.0.1", "localhost").into()); + let address_string = address.to_string(); + + if address.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) { + new_hosts.insert(address_string.replace("0.0.0.0", "127.0.0.1").into()); + new_hosts.insert(address_string.replace("0.0.0.0", "localhost").into()); + } else if address.ip() == IpAddr::V4(Ipv4Addr::LOCALHOST) { + new_hosts.insert(address_string.replace("127.0.0.1", "localhost").into()); + } + + new_hosts.insert(address_string.into()); new_hosts.into_iter().collect() }) } #[cfg(test)] mod tests { - use super::{Host, is_host_valid}; + use super::{is_host_valid, Host}; #[test] fn should_parse_host() { assert_eq!(Host::parse("http://parity.io"), Host::new("parity.io", None)); - assert_eq!(Host::parse("https://parity.io:8443"), Host::new("parity.io", Some(8443))); - assert_eq!(Host::parse("chrome-extension://124.0.0.1"), Host::new("124.0.0.1", None)); + assert_eq!( + Host::parse("https://parity.io:8443"), + Host::new("parity.io", Some(8443)) + ); + assert_eq!( + Host::parse("chrome-extension://124.0.0.1"), + Host::new("124.0.0.1", None) + ); assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None)); - assert_eq!(Host::parse("127.0.0.1:8545/somepath"), Host::new("127.0.0.1", Some(8545))); + assert_eq!( + Host::parse("127.0.0.1:8545/somepath"), + Host::new("127.0.0.1", Some(8545)) + ); } #[test] @@ -204,28 +220,19 @@ mod tests { #[test] fn should_accept_if_on_the_list() { - let valid = is_host_valid( - Some("parity.io"), - &Some(vec!["parity.io".into()]), - ); + let valid = is_host_valid(Some("parity.io"), &Some(vec!["parity.io".into()])); assert_eq!(valid, true); } #[test] fn should_accept_if_on_the_list_with_port() { - let valid = is_host_valid( - Some("parity.io:443"), - &Some(vec!["parity.io:443".into()]), - ); + let valid = is_host_valid(Some("parity.io:443"), &Some(vec!["parity.io:443".into()])); assert_eq!(valid, true); } #[test] fn should_support_wildcards() { - let valid = is_host_valid( - Some("parity.web3.site:8180"), - &Some(vec!["*.web3.site:*".into()]), - ); + let valid = is_host_valid(Some("parity.web3.site:8180"), &Some(vec!["*.web3.site:*".into()])); assert_eq!(valid, true); } } diff --git a/server-utils/src/lib.rs b/server-utils/src/lib.rs index b4e64ce65..3d6cd5cc3 100644 --- a/server-utils/src/lib.rs +++ b/server-utils/src/lib.rs @@ -1,6 +1,6 @@ //! JSON-RPC servers utilities. -#![warn(missing_docs)] +#![deny(missing_docs)] #[macro_use] extern crate log; @@ -8,27 +8,22 @@ extern crate log; #[macro_use] extern crate lazy_static; -extern crate globset; -extern crate jsonrpc_core as core; -extern crate bytes; -extern crate num_cpus; - -pub extern crate tokio; -pub extern crate tokio_codec; +pub use tokio; +pub use tokio_stream; +pub use tokio_util; pub mod cors; pub mod hosts; -pub mod session; -pub mod reactor; mod matcher; +pub mod reactor; +pub mod session; mod stream_codec; mod suspendable_stream; -pub use suspendable_stream::SuspendableStream; -pub use matcher::Pattern; +pub use crate::matcher::Pattern; +pub use crate::suspendable_stream::SuspendableStream; /// Codecs utilities pub mod codecs { - pub use stream_codec::{StreamCodec, Separator}; + pub use crate::stream_codec::{Separator, StreamCodec}; } - diff --git a/server-utils/src/matcher.rs b/server-utils/src/matcher.rs index 060ef99f9..397b5eba8 100644 --- a/server-utils/src/matcher.rs +++ b/server-utils/src/matcher.rs @@ -1,4 +1,4 @@ -use globset::{GlobMatcher, GlobBuilder}; +use globset::{GlobBuilder, GlobMatcher}; use std::{fmt, hash}; /// Pattern that can be matched to string. @@ -18,7 +18,7 @@ impl Matcher { .map(|g| g.compile_matcher()) .map_err(|e| warn!("Invalid glob pattern for {}: {:?}", string, e)) .ok(), - string.into() + string.into(), ) } } @@ -40,7 +40,10 @@ impl fmt::Debug for Matcher { } impl hash::Hash for Matcher { - fn hash(&self, state: &mut H) where H: hash::Hasher { + fn hash(&self, state: &mut H) + where + H: hash::Hasher, + { self.1.hash(state) } } diff --git a/server-utils/src/reactor.rs b/server-utils/src/reactor.rs index 63083d734..041328528 100644 --- a/server-utils/src/reactor.rs +++ b/server-utils/src/reactor.rs @@ -1,18 +1,21 @@ //! Event Loop Executor +//! //! Either spawns a new event loop, or re-uses provided one. +//! Spawned event loop is always single threaded (mostly for +//! historical/backward compatibility reasons) despite the fact +//! that `tokio::runtime` can be multi-threaded. -use std::{io, thread}; -use std::sync::mpsc; -use tokio; -use num_cpus; +use std::io; -use core::futures::{self, Future}; +use tokio::runtime; +/// Task executor for Tokio 0.2 runtime. +pub type TaskExecutor = tokio::runtime::Handle; /// Possibly uninitialized event loop executor. #[derive(Debug)] pub enum UninitializedExecutor { /// Shared instance of executor. - Shared(tokio::runtime::TaskExecutor), + Shared(TaskExecutor), /// Event Loop should be spawned by the transport. Unspawned, } @@ -40,28 +43,20 @@ impl UninitializedExecutor { #[derive(Debug)] pub enum Executor { /// Shared instance - Shared(tokio::runtime::TaskExecutor), + Shared(TaskExecutor), /// Spawned Event Loop Spawned(RpcEventLoop), } impl Executor { /// Get tokio executor associated with this event loop. - pub fn executor(&self) -> tokio::runtime::TaskExecutor { - match *self { + pub fn executor(&self) -> TaskExecutor { + match self { Executor::Shared(ref executor) => executor.clone(), Executor::Spawned(ref eloop) => eloop.executor(), } } - /// Spawn a future onto the Tokio runtime. - pub fn spawn(&self, future: F) - where - F: Future + Send + 'static, - { - self.executor().spawn(future) - } - /// Closes underlying event loop (if any!). pub fn close(self) { if let Executor::Spawned(eloop) = self { @@ -80,9 +75,9 @@ impl Executor { /// A handle to running event loop. Dropping the handle will cause event loop to finish. #[derive(Debug)] pub struct RpcEventLoop { - executor: tokio::runtime::TaskExecutor, - close: Option>, - handle: Option>, + executor: TaskExecutor, + close: Option>, + runtime: Option, } impl Drop for RpcEventLoop { @@ -99,68 +94,68 @@ impl RpcEventLoop { /// Spawns a new named thread with the `EventLoop`. pub fn with_name(name: Option) -> io::Result { - let (stop, stopped) = futures::oneshot(); - let (tx, rx) = mpsc::channel(); - let mut tb = thread::Builder::new(); + let (stop, stopped) = futures::channel::oneshot::channel(); + + let mut tb = runtime::Builder::new_multi_thread(); + tb.worker_threads(1); + tb.enable_all(); + if let Some(name) = name { - tb = tb.name(name); + tb.thread_name(name); } - let handle = tb.spawn(move || { - let mut tp_builder = tokio::executor::thread_pool::Builder::new(); - - let pool_size = match num_cpus::get_physical() { - 1 => 1, - 2...4 => 2, - _ => 3, - }; - - tp_builder - .pool_size(pool_size) - .name_prefix("jsonrpc-eventloop-"); - - let runtime = tokio::runtime::Builder::new() - .threadpool_builder(tp_builder) - .build(); - - match runtime { - Ok(mut runtime) => { - tx.send(Ok(runtime.executor())).expect("Rx is blocking upper thread."); - let terminate = futures::empty().select(stopped) - .map(|_| ()) - .map_err(|_| ()); - runtime.spawn(terminate); - runtime.shutdown_on_idle().wait().unwrap(); - }, - Err(err) => { - tx.send(Err(err)).expect("Rx is blocking upper thread."); - } - } - }).expect("Couldn't spawn a thread."); - - let exec = rx.recv().expect("tx is transfered to a newly spawned thread."); - - exec.map(|executor| RpcEventLoop { + let runtime = tb.build()?; + let executor = runtime.handle().to_owned(); + + runtime.spawn(async { + let _ = stopped.await; + }); + + Ok(RpcEventLoop { executor, close: Some(stop), - handle: Some(handle), + runtime: Some(runtime), }) } /// Get executor for this event loop. - pub fn executor(&self) -> tokio::runtime::TaskExecutor { - self.executor.clone() + pub fn executor(&self) -> runtime::Handle { + self.runtime + .as_ref() + .expect("Runtime is only None if we're being dropped; qed") + .handle() + .clone() } /// Blocks current thread and waits until the event loop is finished. - pub fn wait(mut self) -> thread::Result<()> { - self.handle.take().expect("Handle is always set before self is consumed.").join() + pub fn wait(mut self) -> Result<(), ()> { + // Dropping Tokio 0.2 runtime waits for all spawned tasks to terminate + let runtime = self.runtime.take().ok_or(())?; + drop(runtime); + Ok(()) } /// Finishes this event loop. pub fn close(mut self) { - let _ = self.close.take().expect("Close is always set before self is consumed.").send(()).map_err(|e| { - warn!("Event Loop is already finished. {:?}", e); - }); + let _ = self + .close + .take() + .expect("Close is always set before self is consumed.") + .send(()) + .map_err(|e| { + warn!("Event Loop is already finished. {:?}", e); + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn make_sure_rpc_event_loop_is_send_and_sync() { + fn is_send_and_sync() {} + + is_send_and_sync::(); } } diff --git a/server-utils/src/stream_codec.rs b/server-utils/src/stream_codec.rs index 9a20a7311..92033891e 100644 --- a/server-utils/src/stream_codec.rs +++ b/server-utils/src/stream_codec.rs @@ -1,6 +1,5 @@ -use std::{io, str}; -use tokio_codec::{Decoder, Encoder}; use bytes::BytesMut; +use std::{io, str}; /// Separator for enveloping messages in streaming codecs #[derive(Debug, Clone)] @@ -35,20 +34,17 @@ impl StreamCodec { /// New custom stream codec pub fn new(incoming_separator: Separator, outgoing_separator: Separator) -> Self { StreamCodec { - incoming_separator: incoming_separator, - outgoing_separator: outgoing_separator, + incoming_separator, + outgoing_separator, } } } fn is_whitespace(byte: u8) -> bool { - match byte { - 0x0D | 0x0A | 0x20 | 0x09 => true, - _ => false, - } + matches!(byte, 0x0D | 0x0A | 0x20 | 0x09) } -impl Decoder for StreamCodec { +impl tokio_util::codec::Decoder for StreamCodec { type Item = String; type Error = io::Error; @@ -56,7 +52,7 @@ impl Decoder for StreamCodec { if let Separator::Byte(separator) = self.incoming_separator { if let Some(i) = buf.as_ref().iter().position(|&b| b == separator) { let line = buf.split_to(i); - buf.split_to(1); + let _ = buf.split_to(1); match str::from_utf8(&line.as_ref()) { Ok(s) => Ok(Some(s.to_string())), @@ -80,14 +76,11 @@ impl Decoder for StreamCodec { start_idx = idx; } depth += 1; - } - else if (byte == b'}' || byte == b']') && !in_str { + } else if (byte == b'}' || byte == b']') && !in_str { depth -= 1; - } - else if byte == b'"' && !is_escaped { + } else if byte == b'"' && !is_escaped { in_str = !in_str; - } - else if is_whitespace(byte) { + } else if is_whitespace(byte) { whitespaces += 1; } if byte == b'\\' && !is_escaped && in_str { @@ -99,8 +92,10 @@ impl Decoder for StreamCodec { if depth == 0 && idx != start_idx && idx - start_idx + 1 > whitespaces { let bts = buf.split_to(idx + 1); match String::from_utf8(bts.as_ref().to_vec()) { - Ok(val) => { return Ok(Some(val)) }, - Err(_) => { return Ok(None); } // skip non-utf requests (TODO: log error?) + Ok(val) => return Ok(Some(val)), + Err(_) => { + return Ok(None); + } // skip non-utf requests (TODO: log error?) }; } } @@ -109,8 +104,7 @@ impl Decoder for StreamCodec { } } -impl Encoder for StreamCodec { - type Item = String; +impl tokio_util::codec::Encoder for StreamCodec { type Error = io::Error; fn encode(&mut self, msg: String, buf: &mut BytesMut) -> io::Result<()> { @@ -127,8 +121,8 @@ impl Encoder for StreamCodec { mod tests { use super::StreamCodec; - use tokio_codec::Decoder; - use bytes::{BytesMut, BufMut}; + use bytes::{BufMut, BytesMut}; + use tokio_util::codec::Decoder; #[test] fn simple_encode() { @@ -137,7 +131,8 @@ mod tests { let mut codec = StreamCodec::stream_incoming(); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in simple test") .expect("There should be at least one request in simple test"); @@ -151,23 +146,27 @@ mod tests { let mut codec = StreamCodec::stream_incoming(); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in first escape test") .expect("There should be a request in first escape test"); assert_eq!(request, r#"{ test: "\"\\" }"#); - let request2 = codec.decode(&mut buf) + let request2 = codec + .decode(&mut buf) .expect("There should be no error in 2nd escape test") .expect("There should be a request in 2nd escape test"); assert_eq!(request2, r#"{ test: "\ " }"#); - let request3 = codec.decode(&mut buf) + let request3 = codec + .decode(&mut buf) .expect("There should be no error in 3rd escape test") .expect("There should be a request in 3rd escape test"); assert_eq!(request3, r#"{ test: "\}" }"#); - let request4 = codec.decode(&mut buf) + let request4 = codec + .decode(&mut buf) .expect("There should be no error in 4th escape test") .expect("There should be a request in 4th escape test"); assert_eq!(request4, r#"[ test: "\]" ]"#); @@ -180,26 +179,33 @@ mod tests { let mut codec = StreamCodec::stream_incoming(); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in first whitespace test") .expect("There should be a request in first whitespace test"); assert_eq!(request, "{ test: 1 }"); - let request2 = codec.decode(&mut buf) + let request2 = codec + .decode(&mut buf) .expect("There should be no error in first 2nd test") .expect("There should be aa request in 2nd whitespace test"); // TODO: maybe actually trim it out assert_eq!(request2, "\n\n\n\n{ test: 2 }"); - let request3 = codec.decode(&mut buf) + let request3 = codec + .decode(&mut buf) .expect("There should be no error in first 3rd test") .expect("There should be a request in 3rd whitespace test"); assert_eq!(request3, "\n\r{\n test: 3 }"); - let request4 = codec.decode(&mut buf) + let request4 = codec + .decode(&mut buf) .expect("There should be no error in first 4th test"); - assert!(request4.is_none(), "There should be no 4th request because it contains only whitespaces"); + assert!( + request4.is_none(), + "There should be no 4th request because it contains only whitespaces" + ); } #[test] @@ -209,17 +215,20 @@ mod tests { let mut codec = StreamCodec::stream_incoming(); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in first fragmented test") .expect("There should be at least one request in first fragmented test"); assert_eq!(request, "{ test: 1 }"); - codec.decode(&mut buf) + codec + .decode(&mut buf) .expect("There should be no error in second fragmented test") .expect("There should be at least one request in second fragmented test"); assert_eq!(String::from_utf8(buf.as_ref().to_vec()).unwrap(), "{ tes"); buf.put_slice(b"t: 3 }"); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in third fragmented test") .expect("There should be at least one request in third fragmented test"); assert_eq!(request, "{ test: 3 }"); @@ -247,7 +256,8 @@ mod tests { let mut codec = StreamCodec::stream_incoming(); - let parsed_request = codec.decode(&mut buf) + let parsed_request = codec + .decode(&mut buf) .expect("There should be no error in huge test") .expect("There should be at least one request huge test"); assert_eq!(request, parsed_request); @@ -260,10 +270,12 @@ mod tests { let mut codec = StreamCodec::default(); - let request = codec.decode(&mut buf) + let request = codec + .decode(&mut buf) .expect("There should be no error in simple test") .expect("There should be at least one request in simple test"); - let request2 = codec.decode(&mut buf) + let request2 = codec + .decode(&mut buf) .expect("There should be no error in simple test") .expect("There should be at least one request in simple test"); diff --git a/server-utils/src/suspendable_stream.rs b/server-utils/src/suspendable_stream.rs index 3cdb2f27c..96af9c91a 100644 --- a/server-utils/src/suspendable_stream.rs +++ b/server-utils/src/suspendable_stream.rs @@ -1,7 +1,8 @@ -use std::time::{Duration, Instant}; -use tokio::timer::Delay; +use std::future::Future; use std::io; -use tokio::prelude::*; +use std::pin::Pin; +use std::task::Poll; +use std::time::{Duration, Instant}; /// `Incoming` is a stream of incoming sockets /// Polling the stream may return a temporary io::Error (for instance if we can't open the connection because of "too many open files" limit) @@ -16,7 +17,7 @@ pub struct SuspendableStream { next_delay: Duration, initial_delay: Duration, max_delay: Duration, - timeout: Option, + suspended_until: Option, } impl SuspendableStream { @@ -28,65 +29,70 @@ impl SuspendableStream { next_delay: Duration::from_millis(20), initial_delay: Duration::from_millis(10), max_delay: Duration::from_secs(5), - timeout: None, + suspended_until: None, } } } -impl Stream for SuspendableStream - where S: Stream +impl futures::Stream for SuspendableStream +where + S: futures::Stream> + Unpin, { type Item = I; - type Error = (); - fn poll(&mut self) -> Result>, ()> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { loop { - if let Some(mut timeout) = self.timeout.take() { - match timeout.poll() { - Ok(Async::Ready(_)) => {} - Ok(Async::NotReady) => { - self.timeout = Some(timeout); - return Ok(Async::NotReady); - } - Err(err) => { - warn!("Timeout error {:?}", err); - task::current().notify(); - return Ok(Async::NotReady); + // If we encountered a connection error before then we suspend + // polling from the underlying stream for a bit + if let Some(deadline) = &mut self.suspended_until { + let deadline = tokio::time::Instant::from_std(*deadline); + let sleep = tokio::time::sleep_until(deadline); + futures::pin_mut!(sleep); + match sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(()) => { + self.suspended_until = None; } } } - match self.stream.poll() { - Ok(item) => { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { + if self.next_delay > self.initial_delay { + self.next_delay = self.initial_delay; + } + return Poll::Ready(None); + } + Poll::Ready(Some(Ok(item))) => { if self.next_delay > self.initial_delay { self.next_delay = self.initial_delay; } - return Ok(item) + + return Poll::Ready(Some(item)); } - Err(ref err) => { + Poll::Ready(Some(Err(ref err))) => { if connection_error(err) { warn!("Connection Error: {:?}", err); - continue + continue; } self.next_delay = if self.next_delay < self.max_delay { self.next_delay * 2 } else { self.next_delay }; - warn!("Error accepting connection: {}", err); - warn!("The server will stop accepting connections for {:?}", self.next_delay); - self.timeout = Some(Delay::new(Instant::now() + self.next_delay)); + debug!("Error accepting connection: {}", err); + debug!("The server will stop accepting connections for {:?}", self.next_delay); + self.suspended_until = Some(Instant::now() + self.next_delay); } } } } } - /// assert that the error was a connection error fn connection_error(e: &io::Error) -> bool { - e.kind() == io::ErrorKind::ConnectionRefused || - e.kind() == io::ErrorKind::ConnectionAborted || - e.kind() == io::ErrorKind::ConnectionReset + e.kind() == io::ErrorKind::ConnectionRefused + || e.kind() == io::ErrorKind::ConnectionAborted + || e.kind() == io::ErrorKind::ConnectionReset } - diff --git a/stdio/Cargo.toml b/stdio/Cargo.toml index 811390c35..fe39581ea 100644 --- a/stdio/Cargo.toml +++ b/stdio/Cargo.toml @@ -1,25 +1,25 @@ [package] -name = "jsonrpc-stdio-server" +authors = ["Parity Technologies "] description = "STDIN/STDOUT server for JSON-RPC" -version = "0.1.0" -authors = ["cmichi "] -license = "MIT" +documentation = "https://docs.rs/jsonrpc-stdio-server/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" +license = "MIT" +name = "jsonrpc-stdio-server" repository = "https://github.com/paritytech/jsonrpc" -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_stdio_server/index.html" +version = "18.0.0" [dependencies] -futures = "0.1.23" -jsonrpc-core = { version = "9.0", path = "../core" } +futures = "0.3" +jsonrpc-core = { version = "18.0.0", path = "../core" } log = "0.4" -tokio = "0.1.7" -tokio-codec = "0.1.0" -tokio-io = "0.1.7" -tokio-stdin-stdout = "0.1.4" +tokio = { version = "1", features = ["io-std", "io-util"] } +tokio-util = { version = "0.6", features = ["codec"] } [dev-dependencies] +tokio = { version = "1", features = ["rt", "macros"] } lazy_static = "1.0" -env_logger = "0.6" +env_logger = "0.7" [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/stdio/README.md b/stdio/README.md index f4497901f..bd8152526 100644 --- a/stdio/README.md +++ b/stdio/README.md @@ -10,15 +10,13 @@ Takes one request per line and outputs each response on a new line. ``` [dependencies] -jsonrpc-stdio-server = { git = "https://github.com/paritytech/jsonrpc" } +jsonrpc-stdio-server = "15.0" ``` `main.rs` ```rust -extern crate jsonrpc_stdio_server; - -use jsonrpc_stdio_server::server; +use jsonrpc_stdio_server::ServerBuilder; use jsonrpc_stdio_server::jsonrpc_core::*; fn main() { diff --git a/stdio/examples/stdio.rs b/stdio/examples/stdio.rs index 1a907c9d0..bd2bc2caa 100644 --- a/stdio/examples/stdio.rs +++ b/stdio/examples/stdio.rs @@ -1,13 +1,11 @@ -extern crate jsonrpc_stdio_server; - -use jsonrpc_stdio_server::ServerBuilder; use jsonrpc_stdio_server::jsonrpc_core::*; +use jsonrpc_stdio_server::ServerBuilder; -fn main() { +#[tokio::main] +async fn main() { let mut io = IoHandler::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_owned())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_owned()))); - ServerBuilder::new(io).build(); + let server = ServerBuilder::new(io).build(); + server.await; } diff --git a/stdio/src/lib.rs b/stdio/src/lib.rs index cfeb90f3f..6183aceb2 100644 --- a/stdio/src/lib.rs +++ b/stdio/src/lib.rs @@ -1,76 +1,98 @@ //! jsonrpc server using stdin/stdout //! //! ```no_run -//! extern crate jsonrpc_stdio_server; //! //! use jsonrpc_stdio_server::ServerBuilder; //! use jsonrpc_stdio_server::jsonrpc_core::*; //! -//! fn main() { -//! let mut io = IoHandler::default(); -//! io.add_method("say_hello", |_params| { -//! Ok(Value::String("hello".to_owned())) -//! }); +//! #[tokio::main] +//! async fn main() { +//! let mut io = IoHandler::default(); +//! io.add_sync_method("say_hello", |_params| { +//! Ok(Value::String("hello".to_owned())) +//! }); //! -//! ServerBuilder::new(io).build(); +//! let server = ServerBuilder::new(io).build(); +//! server.await; //! } //! ``` -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate futures; -extern crate tokio; -extern crate tokio_codec; -extern crate tokio_io; -extern crate tokio_stdin_stdout; -#[macro_use] extern crate log; +use std::future::Future; +use std::sync::Arc; -pub extern crate jsonrpc_core; +#[macro_use] +extern crate log; -use std::sync::Arc; -use tokio::prelude::{Future, Stream}; -use tokio_codec::{FramedRead, FramedWrite, LinesCodec}; -use jsonrpc_core::{IoHandler}; +pub use jsonrpc_core; +pub use tokio; + +use jsonrpc_core::{MetaIoHandler, Metadata, Middleware}; +use tokio_util::codec::{FramedRead, LinesCodec}; /// Stdio server builder -pub struct ServerBuilder { - handler: Arc, +pub struct ServerBuilder = jsonrpc_core::NoopMiddleware> { + handler: Arc>, } -impl ServerBuilder { +impl> ServerBuilder +where + M: Default, + T::Future: Unpin, + T::CallFuture: Unpin, +{ /// Returns a new server instance - pub fn new(handler: T) -> Self where T: Into { - ServerBuilder { handler: Arc::new(handler.into()) } + pub fn new(handler: impl Into>) -> Self { + ServerBuilder { + handler: Arc::new(handler.into()), + } } + /// Returns a server future that needs to be polled in order to make progress. + /// /// Will block until EOF is read or until an error occurs. /// The server reads from STDIN line-by-line, one request is taken /// per line and each response is written to STDOUT on a new line. - pub fn build(&self) { - let stdin = tokio_stdin_stdout::stdin(0); - let stdout = tokio_stdin_stdout::stdout(0).make_sendable(); + pub fn build(&self) -> impl Future + 'static { + let handler = self.handler.clone(); - let framed_stdin = FramedRead::new(stdin, LinesCodec::new()); - let framed_stdout = FramedWrite::new(stdout, LinesCodec::new()); + async move { + let stdin = tokio::io::stdin(); + let mut stdout = tokio::io::stdout(); - let handler = self.handler.clone(); - let future = framed_stdin - .and_then(move |line| process(&handler, line).map_err(|_| unreachable!())) - .forward(framed_stdout) - .map(|_| ()) - .map_err(|e| panic!("{:?}", e)); + let mut framed_stdin = FramedRead::new(stdin, LinesCodec::new()); - tokio::run(future); + use futures::StreamExt; + while let Some(request) = framed_stdin.next().await { + match request { + Ok(line) => { + let res = Self::process(&handler, line).await; + let mut sanitized = res.replace('\n', ""); + sanitized.push('\n'); + use tokio::io::AsyncWriteExt; + if let Err(e) = stdout.write_all(sanitized.as_bytes()).await { + log::warn!("Error writing response: {:?}", e); + } + } + Err(e) => { + log::warn!("Error reading line: {:?}", e); + } + } + } + } } -} -/// Process a request asynchronously -fn process(io: &Arc, input: String) -> impl Future + Send { - io.handle_request(&input).map(move |result| match result { - Some(res) => res, - None => { - info!("JSON RPC request produced no response: {:?}", input); - String::from("") - } - }) + /// Process a request asynchronously + fn process(io: &Arc>, input: String) -> impl Future + Send { + use jsonrpc_core::futures::FutureExt; + let f = io.handle_request(&input, Default::default()); + f.map(move |result| match result { + Some(res) => res, + None => { + info!("JSON RPC request produced no response: {:?}", input); + String::from("") + } + }) + } } diff --git a/tcp/Cargo.toml b/tcp/Cargo.toml index 401f09930..3b5eefe33 100644 --- a/tcp/Cargo.toml +++ b/tcp/Cargo.toml @@ -1,23 +1,24 @@ [package] -name = "jsonrpc-tcp-server" +authors = ["Parity Technologies "] description = "TCP/IP server for JSON-RPC" -version = "9.0.0" -authors = ["NikVolf "] -license = "MIT" +documentation = "https://docs.rs/jsonrpc-tcp-server/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" +license = "MIT" +name = "jsonrpc-tcp-server" repository = "https://github.com/paritytech/jsonrpc" -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_tcp_server/index.html" +version = "18.0.0" [dependencies] +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-server-utils = { version = "18.0.0", path = "../server-utils" } log = "0.4" -parking_lot = "0.6" -tokio-service = "0.1" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-server-utils = { version = "9.0", path = "../server-utils" } +parking_lot = "0.11.0" +tower-service = "0.3" [dev-dependencies] lazy_static = "1.0" -env_logger = "0.6" +env_logger = "0.7" [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/tcp/README.md b/tcp/README.md index e850da952..7e12cff57 100644 --- a/tcp/README.md +++ b/tcp/README.md @@ -9,14 +9,12 @@ TCP server for JSON-RPC 2.0. ``` [dependencies] -jsonrpc-tcp-server = { git = "https://github.com/paritytech/jsonrpc" } +jsonrpc-tcp-server = "15.0" ``` `main.rs` ```rust -extern crate jsonrpc_tcp_server; - use jsonrpc_tcp_server::*; use jsonrpc_tcp_server::jsonrpc_core::*; diff --git a/tcp/examples/tcp.rs b/tcp/examples/tcp.rs index ce3ffb473..0c51a5c17 100644 --- a/tcp/examples/tcp.rs +++ b/tcp/examples/tcp.rs @@ -1,12 +1,11 @@ -extern crate jsonrpc_tcp_server; -extern crate env_logger; -use jsonrpc_tcp_server::ServerBuilder; +use env_logger; use jsonrpc_tcp_server::jsonrpc_core::*; +use jsonrpc_tcp_server::ServerBuilder; fn main() { env_logger::init(); let mut io = IoHandler::default(); - io.add_method("say_hello", |_params| { + io.add_sync_method("say_hello", |_params| { println!("Processing"); Ok(Value::String("hello".to_owned())) }); @@ -17,4 +16,3 @@ fn main() { server.wait() } - diff --git a/tcp/src/dispatch.rs b/tcp/src/dispatch.rs index ad0756e1d..664e97a33 100644 --- a/tcp/src/dispatch.rs +++ b/tcp/src/dispatch.rs @@ -1,30 +1,26 @@ -use std; use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::{Arc}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; -use jsonrpc::futures::{Stream, Poll, Async, Sink, Future}; -use jsonrpc::futures::sync::mpsc; +use crate::futures::{channel::mpsc, Stream}; use parking_lot::Mutex; -pub type SenderChannels = Mutex>>; +pub type SenderChannels = Mutex>>; -pub struct PeerMessageQueue { +pub struct PeerMessageQueue { up: S, - receiver: mpsc::Receiver, + receiver: Option>, _addr: SocketAddr, } -impl PeerMessageQueue { - pub fn new( - response_stream: S, - receiver: mpsc::Receiver, - addr: SocketAddr, - ) -> Self { +impl PeerMessageQueue { + pub fn new(response_stream: S, receiver: mpsc::UnboundedReceiver, addr: SocketAddr) -> Self { PeerMessageQueue { up: response_stream, - receiver: receiver, + receiver: Some(receiver), _addr: addr, } } @@ -36,11 +32,11 @@ pub enum PushMessageError { /// Invalid peer NoSuchPeer, /// Send error - Send(mpsc::SendError) + Send(mpsc::TrySendError), } -impl From> for PushMessageError { - fn from(send_err: mpsc::SendError) -> Self { +impl From> for PushMessageError { + fn from(send_err: mpsc::TrySendError) -> Self { PushMessageError::Send(send_err) } } @@ -54,9 +50,7 @@ pub struct Dispatcher { impl Dispatcher { /// Creates a new dispatcher pub fn new(channels: Arc) -> Self { - Dispatcher { - channels: channels, - } + Dispatcher { channels } } /// Pushes message to given peer @@ -65,13 +59,10 @@ impl Dispatcher { match channels.get_mut(peer_addr) { Some(channel) => { - // todo: maybe async here later? - try!(channel.send(msg).wait().map_err(|e| PushMessageError::from(e))); + channel.unbounded_send(msg).map_err(PushMessageError::from)?; Ok(()) - }, - None => { - return Err(PushMessageError::NoSuchPeer); } + None => Err(PushMessageError::NoSuchPeer), } } @@ -86,31 +77,43 @@ impl Dispatcher { } } -impl> Stream for PeerMessageQueue { - - type Item = String; - type Error = std::io::Error; - - fn poll(&mut self) -> Poll, std::io::Error> { +impl> + Unpin> Stream for PeerMessageQueue { + type Item = std::io::Result; + + // The receiver will never return `Ok(Async::Ready(None))` + // Because the sender is kept in `SenderChannels` and it will never be dropped until `the stream` is resolved. + // + // Thus, that is the reason we terminate if `up_closed && receiver == Async::NotReady`. + // + // However, it is possible to have a race between `poll` and `push_work` if the connection is dropped. + // Therefore, the receiver is then dropped when the connection is dropped and an error is propagated when + // a `send` attempt is made on that channel. + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { // check if we have response pending - match self.up.poll() { - Ok(Async::Ready(Some(val))) => { - return Ok(Async::Ready(Some(val))); - }, - Ok(Async::Ready(None)) => { - // this will ensure that this polling will end when incoming i/o stream ends - return Ok(Async::Ready(None)); - }, - _ => {} - } + let this = Pin::into_inner(self); + + let up_closed = match Pin::new(&mut this.up).poll_next(cx) { + Poll::Ready(Some(Ok(item))) => return Poll::Ready(Some(Ok(item))), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), + Poll::Ready(None) => true, + Poll::Pending => false, + }; - match self.receiver.poll() { - Ok(result) => Ok(result), - Err(send_err) => { - // not sure if it can ever happen - warn!("MPSC send error: {:?}", send_err); - Err(std::io::Error::from(std::io::ErrorKind::Other)) + let mut rx = match &mut this.receiver { + None => { + debug_assert!(up_closed); + return Poll::Ready(None); + } + Some(rx) => rx, + }; + + match Pin::new(&mut rx).poll_next(cx) { + Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))), + Poll::Ready(None) | Poll::Pending if up_closed => { + this.receiver = None; + Poll::Ready(None) } + Poll::Ready(None) | Poll::Pending => Poll::Pending, } } } diff --git a/tcp/src/lib.rs b/tcp/src/lib.rs index ae089c680..69c7c4f6d 100644 --- a/tcp/src/lib.rs +++ b/tcp/src/lib.rs @@ -1,49 +1,50 @@ //! jsonrpc server over tcp/ip //! //! ```no_run -//! extern crate jsonrpc_core; -//! extern crate jsonrpc_tcp_server; -//! //! use jsonrpc_core::*; //! use jsonrpc_tcp_server::ServerBuilder; //! //! fn main() { -//! let mut io = IoHandler::default(); -//! io.add_method("say_hello", |_params| { -//! Ok(Value::String("hello".to_string())) -//! }); -//! let server = ServerBuilder::new(io) -//! .start(&"0.0.0.0:0".parse().unwrap()) -//! .expect("Server must start with no issues."); +//! let mut io = IoHandler::default(); +//! io.add_sync_method("say_hello", |_params| { +//! Ok(Value::String("hello".to_string())) +//! }); +//! let server = ServerBuilder::new(io) +//! .start(&"0.0.0.0:0".parse().unwrap()) +//! .expect("Server must start with no issues."); //! -//! server.wait(); +//! server.wait(); //! } //! ``` -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate jsonrpc_server_utils as server_utils; -extern crate parking_lot; -extern crate tokio_service; +use jsonrpc_server_utils as server_utils; -pub extern crate jsonrpc_core; +pub use jsonrpc_core; -#[macro_use] extern crate log; +#[macro_use] +extern crate log; -#[cfg(test)] #[macro_use] extern crate lazy_static; -#[cfg(test)] extern crate env_logger; +#[cfg(test)] +#[macro_use] +extern crate lazy_static; mod dispatch; mod meta; mod server; mod service; -#[cfg(test)] mod logger; -#[cfg(test)] mod tests; +#[cfg(test)] +mod logger; +#[cfg(test)] +mod tests; use jsonrpc_core as jsonrpc; -pub use dispatch::{Dispatcher, PushMessageError}; -pub use meta::{MetaExtractor, RequestContext}; -pub use server::{ServerBuilder, Server}; -pub use self::server_utils::{tokio, codecs::Separator}; +pub(crate) use crate::jsonrpc::futures; + +pub use self::server_utils::{codecs::Separator, tokio}; +pub use crate::dispatch::{Dispatcher, PushMessageError}; +pub use crate::meta::{MetaExtractor, RequestContext}; +pub use crate::server::{Server, ServerBuilder}; diff --git a/tcp/src/logger.rs b/tcp/src/logger.rs index 74f748364..6edd87759 100644 --- a/tcp/src/logger.rs +++ b/tcp/src/logger.rs @@ -1,6 +1,6 @@ -use std::env; -use log::LevelFilter; use env_logger::Builder; +use log::LevelFilter; +use std::env; lazy_static! { static ref LOG_DUMMY: bool = { @@ -8,7 +8,7 @@ lazy_static! { builder.filter(None, LevelFilter::Info); if let Ok(log) = env::var("RUST_LOG") { - builder.parse(&log); + builder.parse_filters(&log); } if let Ok(_) = builder.try_init() { diff --git a/tcp/src/meta.rs b/tcp/src/meta.rs index 795fe9ead..7cf2d4586 100644 --- a/tcp/src/meta.rs +++ b/tcp/src/meta.rs @@ -1,23 +1,23 @@ use std::net::SocketAddr; -use jsonrpc::futures::sync::mpsc; -use jsonrpc::Metadata; +use crate::jsonrpc::{futures::channel::mpsc, Metadata}; /// Request context pub struct RequestContext { /// Peer Address pub peer_addr: SocketAddr, /// Peer Sender channel - pub sender: mpsc::Sender, + pub sender: mpsc::UnboundedSender, } /// Metadata extractor (per session) -pub trait MetaExtractor : Send + Sync { +pub trait MetaExtractor: Send + Sync { /// Extracts metadata from request context fn extract(&self, context: &RequestContext) -> M; } -impl MetaExtractor for F where +impl MetaExtractor for F +where M: Metadata, F: Fn(&RequestContext) -> M + Send + Sync, { @@ -29,5 +29,7 @@ impl MetaExtractor for F where /// Noop-extractor pub struct NoopExtractor; impl MetaExtractor for NoopExtractor { - fn extract(&self, _context: &RequestContext) -> M { M::default() } + fn extract(&self, _context: &RequestContext) -> M { + M::default() + } } diff --git a/tcp/src/server.rs b/tcp/src/server.rs index b5e6e7fbd..40df9ec5f 100644 --- a/tcp/src/server.rs +++ b/tcp/src/server.rs @@ -1,44 +1,51 @@ -use std; +use std::io; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; -use tokio_service::Service as TokioService; +use tower_service::Service as _; -use jsonrpc::{middleware, MetaIoHandler, Metadata, Middleware}; -use jsonrpc::futures::{future, Future, Stream, Sink}; -use jsonrpc::futures::sync::{mpsc, oneshot}; -use server_utils::{ - tokio_codec::Framed, - tokio, reactor, codecs, - SuspendableStream -}; +use crate::futures::{self, future}; +use crate::jsonrpc::{middleware, MetaIoHandler, Metadata, Middleware}; +use crate::server_utils::tokio_stream::wrappers::TcpListenerStream; +use crate::server_utils::{codecs, reactor, tokio, tokio_util::codec::Framed, SuspendableStream}; -use dispatch::{Dispatcher, SenderChannels, PeerMessageQueue}; -use meta::{MetaExtractor, RequestContext, NoopExtractor}; -use service::Service; +use crate::dispatch::{Dispatcher, PeerMessageQueue, SenderChannels}; +use crate::meta::{MetaExtractor, NoopExtractor, RequestContext}; +use crate::service::Service; /// TCP server builder pub struct ServerBuilder = middleware::Noop> { executor: reactor::UninitializedExecutor, handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, channels: Arc, incoming_separator: codecs::Separator, outgoing_separator: codecs::Separator, } -impl + 'static> ServerBuilder { +impl + 'static> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new `ServerBuilder` wih given `IoHandler` - pub fn new(handler: T) -> Self where + pub fn new(handler: T) -> Self + where T: Into>, { Self::with_meta_extractor(handler, NoopExtractor) } } -impl + 'static> ServerBuilder { +impl + 'static> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new `ServerBuilder` wih given `IoHandler` - pub fn with_meta_extractor(handler: T, extractor: E) -> Self where + pub fn with_meta_extractor(handler: T, extractor: E) -> Self + where T: Into>, E: MetaExtractor + 'static, { @@ -53,7 +60,7 @@ impl + 'static> ServerBuilder { } /// Utilize existing event loop executor. - pub fn event_loop_executor(mut self, handle: tokio::runtime::TaskExecutor) -> Self { + pub fn event_loop_executor(mut self, handle: reactor::TaskExecutor) -> Self { self.executor = reactor::UninitializedExecutor::Shared(handle); self } @@ -72,7 +79,7 @@ impl + 'static> ServerBuilder { } /// Starts a new server - pub fn start(self, addr: &SocketAddr) -> std::io::Result { + pub fn start(self, addr: &SocketAddr) -> io::Result { let meta_extractor = self.meta_extractor.clone(); let rpc_handler = self.handler.clone(); let channels = self.channels.clone(); @@ -80,98 +87,95 @@ impl + 'static> ServerBuilder { let outgoing_separator = self.outgoing_separator; let address = addr.to_owned(); let (tx, rx) = std::sync::mpsc::channel(); - let (stop_tx, stop_rx) = oneshot::channel(); + let (stop_tx, stop_rx) = futures::channel::oneshot::channel(); let executor = self.executor.initialize()?; - executor.spawn(future::lazy(move || { - let start = move || { - let listener = tokio::net::TcpListener::bind(&address)?; - let connections = SuspendableStream::new(listener.incoming()); - - let server = connections.for_each(move |socket| { - let peer_addr = socket.peer_addr().expect("Unable to determine socket peer address"); + use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt}; + executor.executor().spawn(async move { + let start = async { + let listener = tokio::net::TcpListener::bind(&address).await?; + let listener = TcpListenerStream::new(listener); + let connections = SuspendableStream::new(listener); + + let server = connections.map(|socket| { + let peer_addr = match socket.peer_addr() { + Ok(addr) => addr, + Err(e) => { + warn!(target: "tcp", "Unable to determine socket peer address, ignoring connection {}", e); + return future::Either::Left(async { io::Result::Ok(()) }); + } + }; trace!(target: "tcp", "Accepted incoming connection from {}", &peer_addr); - let (sender, receiver) = mpsc::channel(65536); + let (sender, receiver) = futures::channel::mpsc::unbounded(); let context = RequestContext { - peer_addr: peer_addr, + peer_addr, sender: sender.clone(), }; let meta = meta_extractor.extract(&context); - let service = Service::new(peer_addr, rpc_handler.clone(), meta); - let (writer, reader) = Framed::new( - socket, - codecs::StreamCodec::new( - incoming_separator.clone(), - outgoing_separator.clone(), - ), - ).split(); - - let responses = reader.and_then( - move |req| service.call(req).then(|response| match response { - Err(e) => { - warn!(target: "tcp", "Error while processing request: {:?}", e); - future::ok(String::new()) - }, - Ok(None) => { - trace!(target: "tcp", "JSON RPC request produced no response"); - future::ok(String::new()) - }, - Ok(Some(response_data)) => { - trace!(target: "tcp", "Sent response: {}", &response_data); - future::ok(response_data) - } - }) - ); - - let peer_message_queue = { + let mut service = Service::new(peer_addr, rpc_handler.clone(), meta); + let (mut writer, reader) = Framed::new( + socket, + codecs::StreamCodec::new(incoming_separator.clone(), outgoing_separator.clone()), + ) + .split(); + + // Work around https://github.com/rust-lang/rust/issues/64552 by boxing the stream type + let responses: Pin> + Send>> = + Box::pin(reader.and_then(move |req| { + service.call(req).then(|response| match response { + Err(e) => { + warn!(target: "tcp", "Error while processing request: {:?}", e); + future::ok(String::new()) + } + Ok(None) => { + trace!(target: "tcp", "JSON RPC request produced no response"); + future::ok(String::new()) + } + Ok(Some(response_data)) => { + trace!(target: "tcp", "Sent response: {}", &response_data); + future::ok(response_data) + } + }) + })); + + let mut peer_message_queue = { let mut channels = channels.lock(); - channels.insert(peer_addr.clone(), sender.clone()); + channels.insert(peer_addr, sender); - PeerMessageQueue::new( - responses, - receiver, - peer_addr.clone(), - ) + PeerMessageQueue::new(responses, receiver, peer_addr) }; let shared_channels = channels.clone(); - let writer = writer.send_all(peer_message_queue).then(move |_| { + let writer = async move { + writer.send_all(&mut peer_message_queue).await?; trace!(target: "tcp", "Peer {}: service finished", peer_addr); let mut channels = shared_channels.lock(); channels.remove(&peer_addr); Ok(()) - }); - - tokio::spawn(writer); + }; - Ok(()) + future::Either::Right(writer) }); Ok(server) }; - let stop = stop_rx.map_err(|_| ()); - match start() { + match start.await { Ok(server) => { tx.send(Ok(())).expect("Rx is blocking parent thread."); - future::Either::A(server.select(stop) - .map(|_| ()) - .map_err(|(e, _)| { - error!("Error while executing the server: {:?}", e); - })) - }, + let server = server.buffer_unordered(1024).for_each(|_| async {}); + + future::select(Box::pin(server), stop_rx).await; + } Err(e) => { tx.send(Err(e)).expect("Rx is blocking parent thread."); - future::Either::B(stop - .map_err(|e| { - error!("Error while executing the server: {:?}", e); - })) - }, + let _ = stop_rx.await; + } } - })); + }); let res = rx.recv().expect("Response is always sent before tx is dropped."); @@ -190,7 +194,7 @@ impl + 'static> ServerBuilder { /// TCP Server handle pub struct Server { executor: Option, - stop: Option>, + stop: Option>, } impl Server { @@ -209,6 +213,20 @@ impl Server { impl Drop for Server { fn drop(&mut self) { let _ = self.stop.take().map(|sg| sg.send(())); - self.executor.take().map(|executor| executor.close()); + if let Some(executor) = self.executor.take() { + executor.close() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn server_is_send_and_sync() { + fn is_send_and_sync() {} + + is_send_and_sync::(); } } diff --git a/tcp/src/service.rs b/tcp/src/service.rs index 532f2d411..a085fe9aa 100644 --- a/tcp/src/service.rs +++ b/tcp/src/service.rs @@ -1,9 +1,11 @@ -use std::sync::Arc; +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; -use tokio_service; - -use jsonrpc::{middleware, FutureResult, Metadata, MetaIoHandler, Middleware}; +use crate::futures; +use crate::jsonrpc::{middleware, MetaIoHandler, Metadata, Middleware}; pub struct Service = middleware::Noop> { handler: Arc>, @@ -13,24 +15,35 @@ pub struct Service = middleware::Noop> { impl> Service { pub fn new(peer_addr: SocketAddr, handler: Arc>, meta: M) -> Self { - Service { peer_addr: peer_addr, handler: handler, meta: meta } + Service { + handler, + peer_addr, + meta, + } } } -impl> tokio_service::Service for Service { +impl> tower_service::Service for Service +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ // These types must match the corresponding protocol types: - type Request = String; type Response = Option; - // For non-streaming protocols, service errors are always io::Error type Error = (); // The future for computing the response; box it for simplicity. - type Future = FutureResult; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } // Produce a future for computing a response from a request. - fn call(&self, req: Self::Request) -> Self::Future { + fn call(&mut self, req: String) -> Self::Future { + use futures::FutureExt; trace!(target: "tcp", "Accepted request from peer {}: {}", &self.peer_addr, req); - self.handler.handle_request(&req, self.meta.clone()) + Box::pin(self.handler.handle_request(&req, self.meta.clone()).map(Ok)) } } diff --git a/tcp/src/tests.rs b/tcp/src/tests.rs index 21326a4f2..36d7c45c8 100644 --- a/tcp/src/tests.rs +++ b/tcp/src/tests.rs @@ -1,79 +1,69 @@ -use std::net::{SocketAddr, Shutdown}; +use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; -use std::time::{Instant, Duration}; +use std::time::Duration; -use jsonrpc::{MetaIoHandler, Value, Metadata}; -use jsonrpc::futures::{self, Future, future}; +use jsonrpc_core::{MetaIoHandler, Metadata, Value}; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; -use server_utils::tokio::{ - timer::Delay, - net::TcpStream, - io::{self}, - self, -}; +use crate::futures; +use crate::server_utils::tokio::{self, net::TcpStream}; use parking_lot::Mutex; -use ServerBuilder; -use MetaExtractor; -use RequestContext; +use crate::MetaExtractor; +use crate::RequestContext; +use crate::ServerBuilder; fn casual_server() -> ServerBuilder { let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_string()))); ServerBuilder::new(io) } +fn run_future(fut: impl std::future::Future + Send) -> O { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(fut) +} + #[test] fn doc_test() { - ::logger::init_log(); + crate::logger::init_log(); let mut io = MetaIoHandler::<()>::default(); - io.add_method("say_hello", |_params| { - Ok(Value::String("hello".to_string())) - }); + io.add_sync_method("say_hello", |_params| Ok(Value::String("hello".to_string()))); let server = ServerBuilder::new(io); - server.start(&SocketAddr::from_str("0.0.0.0:17770").unwrap()) + server + .start(&SocketAddr::from_str("0.0.0.0:17770").unwrap()) .expect("Server must run with no issues") .close() } #[test] fn doc_test_connect() { - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17775".parse().unwrap(); let server = casual_server(); let _server = server.start(&addr).expect("Server must run with no issues"); - let stream = TcpStream::connect(&addr) - .and_then(move |_stream| { - Ok(()) - }) - .map_err(|err| panic!("Server connection error: {:?}", err)); - - tokio::run(stream); + run_future(async move { TcpStream::connect(&addr).await }).expect("Server connection error"); } #[test] fn disconnect() { - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17777".parse().unwrap(); let server = casual_server(); let dispatcher = server.dispatcher(); let _server = server.start(&addr).expect("Server must run with no issues"); - let stream = TcpStream::connect(&addr) - .and_then(move |stream| { - assert_eq!(stream.peer_addr().unwrap(), addr); - stream.shutdown(::std::net::Shutdown::Both) - }) - .map_err(|err| panic!("Error disconnecting: {:?}", err)); - - tokio::run(stream); + run_future(async move { + let mut stream = TcpStream::connect(&addr).await.unwrap(); + assert_eq!(stream.peer_addr().unwrap(), addr); + stream.shutdown().await.unwrap(); + }); ::std::thread::sleep(::std::time::Duration::from_millis(50)); @@ -81,23 +71,22 @@ fn disconnect() { } fn dummy_request(addr: &SocketAddr, data: Vec) -> Vec { - let (ret_tx, ret_rx) = futures::sync::oneshot::channel(); - - let stream = TcpStream::connect(addr) - .and_then(move |stream| { - io::write_all(stream, data) - }) - .and_then(|(stream, _data)| { - stream.shutdown(Shutdown::Write).unwrap(); - io::read_to_end(stream, vec![]) - }) - .and_then(move |(_stream, read_buf)| { - ret_tx.send(read_buf).map_err(|err| panic!("Unable to send {:?}", err)) - }) - .map_err(|err| panic!("Error connecting or closing connection: {:?}", err));; - - tokio::run(stream); - ret_rx.wait().expect("Unable to receive result") + let (ret_tx, ret_rx) = std::sync::mpsc::channel(); + + let stream = async move { + let mut stream = TcpStream::connect(addr).await?; + stream.write_all(&data).await?; + stream.shutdown().await?; + let mut read_buf = vec![]; + let _ = stream.read_to_end(&mut read_buf).await; + + let _ = ret_tx.send(read_buf).map_err(|err| panic!("Unable to send {:?}", err)); + + Ok::<(), Box>(()) + }; + + run_future(stream).unwrap(); + ret_rx.recv().expect("Unable to receive result") } fn dummy_request_str(addr: &SocketAddr, data: Vec) -> String { @@ -106,7 +95,7 @@ fn dummy_request_str(addr: &SocketAddr, data: Vec) -> String { #[test] fn doc_test_handle() { - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17780".parse().unwrap(); let server = casual_server(); @@ -115,20 +104,19 @@ fn doc_test_handle() { let result = dummy_request_str( &addr, b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"[..].to_owned(), - ); + ); assert_eq!( - result, - "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", + result, "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", "Response does not exactly much the expected response", - ); + ); } #[test] fn req_parallel() { use std::thread; - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17782".parse().unwrap(); let server = casual_server(); let _server = server.start(&addr).expect("Server must run with no issues"); @@ -136,22 +124,20 @@ fn req_parallel() { let mut handles = Vec::new(); for _ in 0..6 { let addr = addr.clone(); - handles.push( - thread::spawn(move || { - for _ in 0..100 { - let result = dummy_request_str( - &addr, - b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"[..].to_owned(), - ); - - assert_eq!( - result, - "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", - "Response does not exactly much the expected response", - ); - } - }) - ); + handles.push(thread::spawn(move || { + for _ in 0..100 { + let result = dummy_request_str( + &addr, + b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"[..] + .to_owned(), + ); + + assert_eq!( + result, "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", + "Response does not exactly much the expected response", + ); + } + })); } for handle in handles.drain(..) { @@ -166,7 +152,9 @@ pub struct SocketMetadata { impl Default for SocketMetadata { fn default() -> Self { - SocketMetadata { addr: "0.0.0.0:0".parse().unwrap() } + SocketMetadata { + addr: "0.0.0.0:0".parse().unwrap(), + } } } @@ -176,7 +164,7 @@ impl SocketMetadata { } } -impl Metadata for SocketMetadata { } +impl Metadata for SocketMetadata {} impl From for SocketMetadata { fn from(addr: SocketAddr) -> SocketMetadata { @@ -195,29 +183,27 @@ impl MetaExtractor for PeerMetaExtractor { fn meta_server() -> ServerBuilder { let mut io = MetaIoHandler::::default(); io.add_method_with_meta("say_hello", |_params, meta: SocketMetadata| { - future::ok(Value::String(format!("hello, {}", meta.addr()))) + jsonrpc_core::futures::future::ready(Ok(Value::String(format!("hello, {}", meta.addr())))) }); ServerBuilder::new(io).session_meta_extractor(PeerMetaExtractor) } #[test] fn peer_meta() { - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17785".parse().unwrap(); let server = meta_server(); let _server = server.start(&addr).expect("Server must run with no issues"); let result = dummy_request_str( &addr, - b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"[..].to_owned() - ); + b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"[..].to_owned(), + ); println!("{}", result); // contains random port, so just smoky comparing response length - assert!( - result.len() == 58 || result.len() == 59 - ); + assert!(result.len() == 58 || result.len() == 59); } #[derive(Default)] @@ -236,85 +222,75 @@ impl MetaExtractor for PeerListMetaExtractor { #[test] fn message() { // MASSIVE SETUP - ::logger::init_log(); + crate::logger::init_log(); let addr: SocketAddr = "127.0.0.1:17790".parse().unwrap(); let mut io = MetaIoHandler::::default(); io.add_method_with_meta("say_hello", |_params, _: SocketMetadata| { - future::ok(Value::String("hello".to_owned())) + jsonrpc_core::futures::future::ready(Ok(Value::String("hello".to_owned()))) }); let extractor = PeerListMetaExtractor::default(); let peer_list = extractor.peers.clone(); - let server = ServerBuilder::new(io) - .session_meta_extractor(extractor); + let server = ServerBuilder::new(io).session_meta_extractor(extractor); let dispatcher = server.dispatcher(); let _server = server.start(&addr).expect("Server must run with no issues"); - let delay = Delay::new(Instant::now() + Duration::from_millis(500)) - .map_err(|err| panic!("{:?}", err)); - let message = "ping"; let executed_dispatch = Arc::new(Mutex::new(false)); let executed_request = Arc::new(Mutex::new(false)); let executed_dispatch_move = executed_dispatch.clone(); let executed_request_move = executed_request.clone(); - // CLIENT RUN - let stream = TcpStream::connect(&addr) - .and_then(|stream| { - future::ok(stream).join(delay) - }) - .and_then(move |stream| { - let peer_addr = peer_list.lock()[0].clone(); - dispatcher.push_message( - &peer_addr, - message.to_owned(), - ).expect("Should be sent with no errors"); - trace!(target: "tcp", "Dispatched message for {}", peer_addr); - future::ok(stream) - }) - .and_then(move |(stream, _)| { - // Read message plus newline appended by codec. - io::read_exact(stream, vec![0u8; message.len() + 1]) - }) - .and_then(move |(stream, read_buf)| { - trace!(target: "tcp", "Read ping message"); - let ping_signal = read_buf[..].to_vec(); - - assert_eq!( - format!("{}\n", message), - String::from_utf8(ping_signal).expect("String should be utf-8"), - "Sent request does not match received by the peer", - ); - // ensure that the above assert was actually triggered - *executed_dispatch_move.lock() = true; - - future::ok(stream) - }) - .and_then(|stream| { - // make request AFTER message dispatches - let data = b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"; - io::write_all(stream, &data[..]) - }) - .and_then(|(stream, _)| { - stream.shutdown(Shutdown::Write).unwrap(); - io::read_to_end(stream, Vec::new()) - }) - .and_then(move |(_, read_buf)| { - trace!(target: "tcp", "Read response message"); - let response_signal = read_buf[..].to_vec(); - assert_eq!( - "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", - String::from_utf8(response_signal).expect("String should be utf-8"), - "Response does not match the expected handling", - ); - *executed_request_move.lock() = true; + let client = async move { + let stream = TcpStream::connect(&addr); + let delay = tokio::time::sleep(Duration::from_millis(500)); + let (stream, _) = futures::join!(stream, delay); + let mut stream = stream?; + + let peer_addr = peer_list.lock()[0].clone(); + dispatcher + .push_message(&peer_addr, message.to_owned()) + .expect("Should be sent with no errors"); + trace!(target: "tcp", "Dispatched message for {}", peer_addr); + + // Read message plus newline appended by codec. + let mut read_buf = vec![0u8; message.len() + 1]; + let _ = stream.read_exact(&mut read_buf).await?; + + trace!(target: "tcp", "Read ping message"); + let ping_signal = read_buf[..].to_vec(); + + assert_eq!( + format!("{}\n", message), + String::from_utf8(ping_signal).expect("String should be utf-8"), + "Sent request does not match received by the peer", + ); + // ensure that the above assert was actually triggered + *executed_dispatch_move.lock() = true; + + // make request AFTER message dispatches + let data = b"{\"jsonrpc\": \"2.0\", \"method\": \"say_hello\", \"params\": [42, 23], \"id\": 1}\n"; + stream.write_all(&data[..]).await?; + + stream.shutdown().await.unwrap(); + let mut read_buf = vec![]; + let _ = stream.read_to_end(&mut read_buf).await?; + + trace!(target: "tcp", "Read response message"); + let response_signal = read_buf[..].to_vec(); + assert_eq!( + "{\"jsonrpc\":\"2.0\",\"result\":\"hello\",\"id\":1}\n", + String::from_utf8(response_signal).expect("String should be utf-8"), + "Response does not match the expected handling", + ); + *executed_request_move.lock() = true; + + // delay + Ok::<(), Box>(()) + }; - future::ok(()) - }) - .map_err(|err| panic!("Dispach message error: {:?}", err)); + run_future(client).unwrap(); - tokio::run(stream); assert!(*executed_dispatch.lock()); assert!(*executed_request.lock()); } diff --git a/test/Cargo.toml b/test/Cargo.toml index ca7496812..e3f150450 100644 --- a/test/Cargo.toml +++ b/test/Cargo.toml @@ -1,17 +1,25 @@ [package] name = "jsonrpc-test" description = "Simple test framework for JSON-RPC." -version = "9.0.0" +version = "18.0.0" authors = ["Tomasz DrwiÄ™ga "] license = "MIT" homepage = "https://github.com/paritytech/jsonrpc" repository = "https://github.com/paritytech/jsonrpc" -documentation = "https://paritytech.github.io/jsonrpc/jsonrpc_test/index.html" +documentation = "https://docs.rs/jsonrpc-test/" +edition = "2018" [dependencies] -jsonrpc-core = { path = "../core", version = "9.0" } +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-core-client = { version = "18.0.0", path = "../core-client" } +jsonrpc-pubsub = { version = "18.0.0", path = "../pubsub" } +log = "0.4" serde = "1.0" serde_json = "1.0" +[features] +arbitrary_precision = ["jsonrpc-core-client/arbitrary_precision", "serde_json/arbitrary_precision", "jsonrpc-core/arbitrary_precision"] + [dev-dependencies] -jsonrpc-macros = { path = "../macros", version = "9.0" } +jsonrpc-derive = { version = "18.0.0", path = "../derive" } + diff --git a/test/src/lib.rs b/test/src/lib.rs index 6392eabd5..3143cc2c5 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -1,26 +1,23 @@ //! An utility package to test jsonrpc-core based projects. //! //! ``` -//! #[macro_use] -//! extern crate jsonrpc_macros; +//! use jsonrpc_derive::rpc; +//! use jsonrpc_test as test; //! -//! extern crate jsonrpc_core as core; -//! extern crate jsonrpc_test as test; +//! use jsonrpc_core::{Result, Error, IoHandler}; //! -//! use core::Result; -//! -//! build_rpc_trait! { -//! pub trait Test { -//! #[rpc(name = "rpc_some_method")] -//! fn some_method(&self, u64) -> Result; -//! } +//! #[rpc] +//! pub trait Test { +//! #[rpc(name = "rpc_some_method")] +//! fn some_method(&self, a: u64) -> Result; //! } //! +//! //! struct Dummy; //! impl Test for Dummy { -//! fn some_method(&self, x: u64) -> Result { -//! Ok(x * 2) -//! } +//! fn some_method(&self, x: u64) -> Result { +//! Ok(x * 2) +//! } //! } //! //! fn main() { @@ -32,10 +29,10 @@ //! //! // You can also test RPC created without macros: //! let rpc = { -//! let mut io = core::IoHandler::new(); -//! io.add_method("rpc_test_method", |_| { -//! Err(core::Error::internal_error()) -//! }); +//! let mut io = IoHandler::new(); +//! io.add_sync_method("rpc_test_method", |_| { +//! Err(Error::internal_error()) +//! }); //! test::Rpc::from(io) //! }; //! @@ -46,13 +43,9 @@ //! } //! ``` -#[warn(missing_docs)] +#![deny(missing_docs)] extern crate jsonrpc_core as rpc; -extern crate serde; -extern crate serde_json; - -use std::collections::HashMap; /// Test RPC options. #[derive(Default, Debug)] @@ -70,24 +63,45 @@ pub struct Rpc { pub options: Options, } +/// Encoding format. +pub enum Encoding { + /// Encodes params using `serde::to_string`. + Compact, + /// Encodes params using `serde::to_string_pretty`. + Pretty, +} + impl From for Rpc { fn from(io: rpc::IoHandler) -> Self { - Rpc { io, ..Default::default() } + Rpc { + io, + ..Default::default() + } } } impl Rpc { /// Create a new RPC instance from a single delegate. - pub fn new(delegate: D) -> Self where - D: Into>>, + pub fn new(delegate: D) -> Self + where + D: IntoIterator)>, { let mut io = rpc::IoHandler::new(); io.extend_with(delegate); io.into() } + /// Perform a single, synchronous method call and return pretty-printed value + pub fn request(&self, method: &str, params: &T) -> String + where + T: serde::Serialize, + { + self.make_request(method, params, Encoding::Pretty) + } + /// Perform a single, synchronous method call. - pub fn request(&self, method: &str, params: &T) -> String where + pub fn make_request(&self, method: &str, params: &T, encoding: Encoding) -> String + where T: serde::Serialize, { use self::rpc::types::response; @@ -98,16 +112,23 @@ impl Rpc { serde_json::to_string_pretty(params).expect("Serialization should be infallible."), ); - let response = self.io + let response = self + .io .handle_request_sync(&request) .expect("We are sending a method call not notification."); // extract interesting part from the response - let extracted = match serde_json::from_str(&response).expect("We will always get a single output.") { - response::Output::Success(response::Success { result, .. }) => serde_json::to_string_pretty(&result), - response::Output::Failure(response::Failure { error, .. }) => serde_json::to_string_pretty(&error), - }.expect("Serialization is infallible; qed"); - + let extracted = match rpc::serde_from_str(&response).expect("We will always get a single output.") { + response::Output::Success(response::Success { result, .. }) => match encoding { + Encoding::Compact => serde_json::to_string(&result), + Encoding::Pretty => serde_json::to_string_pretty(&result), + }, + response::Output::Failure(response::Failure { error, .. }) => match encoding { + Encoding::Compact => serde_json::to_string(&error), + Encoding::Pretty => serde_json::to_string_pretty(&error), + }, + } + .expect("Serialization is infallible; qed"); println!("\n{}\n --> {}\n", request, extracted); @@ -120,20 +141,28 @@ mod tests { use super::*; #[test] - fn should_test_simple_method() { + fn should_test_request_is_pretty() { // given let rpc = { let mut io = rpc::IoHandler::new(); - io.add_method("test_method", |_| { - Ok(rpc::Value::Number(5.into())) - }); + io.add_sync_method("test_method", |_| Ok(rpc::Value::Array(vec![5.into(), 10.into()]))); Rpc::from(io) }; // when - assert_eq!( - rpc.request("test_method", &[5u64]), - r#"5"# - ); + assert_eq!(rpc.request("test_method", &[5u64]), "[\n 5,\n 10\n]"); + } + + #[test] + fn should_test_make_request_compact() { + // given + let rpc = { + let mut io = rpc::IoHandler::new(); + io.add_sync_method("test_method", |_| Ok(rpc::Value::Array(vec![5.into(), 10.into()]))); + Rpc::from(io) + }; + + // when + assert_eq!(rpc.make_request("test_method", &[5u64], Encoding::Compact), "[5,10]"); } } diff --git a/ws/Cargo.toml b/ws/Cargo.toml index 2a6e857c5..6c9008fd3 100644 --- a/ws/Cargo.toml +++ b/ws/Cargo.toml @@ -1,21 +1,22 @@ [package] -name = "jsonrpc-ws-server" +authors = ["Parity Technologies "] description = "WebSockets server for JSON-RPC" -version = "9.0.0" -authors = ["tomusdrw "] -license = "MIT" +documentation = "https://docs.rs/jsonrpc-ws-server/" +edition = "2018" homepage = "https://github.com/paritytech/jsonrpc" +license = "MIT" +name = "jsonrpc-ws-server" repository = "https://github.com/paritytech/jsonrpc" -documentation = "https://paritytech.github.io/jsonrpc/json_ws_server/index.html" +version = "18.0.0" [dependencies] -error-chain = "0.12" -jsonrpc-core = { version = "9.0", path = "../core" } -jsonrpc-server-utils = { version = "9.0", path = "../server-utils" } +futures = "0.3" +jsonrpc-core = { version = "18.0.0", path = "../core" } +jsonrpc-server-utils = { version = "18.0.0", path = "../server-utils" } log = "0.4" -parking_lot = "0.6" +parking_lot = "0.11.0" slab = "0.4" -ws = { git = "https://github.com/tomusdrw/ws-rs" } +parity-ws = "0.11" [badges] travis-ci = { repository = "paritytech/jsonrpc", branch = "master"} diff --git a/ws/README.md b/ws/README.md index b95124c44..fed02adcd 100644 --- a/ws/README.md +++ b/ws/README.md @@ -9,14 +9,12 @@ WebSockets server for JSON-RPC 2.0. ``` [dependencies] -jsonrpc-ws-server = { git = "https://github.com/paritytech/jsonrpc" } +jsonrpc-ws-server = "15.0" ``` `main.rs` ```rust -extern crate jsonrpc_ws_server; - use jsonrpc_ws_server::*; use jsonrpc_ws_server::jsonrpc_core::*; diff --git a/ws/examples/ws.rs b/ws/examples/ws.rs index dee6359c2..69fd9a494 100644 --- a/ws/examples/ws.rs +++ b/ws/examples/ws.rs @@ -1,11 +1,9 @@ -extern crate jsonrpc_ws_server; - -use jsonrpc_ws_server::ServerBuilder; use jsonrpc_ws_server::jsonrpc_core::*; +use jsonrpc_ws_server::ServerBuilder; fn main() { let mut io = IoHandler::default(); - io.add_method("say_hello", |_params| { + io.add_sync_method("say_hello", |_params| { println!("Processing"); Ok(Value::String("hello".to_owned())) }); diff --git a/ws/src/error.rs b/ws/src/error.rs index 2602eb770..553c27c50 100644 --- a/ws/src/error.rs +++ b/ws/src/error.rs @@ -1,28 +1,52 @@ -#![allow(missing_docs)] +use std::{error, fmt, io, result}; -use std::io; +use crate::ws; -use ws; +/// WebSockets Server Error +#[derive(Debug)] +pub enum Error { + /// Io Error + Io(io::Error), + /// WebSockets Error + WsError(ws::Error), + /// Connection Closed + ConnectionClosed, +} + +/// WebSockets Server Result +pub type Result = result::Result; -error_chain! { - foreign_links { - Io(io::Error); +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> result::Result<(), fmt::Error> { + match self { + Error::ConnectionClosed => write!(f, "Action on closed connection."), + Error::WsError(err) => write!(f, "WebSockets Error: {}", err), + Error::Io(err) => write!(f, "Io Error: {}", err), + } } +} - errors { - /// Attempted action on closed connection. - ConnectionClosed { - description("connection is closed"), - display("Action on closed connection."), +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Error::Io(io) => Some(io), + Error::WsError(ws) => Some(ws), + Error::ConnectionClosed => None, } } } +impl From for Error { + fn from(err: io::Error) -> Self { + Error::Io(err) + } +} + impl From for Error { fn from(err: ws::Error) -> Self { match err.kind { - ws::ErrorKind::Io(e) => e.into(), - _ => Error::with_chain(err, "WebSockets Error"), + ws::ErrorKind::Io(err) => Error::Io(err), + _ => Error::WsError(err), } } } diff --git a/ws/src/lib.rs b/ws/src/lib.rs index 905884cfd..5c1cfc3ca 100644 --- a/ws/src/lib.rs +++ b/ws/src/lib.rs @@ -1,16 +1,12 @@ //! `WebSockets` server. -#![warn(missing_docs)] +#![deny(missing_docs)] -extern crate jsonrpc_server_utils as server_utils; -extern crate parking_lot; -extern crate slab; +use jsonrpc_server_utils as server_utils; -pub extern crate ws; -pub extern crate jsonrpc_core; +pub use jsonrpc_core; +pub use parity_ws as ws; -#[macro_use] -extern crate error_chain; #[macro_use] extern crate log; @@ -24,12 +20,12 @@ mod tests; use jsonrpc_core as core; -pub use self::error::{Error, ErrorKind, Result}; -pub use self::metadata::{RequestContext, MetaExtractor, NoopExtractor}; -pub use self::session::{RequestMiddleware, MiddlewareAction}; -pub use self::server::{CloseHandle, Server}; +pub use self::error::{Error, Result}; +pub use self::metadata::{MetaExtractor, NoopExtractor, RequestContext}; +pub use self::server::{Broadcaster, CloseHandle, Server}; pub use self::server_builder::ServerBuilder; pub use self::server_utils::cors::Origin; -pub use self::server_utils::hosts::{Host, DomainsValidation}; -pub use self::server_utils::tokio; +pub use self::server_utils::hosts::{DomainsValidation, Host}; pub use self::server_utils::session::{SessionId, SessionStats}; +pub use self::server_utils::tokio; +pub use self::session::{MiddlewareAction, RequestMiddleware}; diff --git a/ws/src/metadata.rs b/ws/src/metadata.rs index 60835c1d4..25b1bed82 100644 --- a/ws/src/metadata.rs +++ b/ws/src/metadata.rs @@ -1,13 +1,16 @@ use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::sync::{atomic, Arc}; +use std::task::{Context, Poll}; -use core::{self, futures}; -use core::futures::sync::mpsc; -use server_utils::{session, tokio::runtime::TaskExecutor}; -use ws; +use crate::core; +use crate::core::futures::channel::mpsc; +use crate::server_utils::{reactor::TaskExecutor, session}; +use crate::ws; -use error; -use {Origin}; +use crate::error; +use crate::Origin; /// Output of WebSocket connection. Use this to send messages to the other endpoint. #[derive(Clone)] @@ -19,24 +22,22 @@ pub struct Sender { impl Sender { /// Creates a new `Sender`. pub fn new(out: ws::Sender, active: Arc) -> Self { - Sender { - out: out, - active: active, - } + Sender { out, active } } fn check_active(&self) -> error::Result<()> { if self.active.load(atomic::Ordering::SeqCst) { Ok(()) } else { - bail!(error::ErrorKind::ConnectionClosed) + Err(error::Error::ConnectionClosed) } } /// Sends a message over the connection. /// Will return error if the connection is not active any more. pub fn send(&self, msg: M) -> error::Result<()> - where M: Into + where + M: Into, { self.check_active()?; self.out.send(msg)?; @@ -45,8 +46,9 @@ impl Sender { /// Sends a message over the endpoints of all connections. /// Will return error if the connection is not active any more. - pub fn broadcast(&self, msg: M) -> error::Result<()> where - M: Into + pub fn broadcast(&self, msg: M) -> error::Result<()> + where + M: Into, { self.check_active()?; self.out.broadcast(msg)?; @@ -79,10 +81,10 @@ pub struct RequestContext { impl RequestContext { /// Get this session as a `Sink` spawning a new future /// in the underlying event loop. - pub fn sender(&self) -> mpsc::Sender { + pub fn sender(&self) -> mpsc::UnboundedSender { let out = self.out.clone(); - let (sender, receiver) = mpsc::channel(1); - self.executor.spawn(SenderFuture(out, receiver)); + let (sender, receiver) = mpsc::unbounded(); + self.executor.spawn(SenderFuture(out, Box::new(receiver))); sender } } @@ -103,7 +105,8 @@ pub trait MetaExtractor: Send + Sync + 'static { fn extract(&self, _context: &RequestContext) -> M; } -impl MetaExtractor for F where +impl MetaExtractor for F +where M: core::Metadata, F: Fn(&RequestContext) -> M + Send + Sync + 'static, { @@ -116,32 +119,30 @@ impl MetaExtractor for F where #[derive(Debug, Clone)] pub struct NoopExtractor; impl MetaExtractor for NoopExtractor { - fn extract(&self, _context: &RequestContext) -> M { M::default() } + fn extract(&self, _context: &RequestContext) -> M { + M::default() + } } -struct SenderFuture(Sender, mpsc::Receiver); -impl futures::Future for SenderFuture { - type Item = (); - type Error = (); +struct SenderFuture(Sender, Box + Send + Unpin>); + +impl Future for SenderFuture { + type Output = (); - fn poll(&mut self) -> futures::Poll { - use self::futures::Stream; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use futures::Stream; + let this = Pin::into_inner(self); loop { - let item = self.1.poll()?; - match item { - futures::Async::NotReady => { - return Ok(futures::Async::NotReady); - }, - futures::Async::Ready(None) => { - return Ok(futures::Async::Ready(())); - }, - futures::Async::Ready(Some(val)) => { - if let Err(e) = self.0.send(val) { + match Pin::new(&mut this.1).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(val)) => { + if let Err(e) = this.0.send(val) { warn!("Error sending a subscription update: {:?}", e); - return Ok(futures::Async::Ready(())); + return Poll::Ready(()); } - }, + } } } } diff --git a/ws/src/server.rs b/ws/src/server.rs index 74a2f6435..e5d795f2b 100644 --- a/ws/src/server.rs +++ b/ws/src/server.rs @@ -1,18 +1,18 @@ -use std::fmt; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use std::thread; +use std::{cmp, fmt}; -use core; -use server_utils::cors::Origin; -use server_utils::hosts::{self, Host}; -use server_utils::reactor::{UninitializedExecutor, Executor}; -use server_utils::session::SessionStats; -use ws; +use crate::core; +use crate::server_utils::cors::Origin; +use crate::server_utils::hosts::{self, Host}; +use crate::server_utils::reactor::{Executor, UninitializedExecutor}; +use crate::server_utils::session::SessionStats; +use crate::ws; -use error::{Error, Result}; -use metadata; -use session; +use crate::error::{Error, Result}; +use crate::metadata; +use crate::session; /// `WebSockets` server implementation. pub struct Server { @@ -23,13 +23,13 @@ pub struct Server { } impl fmt::Debug for Server { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Server") .field("addr", &self.addr) .field("handle", &self.handle) .field("executor", &self.executor) .finish() - } + } } impl Server { @@ -38,28 +38,43 @@ impl Server { &self.addr } + /// Returns a Broadcaster that can be used to send messages on all connections. + pub fn broadcaster(&self) -> Broadcaster { + Broadcaster { + broadcaster: self.broadcaster.clone(), + } + } + /// Starts a new `WebSocket` server in separate thread. /// Returns a `Server` handle which closes the server when droped. pub fn start>( addr: &SocketAddr, handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, allowed_origins: Option>, allowed_hosts: Option>, - request_middleware: Option>, - stats: Option>, + request_middleware: Option>, + stats: Option>, executor: UninitializedExecutor, max_connections: usize, max_payload_bytes: usize, - ) -> Result { + max_in_buffer_capacity: usize, + max_out_buffer_capacity: usize, + ) -> Result + where + S::Future: Unpin, + S::CallFuture: Unpin, + { let config = { let mut config = ws::Settings::default(); config.max_connections = max_connections; // don't accept super large requests - config.max_in_buffer = max_payload_bytes; + config.max_fragment_size = max_payload_bytes; + config.in_buffer_capacity_hard_limit = max_in_buffer_capacity; + config.out_buffer_capacity_hard_limit = max_out_buffer_capacity; // don't grow non-final fragments (to prevent DOS) config.fragments_grow = false; - config.fragments_capacity = max_payload_bytes / config.fragment_size; + config.fragments_capacity = cmp::max(1, max_payload_bytes / config.fragment_size); // accept only handshakes beginning with GET config.method_strict = true; // require masking @@ -78,7 +93,13 @@ impl Server { // Create WebSocket let ws = ws::Builder::new().with_settings(config).build(session::Factory::new( - handler, meta_extractor, allowed_origins, allowed_hosts, request_middleware, stats, executor + handler, + meta_extractor, + allowed_origins, + allowed_hosts, + request_middleware, + stats, + executor, ))?; let broadcaster = ws.broadcaster(); @@ -88,14 +109,12 @@ impl Server { debug!("Bound to local address: {}", local_addr); // Spawn a thread with event loop - let handle = thread::spawn(move || { - match ws.run().map_err(Error::from) { - Err(error) => { - error!("Error while running websockets server. Details: {:?}", error); - Err(error) - }, - Ok(_server) => Ok(()), + let handle = thread::spawn(move || match ws.run().map_err(Error::from) { + Err(error) => { + error!("Error while running websockets server. Details: {:?}", error); + Err(error) } + Ok(_server) => Ok(()), }); // Return a handle @@ -103,7 +122,7 @@ impl Server { addr: local_addr, handle: Some(handle), executor: Arc::new(Mutex::new(Some(eloop))), - broadcaster: broadcaster, + broadcaster, }) } } @@ -111,7 +130,11 @@ impl Server { impl Server { /// Consumes the server and waits for completion pub fn wait(mut self) -> Result<()> { - self.handle.take().expect("Handle is always Some at start.").join().expect("Non-panic exit") + self.handle + .take() + .expect("Handle is always Some at start.") + .join() + .expect("Non-panic exit") } /// Closes the server and waits for it to finish @@ -136,7 +159,6 @@ impl Drop for Server { } } - /// A handle that allows closing of a server even if it owned by a thread blocked in `wait`. #[derive(Clone)] pub struct CloseHandle { @@ -148,6 +170,31 @@ impl CloseHandle { /// Closes the `Server`. pub fn close(self) { let _ = self.broadcaster.shutdown(); - self.executor.lock().unwrap().take().map(|executor| executor.close()); + if let Some(executor) = self.executor.lock().unwrap().take() { + executor.close() + } + } +} + +/// A Broadcaster that can be used to send messages on all connections. +#[derive(Clone)] +pub struct Broadcaster { + broadcaster: ws::Sender, +} + +impl Broadcaster { + /// Send a message to the endpoints of all connections. + #[inline] + pub fn send(&self, msg: M) -> Result<()> + where + M: Into, + { + match self.broadcaster.send(msg).map_err(Error::from) { + Err(error) => { + error!("Error while running sending. Details: {:?}", error); + Err(error) + } + Ok(_server) => Ok(()), + } } } diff --git a/ws/src/server_builder.rs b/ws/src/server_builder.rs index 2a9760f44..2bfdcaadc 100644 --- a/ws/src/server_builder.rs +++ b/ws/src/server_builder.rs @@ -1,43 +1,54 @@ use std::net::SocketAddr; use std::sync::Arc; -use core; -use server_utils; -use server_utils::cors::Origin; -use server_utils::hosts::{Host, DomainsValidation}; -use server_utils::reactor::UninitializedExecutor; -use server_utils::session::SessionStats; - -use error::Result; -use metadata::{MetaExtractor, NoopExtractor}; -use server::Server; -use session; +use crate::core; +use crate::server_utils::cors::Origin; +use crate::server_utils::hosts::{DomainsValidation, Host}; +use crate::server_utils::reactor::{self, UninitializedExecutor}; +use crate::server_utils::session::SessionStats; + +use crate::error::Result; +use crate::metadata::{MetaExtractor, NoopExtractor}; +use crate::server::Server; +use crate::session; /// Builder for `WebSockets` server pub struct ServerBuilder> { handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, allowed_origins: Option>, allowed_hosts: Option>, - request_middleware: Option>, - session_stats: Option>, + request_middleware: Option>, + session_stats: Option>, executor: UninitializedExecutor, max_connections: usize, max_payload_bytes: usize, + max_in_buffer_capacity: usize, + max_out_buffer_capacity: usize, } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new `ServerBuilder` - pub fn new(handler: T) -> Self where + pub fn new(handler: T) -> Self + where T: Into>, { Self::with_meta_extractor(handler, NoopExtractor) } } -impl> ServerBuilder { +impl> ServerBuilder +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ /// Creates new `ServerBuilder` - pub fn with_meta_extractor(handler: T, extractor: E) -> Self where + pub fn with_meta_extractor(handler: T, extractor: E) -> Self + where T: Into>, E: MetaExtractor, { @@ -51,11 +62,13 @@ impl> ServerBuilder { executor: UninitializedExecutor::Unspawned, max_connections: 100, max_payload_bytes: 5 * 1024 * 1024, + max_in_buffer_capacity: 10 * 1024 * 1024, + max_out_buffer_capacity: 10 * 1024 * 1024, } } /// Utilize existing event loop executor to poll RPC results. - pub fn event_loop_executor(mut self, executor: server_utils::tokio::runtime::TaskExecutor) -> Self { + pub fn event_loop_executor(mut self, executor: reactor::TaskExecutor) -> Self { self.executor = UninitializedExecutor::Shared(executor); self } @@ -105,6 +118,20 @@ impl> ServerBuilder { self } + /// The maximum size to which the incoming buffer can grow. + /// Default: 10,485,760 + pub fn max_in_buffer_capacity(mut self, max_in_buffer_capacity: usize) -> Self { + self.max_in_buffer_capacity = max_in_buffer_capacity; + self + } + + /// The maximum size to which the outgoing buffer can grow. + /// Default: 10,485,760 + pub fn max_out_buffer_capacity(mut self, max_out_buffer_capacity: usize) -> Self { + self.max_out_buffer_capacity = max_out_buffer_capacity; + self + } + /// Starts a new `WebSocket` server in separate thread. /// Returns a `Server` handle which closes the server when droped. pub fn start(self, addr: &SocketAddr) -> Result { @@ -119,7 +146,51 @@ impl> ServerBuilder { self.executor, self.max_connections, self.max_payload_bytes, + self.max_in_buffer_capacity, + self.max_out_buffer_capacity, ) } +} +#[cfg(test)] +mod tests { + use super::*; + + fn basic_server_builder() -> ServerBuilder<(), jsonrpc_core::middleware::Noop> { + let io = core::IoHandler::default(); + ServerBuilder::new(io) + } + #[test] + fn config_usize_vals_have_correct_defaults() { + let server = basic_server_builder(); + + assert_eq!(server.max_connections, 100); + assert_eq!(server.max_payload_bytes, 5 * 1024 * 1024); + assert_eq!(server.max_in_buffer_capacity, 10 * 1024 * 1024); + assert_eq!(server.max_out_buffer_capacity, 10 * 1024 * 1024); + } + + #[test] + fn config_usize_vals_can_be_set() { + let server = basic_server_builder(); + + // We can set them individually + let server = server.max_connections(10); + assert_eq!(server.max_connections, 10); + + let server = server.max_payload(29); + assert_eq!(server.max_payload_bytes, 29); + + let server = server.max_in_buffer_capacity(38); + assert_eq!(server.max_in_buffer_capacity, 38); + + let server = server.max_out_buffer_capacity(47); + assert_eq!(server.max_out_buffer_capacity, 47); + + // Setting values consecutively does not impact other values + assert_eq!(server.max_connections, 10); + assert_eq!(server.max_payload_bytes, 29); + assert_eq!(server.max_in_buffer_capacity, 38); + assert_eq!(server.max_out_buffer_capacity, 47); + } } diff --git a/ws/src/session.rs b/ws/src/session.rs index 512a10e0c..eef7b0e9e 100644 --- a/ws/src/session.rs +++ b/ws/src/session.rs @@ -1,22 +1,25 @@ -use std; +use std::future::Future; +use std::pin::Pin; use std::sync::{atomic, Arc}; +use std::task::{Context, Poll}; -use core; -use core::futures::{Async, Future, Poll}; -use core::futures::sync::oneshot; +use crate::core; +use futures::channel::oneshot; +use futures::future; +use futures::FutureExt; use parking_lot::Mutex; use slab::Slab; -use server_utils::Pattern; -use server_utils::cors::Origin; -use server_utils::hosts::Host; -use server_utils::session::{SessionId, SessionStats}; -use server_utils::tokio::runtime::TaskExecutor; -use ws; +use crate::server_utils::cors::Origin; +use crate::server_utils::hosts::Host; +use crate::server_utils::reactor::TaskExecutor; +use crate::server_utils::session::{SessionId, SessionStats}; +use crate::server_utils::Pattern; +use crate::ws; -use error; -use metadata; +use crate::error; +use crate::metadata; /// Middleware to intercept server requests. /// You can either terminate the request (by returning a response) @@ -26,7 +29,8 @@ pub trait RequestMiddleware: Send + Sync + 'static { fn process(&self, req: &ws::Request) -> MiddlewareAction; } -impl RequestMiddleware for F where +impl RequestMiddleware for F +where F: Fn(&ws::Request) -> Option + Send + Sync + 'static, { fn process(&self, req: &ws::Request) -> MiddlewareAction { @@ -73,7 +77,11 @@ impl MiddlewareAction { impl From> for MiddlewareAction { fn from(opt: Option) -> Self { match opt { - Some(res) => MiddlewareAction::Respond { response: res, validate_origin: true, validate_hosts: true }, + Some(res) => MiddlewareAction::Respond { + response: res, + validate_origin: true, + validate_hosts: true, + }, None => MiddlewareAction::Proceed, } } @@ -109,21 +117,25 @@ impl LivenessPoll { (index, rx) }; - LivenessPoll { task_slab: task_slab, slab_handle: index, rx: rx } + LivenessPoll { + task_slab, + slab_handle: index, + rx, + } } } impl Future for LivenessPoll { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = Pin::into_inner(self); // if the future resolves ok then we've been signalled to return. // it should never be cancelled, but if it was the session definitely // isn't live. - match self.rx.poll() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), + match Pin::new(&mut this.rx).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, } } } @@ -139,11 +151,11 @@ pub struct Session> { active: Arc, context: metadata::RequestContext, handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, allowed_origins: Option>, allowed_hosts: Option>, - request_middleware: Option>, - stats: Option>, + request_middleware: Option>, + stats: Option>, metadata: Option, executor: TaskExecutor, task_slab: Arc, @@ -152,7 +164,9 @@ pub struct Session> { impl> Drop for Session { fn drop(&mut self) { self.active.store(false, atomic::Ordering::SeqCst); - self.stats.as_ref().map(|stats| stats.close_session(self.context.session_id)); + if let Some(stats) = self.stats.as_ref() { + stats.close_session(self.context.session_id) + } // signal to all still-live tasks that the session has been dropped. for (_index, task) in self.task_slab.lock().iter_mut() { @@ -194,7 +208,11 @@ impl> Session { } } -impl> ws::Handler for Session { +impl> ws::Handler for Session +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ fn on_request(&mut self, req: &ws::Request) -> ws::Result { // Run middleware let action = if let Some(ref middleware) = self.request_middleware { @@ -218,8 +236,12 @@ impl> ws::Handler for Session { } } - self.context.origin = origin.and_then(|origin| ::std::str::from_utf8(origin).ok()).map(Into::into); - self.context.protocols = req.protocols().ok() + self.context.origin = origin + .and_then(|origin| ::std::str::from_utf8(origin).ok()) + .map(Into::into); + self.context.protocols = req + .protocols() + .ok() .map(|protos| protos.into_iter().map(Into::into).collect()) .unwrap_or_else(Vec::new); self.metadata = Some(self.meta_extractor.extract(&self.context)); @@ -238,7 +260,10 @@ impl> ws::Handler for Session { fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> { let req = msg.as_text()?; let out = self.context.out.clone(); - let metadata = self.metadata.clone().expect("Metadata is always set in on_request; qed"); + let metadata = self + .metadata + .clone() + .expect("Metadata is always set in on_request; qed"); // TODO: creation requires allocating a `oneshot` channel and acquiring a // mutex. we could alternatively do this lazily upon first poll if @@ -246,28 +271,27 @@ impl> ws::Handler for Session { let poll_liveness = LivenessPoll::create(self.task_slab.clone()); let active_lock = self.active.clone(); - let future = self.handler.handle_request(req, metadata) - .map(move |response| { - if !active_lock.load(atomic::Ordering::SeqCst) { - return; - } - if let Some(result) = response { - let res = out.send(result); - match res { - Err(error::Error(error::ErrorKind::ConnectionClosed, _)) => { - active_lock.store(false, atomic::Ordering::SeqCst); - }, - Err(e) => { - warn!("Error while sending response: {:?}", e); - }, - _ => {}, + let response = self.handler.handle_request(req, metadata); + + let future = response.map(move |response| { + if !active_lock.load(atomic::Ordering::SeqCst) { + return; + } + if let Some(result) = response { + let res = out.send(result); + match res { + Err(error::Error::ConnectionClosed) => { + active_lock.store(false, atomic::Ordering::SeqCst); } + Err(e) => { + warn!("Error while sending response: {:?}", e); + } + _ => {} } - }) - .select(poll_liveness) - .map(|_| ()) - .map_err(|_| ()); + } + }); + let future = future::select(future, poll_liveness); self.executor.spawn(future); Ok(()) @@ -277,43 +301,49 @@ impl> ws::Handler for Session { pub struct Factory> { session_id: SessionId, handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, allowed_origins: Option>, allowed_hosts: Option>, - request_middleware: Option>, - stats: Option>, + request_middleware: Option>, + stats: Option>, executor: TaskExecutor, } impl> Factory { pub fn new( handler: Arc>, - meta_extractor: Arc>, + meta_extractor: Arc>, allowed_origins: Option>, allowed_hosts: Option>, - request_middleware: Option>, - stats: Option>, + request_middleware: Option>, + stats: Option>, executor: TaskExecutor, ) -> Self { Factory { session_id: 0, - handler: handler, - meta_extractor: meta_extractor, - allowed_origins: allowed_origins, - allowed_hosts: allowed_hosts, - request_middleware: request_middleware, - stats: stats, - executor: executor, + handler, + meta_extractor, + allowed_origins, + allowed_hosts, + request_middleware, + stats, + executor, } } } -impl> ws::Factory for Factory { +impl> ws::Factory for Factory +where + S::Future: Unpin, + S::CallFuture: Unpin, +{ type Handler = Session; fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler { self.session_id += 1; - self.stats.as_ref().map(|stats| stats.open_session(self.session_id)); + if let Some(executor) = self.stats.as_ref() { + executor.open_session(self.session_id) + } let active = Arc::new(atomic::AtomicBool::new(true)); Session { @@ -338,7 +368,8 @@ impl> ws::Factory for Factory { } } -fn header_is_allowed(allowed: &Option>, header: Option<&[u8]>) -> bool where +fn header_is_allowed(allowed: &Option>, header: Option<&[u8]>) -> bool +where T: Pattern, { let header = header.map(std::str::from_utf8); @@ -352,22 +383,21 @@ fn header_is_allowed(allowed: &Option>, header: Option<&[u8]>) -> bool (Some(Ok(val)), Some(values)) => { for v in values { if v.matches(val) { - return true + return true; } } false - }, + } // Disallow in other cases _ => false, } } - fn forbidden(title: &str, message: &str) -> ws::Response { let mut forbidden = ws::Response::new(403, "Forbidden", format!("{}\n{}\n", title, message).into_bytes()); { let headers = forbidden.headers_mut(); - headers.push(("Connection".to_owned(), "close".as_bytes().to_vec())); + headers.push(("Connection".to_owned(), b"close".to_vec())); } forbidden } diff --git a/ws/src/tests.rs b/ws/src/tests.rs index 2667ebd31..0d2928460 100644 --- a/ws/src/tests.rs +++ b/ws/src/tests.rs @@ -1,18 +1,17 @@ use std::io::{Read, Write}; -use std::net::{TcpStream, Ipv4Addr}; +use std::net::{Ipv4Addr, TcpStream}; use std::str::Lines; -use std::sync::{mpsc, Arc}; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{mpsc, Arc}; use std::thread; use std::time::Duration; -use core; -use core::futures::Future; -use server_utils::hosts::DomainsValidation; -use ws; +use crate::core; +use crate::server_utils::hosts::DomainsValidation; +use crate::ws; -use server::Server; -use server_builder::ServerBuilder; +use crate::server::Server; +use crate::server_builder::ServerBuilder; struct Response { status: String, @@ -28,9 +27,9 @@ impl Response { let body = Self::read_block(&mut lines); Response { - status: status, + status, _headers: headers, - body: body, + body, } } @@ -43,7 +42,7 @@ impl Response { Some(v) => { block.push_str(v); block.push_str("\n"); - }, + } } } block @@ -61,29 +60,27 @@ fn request(server: Server, request: &str) -> Response { } fn serve(port: u16) -> (Server, Arc) { - use std::time::Duration; - use core::futures::sync::oneshot; - + use futures::{channel::oneshot, future, FutureExt}; let pending = Arc::new(AtomicUsize::new(0)); let counter = pending.clone(); let mut io = core::IoHandler::default(); - io.add_method("hello", |_params: core::Params| Ok(core::Value::String("world".into()))); + io.add_sync_method("hello", |_params: core::Params| Ok(core::Value::String("world".into()))); io.add_method("hello_async", |_params: core::Params| { - core::futures::finished(core::Value::String("world".into())) + future::ready(Ok(core::Value::String("world".into()))) }); io.add_method("record_pending", move |_params: core::Params| { counter.fetch_add(1, Ordering::SeqCst); let (send, recv) = oneshot::channel(); - ::std::thread::spawn(move || { - ::std::thread::sleep(Duration::from_millis(500)); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(500)); let _ = send.send(()); }); let counter = counter.clone(); - recv.then(move |res| { + recv.map(move |res| { if res.is_ok() { counter.fetch_sub(1, Ordering::SeqCst); } @@ -113,15 +110,16 @@ fn should_disallow_not_whitelisted_origins() { let (server, _) = serve(30001); // when - let response = request(server, + let response = request( + server, "\ - GET / HTTP/1.1\r\n\ - Host: 127.0.0.1:30001\r\n\ - Origin: http://test.io\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - " + GET / HTTP/1.1\r\n\ + Host: 127.0.0.1:30001\r\n\ + Origin: http://test.io\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ", ); // then @@ -134,14 +132,15 @@ fn should_disallow_not_whitelisted_hosts() { let (server, _) = serve(30002); // when - let response = request(server, + let response = request( + server, "\ - GET / HTTP/1.1\r\n\ - Host: myhost:30002\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - " + GET / HTTP/1.1\r\n\ + Host: myhost:30002\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ", ); // then @@ -154,15 +153,16 @@ fn should_allow_whitelisted_origins() { let (server, _) = serve(30003); // when - let response = request(server, + let response = request( + server, "\ - GET / HTTP/1.1\r\n\ - Host: 127.0.0.1:30003\r\n\ - Origin: https://parity.io\r\n\ - Connection: close\r\n\ - \r\n\ - {}\r\n\ - " + GET / HTTP/1.1\r\n\ + Host: 127.0.0.1:30003\r\n\ + Origin: https://parity.io\r\n\ + Connection: close\r\n\ + \r\n\ + {}\r\n\ + ", ); // then @@ -175,15 +175,16 @@ fn should_intercept_in_middleware() { let (server, _) = serve(30004); // when - let response = request(server, + let response = request( + server, "\ - GET /intercepted HTTP/1.1\r\n\ - Host: 127.0.0.1:30004\r\n\ - Origin: https://parity.io\r\n\ - Connection: close\r\n\ - \r\n\ - {}\r\n\ - " + GET /intercepted HTTP/1.1\r\n\ + Host: 127.0.0.1:30004\r\n\ + Origin: https://parity.io\r\n\ + Connection: close\r\n\ + \r\n\ + {}\r\n\ + ", ); // then @@ -193,22 +194,25 @@ fn should_intercept_in_middleware() { #[test] fn drop_session_should_cancel() { - use ws::{connect, CloseCode}; + use crate::ws::{connect, CloseCode}; // given let (_server, incomplete) = serve(30005); // when connect("ws://127.0.0.1:30005", |out| { - out.send(r#"{"jsonrpc":"2.0", "method":"record_pending", "params": [], "id": 1}"#).unwrap(); + out.send(r#"{"jsonrpc":"2.0", "method":"record_pending", "params": [], "id": 1}"#) + .unwrap(); let incomplete = incomplete.clone(); - move |_| { + move |_| { assert_eq!(incomplete.load(Ordering::SeqCst), 0); - out.send(r#"{"jsonrpc":"2.0", "method":"record_pending", "params": [], "id": 2}"#).unwrap(); + out.send(r#"{"jsonrpc":"2.0", "method":"record_pending", "params": [], "id": 2}"#) + .unwrap(); out.close(CloseCode::Normal) } - }).unwrap(); + }) + .unwrap(); // then let mut i = 0; @@ -217,7 +221,6 @@ fn drop_session_should_cancel() { i += 1; } assert_eq!(incomplete.load(Ordering::SeqCst), 1); - } #[test] @@ -240,6 +243,8 @@ fn close_handle_makes_wait_return() { thread::sleep(Duration::from_secs(1)); close_handle.close(); - let result = rx.recv_timeout(Duration::from_secs(10)).expect("Expected server to close"); + let result = rx + .recv_timeout(Duration::from_secs(10)) + .expect("Expected server to close"); assert!(result.is_ok()); }