Skip to content

Commit 97de034

Browse files
authored
Add PgBindIter for encoding and use it as the implementation encoding &[T] (#3651)
* Add PgBindIter for encoding and use it as the implementation encoding &[T] * Implement suggestions from review * Add docs to PgBindIter and test to ensure it works for owned and borrowed types * Use extension trait for iterators to allow code to flow better. Make struct private. Don't reference unneeded generic T. Make doc tests compile. * Fix doc function * Fix doc test to actually compile * Use Cell<Option<I>> instead of Clone bound
1 parent 60f67db commit 97de034

File tree

4 files changed

+215
-29
lines changed

4 files changed

+215
-29
lines changed

sqlx-postgres/src/bind_iter.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
2+
use core::cell::Cell;
3+
use sqlx_core::{
4+
database::Database,
5+
encode::{Encode, IsNull},
6+
error::BoxDynError,
7+
types::Type,
8+
};
9+
10+
// not exported but pub because it is used in the extension trait
11+
pub struct PgBindIter<I>(Cell<Option<I>>);
12+
13+
/// Iterator extension trait enabling iterators to encode arrays in Postgres.
14+
///
15+
/// Because of the blanket impl of `PgHasArrayType` for all references
16+
/// we can borrow instead of needing to clone or copy in the iterators
17+
/// and it still works
18+
///
19+
/// Previously, 3 separate arrays would be needed in this example which
20+
/// requires iterating 3 times to collect items into the array and then
21+
/// iterating over them again to encode.
22+
///
23+
/// This now requires only iterating over the array once for each field
24+
/// while using less memory giving both speed and memory usage improvements
25+
/// along with allowing much more flexibility in the underlying collection.
26+
///
27+
/// ```rust,no_run
28+
/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
29+
/// # use sqlx::types::chrono::{DateTime, Utc};
30+
/// # use sqlx::Connection;
31+
/// # fn people() -> &'static [Person] {
32+
/// # &[]
33+
/// # }
34+
/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
35+
/// use sqlx::postgres::PgBindIterExt;
36+
///
37+
/// #[derive(sqlx::FromRow)]
38+
/// struct Person {
39+
/// id: i64,
40+
/// name: String,
41+
/// birthdate: DateTime<Utc>,
42+
/// }
43+
///
44+
/// # let people: &[Person] = people();
45+
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
46+
/// .bind(people.iter().map(|p| p.id).bind_iter())
47+
/// .bind(people.iter().map(|p| &p.name).bind_iter())
48+
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
49+
/// .execute(&mut conn)
50+
/// .await?;
51+
///
52+
/// # Ok(())
53+
/// # }
54+
/// ```
55+
pub trait PgBindIterExt: Iterator + Sized {
56+
fn bind_iter(self) -> PgBindIter<Self>;
57+
}
58+
59+
impl<I: Iterator + Sized> PgBindIterExt for I {
60+
fn bind_iter(self) -> PgBindIter<I> {
61+
PgBindIter(Cell::new(Some(self)))
62+
}
63+
}
64+
65+
impl<I> Type<Postgres> for PgBindIter<I>
66+
where
67+
I: Iterator,
68+
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
69+
{
70+
fn type_info() -> <Postgres as Database>::TypeInfo {
71+
<I as Iterator>::Item::array_type_info()
72+
}
73+
fn compatible(ty: &PgTypeInfo) -> bool {
74+
<I as Iterator>::Item::array_compatible(ty)
75+
}
76+
}
77+
78+
impl<'q, I> PgBindIter<I>
79+
where
80+
I: Iterator,
81+
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
82+
{
83+
fn encode_inner(
84+
// need ownership to iterate
85+
mut iter: I,
86+
buf: &mut PgArgumentBuffer,
87+
) -> Result<IsNull, BoxDynError> {
88+
let lower_size_hint = iter.size_hint().0;
89+
let first = iter.next();
90+
let type_info = first
91+
.as_ref()
92+
.and_then(Encode::produces)
93+
.unwrap_or_else(<I as Iterator>::Item::type_info);
94+
95+
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
96+
buf.extend(&0_i32.to_be_bytes()); // flags
97+
98+
match type_info.0 {
99+
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
100+
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
101+
102+
ty => {
103+
buf.extend(&ty.oid().0.to_be_bytes());
104+
}
105+
}
106+
107+
let len_start = buf.len();
108+
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
109+
buf.extend(1_i32.to_be_bytes()); // lower bound
110+
111+
match first {
112+
Some(first) => buf.encode(first)?,
113+
None => return Ok(IsNull::No),
114+
}
115+
116+
let mut count = 1_i32;
117+
const MAX: usize = i32::MAX as usize - 1;
118+
119+
for value in (&mut iter).take(MAX) {
120+
buf.encode(value)?;
121+
count += 1;
122+
}
123+
124+
const OVERFLOW: usize = i32::MAX as usize + 1;
125+
if iter.next().is_some() {
126+
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
127+
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
128+
}
129+
130+
// set the length now that we know what it is.
131+
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
132+
133+
Ok(IsNull::No)
134+
}
135+
}
136+
137+
impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
138+
where
139+
I: Iterator,
140+
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
141+
{
142+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
143+
Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
144+
}
145+
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
146+
where
147+
Self: Sized,
148+
{
149+
Self::encode_inner(
150+
self.0.into_inner().expect("PgBindIter is only used once"),
151+
buf,
152+
)
153+
}
154+
}

sqlx-postgres/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::executor::Executor;
77

88
mod advisory_lock;
99
mod arguments;
10+
mod bind_iter;
1011
mod column;
1112
mod connection;
1213
mod copy;
@@ -47,6 +48,7 @@ pub(crate) use sqlx_core::driver_prelude::*;
4748

4849
pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
4950
pub use arguments::{PgArgumentBuffer, PgArguments};
51+
pub use bind_iter::PgBindIterExt;
5052
pub use column::PgColumn;
5153
pub use connection::PgConnection;
5254
pub use copy::{PgCopyIn, PgPoolCopyExt};

sqlx-postgres/src/types/array.rs

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use std::borrow::Cow;
55
use crate::decode::Decode;
66
use crate::encode::{Encode, IsNull};
77
use crate::error::BoxDynError;
8-
use crate::type_info::PgType;
98
use crate::types::Oid;
109
use crate::types::Type;
1110
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
@@ -156,39 +155,14 @@ where
156155
T: Encode<'q, Postgres> + Type<Postgres>,
157156
{
158157
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
159-
let type_info = self
160-
.first()
161-
.and_then(Encode::produces)
162-
.unwrap_or_else(T::type_info);
163-
164-
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
165-
buf.extend(&0_i32.to_be_bytes()); // flags
166-
167-
// element type
168-
match type_info.0 {
169-
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
170-
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
171-
172-
ty => {
173-
buf.extend(&ty.oid().0.to_be_bytes());
174-
}
175-
}
176-
177-
let array_len = i32::try_from(self.len()).map_err(|_| {
158+
// do the length check early to avoid doing unnecessary work
159+
i32::try_from(self.len()).map_err(|_| {
178160
format!(
179161
"encoded array length is too large for Postgres: {}",
180162
self.len()
181163
)
182164
})?;
183-
184-
buf.extend(array_len.to_be_bytes()); // len
185-
buf.extend(&1_i32.to_be_bytes()); // lower bound
186-
187-
for element in self.iter() {
188-
buf.encode(element)?;
189-
}
190-
191-
Ok(IsNull::No)
165+
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
192166
}
193167
}
194168

tests/postgres/postgres.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,62 @@ async fn test_issue_3052() {
20692069
}
20702070

20712071
#[sqlx_macros::test]
2072+
async fn test_bind_iter() -> anyhow::Result<()> {
2073+
use sqlx::postgres::PgBindIterExt;
2074+
use sqlx::types::chrono::{DateTime, Utc};
2075+
2076+
let mut conn = new::<Postgres>().await?;
2077+
2078+
#[derive(sqlx::FromRow, PartialEq, Debug)]
2079+
struct Person {
2080+
id: i64,
2081+
name: String,
2082+
birthdate: DateTime<Utc>,
2083+
}
2084+
2085+
let people: Vec<Person> = vec![
2086+
Person {
2087+
id: 1,
2088+
name: "Alice".into(),
2089+
birthdate: "1984-01-01T00:00:00Z".parse().unwrap(),
2090+
},
2091+
Person {
2092+
id: 2,
2093+
name: "Bob".into(),
2094+
birthdate: "2000-01-01T00:00:00Z".parse().unwrap(),
2095+
},
2096+
];
2097+
2098+
sqlx::query(
2099+
r#"
2100+
create temporary table person(
2101+
id int8 primary key,
2102+
name text not null,
2103+
birthdate timestamptz not null
2104+
)"#,
2105+
)
2106+
.execute(&mut conn)
2107+
.await?;
2108+
2109+
let rows_affected =
2110+
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
2111+
// owned value
2112+
.bind(people.iter().map(|p| p.id).bind_iter())
2113+
// borrowed value
2114+
.bind(people.iter().map(|p| &p.name).bind_iter())
2115+
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
2116+
.execute(&mut conn)
2117+
.await?
2118+
.rows_affected();
2119+
assert_eq!(rows_affected, 2);
2120+
2121+
let p_query = sqlx::query_as::<_, Person>("select * from person order by id")
2122+
.fetch_all(&mut conn)
2123+
.await?;
2124+
2125+
assert_eq!(people, p_query);
2126+
Ok(())
2127+
}
20722128
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
20732129
let mut conn = new::<Postgres>().await?;
20742130

0 commit comments

Comments
 (0)