diff --git a/hrana-client-proto/src/batch.rs b/hrana-client-proto/src/batch.rs new file mode 100644 index 0000000..377d1cd --- /dev/null +++ b/hrana-client-proto/src/batch.rs @@ -0,0 +1,82 @@ +use serde::{Deserialize, Serialize}; + +use crate::{stmt::StmtResult, Error, Stmt}; + +/// The request type for a `Batch` +#[derive(Serialize, Debug)] +pub struct BatchReq { + pub stream_id: i32, + pub batch: Batch, +} + +/// A `Batch` allows to group multiple `Stmt` to be executed together. Execution of the steps in +/// the batch is controlled with the `BatchCond`. +#[derive(Serialize, Debug, Default)] +pub struct Batch { + steps: Vec, +} + +impl Batch { + /// Creates a new, empty batch + pub fn new() -> Self { + Self { steps: Vec::new() } + } + + /// Adds a step to the batch, with an optional condition. + /// + /// The `condition` specifies whether of not this batch step should be executed, based on the + /// execution of previous steps. + /// If `condition` is `None`, then the step is executed unconditionally. + /// + /// ## Example: + /// ```ignore + /// let mut batch = Batch::new(); + /// // A step that is executed unconditionally + /// batch.step(None, Stmt::new("SELECT * FROM user", true)); + /// // A step that is executed only if the first step was an error + /// batch.step(Some(BatchCond::Error { step: 0 }), Stmt::new("ROLLBACK", false)); + /// ``` + pub fn step(&mut self, condition: Option, stmt: Stmt) { + self.steps.push(BatchStep { condition, stmt }); + } +} + +/// An execution step in a `Batch` +#[derive(Serialize, Debug)] +pub struct BatchStep { + condition: Option, + stmt: Stmt, +} + +/// Represents a condition determining whether a batch step should be executed. +#[derive(Serialize, Debug)] +pub enum BatchCond { + /// Evaluated to true is step `step` was a success + Ok { step: i32 }, + /// Evaluated to true is step `step` was a error + Error { step: i32 }, + /// Evaluates to the negation of `cond` + Not { cond: Box }, + /// Evaluates to the conjunction of `conds` + And { conds: Vec }, + /// Evaluates to the disjunction of `conds` + Or { conds: Vec }, +} + +/// The response type for a `BatchReq` request +#[derive(Deserialize, Debug)] +pub struct BatchResp { + pub result: BatchResult, +} + +/// The result of the execution of a batch. +/// For a given step `i`, is possible for both `step_results[i]` and `step_errors[i]` to be +/// `None`, if that step was skipped because of a negative condition. +/// For a given step `i`, it is not possible for `step_results[i]` and `step_errors[i]` to be `Some` at the same time +#[derive(Deserialize, Debug)] +pub struct BatchResult { + /// The success result for all steps in the batch + pub step_results: Vec>, + /// The error result for all steps in the batch + pub step_errors: Vec>, +} diff --git a/hrana-client-proto/src/lib.rs b/hrana-client-proto/src/lib.rs index b0e51b6..9998b1e 100644 --- a/hrana-client-proto/src/lib.rs +++ b/hrana-client-proto/src/lib.rs @@ -1,199 +1,19 @@ -//! Messages in the Hrana protocol. -//! -//! Please consult the Hrana specification in the `docs/` directory for more information. +///! # hrana protocol +///! This crate defines the hrana protocal types. The hrana protocol is documented [here](https://github.com/libsql/sqld/blob/main/docs/HRANA_SPEC.md). use std::fmt; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ClientMsg { - Hello { jwt: Option }, - Request { request_id: i32, request: Request }, -} - -#[derive(Deserialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ServerMsg { - HelloOk {}, - HelloError { error: Error }, - ResponseOk { request_id: i32, response: Response }, - ResponseError { request_id: i32, error: Error }, -} - -#[derive(Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum Request { - OpenStream(OpenStreamReq), - CloseStream(CloseStreamReq), - Execute(ExecuteReq), - Batch(BatchReq), -} - -#[derive(Deserialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum Response { - OpenStream(OpenStreamResp), - CloseStream(CloseStreamResp), - Execute(ExecuteResp), - Batch(BatchResp), -} - -#[derive(Serialize, Debug)] -pub struct OpenStreamReq { - pub stream_id: i32, -} - -#[derive(Deserialize, Debug)] -pub struct OpenStreamResp {} - -#[derive(Serialize, Debug)] -pub struct CloseStreamReq { - pub stream_id: i32, -} - -#[derive(Deserialize, Debug)] -pub struct CloseStreamResp {} - -#[derive(Serialize, Debug)] -pub struct ExecuteReq { - pub stream_id: i32, - pub stmt: Stmt, -} - -#[derive(Deserialize, Debug)] -pub struct ExecuteResp { - pub result: StmtResult, -} - -#[derive(Serialize, Debug)] -pub struct Stmt { - pub sql: String, - #[serde(default)] - pub args: Vec, - #[serde(default)] - pub named_args: Vec, - pub want_rows: bool, -} - -impl Stmt { - pub fn new(sql: impl Into, want_rows: bool) -> Self { - let sql = sql.into(); - Self { - sql, - want_rows, - named_args: Vec::new(), - args: Vec::new(), - } - } - - pub fn bind(&mut self, val: Value) { - self.args.push(val); - } - - pub fn bind_named(&mut self, name: String, value: Value) { - self.named_args.push(NamedArg { name, value }); - } -} - -#[derive(Serialize, Debug)] -pub struct NamedArg { - pub name: String, - pub value: Value, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct StmtResult { - pub cols: Vec, - pub rows: Vec>, - pub affected_row_count: u64, - #[serde(with = "option_i64_as_str")] - pub last_insert_rowid: Option, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Col { - pub name: Option, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum Value { - Null, - Integer { - #[serde(with = "i64_as_str")] - value: i64, - }, - Float { - value: f64, - }, - Text { - value: String, - }, - Blob { - #[serde(with = "bytes_as_base64", rename = "base64")] - value: Vec, - }, -} - -#[derive(Serialize, Debug)] -pub struct BatchReq { - pub stream_id: i32, - pub batch: Batch, -} - -#[derive(Serialize, Debug, Default)] -pub struct Batch { - steps: Vec, -} - -impl Batch { - pub fn new() -> Self { - Self { steps: Vec::new() } - } - - pub fn step(&mut self, condition: Option, stmt: Stmt) { - self.steps.push(BatchStep { condition, stmt }); - } -} - -#[derive(Serialize, Debug)] -pub struct BatchStep { - condition: Option, - stmt: Stmt, -} - -#[derive(Serialize, Debug)] -pub enum BatchCond { - Ok { step: i32 }, - Error { step: i32 }, - Not { cond: Box }, - And { conds: Vec }, - Or { conds: Vec }, -} +mod batch; +mod message; +mod serde_utils; +mod stmt; +mod value; -#[derive(Deserialize, Debug)] -pub struct BatchResp { - pub result: BatchResult, -} - -#[derive(Deserialize, Debug)] -pub struct BatchResult { - pub step_results: Vec>, - pub step_errors: Vec>, -} - -impl From> for Value -where - T: Into, -{ - fn from(value: Option) -> Self { - match value { - None => Self::Null, - Some(t) => t.into(), - } - } -} +pub use batch::*; +pub use message::*; +pub use stmt::*; +pub use value::Value; #[derive(Deserialize, Debug, Clone)] pub struct Error { @@ -207,141 +27,3 @@ impl fmt::Display for Error { } impl std::error::Error for Error {} - -mod i64_as_str { - use serde::{de, ser}; - use serde::{de::Error as _, Serialize as _}; - - pub fn serialize(value: &i64, ser: S) -> Result { - value.to_string().serialize(ser) - } - - pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { - let str_value = <&'de str as de::Deserialize>::deserialize(de)?; - str_value.parse().map_err(|_| { - D::Error::invalid_value( - de::Unexpected::Str(str_value), - &"decimal integer as a string", - ) - }) - } -} - -mod option_i64_as_str { - use serde::{de, de::Error as _}; - - pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result, D::Error> { - let str_value = as de::Deserialize>::deserialize(de)?; - str_value - .map(|s| { - s.parse().map_err(|_| { - D::Error::invalid_value(de::Unexpected::Str(s), &"decimal integer as a string") - }) - }) - .transpose() - } -} - -mod bytes_as_base64 { - use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _}; - use serde::{de, ser}; - use serde::{de::Error as _, Serialize as _}; - - pub fn serialize(value: &Vec, ser: S) -> Result { - STANDARD_NO_PAD.encode(value).serialize(ser) - } - - pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result, D::Error> { - let str_value = <&'de str as de::Deserialize>::deserialize(de)?; - STANDARD_NO_PAD - .decode(str_value.trim_end_matches('=')) - .map_err(|_| { - D::Error::invalid_value( - de::Unexpected::Str(str_value), - &"binary data encoded as base64", - ) - }) - } -} - -impl std::fmt::Display for Value { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Value::Null => write!(f, "null"), - Value::Integer { value: n } => write!(f, "{n}"), - Value::Float { value: d } => write!(f, "{d}"), - Value::Text { value: s } => write!(f, "{}", serde_json::json!(s)), - Value::Blob { value: b } => { - use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine}; - let b = BASE64_STANDARD_NO_PAD.encode(b); - write!(f, "{{\"base64\": {b}}}") - } - } - } -} - -impl From<()> for Value { - fn from(_: ()) -> Value { - Value::Null - } -} - -macro_rules! impl_from_value { - ($typename: ty, $variant: ident) => { - impl From<$typename> for Value { - fn from(t: $typename) -> Value { - Value::$variant { value: t.into() } - } - } - }; -} - -impl_from_value!(String, Text); -impl_from_value!(&String, Text); -impl_from_value!(&str, Text); - -impl_from_value!(i8, Integer); -impl_from_value!(i16, Integer); -impl_from_value!(i32, Integer); -impl_from_value!(i64, Integer); - -impl_from_value!(u8, Integer); -impl_from_value!(u16, Integer); -impl_from_value!(u32, Integer); - -impl_from_value!(f32, Float); -impl_from_value!(f64, Float); - -impl_from_value!(Vec, Blob); - -macro_rules! impl_value_try_from { - ($variant: ident, $typename: ty) => { - impl TryFrom for $typename { - type Error = String; - fn try_from(v: Value) -> Result<$typename, Self::Error> { - match v { - Value::$variant { value: v } => v.try_into().map_err(|e| format!("{e}")), - other => Err(format!( - "cannot transform {other:?} to {}", - stringify!($variant) - )), - } - } - } - }; -} - -impl_value_try_from!(Text, String); - -impl_value_try_from!(Integer, i8); -impl_value_try_from!(Integer, i16); -impl_value_try_from!(Integer, i32); -impl_value_try_from!(Integer, i64); -impl_value_try_from!(Integer, u8); -impl_value_try_from!(Integer, u16); -impl_value_try_from!(Integer, u32); -impl_value_try_from!(Integer, u64); - -impl_value_try_from!(Float, f64); - -impl_value_try_from!(Blob, Vec); diff --git a/hrana-client-proto/src/message.rs b/hrana-client-proto/src/message.rs new file mode 100644 index 0000000..9c3f4f6 --- /dev/null +++ b/hrana-client-proto/src/message.rs @@ -0,0 +1,66 @@ +use serde::{Deserialize, Serialize}; + +use crate::batch::{BatchReq, BatchResp}; +use crate::stmt::StmtResult; +use crate::{Error, Stmt}; + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientMsg { + Hello { jwt: Option }, + Request { request_id: i32, request: Request }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerMsg { + HelloOk {}, + HelloError { error: Error }, + ResponseOk { request_id: i32, response: Response }, + ResponseError { request_id: i32, error: Error }, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Request { + OpenStream(OpenStreamReq), + CloseStream(CloseStreamReq), + Execute(ExecuteReq), + Batch(BatchReq), +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Response { + OpenStream(OpenStreamResp), + CloseStream(CloseStreamResp), + Execute(ExecuteResp), + Batch(BatchResp), +} + +#[derive(Serialize, Debug)] +pub struct OpenStreamReq { + pub stream_id: i32, +} + +#[derive(Deserialize, Debug)] +pub struct OpenStreamResp {} + +#[derive(Serialize, Debug)] +pub struct CloseStreamReq { + pub stream_id: i32, +} + +#[derive(Deserialize, Debug)] +pub struct CloseStreamResp {} + +#[derive(Serialize, Debug)] +pub struct ExecuteReq { + pub stream_id: i32, + pub stmt: Stmt, +} + +#[derive(Deserialize, Debug)] +pub struct ExecuteResp { + pub result: StmtResult, +} diff --git a/hrana-client-proto/src/serde_utils.rs b/hrana-client-proto/src/serde_utils.rs new file mode 100644 index 0000000..894782a --- /dev/null +++ b/hrana-client-proto/src/serde_utils.rs @@ -0,0 +1,55 @@ +pub mod i64_as_str { + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &i64, ser: S) -> Result { + value.to_string().serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let str_value = <&'de str as de::Deserialize>::deserialize(de)?; + str_value.parse().map_err(|_| { + D::Error::invalid_value( + de::Unexpected::Str(str_value), + &"decimal integer as a string", + ) + }) + } +} + +pub mod option_i64_as_str { + use serde::{de, de::Error as _}; + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result, D::Error> { + let str_value = as de::Deserialize>::deserialize(de)?; + str_value + .map(|s| { + s.parse().map_err(|_| { + D::Error::invalid_value(de::Unexpected::Str(s), &"decimal integer as a string") + }) + }) + .transpose() + } +} + +pub mod bytes_as_base64 { + use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _}; + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &Vec, ser: S) -> Result { + STANDARD_NO_PAD.encode(value).serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result, D::Error> { + let str_value = <&'de str as de::Deserialize>::deserialize(de)?; + STANDARD_NO_PAD + .decode(str_value.trim_end_matches('=')) + .map_err(|_| { + D::Error::invalid_value( + de::Unexpected::Str(str_value), + &"binary data encoded as base64", + ) + }) + } +} diff --git a/hrana-client-proto/src/stmt.rs b/hrana-client-proto/src/stmt.rs new file mode 100644 index 0000000..1722e72 --- /dev/null +++ b/hrana-client-proto/src/stmt.rs @@ -0,0 +1,82 @@ +use serde::{Deserialize, Serialize}; + +use crate::serde_utils::option_i64_as_str; +use crate::Value; + +/// Represents a SQL statement to be executed over the hrana protocol +#[derive(Serialize, Debug)] +pub struct Stmt { + sql: String, + #[serde(default)] + args: Vec, + #[serde(default)] + named_args: Vec, + want_rows: bool, +} + +impl Stmt { + /// Creates a new statement from a SQL string. + /// `want_rows` determines whether the reponse to this statement should return rows. + pub fn new(sql: impl Into, want_rows: bool) -> Self { + let sql = sql.into(); + Self { + sql, + want_rows, + named_args: Vec::new(), + args: Vec::new(), + } + } + + /// Bind the next positional parameter to this statement. + /// + /// ## Example: + /// + /// ```ignore + /// let mut statement = Statement::new("SELECT * FROM users WHERE username=?", true); + /// statement.bind("adhoc"); + /// ``` + pub fn bind(&mut self, val: impl Into) { + self.args.push(val.into()); + } + + /// Bind a named parameter to this statement. + /// + /// ## Example: + /// + /// ```ignore + /// let mut statement = Statement::new("SELECT * FROM users WHERE username=$username", true); + /// statement.bind("$username", "adhoc" }); + /// ``` + pub fn bind_named(&mut self, name: impl Into, value: impl Into) { + self.named_args.push(NamedArg { + name: name.into(), + value: value.into(), + }); + } +} + +#[derive(Serialize, Debug)] +struct NamedArg { + name: String, + value: Value, +} + +/// Result type for the successful execution of a `Stmt` +#[derive(Deserialize, Clone, Debug)] +pub struct StmtResult { + /// List of column descriptors + pub cols: Vec, + /// List of row values + pub rows: Vec>, + /// Number of rows affected by the query + pub affected_row_count: u64, + #[serde(with = "option_i64_as_str")] + pub last_insert_rowid: Option, +} + +/// A column description +#[derive(Deserialize, Clone, Debug)] +pub struct Col { + /// Name of the column + pub name: Option, +} diff --git a/hrana-client-proto/src/value.rs b/hrana-client-proto/src/value.rs new file mode 100644 index 0000000..75efe5d --- /dev/null +++ b/hrana-client-proto/src/value.rs @@ -0,0 +1,117 @@ +use serde::{Deserialize, Serialize}; + +use crate::serde_utils::{bytes_as_base64, i64_as_str}; + +/// A libsql Value type +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Value { + Null, + Integer { + #[serde(with = "i64_as_str")] + value: i64, + }, + Float { + value: f64, + }, + Text { + value: String, + }, + Blob { + #[serde(with = "bytes_as_base64", rename = "base64")] + value: Vec, + }, +} + +impl From> for Value +where + T: Into, +{ + fn from(value: Option) -> Self { + match value { + None => Self::Null, + Some(t) => t.into(), + } + } +} + +impl From<()> for Value { + fn from(_: ()) -> Value { + Value::Null + } +} +macro_rules! impl_from_value { + ($typename: ty, $variant: ident) => { + impl From<$typename> for Value { + fn from(t: $typename) -> Value { + Value::$variant { value: t.into() } + } + } + }; +} + +impl_from_value!(String, Text); +impl_from_value!(&String, Text); +impl_from_value!(&str, Text); + +impl_from_value!(i8, Integer); +impl_from_value!(i16, Integer); +impl_from_value!(i32, Integer); +impl_from_value!(i64, Integer); + +impl_from_value!(u8, Integer); +impl_from_value!(u16, Integer); +impl_from_value!(u32, Integer); + +impl_from_value!(f32, Float); +impl_from_value!(f64, Float); + +impl_from_value!(Vec, Blob); + +macro_rules! impl_value_try_from { + ($variant: ident, $typename: ty) => { + impl TryFrom for $typename { + type Error = String; + fn try_from(v: Value) -> Result<$typename, Self::Error> { + match v { + Value::$variant { value: v } => v.try_into().map_err(|e| format!("{e}")), + other => Err(format!( + "cannot transform {other:?} to {}", + stringify!($variant) + )), + } + } + } + }; +} + +impl_value_try_from!(Text, String); + +impl_value_try_from!(Integer, i8); +impl_value_try_from!(Integer, i16); +impl_value_try_from!(Integer, i32); +impl_value_try_from!(Integer, i64); +impl_value_try_from!(Integer, u8); +impl_value_try_from!(Integer, u16); +impl_value_try_from!(Integer, u32); +impl_value_try_from!(Integer, u64); + +impl_value_try_from!(Float, f64); + +impl_value_try_from!(Blob, Vec); + +impl std::fmt::Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Value::Null => write!(f, "null"), + Value::Integer { value: n } => write!(f, "{n}"), + Value::Float { value: d } => write!(f, "{d}"), + Value::Text { value: s } => write!(f, "{}", serde_json::json!(s)), + Value::Blob { value: b } => { + use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine}; + let b = BASE64_STANDARD_NO_PAD.encode(b); + write!(f, "{{\"base64\": {b}}}") + } + } + } +}