Skip to content

Implement support for diesel::Instrumentation for all provided connec… #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ jobs:

- name: Set environment variables
shell: bash
if: matrix.rust == 'nightly'
if: matrix.rust != 'nightly'
run: |
echo "RUSTFLAGS=--cap-lints=warn" >> $GITHUB_ENV
echo "RUSTFLAGS=-D warnings" >> $GITHUB_ENV
echo "RUSTDOCFLAGS=-D warnings" >> $GITHUB_ENV

- uses: ilammy/setup-nasm@v1
if: matrix.backend == 'postgres' && matrix.os == 'windows-2019'
Expand Down Expand Up @@ -234,7 +235,7 @@ jobs:
find ~/.cargo/registry -iname "*clippy.toml" -delete

- name: Run clippy
run: cargo +stable clippy --all
run: cargo +stable clippy --all --all-features

- name: Check formating
run: cargo +stable fmt --all -- --check
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
## [Unreleased]

* Added type `diesel_async::pooled_connection::mobc::PooledConnection`
* MySQL/MariaDB now use `CLIENT_FOUND_ROWS` capability to allow consistent behavior with PostgreSQL regarding return value of UPDATe commands.
* MySQL/MariaDB now use `CLIENT_FOUND_ROWS` capability to allow consistent behaviour with PostgreSQL regarding return value of UPDATe commands.
* The minimal supported rust version is now 1.78.0
* Add a `SyncConnectionWrapper` type that turns a sync connection into an async one. This enables SQLite support for diesel-async
* Add support for `diesel::connection::Instrumentation` to support logging and other instrumentation for any of the provided connection impls.

## [0.4.1] - 2023-09-01

Expand Down
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cfg-if = "1"
chrono = "0.4"
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
diesel_migrations = "2.2.0"
assert_matches = "1.0.1"

[features]
default = []
Expand Down Expand Up @@ -83,8 +84,8 @@ features = [
"r2d2",
]
no-default-features = true
rustc-args = ["--cfg", "doc_cfg"]
rustdoc-args = ["--cfg", "doc_cfg"]
rustc-args = ["--cfg", "docsrs"]
rustdoc-args = ["--cfg", "docsrs"]

[workspace]
members = [
Expand Down
16 changes: 5 additions & 11 deletions src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ mod implementation {
pub struct AsyncConnectionWrapper<C, B> {
inner: C,
runtime: B,
instrumentation: Option<Box<dyn Instrumentation>>,
}

impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
Expand All @@ -119,7 +118,6 @@ mod implementation {
Self {
inner,
runtime: B::get_runtime(),
instrumentation: None,
}
}
}
Expand Down Expand Up @@ -150,11 +148,7 @@ mod implementation {
let runtime = B::get_runtime();
let f = C::establish(database_url);
let inner = runtime.block_on(f)?;
Ok(Self {
inner,
runtime,
instrumentation: None,
})
Ok(Self { inner, runtime })
}

fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
Expand All @@ -165,18 +159,18 @@ mod implementation {
self.runtime.block_on(f)
}

fn transaction_state(
&mut self,
fn transaction_state(
&mut self,
) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
self.inner.transaction_state()
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
&mut self.instrumentation
self.inner.instrumentation()
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.instrumentation = Some(Box::new(instrumentation));
self.inner.set_instrumentation(instrumentation);
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![cfg_attr(doc_cfg, feature(doc_cfg, doc_auto_cfg))]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
//! Diesel-async provides async variants of diesel related query functionality
//!
//! diesel-async is an extension to diesel itself. It is designed to be used together
Expand Down Expand Up @@ -69,6 +69,7 @@
#![warn(missing_docs)]

use diesel::backend::Backend;
use diesel::connection::Instrumentation;
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::result::Error;
use diesel::row::Row;
Expand Down Expand Up @@ -347,4 +348,10 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {}
#[doc(hidden)]
fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {}

#[doc(hidden)]
fn instrumentation(&mut self) -> &mut dyn Instrumentation;

/// Set a specific [`Instrumentation`] implementation for this connection
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);
}
144 changes: 107 additions & 37 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::stmt_cache::{PrepareCallback, StmtCache};
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::connection::StrQueryHelper;
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
use diesel::query_builder::QueryBuilder;
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
Expand All @@ -26,12 +29,28 @@ pub struct AsyncMysqlConnection {
conn: mysql_async::Conn,
stmt_cache: StmtCache<Mysql, Statement>,
transaction_manager: AnsiTransactionManager,
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
}

#[async_trait::async_trait]
impl SimpleAsyncConnection for AsyncMysqlConnection {
async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
Ok(self.conn.query_drop(query).await.map_err(ErrorHelper)?)
self.instrumentation()
.on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
query,
)));
let result = self
.conn
.query_drop(query)
.await
.map_err(ErrorHelper)
.map_err(Into::into);
self.instrumentation()
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(query),
result.as_ref().err(),
));
result
}
}

Expand All @@ -53,20 +72,18 @@ impl AsyncConnection for AsyncMysqlConnection {
type TransactionManager = AnsiTransactionManager;

async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let opts = Opts::from_url(database_url)
.map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
let builder = OptsBuilder::from_opts(opts)
.init(CONNECTION_SETUP_QUERIES.to_vec())
.stmt_cache_size(0) // We have our own cache
.client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)

let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
})
let mut instrumentation = diesel::connection::get_default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
let r = Self::establish_connection_inner(database_url).await;
instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
database_url,
r.as_ref().err(),
));
let mut conn = r?;
conn.instrumentation = std::sync::Mutex::new(instrumentation);
Ok(conn)
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
Expand All @@ -80,7 +97,10 @@ impl AsyncConnection for AsyncMysqlConnection {
let stmt_for_exec = match stmt {
MaybeCached::Cached(ref s) => (*s).clone(),
MaybeCached::CannotCache(ref s) => s.clone(),
_ => todo!(),
_ => unreachable!(
"Diesel has only two variants here at the time of writing.\n\
If you ever see this error message please open in issue in the diesel-async issue tracker"
),
};

let (tx, rx) = futures_channel::mpsc::channel(0);
Expand Down Expand Up @@ -152,6 +172,19 @@ impl AsyncConnection for AsyncMysqlConnection {
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
&mut self.transaction_manager
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
self.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
*self
.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
}
}

#[inline(always)]
Expand Down Expand Up @@ -195,6 +228,9 @@ impl AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(
diesel::connection::get_default_instrumentation(),
),
};

for stmt in CONNECTION_SETUP_QUERIES {
Expand All @@ -219,6 +255,10 @@ impl AsyncMysqlConnection {
T: QueryFragment<Mysql> + QueryId,
F: Future<Output = QueryResult<R>> + Send,
{
self.instrumentation()
.on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
&query,
)));
let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
let bind_collector = query
.collect_binds(&mut bind_collector, &mut (), &Mysql)
Expand All @@ -228,6 +268,7 @@ impl AsyncMysqlConnection {
ref mut conn,
ref mut stmt_cache,
ref mut transaction_manager,
ref mut instrumentation,
..
} = self;

Expand All @@ -242,28 +283,37 @@ impl AsyncMysqlConnection {
} = bind_collector?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
let sql = sql?;
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
let inner = async {
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql.clone(),
is_safe_to_cache_prepared,
&metadata,
conn,
instrumentation,
)
.await?;
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql,
is_safe_to_cache_prepared,
&metadata,
conn,
)
.await?;
update_transaction_manager_status(
callback(conn, stmt, ToSqlHelper { metadata, binds }).await,
transaction_manager,
)
let r = update_transaction_manager_status(inner.await, transaction_manager);
instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&sql),
r.as_ref().err(),
));
r
}
.boxed()
}
Expand Down Expand Up @@ -300,6 +350,26 @@ impl AsyncMysqlConnection {

Ok(())
}

async fn establish_connection_inner(
database_url: &str,
) -> Result<AsyncMysqlConnection, ConnectionError> {
let opts = Opts::from_url(database_url)
.map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
let builder = OptsBuilder::from_opts(opts)
.init(CONNECTION_SETUP_QUERIES.to_vec())
.stmt_cache_size(0) // We have our own cache
.client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)

let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(None),
})
}
}

#[cfg(any(
Expand Down
Loading
Loading