Skip to content

Commit 34201c6

Browse files
authored
Fix AlgebraicTypeLayout::is_compatible_with (#2932)
1 parent b63216a commit 34201c6

File tree

9 files changed

+103
-59
lines changed

9 files changed

+103
-59
lines changed

crates/sats/src/algebraic_value/ser.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use crate::bsatn::decode;
2+
use crate::de::DeserializeSeed;
13
use crate::ser::{self, ForwardNamedToSeqProduct, Serialize};
2-
use crate::{i256, u256};
3-
use crate::{AlgebraicType, AlgebraicValue, ArrayValue, F32, F64};
4+
use crate::{i256, u256, WithTypespace};
5+
use crate::{AlgebraicValue, ArrayValue, F32, F64};
46
use core::convert::Infallible;
57
use core::mem::MaybeUninit;
68
use core::ptr;
@@ -81,18 +83,25 @@ impl ser::Serializer for ValueSerializer {
8183
value.serialize(self).map(|v| AlgebraicValue::sum(tag, v))
8284
}
8385

84-
unsafe fn serialize_bsatn(self, ty: &AlgebraicType, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
85-
let res = AlgebraicValue::decode(ty, &mut bsatn);
86+
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
87+
where
88+
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
89+
{
90+
let res = decode(ty, &mut bsatn);
8691
// SAFETY: Caller promised that `res.is_ok()`.
87-
Ok(unsafe { res.unwrap_unchecked() })
92+
let val = unsafe { res.unwrap_unchecked() };
93+
Ok(val.into())
8894
}
8995

90-
unsafe fn serialize_bsatn_in_chunks<'a, I: Iterator<Item = &'a [u8]>>(
96+
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Iterator<Item = &'a [u8]>>(
9197
self,
92-
ty: &crate::AlgebraicType,
98+
ty: &Ty,
9399
total_bsatn_len: usize,
94100
chunks: I,
95-
) -> Result<Self::Ok, Self::Error> {
101+
) -> Result<Self::Ok, Self::Error>
102+
where
103+
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
104+
{
96105
// SAFETY: Caller promised `total_bsatn_len == chunks.map(|c| c.len()).sum() <= isize::MAX`.
97106
unsafe {
98107
concat_byte_chunks_buf(total_bsatn_len, chunks, |bsatn| {

crates/sats/src/algebraic_value_hash.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl Hash for ArrayValue {
122122

123123
type HR = Result<(), DecodeError>;
124124

125-
pub fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
125+
fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
126126
match ty {
127127
AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
128128
AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de),
@@ -166,7 +166,11 @@ fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deseri
166166
}
167167

168168
/// Hashes every elem in the BSATN-encoded array value.
169-
fn hash_bsatn_array<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
169+
pub fn hash_bsatn_array<'a>(
170+
state: &mut impl Hasher,
171+
ty: &AlgebraicType,
172+
de: Deserializer<'_, impl BufReader<'a>>,
173+
) -> HR {
170174
// The BSATN is length-prefixed.
171175
// `Hash for &[T]` also does length-prefixing.
172176
match ty {
@@ -236,9 +240,9 @@ fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>(
236240

237241
#[cfg(test)]
238242
mod tests {
243+
use super::hash_bsatn;
239244
use crate::{
240245
bsatn::{to_vec, Deserializer},
241-
hash_bsatn,
242246
proptest::generate_typed_value,
243247
AlgebraicType, AlgebraicValue,
244248
};

crates/sats/src/bsatn/ser.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::buffer::BufWriter;
2+
use crate::de::DeserializeSeed;
23
use crate::ser::{self, Error, ForwardNamedToSeqProduct, SerializeArray, SerializeSeqProduct};
3-
use crate::AlgebraicValue;
44
use crate::{i256, u256};
5+
use crate::{AlgebraicValue, WithTypespace};
56
use core::fmt;
67

78
/// Defines the BSATN serialization data format.
@@ -159,20 +160,26 @@ impl<W: BufWriter> ser::Serializer for Serializer<'_, W> {
159160
value.serialize(self)
160161
}
161162

162-
unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
163-
debug_assert!(AlgebraicValue::decode(ty, &mut { bsatn }).is_ok());
163+
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
164+
where
165+
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
166+
{
167+
debug_assert!(crate::bsatn::decode(ty, &mut { bsatn }).is_ok());
164168
self.writer.put_slice(bsatn);
165169
Ok(())
166170
}
167171

168-
unsafe fn serialize_bsatn_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
172+
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Clone + Iterator<Item = &'a [u8]>>(
169173
self,
170-
ty: &crate::AlgebraicType,
174+
ty: &Ty,
171175
total_bsatn_len: usize,
172176
bsatn: I,
173-
) -> Result<Self::Ok, Self::Error> {
177+
) -> Result<Self::Ok, Self::Error>
178+
where
179+
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
180+
{
174181
debug_assert!(total_bsatn_len <= isize::MAX as usize);
175-
debug_assert!(AlgebraicValue::decode(ty, &mut &*concat_bytes_slow(total_bsatn_len, bsatn.clone())).is_ok());
182+
debug_assert!(crate::bsatn::decode(ty, &mut &*concat_bytes_slow(total_bsatn_len, bsatn.clone())).is_ok());
176183

177184
for chunk in bsatn {
178185
self.writer.put_slice(chunk);

crates/sats/src/layout.rs

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use crate::{
1313
},
1414
i256, impl_deserialize, impl_serialize,
1515
sum_type::{OPTION_NONE_TAG, OPTION_SOME_TAG},
16-
u256, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, SumType, SumTypeVariant,
17-
SumValue, WithTypespace,
16+
u256, AlgebraicType, AlgebraicValue, ArrayType, ProductType, ProductTypeElement, ProductValue, SumType,
17+
SumTypeVariant, SumValue, WithTypespace,
1818
};
1919
use core::ops::{Index, Mul};
2020
use core::{mem, ops::Deref};
@@ -199,8 +199,8 @@ impl AlgebraicTypeLayout {
199199
// but we don't care to avoid that and optimize right now,
200200
// as this is only executed during upgrade / migration,
201201
// and that doesn't need to be fast right now.
202-
let old = AlgebraicTypeLayout::from(old.deref().clone());
203-
let new = AlgebraicTypeLayout::from(new.deref().clone());
202+
let old = AlgebraicTypeLayout::from(old.elem_ty.deref().clone());
203+
let new = AlgebraicTypeLayout::from(new.elem_ty.deref().clone());
204204
old.is_compatible_with(&new)
205205
}
206206
(Self::VarLen(VarLenType::String), Self::VarLen(VarLenType::String)) => true,
@@ -515,11 +515,11 @@ impl HasLayout for PrimitiveType {
515515
pub enum VarLenType {
516516
/// The string type corresponds to `AlgebraicType::String`.
517517
String,
518-
/// An array type. The whole outer `AlgebraicType` is stored here.
518+
/// An array type. The inner `AlgebraicType` is stored here.
519519
///
520-
/// Storing the whole `AlgebraicType` here allows us to directly call BSATN ser/de,
521-
/// and to report type errors.
522-
Array(Box<AlgebraicType>),
520+
/// Previously, the outer type, i.e., `AlgebraicType::Array` was stored.
521+
/// However, this is both more inefficient and bug prone.
522+
Array(ArrayType),
523523
}
524524

525525
#[cfg(feature = "memory-usage")]
@@ -554,7 +554,7 @@ impl From<AlgebraicType> for AlgebraicTypeLayout {
554554
AlgebraicType::Product(prod) => AlgebraicTypeLayout::Product(prod.into()),
555555

556556
AlgebraicType::String => AlgebraicTypeLayout::VarLen(VarLenType::String),
557-
AlgebraicType::Array(_) => AlgebraicTypeLayout::VarLen(VarLenType::Array(Box::new(ty))),
557+
AlgebraicType::Array(array) => AlgebraicTypeLayout::VarLen(VarLenType::Array(array)),
558558

559559
AlgebraicType::Bool => AlgebraicTypeLayout::Bool,
560560
AlgebraicType::I8 => AlgebraicTypeLayout::I8,
@@ -690,19 +690,11 @@ impl AlgebraicTypeLayout {
690690
/// It is intended for use in error paths, where performance is a secondary concern.
691691
pub fn algebraic_type(&self) -> AlgebraicType {
692692
match self {
693-
AlgebraicTypeLayout::Primitive(prim) => prim.algebraic_type(),
694-
AlgebraicTypeLayout::VarLen(var_len) => var_len.algebraic_type(),
695-
AlgebraicTypeLayout::Product(prod) => AlgebraicType::Product(prod.view().product_type()),
696-
AlgebraicTypeLayout::Sum(sum) => AlgebraicType::Sum(sum.sum_type()),
697-
}
698-
}
699-
}
700-
701-
impl VarLenType {
702-
fn algebraic_type(&self) -> AlgebraicType {
703-
match self {
704-
VarLenType::String => AlgebraicType::String,
705-
VarLenType::Array(ty) => ty.as_ref().clone(),
693+
Self::Primitive(prim) => prim.algebraic_type(),
694+
Self::VarLen(VarLenType::String) => AlgebraicType::String,
695+
Self::VarLen(VarLenType::Array(array)) => AlgebraicType::Array(array.clone()),
696+
Self::Product(prod) => AlgebraicType::Product(prod.view().product_type()),
697+
Self::Sum(sum) => AlgebraicType::Sum(sum.sum_type()),
706698
}
707699
}
708700
}
@@ -828,7 +820,9 @@ impl<'de> DeserializeSeed<'de> for &AlgebraicTypeLayout {
828820
AlgebraicTypeLayout::Primitive(PrimitiveType::U256) => u256::deserialize(de).map(Into::into),
829821
AlgebraicTypeLayout::Primitive(PrimitiveType::F32) => f32::deserialize(de).map(Into::into),
830822
AlgebraicTypeLayout::Primitive(PrimitiveType::F64) => f64::deserialize(de).map(Into::into),
831-
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => WithTypespace::empty(&**ty).deserialize(de),
823+
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
824+
WithTypespace::empty(ty).deserialize(de).map(AlgebraicValue::Array)
825+
}
832826
AlgebraicTypeLayout::VarLen(VarLenType::String) => <Box<str>>::deserialize(de).map(Into::into),
833827
}
834828
}
@@ -1124,4 +1118,15 @@ mod test {
11241118
}
11251119
}
11261120
}
1121+
1122+
#[test]
1123+
fn infinite_recursion_in_is_compatible_with_with_array_type() {
1124+
let ty = AlgebraicTypeLayout::from(AlgebraicType::array(AlgebraicType::U64));
1125+
// This would previously cause an infinite recursion / stack overflow
1126+
// due the setup where `AlgebraicTypeLayout::VarLen(Array(x))` stored
1127+
// `x = Box::new(AlgebraicType::Array(elem_ty))`.
1128+
// The method `AlgebraicTypeLayout::is_compatible_with` was not setup to handle that.
1129+
// To avoid such bugs in the future, `x` is now `elem_ty` instead.
1130+
assert!(ty.is_compatible_with(&ty));
1131+
}
11271132
}

crates/sats/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ pub use crate as sats;
7777
pub use algebraic_type::AlgebraicType;
7878
pub use algebraic_type_ref::AlgebraicTypeRef;
7979
pub use algebraic_value::{i256, u256, AlgebraicValue, F32, F64};
80-
pub use algebraic_value_hash::hash_bsatn;
80+
pub use algebraic_value_hash::hash_bsatn_array;
8181
pub use array_type::ArrayType;
8282
pub use array_value::ArrayValue;
8383
pub use product_type::ProductType;

crates/sats/src/satn.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use crate::de::DeserializeSeed;
12
use crate::time_duration::TimeDuration;
23
use crate::timestamp::Timestamp;
3-
use crate::{i256, u256};
4+
use crate::{i256, u256, AlgebraicValue, WithTypespace};
45
use crate::{ser, ProductType, ProductTypeElement};
56
use core::fmt;
67
use core::fmt::Write as _;
@@ -706,17 +707,23 @@ impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> {
706707
self.fmt.serialize_variant(tag, name, value)
707708
}
708709

709-
unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
710+
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
711+
where
712+
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
713+
{
710714
// SAFETY: Forward caller requirements of this method to that we are calling.
711715
unsafe { self.fmt.serialize_bsatn(ty, bsatn) }
712716
}
713717

714-
unsafe fn serialize_bsatn_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
718+
unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator<Item = &'c [u8]>>(
715719
self,
716-
ty: &crate::AlgebraicType,
720+
ty: &Ty,
717721
total_bsatn_len: usize,
718722
bsatn: I,
719-
) -> Result<Self::Ok, Self::Error> {
723+
) -> Result<Self::Ok, Self::Error>
724+
where
725+
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
726+
{
720727
// SAFETY: Forward caller requirements of this method to that we are calling.
721728
unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) }
722729
}

crates/sats/src/ser.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ mod impls;
55
#[cfg(feature = "serde")]
66
pub mod serde;
77

8-
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, AlgebraicType};
8+
use crate::de::DeserializeSeed;
9+
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter};
10+
use crate::{AlgebraicValue, WithTypespace};
911
use core::marker::PhantomData;
1012
use core::{convert::Infallible, fmt};
1113
use ethnum::{i256, u256};
@@ -130,9 +132,13 @@ pub trait Serializer: Sized {
130132
///
131133
/// # Safety
132134
///
133-
/// - `AlgebraicValue::decode(ty, &mut bsatn).is_ok()`.
135+
/// - `decode(ty, &mut bsatn).is_ok()`.
134136
/// That is, `bsatn` encodes a valid element of `ty`.
135-
unsafe fn serialize_bsatn(self, ty: &AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
137+
/// It's up to the caller to arrange `Ty` such that this holds.
138+
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
139+
where
140+
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
141+
{
136142
// TODO(Centril): Consider instead deserializing the `bsatn` through a
137143
// deserializer that serializes into `self` directly.
138144

@@ -168,14 +174,18 @@ pub trait Serializer: Sized {
168174
///
169175
/// - `total_bsatn_len == bsatn.map(|c| c.len()).sum() <= isize::MAX`
170176
/// - Let `buf` be defined as above, i.e., the bytes of `bsatn` concatenated.
171-
/// Then `AlgebraicValue::decode(ty, &mut buf).is_ok()`.
177+
/// Then `decode(ty, &mut buf).is_ok()`.
172178
/// That is, `buf` encodes a valid element of `ty`.
173-
unsafe fn serialize_bsatn_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
179+
/// It's up to the caller to arrange `Ty` such that this holds.
180+
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Clone + Iterator<Item = &'a [u8]>>(
174181
self,
175-
ty: &AlgebraicType,
182+
ty: &Ty,
176183
total_bsatn_len: usize,
177184
bsatn: I,
178-
) -> Result<Self::Ok, Self::Error> {
185+
) -> Result<Self::Ok, Self::Error>
186+
where
187+
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
188+
{
179189
// TODO(Centril): Unlike above, in this case we must at minimum concatenate `bsatn`
180190
// before we can do the piping mentioned above, but that's better than
181191
// serializing to `AlgebraicValue` first, so consider that.

crates/table/src/bflatn_from.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use spacetimedb_sats::{
1717
align_to, AlgebraicTypeLayout, HasLayout as _, ProductTypeLayoutView, RowTypeLayout, SumTypeLayout, VarLenType,
1818
},
1919
ser::{SerializeNamedProduct, Serializer},
20-
u256, AlgebraicType,
20+
u256, ArrayType,
2121
};
2222

2323
/// Serializes the row in `page` where the fixed part starts at `fixed_offset`
@@ -243,7 +243,7 @@ pub(crate) unsafe fn serialize_value<S: Serializer>(
243243
}
244244
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
245245
// SAFETY: `value` was valid at `ty` and `VarLenRef`s won't be dangling.
246-
unsafe { serialize_bsatn(ser, bytes, page, blob_store, curr_offset, ty) }
246+
unsafe { serialize_array(ser, bytes, page, blob_store, curr_offset, ty) }
247247
}
248248
}
249249
}
@@ -285,13 +285,13 @@ unsafe fn serialize_string<S: Serializer>(
285285
}
286286
}
287287

288-
unsafe fn serialize_bsatn<S: Serializer>(
288+
unsafe fn serialize_array<S: Serializer>(
289289
ser: S,
290290
bytes: &Bytes,
291291
page: &Page,
292292
blob_store: &dyn BlobStore,
293293
curr_offset: CurrOffset<'_>,
294-
ty: &AlgebraicType,
294+
ty: &ArrayType,
295295
) -> Result<S::Ok, S::Error> {
296296
// SAFETY: `value` was valid at and aligned for `ty`.
297297
// These `ty` store a `vlr: VarLenRef` as their fixed value.

crates/table/src/row_hash.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ unsafe fn hash_value(
160160
}
161161
}
162162
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
163+
let ty = &ty.elem_ty;
164+
163165
// SAFETY: `value` was valid at and aligned for `ty`.
164166
// These `ty` store a `vlr: VarLenRef` as their value,
165167
// so the range is valid and properly aligned for `VarLenRef`.
@@ -168,7 +170,7 @@ unsafe fn hash_value(
168170
unsafe {
169171
run_vlo_bytes(page, bytes, blob_store, curr_offset, |mut bsatn| {
170172
let de = Deserializer::new(&mut bsatn);
171-
spacetimedb_sats::hash_bsatn(hasher, ty, de).unwrap();
173+
spacetimedb_sats::hash_bsatn_array(hasher, ty, de).unwrap();
172174
});
173175
}
174176
}

0 commit comments

Comments
 (0)