Skip to content

Commit 12cda62

Browse files
Googlercopybara-github
authored andcommitted
Add #[break_cycles_with = ..] attribute to memoized.rs
This will allow for gracefully handling DFS-like memoized calls where cycles should safely be ignored. This is motivated by RsTypeKind::Record creation, where I want to traverse fields to see if they contain unsafe types like ptrs. PiperOrigin-RevId: 715527804 Change-Id: Ida465ba81b80fad0ba884cafadd3283efd2a24a1
1 parent 21d6cba commit 12cda62

File tree

1 file changed

+238
-18
lines changed

1 file changed

+238
-18
lines changed

common/memoized.rs

Lines changed: 238 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
//! * Immutable input and no support for recomputation given mutated inputs.
2525
//! * Correspondingly, no requirement that the *return* types implement `Eq` or
2626
//! `Hash`.
27+
//! * Supports `#[break_cycles_with = <default-value>]`, which generates a
28+
//! function that returns <default-value> if a cycle is detected.
2729
//!
2830
//! There are more substantial differences with Salsa 2022 - this was written
2931
//! based on Salsa 0.16. We don't need to match exactly the API, but the
@@ -75,6 +77,15 @@
7577
/// //
7678
/// // When called on the `&dyn QueryGroupName` or directly on the concrete type, the functions
7779
/// // will be memoized, and the return value will be cached, automatically.
80+
/// //
81+
/// // Some functions may need to gracefully handle cycles, in which case they should be
82+
/// // annotated with `#[break_cycles_with = <default_value>]`. This will generate a function
83+
/// // that returns <default_value> if a cycle is detected, but will _not_ cache the result.
84+
/// // All `#[break_cycles_with = ..]` functions must appear before all
85+
/// // non-`#[break_cycles_with = ..]` functions.
86+
/// #[break_cycles_with = ReturnType::default()]
87+
/// fn may_be_cyclic(&self, arg: ArgType) -> ReturnType;
88+
///
7889
/// fn some_function(&self, arg: ArgType) -> ReturnType;
7990
/// }
8091
/// // The concrete type for the storage of inputs and memoized values.
@@ -83,6 +94,10 @@
8394
/// }
8495
///
8596
/// // The non-memoized implementation of the memoized functions
97+
/// fn may_by_cyclic(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
98+
/// // ...
99+
/// }
100+
///
86101
/// fn some_function(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
87102
/// // ...
88103
/// }
@@ -152,6 +167,19 @@ macro_rules! query_group {
152167
#[input]
153168
fn $input_function:ident(&self $(,)?) -> $input_type:ty;
154169
)*
170+
$(
171+
// TODO(jeanpierreda): Ideally would like to preserve doc comments here, but it introduces a
172+
// parsing ambiguity with how the macro is currently structured.
173+
// $(#[doc = $break_cycles_doc:literal])*
174+
#[break_cycles_with = $break_cycles_default_value:expr]
175+
fn $break_cycles_function:ident(
176+
&self
177+
$(
178+
, $break_cycles_arg:ident : $break_cycles_arg_type:ty
179+
)*
180+
$(,)?
181+
) -> $break_cycles_return_type:ty;
182+
)*
155183
$(
156184
// TODO(jeanpierreda): Ideally would like to preserve doc comments here, but it introduces a
157185
// parsing ambiguity with how the macro is currently structured.
@@ -174,6 +202,14 @@ macro_rules! query_group {
174202
$(#[doc = $input_doc])*
175203
fn $input_function(&self) -> $input_type
176204
;)*
205+
$(
206+
fn $break_cycles_function(
207+
&self,
208+
$(
209+
$break_cycles_arg : $break_cycles_arg_type
210+
),*
211+
) -> $break_cycles_return_type
212+
;)*
177213
$(
178214
fn $function(
179215
&self,
@@ -186,9 +222,15 @@ macro_rules! query_group {
186222

187223
// Now we can generate a database struct that contains the lookup tables.
188224
$struct_vis struct $database_struct $(<$($type_param),*>)? {
225+
__unwinding_cycles: ::core::cell::Cell<u32>,
189226
$(
190227
$input_function: $input_type,
191228
)*
229+
$(
230+
// Note that we store $break_cycles_return_type here, not Option<$break_cycles_return_type>.
231+
// This is because we don't cache failed calls.
232+
$break_cycles_function: $crate::internal::MemoizationTable<($($break_cycles_arg_type,)*), $break_cycles_return_type>,
233+
)*
192234
$(
193235
$function: $crate::internal::MemoizationTable<($($arg_type,)*), $return_type>,
194236
)*
@@ -204,6 +246,25 @@ macro_rules! query_group {
204246
(&self.$input_function).clone()
205247
}
206248
)*
249+
$(
250+
fn $break_cycles_function(
251+
&self,
252+
$(
253+
$break_cycles_arg : $break_cycles_arg_type
254+
),*
255+
) -> $break_cycles_return_type {
256+
self.$break_cycles_function.break_cycles_internal_memoized_call(
257+
($(
258+
$break_cycles_arg,
259+
)*),
260+
|($($break_cycles_arg,)*)| {
261+
// Force the use of &dyn $trait, so that we don't rule out separate compilation later.
262+
$break_cycles_function(self as &dyn $trait, $($break_cycles_arg),*)
263+
},
264+
&self.__unwinding_cycles,
265+
).unwrap_or($break_cycles_default_value)
266+
}
267+
)*
207268
$(
208269
fn $function(
209270
&self,
@@ -215,11 +276,11 @@ macro_rules! query_group {
215276
($(
216277
$arg,
217278
)*),
218-
|args| {
219-
let ($($arg,)*) = args;
279+
|($($arg,)*)| {
220280
// Force the use of &dyn $trait, so that we don't rule out separate compilation later.
221281
$function(self as &dyn $trait, $($arg),*)
222-
}
282+
},
283+
&self.__unwinding_cycles,
223284
)
224285
}
225286
)*
@@ -228,9 +289,13 @@ macro_rules! query_group {
228289
impl $(<$($type_param),*>)? $database_struct $(<$($type_param),*>)? {
229290
$struct_vis fn new($($input_function: $input_type),*) -> Self {
230291
Self {
292+
__unwinding_cycles: ::core::cell::Cell::new(0),
231293
$(
232294
$input_function,
233295
)*
296+
$(
297+
$break_cycles_function: Default::default(),
298+
)*
234299
$(
235300
$function: Default::default(),
236301
)*
@@ -242,16 +307,23 @@ macro_rules! query_group {
242307

243308
#[doc(hidden)]
244309
pub mod internal {
245-
use std::cell::RefCell;
246-
use std::collections::{HashMap, HashSet};
310+
use std::cell::{Cell, RefCell};
311+
use std::collections::HashMap;
247312
use std::hash::Hash;
313+
314+
#[derive(Copy, Clone, PartialEq, Eq)]
315+
enum FoundCycle {
316+
No,
317+
Yes,
318+
}
319+
248320
pub struct MemoizationTable<Args, Return>
249321
where
250322
Args: Clone + Eq + Hash,
251323
Return: Clone,
252324
{
253325
memoized: RefCell<HashMap<Args, Return>>,
254-
active: RefCell<HashSet<Args>>,
326+
active: RefCell<HashMap<Args, FoundCycle>>,
255327
}
256328

257329
// Separate `impl` instead of `#[derive(Default)]` because the `derive` would
@@ -262,7 +334,26 @@ pub mod internal {
262334
Return: Clone,
263335
{
264336
fn default() -> Self {
265-
Self { memoized: RefCell::new(HashMap::new()), active: RefCell::new(HashSet::new()) }
337+
Self { memoized: RefCell::new(HashMap::new()), active: RefCell::new(HashMap::new()) }
338+
}
339+
}
340+
341+
impl<Args, Return> MemoizationTable<Args, Return>
342+
where
343+
Args: Clone + Eq + Hash,
344+
Return: Clone,
345+
{
346+
pub fn internal_memoized_call<F>(
347+
&self,
348+
args: Args,
349+
f: F,
350+
unwinding_cycles: &Cell<u32>,
351+
) -> Return
352+
where
353+
F: FnOnce(Args) -> Return,
354+
{
355+
self.break_cycles_internal_memoized_call(args, f, unwinding_cycles)
356+
.expect("Cycle detected: a memoized function depends on its own return value")
266357
}
267358
}
268359

@@ -271,31 +362,55 @@ pub mod internal {
271362
Args: Clone + Eq + Hash,
272363
Return: Clone,
273364
{
274-
pub fn internal_memoized_call<F>(&self, args: Args, f: F) -> Return
365+
pub fn break_cycles_internal_memoized_call<F>(
366+
&self,
367+
args: Args,
368+
f: F,
369+
unwinding_cycles: &Cell<u32>,
370+
) -> Option<Return>
275371
where
276372
F: FnOnce(Args) -> Return,
277373
{
278374
if let Some(return_value) = self.memoized.borrow().get(&args) {
279-
return return_value.clone();
375+
return Some(return_value.clone());
280376
}
281-
if self.active.borrow().contains(&args) {
282-
panic!("Cycle detected: a memoized function depends on its own return value");
377+
if let Some(found_cycle) = self.active.borrow_mut().get_mut(&args) {
378+
// We're in a cycle.
379+
if *found_cycle == FoundCycle::No {
380+
// Only increase the count if we haven't hit this cycle before.
381+
unwinding_cycles.set(unwinding_cycles.get() + 1);
382+
}
383+
*found_cycle = FoundCycle::Yes;
384+
return None;
283385
}
284-
let args_cloned = args.clone();
285-
self.active.borrow_mut().insert(args_cloned);
386+
self.active.borrow_mut().insert(args.clone(), FoundCycle::No);
286387
let return_value = f(args.clone());
287-
self.active.borrow_mut().remove(&args);
288-
let return_value_cloned = return_value.clone();
289-
self.memoized.borrow_mut().insert(args, return_value_cloned);
290-
return_value
388+
let found_cycle = self
389+
.active
390+
.borrow_mut()
391+
.remove(&args)
392+
.expect("This call frame inserted args and nobody removed them");
393+
394+
if found_cycle == FoundCycle::Yes {
395+
// We did hit outselves in a cycle but now we've broken out of it.
396+
// If we hit ourselves multiple times, we were careful to only increment this
397+
// count once.
398+
unwinding_cycles.set(unwinding_cycles.get() - 1);
399+
}
400+
if unwinding_cycles.get() == 0 {
401+
// No cycles, we can safely cache the result knowing that we haven't depended on
402+
// any cycle default values.
403+
self.memoized.borrow_mut().insert(args, return_value.clone());
404+
}
405+
Some(return_value)
291406
}
292407
}
293408
}
294409

295410
#[cfg(test)]
296411
pub mod tests {
297412
use googletest::prelude::*;
298-
use std::cell::Cell;
413+
use std::cell::{Cell, RefCell};
299414
use std::rc::Rc;
300415

301416
#[gtest]
@@ -389,6 +504,111 @@ pub mod tests {
389504
db.add10(1);
390505
}
391506

507+
#[gtest]
508+
fn test_break_cycles_with_option() {
509+
crate::query_group! {
510+
pub trait Add10 {
511+
#[break_cycles_with = None]
512+
fn add10(&self, arg: i32) -> Option<i32>;
513+
}
514+
pub struct Database;
515+
}
516+
fn add10(db: &dyn Add10, arg: i32) -> Option<i32> {
517+
db.add10(arg)
518+
}
519+
let db = Database::new();
520+
assert_eq!(db.add10(1), None);
521+
}
522+
523+
#[gtest]
524+
fn test_break_cycles_with_sentinel() {
525+
crate::query_group! {
526+
pub trait Add10 {
527+
#[break_cycles_with = -1]
528+
fn add10(&self, arg: i32) -> i32;
529+
}
530+
pub struct Database;
531+
}
532+
fn add10(db: &dyn Add10, arg: i32) -> i32 {
533+
db.add10(arg)
534+
}
535+
let db = Database::new();
536+
assert_eq!(db.add10(1), -1);
537+
}
538+
539+
#[gtest]
540+
fn test_calls_in_cycle_are_not_memoized() {
541+
crate::query_group! {
542+
pub trait Table {
543+
#[input]
544+
fn logging(&self) -> Rc<RefCell<Vec<String>>>;
545+
546+
#[input]
547+
fn records(&self) -> &'static [Record];
548+
549+
#[break_cycles_with = false]
550+
fn is_unsafe(&self, name: &'static str) -> bool;
551+
552+
fn record(&self, name: &'static str) -> Record;
553+
}
554+
pub struct Database;
555+
}
556+
557+
#[derive(Clone)]
558+
struct Record {
559+
name: &'static str,
560+
is_unsafe: bool,
561+
fields: &'static [&'static str],
562+
}
563+
564+
// Returns whether or not a record is unsafe, checking recursively.
565+
fn is_unsafe(db: &dyn Table, name: &'static str) -> bool {
566+
let record = db.record(name);
567+
let outcome =
568+
record.is_unsafe || record.fields.iter().any(|&field| db.is_unsafe(field));
569+
db.logging().borrow_mut().push(format!("is_unsafe({name}) = {outcome}"));
570+
outcome
571+
}
572+
573+
// Helper function so we can refer to records by name instead of by index.
574+
fn record(db: &dyn Table, name: &'static str) -> Record {
575+
db.records()
576+
.iter()
577+
.find(|record| record.name == name)
578+
.expect("Record not found")
579+
.clone()
580+
}
581+
582+
let logging = Rc::default();
583+
584+
let db = Database::new(
585+
Rc::clone(&logging),
586+
&[
587+
Record { name: "A", is_unsafe: false, fields: &["B", "Unsafe"] },
588+
Record { name: "B", is_unsafe: false, fields: &["A"] },
589+
Record { name: "Unsafe", is_unsafe: true, fields: &[] },
590+
],
591+
);
592+
// When checking if A is unsafe, it will first ask B, which will try to ask A
593+
// again, defaulting to false. So B says "I guess I'm safe", but _doesn't_
594+
// memoize that result. A will then see that it has Unsafe which is unsafe, so A
595+
// will memoize itself as unsafe. But when we go to ask B if it's unsafe now, it
596+
// will have correctly _not_ memoized that it's safe, and so it will ask
597+
// A again, which will again say "I am unsafe", and so B will correctly memoize
598+
// that it's unsafe.
599+
assert!(db.is_unsafe("A"));
600+
assert!(db.is_unsafe("B"));
601+
assert_eq!(
602+
logging.borrow().clone(),
603+
vec![
604+
"is_unsafe(B) = false".to_string(), // this is the cycle-default value
605+
"is_unsafe(Unsafe) = true".to_string(),
606+
"is_unsafe(A) = true".to_string(),
607+
"is_unsafe(B) = true".to_string(), // as we can see, the default wasn't memoized
608+
]
609+
);
610+
}
611+
392612
#[gtest]
393613
fn test_finite_recursion() {
394614
crate::query_group! {

0 commit comments

Comments
 (0)