Skip to content

Reduce amount of code in AsyncPgConnection functions that have query type as generic parameter #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 72 additions & 33 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,10 @@ impl AsyncConnection for AsyncPgConnection {
T: AsQuery + 'query,
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
{
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
let query = source.as_query();
let load_future = self.with_prepared_statement(query, |conn, stmt, binds| async move {
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;

Ok(res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new)
.boxed())
});
let load_future = self.with_prepared_statement(query, load_prepared);

drive_future(connection_future, load_future).boxed()
self.run_with_connection_future(load_future)
}

fn execute_returning_count<'conn, 'query, T>(
Expand All @@ -177,19 +169,8 @@ impl AsyncConnection for AsyncPgConnection {
where
T: QueryFragment<Self::Backend> + QueryId + 'query,
{
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
let execute = self.with_prepared_statement(source, |conn, stmt, binds| async move {
let binds = binds
.iter()
.map(|b| b as &(dyn ToSql + Sync))
.collect::<Vec<_>>();

let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
.await
.map_err(ErrorHelper)?;
Ok(res as usize)
});
drive_future(connection_future, execute).boxed()
let execute = self.with_prepared_statement(source, execute_prepared);
self.run_with_connection_future(execute)
}

fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
Expand All @@ -212,6 +193,35 @@ impl Drop for AsyncPgConnection {
}
}

async fn load_prepared(
conn: Arc<tokio_postgres::Client>,
stmt: Statement,
binds: Vec<ToSqlHelper>,
) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;

Ok(res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new)
.boxed())
}

async fn execute_prepared(
conn: Arc<tokio_postgres::Client>,
stmt: Statement,
binds: Vec<ToSqlHelper>,
) -> QueryResult<usize> {
let binds = binds
.iter()
.map(|b| b as &(dyn ToSql + Sync))
.collect::<Vec<_>>();

let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
.await
.map_err(ErrorHelper)?;
Ok(res as usize)
}

#[inline(always)]
fn update_transaction_manager_status<T>(
query_result: QueryResult<T>,
Expand Down Expand Up @@ -335,14 +345,22 @@ impl AsyncPgConnection {
Ok(())
}

fn run_with_connection_future<'a, R: 'a>(
&self,
future: impl Future<Output = QueryResult<R>> + Send + 'a,
) -> BoxFuture<'a, QueryResult<R>> {
let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
drive_future(connection_future, future).boxed()
}

fn with_prepared_statement<'a, T, F, R>(
&mut self,
query: T,
callback: impl FnOnce(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F + Send + 'a,
callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
) -> BoxFuture<'a, QueryResult<R>>
where
T: QueryFragment<diesel::pg::Pg> + QueryId,
F: Future<Output = QueryResult<R>> + Send,
F: Future<Output = QueryResult<R>> + Send + 'a,
R: Send,
{
// we explicilty descruct the query here before going into the async block
Expand All @@ -352,14 +370,9 @@ impl AsyncPgConnection {
// which both are `Send`.
// We also collect the query id (essentially an integer) and the safe_to_cache flag here
// so there is no need to even access the query in the async block below
let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&diesel::pg::Pg);
let mut query_builder = PgQueryBuilder::default();
let sql = query
.to_sql(&mut query_builder, &Pg)
.map(|_| query_builder.finish());

let mut bind_collector = RawBytesBindCollector::<diesel::pg::Pg>::new();
let query_id = T::query_id();

// we don't resolve custom types here yet, we do that later
// in the async block below as we might need to perform lookup
Expand All @@ -368,16 +381,42 @@ impl AsyncPgConnection {
// We apply this workaround to prevent requiring all the diesel
// serialization code to beeing async
let mut metadata_lookup = PgAsyncMetadataLookup::new();
let collect_bind_result =
query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg);

// The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
self.with_prepared_statement_after_sql_built(
callback,
query.is_safe_to_cache_prepared(&Pg),
T::query_id(),
query.to_sql(&mut query_builder, &Pg),
query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg),
query_builder,
bind_collector,
metadata_lookup,
)
}

fn with_prepared_statement_after_sql_built<'a, F, R>(
&mut self,
callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
is_safe_to_cache_prepared: QueryResult<bool>,
query_id: Option<std::any::TypeId>,
to_sql_result: QueryResult<()>,
collect_bind_result: QueryResult<()>,
query_builder: PgQueryBuilder,
mut bind_collector: RawBytesBindCollector<Pg>,
metadata_lookup: PgAsyncMetadataLookup,
) -> BoxFuture<'a, QueryResult<R>>
where
F: Future<Output = QueryResult<R>> + Send + 'a,
R: Send,
{
let raw_connection = self.conn.clone();
let stmt_cache = self.stmt_cache.clone();
let metadata_cache = self.metadata_cache.clone();
let tm = self.transaction_state.clone();

async move {
let sql = sql?;
let sql = to_sql_result.map(|_| query_builder.finish())?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
collect_bind_result?;
// Check whether we need to resolve some types at all
Expand Down
Loading