1
1
//! `AlgebraicType` extensions for generating client code.
2
2
3
3
use enum_as_inner:: EnumAsInner ;
4
+ use petgraph:: {
5
+ algo:: tarjan_scc,
6
+ visit:: { GraphBase , IntoNeighbors , IntoNodeIdentifiers , NodeIndexable } ,
7
+ } ;
4
8
use smallvec:: SmallVec ;
5
9
use spacetimedb_data_structures:: {
6
10
error_stream:: { CollectAllErrors , CombineErrors , ErrorStream } ,
7
11
map:: { HashMap , HashSet } ,
8
12
} ;
9
13
use spacetimedb_lib:: { AlgebraicType , ProductTypeElement } ;
10
14
use spacetimedb_sats:: { typespace:: TypeRefError , AlgebraicTypeRef , ArrayType , SumTypeVariant , Typespace } ;
11
- use std:: { ops:: Index , sync:: Arc } ;
15
+ use std:: { cell :: RefCell , ops:: Index , sync:: Arc } ;
12
16
13
17
use crate :: {
14
18
error:: { IdentifierError , PrettyAlgebraicType } ,
@@ -151,6 +155,11 @@ pub enum AlgebraicTypeDef {
151
155
PlainEnum ( PlainEnumTypeDef ) ,
152
156
}
153
157
158
+ thread_local ! {
159
+ /// Used to efficiently extract refs from a def.
160
+ static EXTRACT_REFS_BUF : RefCell <HashSet <AlgebraicTypeRef >> = RefCell :: new( HashSet :: new( ) ) ;
161
+ }
162
+
154
163
impl AlgebraicTypeDef {
155
164
/// Check if a def is recursive.
156
165
pub fn is_recursive ( & self ) -> bool {
@@ -161,21 +170,26 @@ impl AlgebraicTypeDef {
161
170
}
162
171
}
163
172
164
- /// Extract all `AlgebraicTypeRef`s that are used in this type into the buffer.
165
- fn extract_refs ( & self , buf : & mut HashSet < AlgebraicTypeRef > ) {
166
- match self {
167
- AlgebraicTypeDef :: Product ( ProductTypeDef { elements, .. } ) => {
168
- for ( _, ty) in elements. iter ( ) {
169
- ty. extract_refs ( buf) ;
173
+ /// Extract all `AlgebraicTypeRef`s that are used in this type into a buffer.
174
+ /// The buffer may be in arbitrary order, but will not contain duplicates.
175
+ fn extract_refs ( & self ) -> SmallVec < [ AlgebraicTypeRef ; 16 ] > {
176
+ EXTRACT_REFS_BUF . with_borrow_mut ( |buf| {
177
+ buf. clear ( ) ;
178
+ match self {
179
+ AlgebraicTypeDef :: Product ( ProductTypeDef { elements, .. } ) => {
180
+ for ( _, use_) in elements. iter ( ) {
181
+ use_. extract_refs ( buf) ;
182
+ }
170
183
}
171
- }
172
- AlgebraicTypeDef :: Sum ( SumTypeDef { variants , .. } ) => {
173
- for ( _ , ty ) in variants . iter ( ) {
174
- ty . extract_refs ( buf ) ;
184
+ AlgebraicTypeDef :: Sum ( SumTypeDef { variants , .. } ) => {
185
+ for ( _ , use_ ) in variants . iter ( ) {
186
+ use_ . extract_refs ( buf ) ;
187
+ }
175
188
}
189
+ AlgebraicTypeDef :: PlainEnum ( _) => { }
176
190
}
177
- AlgebraicTypeDef :: PlainEnum ( _ ) => { }
178
- }
191
+ buf . drain ( ) . collect ( )
192
+ } )
179
193
}
180
194
181
195
/// Mark a def recursive.
@@ -608,114 +622,79 @@ impl TypespaceForGenerateBuilder<'_> {
608
622
/// Cycles passing through definitions are allowed.
609
623
/// This function is called after all definitions have been processed.
610
624
fn mark_allowed_cycles ( & mut self ) {
611
- let mut to_process = self . is_def . clone ( ) ;
612
- let mut scratch = HashSet :: new ( ) ;
613
- // We reuse this here as well.
614
- self . currently_touching . clear ( ) ;
625
+ let strongly_connected_components: Vec < Vec < AlgebraicTypeRef > > = tarjan_scc ( & * self ) ;
626
+ for component in strongly_connected_components {
627
+ if component. len ( ) == 1 {
628
+ // petgraph's implementation returns a vector for all nodes, not distinguishing between
629
+ // self referential and non-self-referential nodes. ignore this for now.
630
+ continue ;
631
+ }
632
+ for ref_ in component {
633
+ self . result
634
+ . defs
635
+ . get_mut ( & ref_)
636
+ . expect ( "all defs should be processed by now" )
637
+ . mark_recursive ( ) ;
638
+ }
639
+ }
615
640
616
- while let Some ( ref_) = to_process. iter ( ) . next ( ) . cloned ( ) {
617
- self . mark_allowed_cycles_rec ( None , ref_, & mut to_process, & mut scratch) ;
641
+ // Now, fix up directly self-referential nodes.
642
+ for ( ref_, def_) in & mut self . result . defs {
643
+ let ref_ = * ref_;
644
+ if def_. is_recursive ( ) {
645
+ continue ;
646
+ }
647
+ let refs = def_. extract_refs ( ) ;
648
+ if refs. contains ( & ref_) {
649
+ def_. mark_recursive ( ) ;
650
+ }
618
651
}
619
652
}
653
+ }
620
654
621
- /// Recursively mark allowed cycles.
622
- fn mark_allowed_cycles_rec (
623
- & mut self ,
624
- parent : Option < & ParentChain > ,
625
- def : AlgebraicTypeRef ,
626
- to_process : & mut HashSet < AlgebraicTypeRef > ,
627
- scratch : & mut HashSet < AlgebraicTypeRef > ,
628
- ) {
629
- // Mark who we're touching right now.
630
- let not_already_present = self . currently_touching . insert ( def) ;
631
- assert ! (
632
- not_already_present,
633
- "mark_allowed_cycles_rec should never be called on a ref that is already being touched"
634
- ) ;
655
+ // We implement some `petgraph` traits for `TypespaceForGenerate` to allow using
656
+ // petgraph's implementation of Tarjan's strongly-connected-components algorithm.
657
+ // This is used in `mark_allowed_cycles`.
658
+ // We don't implement all the traits, only the ones we need.
659
+ // The traits are intended to be used *after* all defs have been processed.
635
660
636
- // Figure out who to look at.
637
- // Note: this skips over refs in the original typespace that
638
- // didn't point to definitions; those have already been removed.
639
- scratch. clear ( ) ;
640
- let to_examine = scratch;
641
- self . result . defs [ & def] . extract_refs ( to_examine) ;
642
-
643
- // Update the parent chain with the current def, for passing to children.
644
- let chain = ParentChain { parent, ref_ : def } ;
645
-
646
- // First, check for finished cycles.
647
- for element in to_examine. iter ( ) {
648
- if self . currently_touching . contains ( element) {
649
- // We have a cycle.
650
- for parent_ref in chain. iter ( ) {
651
- // For each def participating in the cycle, mark it as recursive.
652
- self . result
653
- . defs
654
- . get_mut ( & parent_ref)
655
- . expect ( "all defs should have been processed by now" )
656
- . mark_recursive ( ) ;
657
- // It's tempting to also remove `parent_ref` from `to_process` here,
658
- // but that's wrong, because it might participate in other cycles.
659
-
660
- // We want to mark the start of the cycle as recursive too.
661
- // If we've just done that, break.
662
- if parent_ref == * element {
663
- break ;
664
- }
665
- }
666
- }
667
- }
661
+ impl GraphBase for TypespaceForGenerateBuilder < ' _ > {
662
+ /// Specifically, definition IDs.
663
+ type NodeId = AlgebraicTypeRef ;
668
664
669
- // Now that we've marked everything possible, we need to recurse.
670
- // Need a buffer to iterate from because we reuse `to_examine` in children.
671
- // This will usually not allocate. Most defs have less than 16 refs.
672
- let to_recurse = to_examine
673
- . iter ( )
674
- . cloned ( )
675
- . filter ( |element| to_process. contains ( element) && !self . currently_touching . contains ( element) )
676
- . collect :: < SmallVec < [ AlgebraicTypeRef ; 16 ] > > ( ) ;
677
-
678
- // Recurse.
679
- let scratch = to_examine;
680
- for element in to_recurse {
681
- self . mark_allowed_cycles_rec ( Some ( & chain) , element, to_process, scratch) ;
682
- }
665
+ /// Definition `.0` uses definition `.1`.
666
+ type EdgeId = ( AlgebraicTypeRef , AlgebraicTypeRef ) ;
667
+ }
668
+ impl NodeIndexable for TypespaceForGenerateBuilder < ' _ > {
669
+ fn node_bound ( & self ) -> usize {
670
+ self . typespace . types . len ( )
671
+ }
683
672
684
- // We're done with this def.
685
- // Clean up our state.
686
- let was_present = self . currently_touching . remove ( & def) ;
687
- assert ! (
688
- was_present,
689
- "mark_allowed_cycles_rec is finishing, we should be touching that ref."
690
- ) ;
691
- // Only remove a def from `to_process` once we've explored all the paths leaving it.
692
- to_process. remove ( & def) ;
673
+ fn to_index ( & self , a : Self :: NodeId ) -> usize {
674
+ a. idx ( )
693
675
}
694
- }
695
676
696
- /// A chain of parent type definitions.
697
- /// If type T uses type U, then T is a parent of U.
698
- struct ParentChain < ' a > {
699
- parent : Option < & ' a ParentChain < ' a > > ,
700
- ref_ : AlgebraicTypeRef ,
701
- }
702
- impl < ' a > ParentChain < ' a > {
703
- fn iter ( & ' a self ) -> ParentChainIter < ' a > {
704
- ParentChainIter { current : Some ( self ) }
677
+ fn from_index ( & self , i : usize ) -> Self :: NodeId {
678
+ AlgebraicTypeRef ( i as _ )
705
679
}
706
680
}
681
+ impl < ' a > IntoNodeIdentifiers for & ' a TypespaceForGenerateBuilder < ' a > {
682
+ type NodeIdentifiers = std:: iter:: Cloned < hashbrown:: hash_set:: Iter < ' a , spacetimedb_sats:: AlgebraicTypeRef > > ;
707
683
708
- /// An iterator over a chain of parent type definitions.
709
- struct ParentChainIter < ' a > {
710
- current : Option < & ' a ParentChain < ' a > > ,
684
+ fn node_identifiers ( self ) -> Self :: NodeIdentifiers {
685
+ self . is_def . iter ( ) . cloned ( )
686
+ }
711
687
}
712
- impl Iterator for ParentChainIter < ' _ > {
713
- type Item = AlgebraicTypeRef ;
688
+ impl < ' a > IntoNeighbors for & ' a TypespaceForGenerateBuilder < ' a > {
689
+ type Neighbors = < SmallVec < [ AlgebraicTypeRef ; 16 ] > as IntoIterator > :: IntoIter ;
714
690
715
- fn next ( & mut self ) -> Option < Self :: Item > {
716
- let current = self . current ?;
717
- self . current = current. parent ;
718
- Some ( current. ref_ )
691
+ fn neighbors ( self , a : Self :: NodeId ) -> Self :: Neighbors {
692
+ self . result
693
+ . defs
694
+ . get ( & a)
695
+ . expect ( "all defs should have been processed by now" )
696
+ . extract_refs ( )
697
+ . into_iter ( )
719
698
}
720
699
}
721
700
@@ -805,7 +784,7 @@ mod tests {
805
784
}
806
785
807
786
#[ test]
808
- fn test_detects_cycles ( ) {
787
+ fn test_detects_cycles_1 ( ) {
809
788
let cyclic_1 = Typespace :: new ( vec ! [ AlgebraicType :: Ref ( AlgebraicTypeRef ( 0 ) ) ] ) ;
810
789
let mut for_generate = TypespaceForGenerate :: builder ( & cyclic_1, [ ] ) ;
811
790
let err1 = for_generate. parse_use ( & AlgebraicType :: Ref ( AlgebraicTypeRef ( 0 ) ) ) ;
@@ -814,7 +793,10 @@ mod tests {
814
793
err1,
815
794
ClientCodegenError :: TypeRefError ( TypeRefError :: RecursiveTypeRef ( AlgebraicTypeRef ( 0 ) ) )
816
795
) ;
796
+ }
817
797
798
+ #[ test]
799
+ fn test_detects_cycles_2 ( ) {
818
800
let cyclic_2 = Typespace :: new ( vec ! [
819
801
AlgebraicType :: Ref ( AlgebraicTypeRef ( 1 ) ) ,
820
802
AlgebraicType :: Ref ( AlgebraicTypeRef ( 0 ) ) ,
@@ -826,7 +808,10 @@ mod tests {
826
808
err2,
827
809
ClientCodegenError :: TypeRefError ( TypeRefError :: RecursiveTypeRef ( AlgebraicTypeRef ( 0 ) ) )
828
810
) ;
811
+ }
829
812
813
+ #[ test]
814
+ fn test_detects_cycles_3 ( ) {
830
815
let cyclic_3 = Typespace :: new ( vec ! [
831
816
AlgebraicType :: Ref ( AlgebraicTypeRef ( 1 ) ) ,
832
817
AlgebraicType :: product( [ ( "field" , AlgebraicType :: Ref ( AlgebraicTypeRef ( 0 ) ) ) ] ) ,
@@ -842,7 +827,10 @@ mod tests {
842
827
let table = result. defs ( ) . get ( & AlgebraicTypeRef ( 1 ) ) . expect ( "should be defined" ) ;
843
828
844
829
assert ! ( table. is_recursive( ) , "recursion not detected? table: {table:?}" ) ;
830
+ }
845
831
832
+ #[ test]
833
+ fn test_detects_cycles_4 ( ) {
846
834
let cyclic_4 = Typespace :: new ( vec ! [
847
835
AlgebraicType :: product( [ ( "field" , AlgebraicTypeRef ( 1 ) . into( ) ) ] ) ,
848
836
AlgebraicType :: product( [ ( "field" , AlgebraicTypeRef ( 2 ) . into( ) ) ] ) ,
@@ -878,7 +866,10 @@ mod tests {
878
866
!result[ AlgebraicTypeRef ( 4 ) ] . is_recursive( ) ,
879
867
"recursion detected incorrectly"
880
868
) ;
869
+ }
881
870
871
+ #[ test]
872
+ fn test_detects_cycles_5 ( ) {
882
873
// Branching cycles.
883
874
let cyclic_5 = Typespace :: new ( vec ! [
884
875
// cyclic component.
0 commit comments