diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 84b1a660d8..758cca7330 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -182,76 +182,3 @@ pub trait Executor<'c>: Send + Debug + Sized { where 'c: 'e; } - -/// A type that may be executed against a database connection. -/// -/// Implemented for the following: -/// -/// * [`&str`](std::str) -/// * [`Query`](super::query::Query) -/// -pub trait Execute<'q, DB: Database>: Send + Sized { - /// Gets the SQL that will be executed. - fn sql(&self) -> &'q str; - - /// Gets the previously cached statement, if available. - fn statement(&self) -> Option<&DB::Statement<'q>>; - - /// Returns the arguments to be bound against the query string. - /// - /// Returning `Ok(None)` for `Arguments` indicates to use a "simple" query protocol and to not - /// prepare the query. Returning `Ok(Some(Default::default()))` is an empty arguments object that - /// will be prepared (and cached) before execution. - /// - /// Returns `Err` if encoding any of the arguments failed. - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError>; - - /// Returns `true` if the statement should be cached. - fn persistent(&self) -> bool; -} - -// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more -// involved to write `conn.execute(format!("SELECT {val}"))` -impl<'q, DB: Database> Execute<'q, DB> for &'q str { - #[inline] - fn sql(&self) -> &'q str { - self - } - - #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { - None - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(None) - } - - #[inline] - fn persistent(&self) -> bool { - true - } -} - -impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { - #[inline] - fn sql(&self) -> &'q str { - self.0 - } - - #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { - None - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - Ok(self.1.take()) - } - - #[inline] - fn persistent(&self) -> bool { - true - } -} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index cc0122c907..80631f6a4d 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -74,6 +74,7 @@ pub mod net; pub mod query_as; pub mod query_builder; pub mod query_scalar; +pub mod sql_str; pub mod raw_sql; pub mod row; diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 60f509c342..7bd424bf16 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -8,15 +8,16 @@ use crate::arguments::{Arguments, IntoArguments}; use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; -use crate::executor::{Execute, Executor}; +use crate::executor::{Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] -pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, - pub(crate) arguments: Option>, +pub struct Query<'q, 'a, DB: Database> { + pub(crate) statement: Either>, + pub(crate) arguments: Option, BoxDynError>>, pub(crate) database: PhantomData, pub(crate) persistent: bool, } @@ -33,46 +34,32 @@ pub struct Query<'q, DB: Database, A> { /// before `.try_map()`. This is also to prevent adding superfluous binds to the result of /// `query!()` et al. #[must_use = "query must be executed to affect database"] -pub struct Map<'q, DB: Database, F, A> { - inner: Query<'q, DB, A>, +pub struct Map<'q, 'a, DB: Database, F> { + inner: Query<'q, 'a, DB>, mapper: F, } -impl<'q, DB, A> Execute<'q, DB> for Query<'q, DB, A> +impl<'q, 'a, DB> Query<'q, 'a, DB> where - DB: Database, - A: Send + IntoArguments<'q, DB>, + DB: Database + HasStatementCache, { - #[inline] - fn sql(&self) -> &'q str { - match self.statement { - Either::Right(statement) => statement.sql(), - Either::Left(sql) => sql, - } - } - - fn statement(&self) -> Option<&DB::Statement<'q>> { - match self.statement { - Either::Right(statement) => Some(statement), - Either::Left(_) => None, - } - } - - #[inline] - fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - self.arguments - .take() - .transpose() - .map(|option| option.map(IntoArguments::into_arguments)) - } - - #[inline] - fn persistent(&self) -> bool { - self.persistent + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// If `false`, the prepared statement will be closed after execution. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.persistent = value; + self } } -impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { +impl<'q, 'a, DB: Database> Query<'q, 'a, DB> { /// Bind a value for use with this SQL query. /// /// If the number of times this is called does not match the number of bind parameters that @@ -120,31 +107,10 @@ impl<'q, DB: Database> Query<'q, DB, ::Arguments<'q>> { } } -impl<'q, DB, A> Query<'q, DB, A> -where - DB: Database + HasStatementCache, -{ - /// If `true`, the statement will get prepared once and cached to the - /// connection's statement cache. - /// - /// If queried once with the flag set to `true`, all subsequent queries - /// matching the one with the flag will use the cached statement until the - /// cache is cleared. - /// - /// If `false`, the prepared statement will be closed after execution. - /// - /// Default: `true`. - pub fn persistent(mut self, value: bool) -> Self { - self.persistent = value; - self - } -} - -impl<'q, DB, A: Send> Query<'q, DB, A> -where - DB: Database, - A: 'q + IntoArguments<'q, DB>, -{ +impl<'q, 'a, DB> Query<'q, 'a, DB> + where + DB: Database, + { /// Map each row in the result to another type. /// /// See [`try_map`](Query::try_map) for a fallible version of this method. @@ -155,7 +121,7 @@ where pub fn map( self, mut f: F, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where F: FnMut(DB::Row) -> O + Send, O: Unpin, @@ -168,7 +134,7 @@ where /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using /// a [`FromRow`](super::from_row::FromRow) implementation. #[inline] - pub fn try_map(self, f: F) -> Map<'q, DB, F, A> + pub fn try_map(self, f: F) -> Map<'q, 'a, DB, F> where F: FnMut(DB::Row) -> Result + Send, O: Unpin, @@ -184,33 +150,17 @@ where pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, - A: 'e, + 'a: 'e, E: Executor<'c, Database = DB>, { executor.execute(self).await } - /// Execute multiple queries and return the rows affected from each query, in a stream. - #[inline] - #[deprecated = "Only the SQLite driver supports multiple statements in one prepared statement and that behavior is deprecated. Use `sqlx::raw_sql()` instead. See https://github.com/launchbadge/sqlx/issues/3108 for discussion."] - pub async fn execute_many<'e, 'c: 'e, E>( - self, - executor: E, - ) -> BoxStream<'e, Result> - where - 'q: 'e, - A: 'e, - E: Executor<'c, Database = DB>, - { - executor.execute_many(self) - } - /// Execute the query and return the generated results as a stream. #[inline] pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch(self) @@ -229,7 +179,6 @@ where ) -> BoxStream<'e, Result, Error>> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_many(self) @@ -246,7 +195,6 @@ where pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_all(self).await @@ -268,7 +216,6 @@ where pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_one(self).await @@ -290,45 +237,51 @@ where pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> where 'q: 'e, - A: 'e, E: Executor<'c, Database = DB>, { executor.fetch_optional(self).await } } -impl<'q, DB, F: Send, A: Send> Execute<'q, DB> for Map<'q, DB, F, A> +#[doc(hidden)] +impl<'q, 'a, DB> Query<'q, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, { #[inline] fn sql(&self) -> &'q str { - self.inner.sql() + match &self.statement { + Either::Right(statement) => statement.sql(), + Either::Left(sql) => sql.as_str(), + } } - #[inline] fn statement(&self) -> Option<&DB::Statement<'q>> { - self.inner.statement() + match self.statement { + Either::Right(statement) => Some(statement), + Either::Left(_) => None, + } } #[inline] fn take_arguments(&mut self) -> Result::Arguments<'q>>, BoxDynError> { - self.inner.take_arguments() + self.arguments + .take() + .transpose() + .map(|option| option.map(IntoArguments::into_arguments)) } #[inline] - fn persistent(&self) -> bool { - self.inner.arguments.is_some() + fn is_persistent(&self) -> bool { + self.persistent } } -impl<'q, DB, F, O, A> Map<'q, DB, F, A> +impl<'q, 'a, DB, F, O> Map<'q, 'a, DB, F> where DB: Database, F: FnMut(DB::Row) -> Result + Send, O: Send + Unpin, - A: 'q + Send + IntoArguments<'q, DB>, { /// Map each row in the result to another type. /// @@ -340,7 +293,7 @@ where pub fn map( self, mut g: G, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where G: FnMut(O) -> P + Send, P: Unpin, @@ -356,7 +309,7 @@ where pub fn try_map( self, mut g: G, - ) -> Map<'q, DB, impl FnMut(DB::Row) -> Result + Send, A> + ) -> Map<'q, 'a, DB, impl FnMut(DB::Row) -> Result + Send> where G: FnMut(O) -> Result + Send, P: Unpin, @@ -497,9 +450,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, DB>( +pub fn query_statement<'q, 'a, DB>( statement: &'q DB::Statement<'q>, -) -> Query<'q, DB, ::Arguments<'_>> +) -> Query<'q, 'a, DB> where DB: Database, { @@ -512,17 +465,17 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. -pub fn query_statement_with<'q, DB, A>( +pub fn query_statement_with<'q, 'a, DB, A>( statement: &'q DB::Statement<'q>, arguments: A, -) -> Query<'q, DB, A> +) -> Query<'q, 'a, DB> where DB: Database, A: IntoArguments<'q, DB>, { Query { database: PhantomData, - arguments: Some(Ok(arguments)), + arguments: Some(Ok(arguments.into_arguments())), statement: Either::Right(statement), persistent: true, } @@ -652,14 +605,15 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> +pub fn query<'a, DB, SQL>(sql: SQL) -> Query<'static, 'a, DB> where DB: Database, + SQL: SqlSafeStr, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } @@ -667,27 +621,27 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +pub fn query_with<'a, DB, SQL, A>(sql: SQL, arguments: A) -> Query<'static, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, + A: IntoArguments<'a, DB>, { query_with_result(sql, Ok(arguments)) } /// Same as [`query_with`] but is initialized with a Result of arguments instead -pub fn query_with_result<'q, DB, A>( - sql: &'q str, - arguments: Result, -) -> Query<'q, DB, A> +pub fn query_with_result<'a, DB, SQL>( + sql: SQL, + arguments: Result, BoxDynError>, +) -> Query<'static, 'a, DB> where DB: Database, - A: IntoArguments<'q, DB>, + SQL: SqlSafeStr, { Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index fbc7fab55b..a2984c3f27 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,6 +11,7 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; +use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -339,7 +340,7 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, SQL, O>(sql: SQL) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -357,9 +358,10 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, + SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, O: for<'r> FromRow<'r, DB::Row>, { diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index b071ff8a47..28f06c3787 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; - +use std::sync::Arc; use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; use crate::encode::Encode; @@ -13,6 +13,7 @@ use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; use crate::types::Type; use crate::Either; +use crate::sql_str::AssertSqlSafe; /// A builder type for constructing queries at runtime. /// @@ -25,7 +26,9 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - query: String, + // Using `Arc` allows us to share the query string allocation with the database driver. + // It's only copied if the driver retains ownership after execution. + query: Arc, init_len: usize, arguments: Option<::Arguments<'args>>, } @@ -85,6 +88,16 @@ where "QueryBuilder must be reset before reuse after `.build()`" ); } + + fn query_mut(&mut self) -> &mut String { + assert!( + self.arguments.is_some(), + "QueryBuilder must be reset before reuse after `.build()`" + ); + + Arc::get_mut(&mut self.query) + .expect("BUG: query must not be shared at this point in time") + } /// Append a SQL fragment to the query. /// @@ -116,7 +129,7 @@ where pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); - write!(self.query, "{sql}").expect("error formatting `sql`"); + write!(self.query_mut(), "{sql}").expect("error formatting `sql`"); self } @@ -158,7 +171,7 @@ where arguments.add(value).expect("Failed to add argument"); arguments - .format_placeholder(&mut self.query) + .format_placeholder(self.query_mut()) .expect("error in format_placeholder"); self @@ -448,12 +461,10 @@ where pub fn build(&mut self) -> Query<'_, DB, ::Arguments<'args>> { self.sanity_check(); - Query { - statement: Either::Left(&self.query), - arguments: self.arguments.take().map(Ok), - database: PhantomData, - persistent: true, - } + crate::query::query_with( + AssertSqlSafe(&self.query), + self.arguments.take().expect("BUG: just ran sanity_check") + ) } /// Produce an executable query from this builder. diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index c131adcca3..c097b9b18b 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -4,6 +4,7 @@ use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; use crate::database::{Database, HasStatementCache}; +use crate::decode::Decode; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; @@ -11,6 +12,7 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; +use crate::sql_str::SqlSafeStr; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -318,12 +320,13 @@ where /// # } /// ``` #[inline] -pub fn query_scalar<'q, DB, O>( - sql: &'q str, +pub fn query_scalar<'q, DB, SQL, O>( + sql: SQL, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, - (O,): for<'r> FromRow<'r, DB::Row>, + SQL: SqlSafeStr<'q>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_as(sql), @@ -337,11 +340,12 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, SQL, O, A>(sql: SQL, arguments: A) -> QueryScalar<'q, DB, O, A> where DB: Database, + SQL: SqlSafeStr<'q>, A: IntoArguments<'q, DB>, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { query_scalar_with_result(sql, Ok(arguments)) } @@ -368,7 +372,7 @@ pub fn query_statement_scalar<'q, DB, O>( ) -> QueryScalar<'q, DB, O, ::Arguments<'_>> where DB: Database, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_statement_as(statement), @@ -383,7 +387,7 @@ pub fn query_statement_scalar_with<'q, DB, O, A>( where DB: Database, A: IntoArguments<'q, DB>, - (O,): for<'r> FromRow<'r, DB::Row>, + O: Type + for<'r> Decode<'r, DB>, { QueryScalar { inner: query_statement_as_with(statement, arguments), diff --git a/sqlx-core/src/sql_str.rs b/sqlx-core/src/sql_str.rs new file mode 100644 index 0000000000..58ed23d6d3 --- /dev/null +++ b/sqlx-core/src/sql_str.rs @@ -0,0 +1,182 @@ +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// A SQL string that is safe to execute on a database connection. +/// +/// A "safe" SQL string is one that is unlikely to contain a [SQL injection vulnerability][injection]. +/// +/// In practice, this means a string type that is unlikely to contain dynamic data or user input. +/// +/// `&'static str` is the only string type that satisfies the requirements of this trait +/// (ignoring [`String::leak()`] which has niche use-cases) and so is the only string type that +/// natively implements this trait by default. +/// +/// For other string types, use [`AssertSqlSafe`] to assert this property. +/// This is the only intended way to pass an owned `String` to [`query()`] and its related functions +/// as well as [`raw_sql()`]. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. +/// +/// ### Motivation +/// This is designed to act as a speed bump against naively using `format!()` to add dynamic data +/// or user input to a query, which is a classic vector for SQL injection as SQLx does not +/// provide any sort of escaping or sanitization (which would have to be specially implemented +/// for each database flavor/locale). +/// +/// The recommended way to incorporate dynamic data or user input in a query is to use +/// bind parameters, which requires the query to execute as a prepared statement. +/// See [`query()`] for details. +/// +/// This trait and [`AssertSqlSafe`] are intentionally analogous to +/// [`std::panic::UnwindSafe`] and [`std::panic::AssertUnwindSafe`], respectively. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +/// [`query()`]: crate::query::query +/// [`raw_sql()`]: crate::raw_sql::raw_sql +pub trait SqlSafeStr { + /// Convert `self` to a [`SqlStr`]. + fn into_sql_str(self) -> SqlStr; +} + +impl SqlSafeStr for &'static str { + #[inline] + + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Static(self)) + } +} + +/// Assert that a query string is safe to execute on a database connection. +/// +/// Using this API means that **you** have made sure that the string contents do not contain a +/// [SQL injection vulnerability][injection]. It means that, if the string was constructed +/// dynamically, and/or from user input, you have taken care to sanitize the input yourself. +/// SQLx does not provide any sort of sanitization; the design of SQLx prefers the use +/// of prepared statements for dynamic input. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. **Use at your own risk.** +/// +/// Note that `&'static str` implements [`SqlSafeStr`] directly and so does not need to be wrapped +/// with this type. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +pub struct AssertSqlSafe(pub T); + +/// Note: copies the string. +/// +/// It is recommended to pass one of the supported owned string types instead. +impl<'a> SqlSafeStr for AssertSqlSafe<&'a str> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0.into())) + } +} +impl SqlSafeStr for AssertSqlSafe { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Owned(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Boxed(self.0)) + } +} + +// Note: this is not implemented for `Rc` because it would make `QueryString: !Send`. +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.into())) + } +} + +/// A SQL string that is ready to execute on a database connection. +/// +/// This is essentially `Cow<'static, str>` but which can be constructed from additional types +/// without copying. +/// +/// See [`SqlSafeStr`] for details. +#[derive(Debug)] +pub struct SqlStr(Repr); + +#[derive(Debug)] +enum Repr { + /// We need a variant to memoize when we already have a static string, so we don't copy it. + Static(&'static str), + /// Thanks to the new niche in `String`, this doesn't increase the size beyond 3 words. + /// We essentially get all these variants for free. + Owned(String), + Boxed(Box), + Arced(Arc), + /// Allows for dynamic shared ownership with `query_builder`. + ArcString(Arc), +} + +impl Clone for SqlStr { + fn clone(&self) -> Self { + Self(match &self.0 { + Repr::Static(s) => Repr::Static(s), + Repr::Arced(s) => Repr::Arced(s.clone()), + _ => Repr::Arced(self.as_str().into()), + }) + } +} + +impl SqlSafeStr for SqlStr { + #[inline] + fn into_sql_str(self) -> SqlStr { + self + } +} + +impl SqlStr { + pub(crate) fn from_arc_string(arc: Arc) -> Self { + SqlStr(Repr::ArcString(arc)) + } + + /// Borrow the inner query string. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + Repr::Static(s) => s, + Repr::Owned(s) => s, + Repr::Boxed(s) => s, + Repr::Arced(s) => s, + Repr::ArcString(s) => s, + } + } +} + +impl AsRef for SqlStr { + #[inline] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Borrow for SqlStr { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for SqlStr where T: AsRef { + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl Eq for SqlStr {} + +impl Hash for SqlStr { + fn hash(&self, state: &mut H) { + self.as_str().hash(state) + } +}