24
24
//! * Immutable input and no support for recomputation given mutated inputs.
25
25
//! * Correspondingly, no requirement that the *return* types implement `Eq` or
26
26
//! `Hash`.
27
+ //! * Supports `#[break_cycles_with = <default-value>]`, which generates a
28
+ //! function that returns <default-value> if a cycle is detected.
27
29
//!
28
30
//! There are more substantial differences with Salsa 2022 - this was written
29
31
//! based on Salsa 0.16. We don't need to match exactly the API, but the
75
77
/// //
76
78
/// // When called on the `&dyn QueryGroupName` or directly on the concrete type, the functions
77
79
/// // will be memoized, and the return value will be cached, automatically.
80
+ /// //
81
+ /// // Some functions may need to gracefully handle cycles, in which case they should be
82
+ /// // annotated with `#[break_cycles_with = <default_value>]`. This will generate a function
83
+ /// // that returns <default_value> if a cycle is detected, but will _not_ cache the result.
84
+ /// // All `#[break_cycles_with = ..]` functions must appear before all
85
+ /// // non-`#[break_cycles_with = ..]` functions.
86
+ /// #[break_cycles_with = ReturnType::default()]
87
+ /// fn may_be_cyclic(&self, arg: ArgType) -> ReturnType;
88
+ ///
78
89
/// fn some_function(&self, arg: ArgType) -> ReturnType;
79
90
/// }
80
91
/// // The concrete type for the storage of inputs and memoized values.
83
94
/// }
84
95
///
85
96
/// // The non-memoized implementation of the memoized functions
97
+ /// fn may_by_cyclic(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
98
+ /// // ...
99
+ /// }
100
+ ///
86
101
/// fn some_function(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
87
102
/// // ...
88
103
/// }
@@ -152,6 +167,19 @@ macro_rules! query_group {
152
167
#[ input]
153
168
fn $input_function: ident( & self $( , ) ?) -> $input_type: ty;
154
169
) *
170
+ $(
171
+ // TODO(jeanpierreda): Ideally would like to preserve doc comments here, but it introduces a
172
+ // parsing ambiguity with how the macro is currently structured.
173
+ // $(#[doc = $break_cycles_doc:literal])*
174
+ #[ break_cycles_with = $break_cycles_default_value: expr]
175
+ fn $break_cycles_function: ident(
176
+ & self
177
+ $(
178
+ , $break_cycles_arg: ident : $break_cycles_arg_type: ty
179
+ ) *
180
+ $( , ) ?
181
+ ) -> $break_cycles_return_type: ty;
182
+ ) *
155
183
$(
156
184
// TODO(jeanpierreda): Ideally would like to preserve doc comments here, but it introduces a
157
185
// parsing ambiguity with how the macro is currently structured.
@@ -174,6 +202,14 @@ macro_rules! query_group {
174
202
$( #[ doc = $input_doc] ) *
175
203
fn $input_function( & self ) -> $input_type
176
204
; ) *
205
+ $(
206
+ fn $break_cycles_function(
207
+ & self ,
208
+ $(
209
+ $break_cycles_arg : $break_cycles_arg_type
210
+ ) ,*
211
+ ) -> $break_cycles_return_type
212
+ ; ) *
177
213
$(
178
214
fn $function(
179
215
& self ,
@@ -186,9 +222,15 @@ macro_rules! query_group {
186
222
187
223
// Now we can generate a database struct that contains the lookup tables.
188
224
$struct_vis struct $database_struct $( <$( $type_param) ,* >) ? {
225
+ __unwinding_cycles: :: core:: cell:: Cell <u32 >,
189
226
$(
190
227
$input_function: $input_type,
191
228
) *
229
+ $(
230
+ // Note that we store $break_cycles_return_type here, not Option<$break_cycles_return_type>.
231
+ // This is because we don't cache failed calls.
232
+ $break_cycles_function: $crate:: internal:: MemoizationTable <( $( $break_cycles_arg_type, ) * ) , $break_cycles_return_type>,
233
+ ) *
192
234
$(
193
235
$function: $crate:: internal:: MemoizationTable <( $( $arg_type, ) * ) , $return_type>,
194
236
) *
@@ -204,6 +246,25 @@ macro_rules! query_group {
204
246
( & self . $input_function) . clone( )
205
247
}
206
248
) *
249
+ $(
250
+ fn $break_cycles_function(
251
+ & self ,
252
+ $(
253
+ $break_cycles_arg : $break_cycles_arg_type
254
+ ) ,*
255
+ ) -> $break_cycles_return_type {
256
+ self . $break_cycles_function. break_cycles_internal_memoized_call(
257
+ ( $(
258
+ $break_cycles_arg,
259
+ ) * ) ,
260
+ |( $( $break_cycles_arg, ) * ) | {
261
+ // Force the use of &dyn $trait, so that we don't rule out separate compilation later.
262
+ $break_cycles_function( self as & dyn $trait, $( $break_cycles_arg) ,* )
263
+ } ,
264
+ & self . __unwinding_cycles,
265
+ ) . unwrap_or( $break_cycles_default_value)
266
+ }
267
+ ) *
207
268
$(
208
269
fn $function(
209
270
& self ,
@@ -215,11 +276,11 @@ macro_rules! query_group {
215
276
( $(
216
277
$arg,
217
278
) * ) ,
218
- |args| {
219
- let ( $( $arg, ) * ) = args;
279
+ |( $( $arg, ) * ) | {
220
280
// Force the use of &dyn $trait, so that we don't rule out separate compilation later.
221
281
$function( self as & dyn $trait, $( $arg) ,* )
222
- }
282
+ } ,
283
+ & self . __unwinding_cycles,
223
284
)
224
285
}
225
286
) *
@@ -228,9 +289,13 @@ macro_rules! query_group {
228
289
impl $( <$( $type_param) ,* >) ? $database_struct $( <$( $type_param) ,* >) ? {
229
290
$struct_vis fn new( $( $input_function: $input_type) ,* ) -> Self {
230
291
Self {
292
+ __unwinding_cycles: :: core:: cell:: Cell :: new( 0 ) ,
231
293
$(
232
294
$input_function,
233
295
) *
296
+ $(
297
+ $break_cycles_function: Default :: default ( ) ,
298
+ ) *
234
299
$(
235
300
$function: Default :: default ( ) ,
236
301
) *
@@ -242,16 +307,23 @@ macro_rules! query_group {
242
307
243
308
#[ doc( hidden) ]
244
309
pub mod internal {
245
- use std:: cell:: RefCell ;
246
- use std:: collections:: { HashMap , HashSet } ;
310
+ use std:: cell:: { Cell , RefCell } ;
311
+ use std:: collections:: HashMap ;
247
312
use std:: hash:: Hash ;
313
+
314
+ #[ derive( Copy , Clone , PartialEq , Eq ) ]
315
+ enum FoundCycle {
316
+ No ,
317
+ Yes ,
318
+ }
319
+
248
320
pub struct MemoizationTable < Args , Return >
249
321
where
250
322
Args : Clone + Eq + Hash ,
251
323
Return : Clone ,
252
324
{
253
325
memoized : RefCell < HashMap < Args , Return > > ,
254
- active : RefCell < HashSet < Args > > ,
326
+ active : RefCell < HashMap < Args , FoundCycle > > ,
255
327
}
256
328
257
329
// Separate `impl` instead of `#[derive(Default)]` because the `derive` would
@@ -262,7 +334,26 @@ pub mod internal {
262
334
Return : Clone ,
263
335
{
264
336
fn default ( ) -> Self {
265
- Self { memoized : RefCell :: new ( HashMap :: new ( ) ) , active : RefCell :: new ( HashSet :: new ( ) ) }
337
+ Self { memoized : RefCell :: new ( HashMap :: new ( ) ) , active : RefCell :: new ( HashMap :: new ( ) ) }
338
+ }
339
+ }
340
+
341
+ impl < Args , Return > MemoizationTable < Args , Return >
342
+ where
343
+ Args : Clone + Eq + Hash ,
344
+ Return : Clone ,
345
+ {
346
+ pub fn internal_memoized_call < F > (
347
+ & self ,
348
+ args : Args ,
349
+ f : F ,
350
+ unwinding_cycles : & Cell < u32 > ,
351
+ ) -> Return
352
+ where
353
+ F : FnOnce ( Args ) -> Return ,
354
+ {
355
+ self . break_cycles_internal_memoized_call ( args, f, unwinding_cycles)
356
+ . expect ( "Cycle detected: a memoized function depends on its own return value" )
266
357
}
267
358
}
268
359
@@ -271,31 +362,55 @@ pub mod internal {
271
362
Args : Clone + Eq + Hash ,
272
363
Return : Clone ,
273
364
{
274
- pub fn internal_memoized_call < F > ( & self , args : Args , f : F ) -> Return
365
+ pub fn break_cycles_internal_memoized_call < F > (
366
+ & self ,
367
+ args : Args ,
368
+ f : F ,
369
+ unwinding_cycles : & Cell < u32 > ,
370
+ ) -> Option < Return >
275
371
where
276
372
F : FnOnce ( Args ) -> Return ,
277
373
{
278
374
if let Some ( return_value) = self . memoized . borrow ( ) . get ( & args) {
279
- return return_value. clone ( ) ;
375
+ return Some ( return_value. clone ( ) ) ;
280
376
}
281
- if self . active . borrow ( ) . contains ( & args) {
282
- panic ! ( "Cycle detected: a memoized function depends on its own return value" ) ;
377
+ if let Some ( found_cycle) = self . active . borrow_mut ( ) . get_mut ( & args) {
378
+ // We're in a cycle.
379
+ if * found_cycle == FoundCycle :: No {
380
+ // Only increase the count if we haven't hit this cycle before.
381
+ unwinding_cycles. set ( unwinding_cycles. get ( ) + 1 ) ;
382
+ }
383
+ * found_cycle = FoundCycle :: Yes ;
384
+ return None ;
283
385
}
284
- let args_cloned = args. clone ( ) ;
285
- self . active . borrow_mut ( ) . insert ( args_cloned) ;
386
+ self . active . borrow_mut ( ) . insert ( args. clone ( ) , FoundCycle :: No ) ;
286
387
let return_value = f ( args. clone ( ) ) ;
287
- self . active . borrow_mut ( ) . remove ( & args) ;
288
- let return_value_cloned = return_value. clone ( ) ;
289
- self . memoized . borrow_mut ( ) . insert ( args, return_value_cloned) ;
290
- return_value
388
+ let found_cycle = self
389
+ . active
390
+ . borrow_mut ( )
391
+ . remove ( & args)
392
+ . expect ( "This call frame inserted args and nobody removed them" ) ;
393
+
394
+ if found_cycle == FoundCycle :: Yes {
395
+ // We did hit outselves in a cycle but now we've broken out of it.
396
+ // If we hit ourselves multiple times, we were careful to only increment this
397
+ // count once.
398
+ unwinding_cycles. set ( unwinding_cycles. get ( ) - 1 ) ;
399
+ }
400
+ if unwinding_cycles. get ( ) == 0 {
401
+ // No cycles, we can safely cache the result knowing that we haven't depended on
402
+ // any cycle default values.
403
+ self . memoized . borrow_mut ( ) . insert ( args, return_value. clone ( ) ) ;
404
+ }
405
+ Some ( return_value)
291
406
}
292
407
}
293
408
}
294
409
295
410
#[ cfg( test) ]
296
411
pub mod tests {
297
412
use googletest:: prelude:: * ;
298
- use std:: cell:: Cell ;
413
+ use std:: cell:: { Cell , RefCell } ;
299
414
use std:: rc:: Rc ;
300
415
301
416
#[ gtest]
@@ -389,6 +504,111 @@ pub mod tests {
389
504
db. add10 ( 1 ) ;
390
505
}
391
506
507
+ #[ gtest]
508
+ fn test_break_cycles_with_option ( ) {
509
+ crate :: query_group! {
510
+ pub trait Add10 {
511
+ #[ break_cycles_with = None ]
512
+ fn add10( & self , arg: i32 ) -> Option <i32 >;
513
+ }
514
+ pub struct Database ;
515
+ }
516
+ fn add10 ( db : & dyn Add10 , arg : i32 ) -> Option < i32 > {
517
+ db. add10 ( arg)
518
+ }
519
+ let db = Database :: new ( ) ;
520
+ assert_eq ! ( db. add10( 1 ) , None ) ;
521
+ }
522
+
523
+ #[ gtest]
524
+ fn test_break_cycles_with_sentinel ( ) {
525
+ crate :: query_group! {
526
+ pub trait Add10 {
527
+ #[ break_cycles_with = -1 ]
528
+ fn add10( & self , arg: i32 ) -> i32 ;
529
+ }
530
+ pub struct Database ;
531
+ }
532
+ fn add10 ( db : & dyn Add10 , arg : i32 ) -> i32 {
533
+ db. add10 ( arg)
534
+ }
535
+ let db = Database :: new ( ) ;
536
+ assert_eq ! ( db. add10( 1 ) , -1 ) ;
537
+ }
538
+
539
+ #[ gtest]
540
+ fn test_calls_in_cycle_are_not_memoized ( ) {
541
+ crate :: query_group! {
542
+ pub trait Table {
543
+ #[ input]
544
+ fn logging( & self ) -> Rc <RefCell <Vec <String >>>;
545
+
546
+ #[ input]
547
+ fn records( & self ) -> & ' static [ Record ] ;
548
+
549
+ #[ break_cycles_with = false ]
550
+ fn is_unsafe( & self , name: & ' static str ) -> bool ;
551
+
552
+ fn record( & self , name: & ' static str ) -> Record ;
553
+ }
554
+ pub struct Database ;
555
+ }
556
+
557
+ #[ derive( Clone ) ]
558
+ struct Record {
559
+ name : & ' static str ,
560
+ is_unsafe : bool ,
561
+ fields : & ' static [ & ' static str ] ,
562
+ }
563
+
564
+ // Returns whether or not a record is unsafe, checking recursively.
565
+ fn is_unsafe ( db : & dyn Table , name : & ' static str ) -> bool {
566
+ let record = db. record ( name) ;
567
+ let outcome =
568
+ record. is_unsafe || record. fields . iter ( ) . any ( |& field| db. is_unsafe ( field) ) ;
569
+ db. logging ( ) . borrow_mut ( ) . push ( format ! ( "is_unsafe({name}) = {outcome}" ) ) ;
570
+ outcome
571
+ }
572
+
573
+ // Helper function so we can refer to records by name instead of by index.
574
+ fn record ( db : & dyn Table , name : & ' static str ) -> Record {
575
+ db. records ( )
576
+ . iter ( )
577
+ . find ( |record| record. name == name)
578
+ . expect ( "Record not found" )
579
+ . clone ( )
580
+ }
581
+
582
+ let logging = Rc :: default ( ) ;
583
+
584
+ let db = Database :: new (
585
+ Rc :: clone ( & logging) ,
586
+ & [
587
+ Record { name : "A" , is_unsafe : false , fields : & [ "B" , "Unsafe" ] } ,
588
+ Record { name : "B" , is_unsafe : false , fields : & [ "A" ] } ,
589
+ Record { name : "Unsafe" , is_unsafe : true , fields : & [ ] } ,
590
+ ] ,
591
+ ) ;
592
+ // When checking if A is unsafe, it will first ask B, which will try to ask A
593
+ // again, defaulting to false. So B says "I guess I'm safe", but _doesn't_
594
+ // memoize that result. A will then see that it has Unsafe which is unsafe, so A
595
+ // will memoize itself as unsafe. But when we go to ask B if it's unsafe now, it
596
+ // will have correctly _not_ memoized that it's safe, and so it will ask
597
+ // A again, which will again say "I am unsafe", and so B will correctly memoize
598
+ // that it's unsafe.
599
+ assert ! ( db. is_unsafe( "A" ) ) ;
600
+ assert ! ( db. is_unsafe( "B" ) ) ;
601
+ assert_eq ! (
602
+ logging. borrow( ) . clone( ) ,
603
+ vec![
604
+ "is_unsafe(B) = false" . to_string( ) , // this is the cycle-default value
605
+ "is_unsafe(Unsafe) = true" . to_string( ) ,
606
+ "is_unsafe(A) = true" . to_string( ) ,
607
+ "is_unsafe(B) = true" . to_string( ) , // as we can see, the default wasn't memoized
608
+ ]
609
+ ) ;
610
+ }
611
+
392
612
#[ gtest]
393
613
fn test_finite_recursion ( ) {
394
614
crate :: query_group! {
0 commit comments