Skip to content

Commit 4d3a992

Browse files
authored
Fix heisenbug by using petgraph (#1712)
1 parent 4f29788 commit 4d3a992

File tree

4 files changed

+118
-108
lines changed

4 files changed

+118
-108
lines changed

Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ nohash-hasher = "0.2"
175175
once_cell = "1.16"
176176
parking_lot = { version = "0.12.1", features = ["send_guard", "arc_lock"] }
177177
paste = "1.0"
178+
petgraph = { version = "0.6.5", default-features = false }
178179
pin-project-lite = "0.2.9"
179180
postgres-types = "0.2.5"
180181
pretty_assertions = "1.4"

crates/schema/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ lazy_static.workspace = true
1717
thiserror.workspace = true
1818
unicode-ident.workspace = true
1919
unicode-normalization.workspace = true
20+
petgraph.workspace = true
2021
serde_json.workspace = true
2122
smallvec.workspace = true
2223
hashbrown.workspace = true

crates/schema/src/type_for_generate.rs

Lines changed: 99 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
//! `AlgebraicType` extensions for generating client code.
22
33
use enum_as_inner::EnumAsInner;
4+
use petgraph::{
5+
algo::tarjan_scc,
6+
visit::{GraphBase, IntoNeighbors, IntoNodeIdentifiers, NodeIndexable},
7+
};
48
use smallvec::SmallVec;
59
use spacetimedb_data_structures::{
610
error_stream::{CollectAllErrors, CombineErrors, ErrorStream},
711
map::{HashMap, HashSet},
812
};
913
use spacetimedb_lib::{AlgebraicType, ProductTypeElement};
1014
use spacetimedb_sats::{typespace::TypeRefError, AlgebraicTypeRef, ArrayType, SumTypeVariant, Typespace};
11-
use std::{ops::Index, sync::Arc};
15+
use std::{cell::RefCell, ops::Index, sync::Arc};
1216

1317
use crate::{
1418
error::{IdentifierError, PrettyAlgebraicType},
@@ -151,6 +155,11 @@ pub enum AlgebraicTypeDef {
151155
PlainEnum(PlainEnumTypeDef),
152156
}
153157

158+
thread_local! {
159+
/// Used to efficiently extract refs from a def.
160+
static EXTRACT_REFS_BUF: RefCell<HashSet<AlgebraicTypeRef>> = RefCell::new(HashSet::new());
161+
}
162+
154163
impl AlgebraicTypeDef {
155164
/// Check if a def is recursive.
156165
pub fn is_recursive(&self) -> bool {
@@ -161,21 +170,26 @@ impl AlgebraicTypeDef {
161170
}
162171
}
163172

164-
/// Extract all `AlgebraicTypeRef`s that are used in this type into the buffer.
165-
fn extract_refs(&self, buf: &mut HashSet<AlgebraicTypeRef>) {
166-
match self {
167-
AlgebraicTypeDef::Product(ProductTypeDef { elements, .. }) => {
168-
for (_, ty) in elements.iter() {
169-
ty.extract_refs(buf);
173+
/// Extract all `AlgebraicTypeRef`s that are used in this type into a buffer.
174+
/// The buffer may be in arbitrary order, but will not contain duplicates.
175+
fn extract_refs(&self) -> SmallVec<[AlgebraicTypeRef; 16]> {
176+
EXTRACT_REFS_BUF.with_borrow_mut(|buf| {
177+
buf.clear();
178+
match self {
179+
AlgebraicTypeDef::Product(ProductTypeDef { elements, .. }) => {
180+
for (_, use_) in elements.iter() {
181+
use_.extract_refs(buf);
182+
}
170183
}
171-
}
172-
AlgebraicTypeDef::Sum(SumTypeDef { variants, .. }) => {
173-
for (_, ty) in variants.iter() {
174-
ty.extract_refs(buf);
184+
AlgebraicTypeDef::Sum(SumTypeDef { variants, .. }) => {
185+
for (_, use_) in variants.iter() {
186+
use_.extract_refs(buf);
187+
}
175188
}
189+
AlgebraicTypeDef::PlainEnum(_) => {}
176190
}
177-
AlgebraicTypeDef::PlainEnum(_) => {}
178-
}
191+
buf.drain().collect()
192+
})
179193
}
180194

181195
/// Mark a def recursive.
@@ -608,114 +622,79 @@ impl TypespaceForGenerateBuilder<'_> {
608622
/// Cycles passing through definitions are allowed.
609623
/// This function is called after all definitions have been processed.
610624
fn mark_allowed_cycles(&mut self) {
611-
let mut to_process = self.is_def.clone();
612-
let mut scratch = HashSet::new();
613-
// We reuse this here as well.
614-
self.currently_touching.clear();
625+
let strongly_connected_components: Vec<Vec<AlgebraicTypeRef>> = tarjan_scc(&*self);
626+
for component in strongly_connected_components {
627+
if component.len() == 1 {
628+
// petgraph's implementation returns a vector for all nodes, not distinguishing between
629+
// self referential and non-self-referential nodes. ignore this for now.
630+
continue;
631+
}
632+
for ref_ in component {
633+
self.result
634+
.defs
635+
.get_mut(&ref_)
636+
.expect("all defs should be processed by now")
637+
.mark_recursive();
638+
}
639+
}
615640

616-
while let Some(ref_) = to_process.iter().next().cloned() {
617-
self.mark_allowed_cycles_rec(None, ref_, &mut to_process, &mut scratch);
641+
// Now, fix up directly self-referential nodes.
642+
for (ref_, def_) in &mut self.result.defs {
643+
let ref_ = *ref_;
644+
if def_.is_recursive() {
645+
continue;
646+
}
647+
let refs = def_.extract_refs();
648+
if refs.contains(&ref_) {
649+
def_.mark_recursive();
650+
}
618651
}
619652
}
653+
}
620654

621-
/// Recursively mark allowed cycles.
622-
fn mark_allowed_cycles_rec(
623-
&mut self,
624-
parent: Option<&ParentChain>,
625-
def: AlgebraicTypeRef,
626-
to_process: &mut HashSet<AlgebraicTypeRef>,
627-
scratch: &mut HashSet<AlgebraicTypeRef>,
628-
) {
629-
// Mark who we're touching right now.
630-
let not_already_present = self.currently_touching.insert(def);
631-
assert!(
632-
not_already_present,
633-
"mark_allowed_cycles_rec should never be called on a ref that is already being touched"
634-
);
655+
// We implement some `petgraph` traits for `TypespaceForGenerate` to allow using
656+
// petgraph's implementation of Tarjan's strongly-connected-components algorithm.
657+
// This is used in `mark_allowed_cycles`.
658+
// We don't implement all the traits, only the ones we need.
659+
// The traits are intended to be used *after* all defs have been processed.
635660

636-
// Figure out who to look at.
637-
// Note: this skips over refs in the original typespace that
638-
// didn't point to definitions; those have already been removed.
639-
scratch.clear();
640-
let to_examine = scratch;
641-
self.result.defs[&def].extract_refs(to_examine);
642-
643-
// Update the parent chain with the current def, for passing to children.
644-
let chain = ParentChain { parent, ref_: def };
645-
646-
// First, check for finished cycles.
647-
for element in to_examine.iter() {
648-
if self.currently_touching.contains(element) {
649-
// We have a cycle.
650-
for parent_ref in chain.iter() {
651-
// For each def participating in the cycle, mark it as recursive.
652-
self.result
653-
.defs
654-
.get_mut(&parent_ref)
655-
.expect("all defs should have been processed by now")
656-
.mark_recursive();
657-
// It's tempting to also remove `parent_ref` from `to_process` here,
658-
// but that's wrong, because it might participate in other cycles.
659-
660-
// We want to mark the start of the cycle as recursive too.
661-
// If we've just done that, break.
662-
if parent_ref == *element {
663-
break;
664-
}
665-
}
666-
}
667-
}
661+
impl GraphBase for TypespaceForGenerateBuilder<'_> {
662+
/// Specifically, definition IDs.
663+
type NodeId = AlgebraicTypeRef;
668664

669-
// Now that we've marked everything possible, we need to recurse.
670-
// Need a buffer to iterate from because we reuse `to_examine` in children.
671-
// This will usually not allocate. Most defs have less than 16 refs.
672-
let to_recurse = to_examine
673-
.iter()
674-
.cloned()
675-
.filter(|element| to_process.contains(element) && !self.currently_touching.contains(element))
676-
.collect::<SmallVec<[AlgebraicTypeRef; 16]>>();
677-
678-
// Recurse.
679-
let scratch = to_examine;
680-
for element in to_recurse {
681-
self.mark_allowed_cycles_rec(Some(&chain), element, to_process, scratch);
682-
}
665+
/// Definition `.0` uses definition `.1`.
666+
type EdgeId = (AlgebraicTypeRef, AlgebraicTypeRef);
667+
}
668+
impl NodeIndexable for TypespaceForGenerateBuilder<'_> {
669+
fn node_bound(&self) -> usize {
670+
self.typespace.types.len()
671+
}
683672

684-
// We're done with this def.
685-
// Clean up our state.
686-
let was_present = self.currently_touching.remove(&def);
687-
assert!(
688-
was_present,
689-
"mark_allowed_cycles_rec is finishing, we should be touching that ref."
690-
);
691-
// Only remove a def from `to_process` once we've explored all the paths leaving it.
692-
to_process.remove(&def);
673+
fn to_index(&self, a: Self::NodeId) -> usize {
674+
a.idx()
693675
}
694-
}
695676

696-
/// A chain of parent type definitions.
697-
/// If type T uses type U, then T is a parent of U.
698-
struct ParentChain<'a> {
699-
parent: Option<&'a ParentChain<'a>>,
700-
ref_: AlgebraicTypeRef,
701-
}
702-
impl<'a> ParentChain<'a> {
703-
fn iter(&'a self) -> ParentChainIter<'a> {
704-
ParentChainIter { current: Some(self) }
677+
fn from_index(&self, i: usize) -> Self::NodeId {
678+
AlgebraicTypeRef(i as _)
705679
}
706680
}
681+
impl<'a> IntoNodeIdentifiers for &'a TypespaceForGenerateBuilder<'a> {
682+
type NodeIdentifiers = std::iter::Cloned<hashbrown::hash_set::Iter<'a, spacetimedb_sats::AlgebraicTypeRef>>;
707683

708-
/// An iterator over a chain of parent type definitions.
709-
struct ParentChainIter<'a> {
710-
current: Option<&'a ParentChain<'a>>,
684+
fn node_identifiers(self) -> Self::NodeIdentifiers {
685+
self.is_def.iter().cloned()
686+
}
711687
}
712-
impl Iterator for ParentChainIter<'_> {
713-
type Item = AlgebraicTypeRef;
688+
impl<'a> IntoNeighbors for &'a TypespaceForGenerateBuilder<'a> {
689+
type Neighbors = <SmallVec<[AlgebraicTypeRef; 16]> as IntoIterator>::IntoIter;
714690

715-
fn next(&mut self) -> Option<Self::Item> {
716-
let current = self.current?;
717-
self.current = current.parent;
718-
Some(current.ref_)
691+
fn neighbors(self, a: Self::NodeId) -> Self::Neighbors {
692+
self.result
693+
.defs
694+
.get(&a)
695+
.expect("all defs should have been processed by now")
696+
.extract_refs()
697+
.into_iter()
719698
}
720699
}
721700

@@ -805,7 +784,7 @@ mod tests {
805784
}
806785

807786
#[test]
808-
fn test_detects_cycles() {
787+
fn test_detects_cycles_1() {
809788
let cyclic_1 = Typespace::new(vec![AlgebraicType::Ref(AlgebraicTypeRef(0))]);
810789
let mut for_generate = TypespaceForGenerate::builder(&cyclic_1, []);
811790
let err1 = for_generate.parse_use(&AlgebraicType::Ref(AlgebraicTypeRef(0)));
@@ -814,7 +793,10 @@ mod tests {
814793
err1,
815794
ClientCodegenError::TypeRefError(TypeRefError::RecursiveTypeRef(AlgebraicTypeRef(0)))
816795
);
796+
}
817797

798+
#[test]
799+
fn test_detects_cycles_2() {
818800
let cyclic_2 = Typespace::new(vec![
819801
AlgebraicType::Ref(AlgebraicTypeRef(1)),
820802
AlgebraicType::Ref(AlgebraicTypeRef(0)),
@@ -826,7 +808,10 @@ mod tests {
826808
err2,
827809
ClientCodegenError::TypeRefError(TypeRefError::RecursiveTypeRef(AlgebraicTypeRef(0)))
828810
);
811+
}
829812

813+
#[test]
814+
fn test_detects_cycles_3() {
830815
let cyclic_3 = Typespace::new(vec![
831816
AlgebraicType::Ref(AlgebraicTypeRef(1)),
832817
AlgebraicType::product([("field", AlgebraicType::Ref(AlgebraicTypeRef(0)))]),
@@ -842,7 +827,10 @@ mod tests {
842827
let table = result.defs().get(&AlgebraicTypeRef(1)).expect("should be defined");
843828

844829
assert!(table.is_recursive(), "recursion not detected? table: {table:?}");
830+
}
845831

832+
#[test]
833+
fn test_detects_cycles_4() {
846834
let cyclic_4 = Typespace::new(vec![
847835
AlgebraicType::product([("field", AlgebraicTypeRef(1).into())]),
848836
AlgebraicType::product([("field", AlgebraicTypeRef(2).into())]),
@@ -878,7 +866,10 @@ mod tests {
878866
!result[AlgebraicTypeRef(4)].is_recursive(),
879867
"recursion detected incorrectly"
880868
);
869+
}
881870

871+
#[test]
872+
fn test_detects_cycles_5() {
882873
// Branching cycles.
883874
let cyclic_5 = Typespace::new(vec![
884875
// cyclic component.

0 commit comments

Comments
 (0)