Skip to content

Commit 7b4fd89

Browse files
committed
Decimal arrays and ranges
Signed-off-by: itowlson <ivan.towlson@fermyon.com>
1 parent 9cfb7ce commit 7b4fd89

File tree

3 files changed

+151
-16
lines changed

3 files changed

+151
-16
lines changed

crates/factor-outbound-pg/src/client.rs

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
172172
.with_context(|| format!("invalid decimal {v}"))?;
173173
Ok(Box::new(dec))
174174
}
175-
ParameterValue::Range32((lower, upper)) => {
175+
ParameterValue::RangeInt32((lower, upper)) => {
176176
let lbound = lower.map(|(value, kind)| {
177177
postgres_range::RangeBound::new(value, range_bound_kind(kind))
178178
});
@@ -182,7 +182,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
182182
let r = postgres_range::Range::new(lbound, ubound);
183183
Ok(Box::new(r))
184184
}
185-
ParameterValue::Range64((lower, upper)) => {
185+
ParameterValue::RangeInt64((lower, upper)) => {
186186
let lbound = lower.map(|(value, kind)| {
187187
postgres_range::RangeBound::new(value, range_bound_kind(kind))
188188
});
@@ -192,8 +192,48 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
192192
let r = postgres_range::Range::new(lbound, ubound);
193193
Ok(Box::new(r))
194194
}
195+
ParameterValue::RangeDecimal((lower, upper)) => {
196+
let lbound = match lower {
197+
None => None,
198+
Some((value, kind)) => {
199+
let dec = rust_decimal::Decimal::from_str_exact(value)
200+
.with_context(|| format!("invalid decimal {value}"))?;
201+
let dec = RangeableDecimal(dec);
202+
Some(postgres_range::RangeBound::new(
203+
dec,
204+
range_bound_kind(*kind),
205+
))
206+
}
207+
};
208+
let ubound = match upper {
209+
None => None,
210+
Some((value, kind)) => {
211+
let dec = rust_decimal::Decimal::from_str_exact(value)
212+
.with_context(|| format!("invalid decimal {value}"))?;
213+
let dec = RangeableDecimal(dec);
214+
Some(postgres_range::RangeBound::new(
215+
dec,
216+
range_bound_kind(*kind),
217+
))
218+
}
219+
};
220+
let r = postgres_range::Range::new(lbound, ubound);
221+
Ok(Box::new(r))
222+
}
195223
ParameterValue::ArrayInt32(vs) => Ok(Box::new(vs.to_owned())),
196224
ParameterValue::ArrayInt64(vs) => Ok(Box::new(vs.to_owned())),
225+
ParameterValue::ArrayDecimal(vs) => {
226+
let decs = vs
227+
.iter()
228+
.map(|v| match v {
229+
None => Ok(None),
230+
Some(v) => rust_decimal::Decimal::from_str_exact(v)
231+
.with_context(|| format!("invalid decimal {v}"))
232+
.map(Some),
233+
})
234+
.collect::<anyhow::Result<Vec<_>>>()?;
235+
Ok(Box::new(decs))
236+
}
197237
ParameterValue::ArrayStr(vs) => Ok(Box::new(vs.to_owned())),
198238
ParameterValue::Interval(v) => Ok(Box::new(Interval(*v))),
199239
ParameterValue::DbNull => Ok(Box::new(PgNull)),
@@ -238,11 +278,14 @@ fn convert_data_type(pg_type: &Type) -> DbDataType {
238278
Type::UUID => DbDataType::Uuid,
239279
Type::JSONB => DbDataType::Jsonb,
240280
Type::NUMERIC => DbDataType::Decimal,
241-
Type::INT4_RANGE => DbDataType::Range32,
242-
Type::INT8_RANGE => DbDataType::Range64,
281+
Type::INT4_RANGE => DbDataType::RangeInt32,
282+
Type::INT8_RANGE => DbDataType::RangeInt64,
283+
Type::NUM_RANGE => DbDataType::RangeDecimal,
243284
Type::INT4_ARRAY => DbDataType::ArrayInt32,
244285
Type::INT8_ARRAY => DbDataType::ArrayInt64,
286+
Type::NUMERIC_ARRAY => DbDataType::ArrayDecimal,
245287
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY | Type::BPCHAR_ARRAY => DbDataType::ArrayStr,
288+
Type::INTERVAL => DbDataType::Interval,
246289
_ => {
247290
tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
248291
DbDataType::Other
@@ -367,7 +410,7 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
367410
Some(v) => {
368411
let lower = v.lower().map(tuplify_range_bound);
369412
let upper = v.upper().map(tuplify_range_bound);
370-
DbValue::Range32((lower, upper))
413+
DbValue::RangeInt32((lower, upper))
371414
}
372415
None => DbValue::DbNull,
373416
}
@@ -378,7 +421,22 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
378421
Some(v) => {
379422
let lower = v.lower().map(tuplify_range_bound);
380423
let upper = v.upper().map(tuplify_range_bound);
381-
DbValue::Range64((lower, upper))
424+
DbValue::RangeInt64((lower, upper))
425+
}
426+
None => DbValue::DbNull,
427+
}
428+
}
429+
&Type::NUM_RANGE => {
430+
let value: Option<postgres_range::Range<RangeableDecimal>> = row.try_get(index)?;
431+
match value {
432+
Some(v) => {
433+
let lower = v
434+
.lower()
435+
.map(|b| tuplify_range_bound_map(b, |d| d.0.to_string()));
436+
let upper = v
437+
.upper()
438+
.map(|b| tuplify_range_bound_map(b, |d| d.0.to_string()));
439+
DbValue::RangeDecimal((lower, upper))
382440
}
383441
None => DbValue::DbNull,
384442
}
@@ -397,6 +455,16 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
397455
None => DbValue::DbNull,
398456
}
399457
}
458+
&Type::NUMERIC_ARRAY => {
459+
let value: Option<Vec<Option<rust_decimal::Decimal>>> = row.try_get(index)?;
460+
match value {
461+
Some(v) => {
462+
let dstrs = v.iter().map(|opt| opt.map(|d| d.to_string())).collect();
463+
DbValue::ArrayDecimal(dstrs)
464+
}
465+
None => DbValue::DbNull,
466+
}
467+
}
400468
&Type::TEXT_ARRAY | &Type::VARCHAR_ARRAY | &Type::BPCHAR_ARRAY => {
401469
let value: Option<Vec<Option<String>>> = row.try_get(index)?;
402470
match value {
@@ -429,6 +497,13 @@ fn tuplify_range_bound<S: postgres_range::BoundSided, T: Copy>(
429497
(value.value, wit_bound_kind(value.type_))
430498
}
431499

500+
fn tuplify_range_bound_map<S: postgres_range::BoundSided, T, U>(
501+
value: &postgres_range::RangeBound<S, T>,
502+
map_fn: impl Fn(&T) -> U,
503+
) -> (U, v4::RangeBoundKind) {
504+
(map_fn(&value.value), wit_bound_kind(value.type_))
505+
}
506+
432507
fn wit_bound_kind(bound_type: postgres_range::BoundType) -> v4::RangeBoundKind {
433508
match bound_type {
434509
postgres_range::BoundType::Inclusive => v4::RangeBoundKind::Inclusive,
@@ -589,3 +664,53 @@ impl std::fmt::Debug for IntervalLengthError {
589664
std::fmt::Display::fmt(self, f)
590665
}
591666
}
667+
668+
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
669+
struct RangeableDecimal(rust_decimal::Decimal);
670+
671+
impl ToSql for RangeableDecimal {
672+
tokio_postgres::types::to_sql_checked!();
673+
674+
fn to_sql(
675+
&self,
676+
ty: &Type,
677+
out: &mut tokio_postgres::types::private::BytesMut,
678+
) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
679+
where
680+
Self: Sized,
681+
{
682+
self.0.to_sql(ty, out)
683+
}
684+
685+
fn accepts(ty: &Type) -> bool
686+
where
687+
Self: Sized,
688+
{
689+
<rust_decimal::Decimal as ToSql>::accepts(ty)
690+
}
691+
}
692+
693+
impl FromSql<'_> for RangeableDecimal {
694+
fn from_sql(
695+
ty: &Type,
696+
raw: &'_ [u8],
697+
) -> std::result::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
698+
let d = <rust_decimal::Decimal as FromSql>::from_sql(ty, raw)?;
699+
Ok(Self(d))
700+
}
701+
702+
fn accepts(ty: &Type) -> bool {
703+
<rust_decimal::Decimal as FromSql>::accepts(ty)
704+
}
705+
}
706+
707+
impl postgres_range::Normalizable for RangeableDecimal {
708+
fn normalize<S>(
709+
bound: postgres_range::RangeBound<S, Self>,
710+
) -> postgres_range::RangeBound<S, Self>
711+
where
712+
S: postgres_range::BoundSided,
713+
{
714+
bound
715+
}
716+
}

crates/world/src/conversions.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,12 @@ mod rdbms_types {
120120
pg4::DbValue::Uuid(_) => pg3::DbValue::Unsupported,
121121
pg4::DbValue::Jsonb(_) => pg3::DbValue::Unsupported,
122122
pg4::DbValue::Decimal(_) => pg3::DbValue::Unsupported,
123-
pg4::DbValue::Range32(_) => pg3::DbValue::Unsupported,
124-
pg4::DbValue::Range64(_) => pg3::DbValue::Unsupported,
123+
pg4::DbValue::RangeInt32(_) => pg3::DbValue::Unsupported,
124+
pg4::DbValue::RangeInt64(_) => pg3::DbValue::Unsupported,
125+
pg4::DbValue::RangeDecimal(_) => pg3::DbValue::Unsupported,
125126
pg4::DbValue::ArrayInt32(_) => pg3::DbValue::Unsupported,
126127
pg4::DbValue::ArrayInt64(_) => pg3::DbValue::Unsupported,
128+
pg4::DbValue::ArrayDecimal(_) => pg3::DbValue::Unsupported,
127129
pg4::DbValue::ArrayStr(_) => pg3::DbValue::Unsupported,
128130
pg4::DbValue::Interval(_) => pg3::DbValue::Unsupported,
129131
pg4::DbValue::DbNull => pg3::DbValue::DbNull,
@@ -187,10 +189,12 @@ mod rdbms_types {
187189
pg4::DbDataType::Uuid => pg3::DbDataType::Other,
188190
pg4::DbDataType::Jsonb => pg3::DbDataType::Other,
189191
pg4::DbDataType::Decimal => pg3::DbDataType::Other,
190-
pg4::DbDataType::Range32 => pg3::DbDataType::Other,
191-
pg4::DbDataType::Range64 => pg3::DbDataType::Other,
192+
pg4::DbDataType::RangeInt32 => pg3::DbDataType::Other,
193+
pg4::DbDataType::RangeInt64 => pg3::DbDataType::Other,
194+
pg4::DbDataType::RangeDecimal => pg3::DbDataType::Other,
192195
pg4::DbDataType::ArrayInt32 => pg3::DbDataType::Other,
193196
pg4::DbDataType::ArrayInt64 => pg3::DbDataType::Other,
197+
pg4::DbDataType::ArrayDecimal => pg3::DbDataType::Other,
194198
pg4::DbDataType::ArrayStr => pg3::DbDataType::Other,
195199
pg4::DbDataType::Interval => pg3::DbDataType::Other,
196200
pg4::DbDataType::Other => pg3::DbDataType::Other,

wit/deps/spin-postgres@4.0.0/postgres.wit

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ interface postgres {
2828
uuid,
2929
jsonb,
3030
decimal,
31-
range32,
32-
range64,
31+
range-int32,
32+
range-int64,
33+
range-decimal,
3334
array-int32,
3435
array-int64,
36+
array-decimal,
3537
array-str,
3638
interval,
3739
other,
@@ -58,10 +60,12 @@ interface postgres {
5860
uuid(string),
5961
jsonb(list<u8>),
6062
decimal(string), // I admit defeat. Base 10
61-
range32(tuple<option<tuple<s32, range-bound-kind>>, option<tuple<s32, range-bound-kind>>>),
62-
range64(tuple<option<tuple<s64, range-bound-kind>>, option<tuple<s64, range-bound-kind>>>),
63+
range-int32(tuple<option<tuple<s32, range-bound-kind>>, option<tuple<s32, range-bound-kind>>>),
64+
range-int64(tuple<option<tuple<s64, range-bound-kind>>, option<tuple<s64, range-bound-kind>>>),
65+
range-decimal(tuple<option<tuple<string, range-bound-kind>>, option<tuple<string, range-bound-kind>>>),
6366
array-int32(list<option<s32>>),
6467
array-int64(list<option<s64>>),
68+
array-decimal(list<option<string>>),
6569
array-str(list<option<string>>),
6670
interval(interval),
6771
db-null,
@@ -89,10 +93,12 @@ interface postgres {
8993
uuid(string),
9094
jsonb(list<u8>),
9195
decimal(string), // base 10
92-
range32(tuple<option<tuple<s32, range-bound-kind>>, option<tuple<s32, range-bound-kind>>>),
93-
range64(tuple<option<tuple<s64, range-bound-kind>>, option<tuple<s64, range-bound-kind>>>),
96+
range-int32(tuple<option<tuple<s32, range-bound-kind>>, option<tuple<s32, range-bound-kind>>>),
97+
range-int64(tuple<option<tuple<s64, range-bound-kind>>, option<tuple<s64, range-bound-kind>>>),
98+
range-decimal(tuple<option<tuple<string, range-bound-kind>>, option<tuple<string, range-bound-kind>>>),
9499
array-int32(list<option<s32>>),
95100
array-int64(list<option<s64>>),
101+
array-decimal(list<option<string>>),
96102
array-str(list<option<string>>),
97103
interval(interval),
98104
db-null,

0 commit comments

Comments
 (0)