|
| 1 | +use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres}; |
| 2 | +use core::cell::Cell; |
| 3 | +use sqlx_core::{ |
| 4 | + database::Database, |
| 5 | + encode::{Encode, IsNull}, |
| 6 | + error::BoxDynError, |
| 7 | + types::Type, |
| 8 | +}; |
| 9 | + |
| 10 | +// not exported but pub because it is used in the extension trait |
| 11 | +pub struct PgBindIter<I>(Cell<Option<I>>); |
| 12 | + |
| 13 | +/// Iterator extension trait enabling iterators to encode arrays in Postgres. |
| 14 | +/// |
| 15 | +/// Because of the blanket impl of `PgHasArrayType` for all references |
| 16 | +/// we can borrow instead of needing to clone or copy in the iterators |
| 17 | +/// and it still works |
| 18 | +/// |
| 19 | +/// Previously, 3 separate arrays would be needed in this example which |
| 20 | +/// requires iterating 3 times to collect items into the array and then |
| 21 | +/// iterating over them again to encode. |
| 22 | +/// |
| 23 | +/// This now requires only iterating over the array once for each field |
| 24 | +/// while using less memory giving both speed and memory usage improvements |
| 25 | +/// along with allowing much more flexibility in the underlying collection. |
| 26 | +/// |
| 27 | +/// ```rust,no_run |
| 28 | +/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> { |
| 29 | +/// # use sqlx::types::chrono::{DateTime, Utc}; |
| 30 | +/// # use sqlx::Connection; |
| 31 | +/// # fn people() -> &'static [Person] { |
| 32 | +/// # &[] |
| 33 | +/// # } |
| 34 | +/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?; |
| 35 | +/// use sqlx::postgres::PgBindIterExt; |
| 36 | +/// |
| 37 | +/// #[derive(sqlx::FromRow)] |
| 38 | +/// struct Person { |
| 39 | +/// id: i64, |
| 40 | +/// name: String, |
| 41 | +/// birthdate: DateTime<Utc>, |
| 42 | +/// } |
| 43 | +/// |
| 44 | +/// # let people: &[Person] = people(); |
| 45 | +/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)") |
| 46 | +/// .bind(people.iter().map(|p| p.id).bind_iter()) |
| 47 | +/// .bind(people.iter().map(|p| &p.name).bind_iter()) |
| 48 | +/// .bind(people.iter().map(|p| &p.birthdate).bind_iter()) |
| 49 | +/// .execute(&mut conn) |
| 50 | +/// .await?; |
| 51 | +/// |
| 52 | +/// # Ok(()) |
| 53 | +/// # } |
| 54 | +/// ``` |
| 55 | +pub trait PgBindIterExt: Iterator + Sized { |
| 56 | + fn bind_iter(self) -> PgBindIter<Self>; |
| 57 | +} |
| 58 | + |
| 59 | +impl<I: Iterator + Sized> PgBindIterExt for I { |
| 60 | + fn bind_iter(self) -> PgBindIter<I> { |
| 61 | + PgBindIter(Cell::new(Some(self))) |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +impl<I> Type<Postgres> for PgBindIter<I> |
| 66 | +where |
| 67 | + I: Iterator, |
| 68 | + <I as Iterator>::Item: Type<Postgres> + PgHasArrayType, |
| 69 | +{ |
| 70 | + fn type_info() -> <Postgres as Database>::TypeInfo { |
| 71 | + <I as Iterator>::Item::array_type_info() |
| 72 | + } |
| 73 | + fn compatible(ty: &PgTypeInfo) -> bool { |
| 74 | + <I as Iterator>::Item::array_compatible(ty) |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +impl<'q, I> PgBindIter<I> |
| 79 | +where |
| 80 | + I: Iterator, |
| 81 | + <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>, |
| 82 | +{ |
| 83 | + fn encode_inner( |
| 84 | + // need ownership to iterate |
| 85 | + mut iter: I, |
| 86 | + buf: &mut PgArgumentBuffer, |
| 87 | + ) -> Result<IsNull, BoxDynError> { |
| 88 | + let lower_size_hint = iter.size_hint().0; |
| 89 | + let first = iter.next(); |
| 90 | + let type_info = first |
| 91 | + .as_ref() |
| 92 | + .and_then(Encode::produces) |
| 93 | + .unwrap_or_else(<I as Iterator>::Item::type_info); |
| 94 | + |
| 95 | + buf.extend(&1_i32.to_be_bytes()); // number of dimensions |
| 96 | + buf.extend(&0_i32.to_be_bytes()); // flags |
| 97 | + |
| 98 | + match type_info.0 { |
| 99 | + PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), |
| 100 | + PgType::DeclareArrayOf(array) => buf.patch_array_type(array), |
| 101 | + |
| 102 | + ty => { |
| 103 | + buf.extend(&ty.oid().0.to_be_bytes()); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + let len_start = buf.len(); |
| 108 | + buf.extend(0_i32.to_be_bytes()); // len (unknown so far) |
| 109 | + buf.extend(1_i32.to_be_bytes()); // lower bound |
| 110 | + |
| 111 | + match first { |
| 112 | + Some(first) => buf.encode(first)?, |
| 113 | + None => return Ok(IsNull::No), |
| 114 | + } |
| 115 | + |
| 116 | + let mut count = 1_i32; |
| 117 | + const MAX: usize = i32::MAX as usize - 1; |
| 118 | + |
| 119 | + for value in (&mut iter).take(MAX) { |
| 120 | + buf.encode(value)?; |
| 121 | + count += 1; |
| 122 | + } |
| 123 | + |
| 124 | + const OVERFLOW: usize = i32::MAX as usize + 1; |
| 125 | + if iter.next().is_some() { |
| 126 | + let iter_size = std::cmp::max(lower_size_hint, OVERFLOW); |
| 127 | + return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into()); |
| 128 | + } |
| 129 | + |
| 130 | + // set the length now that we know what it is. |
| 131 | + buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes()); |
| 132 | + |
| 133 | + Ok(IsNull::No) |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +impl<'q, I> Encode<'q, Postgres> for PgBindIter<I> |
| 138 | +where |
| 139 | + I: Iterator, |
| 140 | + <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>, |
| 141 | +{ |
| 142 | + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { |
| 143 | + Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf) |
| 144 | + } |
| 145 | + fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> |
| 146 | + where |
| 147 | + Self: Sized, |
| 148 | + { |
| 149 | + Self::encode_inner( |
| 150 | + self.0.into_inner().expect("PgBindIter is only used once"), |
| 151 | + buf, |
| 152 | + ) |
| 153 | + } |
| 154 | +} |
0 commit comments