Skip to content

Commit 470a69f

Browse files
committed
Propagate constraints
1 parent cede211 commit 470a69f

File tree

8 files changed

+103
-106
lines changed

8 files changed

+103
-106
lines changed

macros/src/track.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ fn create(
226226
#(#wrapper_methods)*
227227
}
228228

229-
#[derive(Default)]
229+
#[derive(Debug, Default)]
230230
pub struct #constraint {
231231
#(#constraint_fields)*
232232
}
@@ -260,6 +260,7 @@ fn create_wrapper(method: &Method) -> TokenStream {
260260
let name = &method.sig.ident;
261261
let args = &method.args;
262262
quote! {
263+
#[track_caller]
263264
#vis #sig {
264265
let input = (#(#args.to_owned(),)*);
265266
let (value, constraint) = ::comemo::internal::to_parts(self.0);

src/cache.rs

Lines changed: 27 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,6 @@ thread_local! {
1414
static CACHE: RefCell<Cache> = RefCell::new(Cache::default());
1515
}
1616

17-
/// Configure the caching and eviction behaviour.
18-
pub fn config(config: Config) {
19-
CACHE.with(|cache| cache.borrow_mut().config = config);
20-
}
21-
22-
/// Configuration for caching and eviction behaviour.
23-
pub struct Config {
24-
max_age: u32,
25-
}
26-
27-
impl Config {
28-
/// The maximum number of evictions an entry can survive without having been
29-
/// used in between.
30-
///
31-
/// Default: 5
32-
pub fn max_age(mut self, age: u32) -> Self {
33-
self.max_age = age;
34-
self
35-
}
36-
}
37-
38-
impl Default for Config {
39-
fn default() -> Self {
40-
Self { max_age: 5 }
41-
}
42-
}
43-
44-
/// Evict cache entries that haven't been used in a while.
45-
///
46-
/// The eviction behaviour can be customized with the [`config`] function.
47-
/// Currently, comemo does not evict the cache automatically (this might
48-
/// change in the future).
49-
pub fn evict() {
50-
CACHE.with(|cache| {
51-
let mut cache = cache.borrow_mut();
52-
let max = cache.config.max_age;
53-
cache.map.retain(|_, entries| {
54-
entries.retain_mut(|entry| {
55-
entry.age += 1;
56-
entry.age <= max
57-
});
58-
!entries.is_empty()
59-
});
60-
});
61-
}
62-
6317
/// Execute a function or use a cached result for it.
6418
pub fn memoized<In, Out, F>(id: TypeId, input: In, func: F) -> Out
6519
where
@@ -76,55 +30,57 @@ where
7630
(id, hash)
7731
};
7832

79-
// Check if there is a cached output.
80-
let mut borrow = cache.borrow_mut();
81-
if let Some(output) = borrow.lookup::<In, Out>(key, &input) {
82-
return output;
83-
}
84-
85-
borrow.depth += 1;
86-
drop(borrow);
87-
8833
// Point all tracked parts of the input to these constraints.
8934
let constraint = In::Constraint::default();
90-
let (tracked, outer) = input.retrack(&constraint);
35+
36+
// Check if there is a cached output.
37+
if let Some(constrained) = cache.borrow().lookup::<In, Out>(key, &input) {
38+
// Add the cached constraints to the outer ones.
39+
let (_, outer) = input.retrack(&constraint);
40+
outer.join(&constrained.constraint);
41+
return constrained.output.clone();
42+
}
9143

9244
// Execute the function with the new constraints hooked in.
93-
let output = func(tracked);
45+
let (input, outer) = input.retrack(&constraint);
46+
let output = func(input);
9447

95-
// Add the new constraints to the previous outer ones.
48+
// Add the new constraints to the outer ones.
9649
outer.join(&constraint);
9750

9851
// Insert the result into the cache.
99-
borrow = cache.borrow_mut();
100-
borrow.insert::<In, Out>(key, constraint, output.clone());
101-
borrow.depth -= 1;
52+
cache.borrow_mut().insert::<In, Out>(key, constraint, output.clone());
10253

10354
output
10455
})
10556
}
10657

58+
/// Completely clear the cache.
59+
pub fn clear() {
60+
CACHE.with(|cache| cache.borrow_mut().map.clear());
61+
}
62+
10763
/// The global cache.
10864
#[derive(Default)]
10965
struct Cache {
11066
/// Maps from function IDs + hashes to memoized results.
11167
map: HashMap<(TypeId, u128), Vec<Entry>>,
112-
/// The current depth of the memoized call stack.
113-
depth: usize,
114-
/// The current configuration.
115-
config: Config,
11668
}
11769

11870
impl Cache {
11971
/// Look for a matching entry in the cache.
120-
fn lookup<In, Out>(&mut self, key: (TypeId, u128), input: &In) -> Option<Out>
72+
fn lookup<In, Out>(
73+
&self,
74+
key: (TypeId, u128),
75+
input: &In,
76+
) -> Option<&Constrained<In::Constraint, Out>>
12177
where
12278
In: Input,
12379
Out: Clone + 'static,
12480
{
12581
self.map
126-
.get_mut(&key)?
127-
.iter_mut()
82+
.get(&key)?
83+
.iter()
12884
.find_map(|entry| entry.lookup::<In, Out>(input))
12985
}
13086

@@ -151,8 +107,6 @@ struct Entry {
151107
///
152108
/// This is of type `Constrained<In::Constraint, Out>`.
153109
constrained: Box<dyn Any>,
154-
/// How many evictions have passed since the entry has been last used.
155-
age: u32,
156110
}
157111

158112
/// A value with a constraint.
@@ -172,22 +126,18 @@ impl Entry {
172126
{
173127
Self {
174128
constrained: Box::new(Constrained { constraint, output }),
175-
age: 0,
176129
}
177130
}
178131

179132
/// Return the entry's output if it is valid for the given input.
180-
fn lookup<In, Out>(&mut self, input: &In) -> Option<Out>
133+
fn lookup<In, Out>(&self, input: &In) -> Option<&Constrained<In::Constraint, Out>>
181134
where
182135
In: Input,
183136
Out: Clone + 'static,
184137
{
185-
let Constrained::<In::Constraint, Out> { constraint, output } =
138+
let constrained: &Constrained<In::Constraint, Out> =
186139
self.constrained.downcast_ref().expect("wrong entry type");
187140

188-
input.valid(constraint).then(|| {
189-
self.age = 0;
190-
output.clone()
191-
})
141+
input.valid(&constrained.constraint).then(|| constrained)
192142
}
193143
}

src/constraint.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::cell::{Cell, RefCell};
2+
use std::fmt::Debug;
23
use std::hash::Hash;
34

45
use siphasher::sip128::{Hasher128, SipHasher};
@@ -10,6 +11,7 @@ pub trait Join<T = Self> {
1011
}
1112

1213
impl<T: Join> Join<T> for Option<&T> {
14+
#[inline]
1315
fn join(&self, inner: &T) {
1416
if let Some(outer) = self {
1517
outer.join(inner);
@@ -18,18 +20,28 @@ impl<T: Join> Join<T> for Option<&T> {
1820
}
1921

2022
/// Defines a constraint for a tracked method without arguments.
21-
#[derive(Default)]
23+
#[derive(Debug, Default)]
2224
pub struct SoloConstraint {
2325
cell: Cell<Option<u128>>,
2426
}
2527

2628
impl SoloConstraint {
2729
/// Set the constraint for the value.
30+
#[inline]
31+
#[track_caller]
2832
pub fn set(&self, _: (), hash: u128) {
29-
self.cell.set(Some(hash));
33+
// If there's already a constraint, it must match.
34+
// This assertion can fail if a tracked function isn't pure
35+
// (which violates comemo's contract).
36+
if let Some(existing) = self.cell.get() {
37+
check(hash, existing);
38+
} else {
39+
self.cell.set(Some(hash));
40+
}
3041
}
3142

3243
/// Whether the value fulfills the constraint.
44+
#[inline]
3345
pub fn valid<F>(&self, f: F) -> bool
3446
where
3547
F: Fn(()) -> u128,
@@ -39,15 +51,16 @@ impl SoloConstraint {
3951
}
4052

4153
impl Join for SoloConstraint {
54+
#[inline]
4255
fn join(&self, inner: &Self) {
43-
let inner = inner.cell.get();
44-
if inner.is_some() {
45-
self.cell.set(inner);
56+
if let Some(hash) = inner.cell.get() {
57+
self.set((), hash);
4658
}
4759
}
4860
}
4961

5062
/// Defines a constraint for a tracked method with arguments.
63+
#[derive(Debug)]
5164
pub struct MultiConstraint<In> {
5265
calls: RefCell<Vec<(In, u128)>>,
5366
}
@@ -57,45 +70,65 @@ where
5770
In: Clone + PartialEq,
5871
{
5972
/// Enter a constraint for a pair of inputs and output.
73+
#[inline]
74+
#[track_caller]
6075
pub fn set(&self, input: In, hash: u128) {
6176
let mut calls = self.calls.borrow_mut();
62-
if calls.iter().all(|item| item.0 != input) {
77+
if let Some(item) = calls.iter().find(|item| item.0 == input) {
78+
check(item.1, hash);
79+
} else {
6380
calls.push((input, hash));
6481
}
6582
}
6683

6784
/// Whether the method satisfies as all input-output pairs.
85+
#[inline]
6886
pub fn valid<F>(&self, f: F) -> bool
6987
where
7088
F: Fn(&In) -> u128,
7189
{
72-
let calls = self.calls.borrow();
73-
calls.iter().all(|(input, hash)| *hash == f(input))
90+
self.calls.borrow().iter().all(|(input, hash)| *hash == f(input))
7491
}
7592
}
7693

7794
impl<In> Join for MultiConstraint<In>
7895
where
79-
In: Clone + PartialEq,
96+
In: Debug + Clone + PartialEq,
8097
{
98+
#[inline]
8199
fn join(&self, inner: &Self) {
82100
let mut calls = self.calls.borrow_mut();
83-
let inner = inner.calls.borrow();
84-
for (input, hash) in inner.iter() {
85-
if calls.iter().all(|item| &item.0 != input) {
101+
for (input, hash) in inner.calls.borrow().iter() {
102+
if let Some(item) = calls.iter().find(|item| &item.0 == input) {
103+
check(item.1, *hash);
104+
} else {
86105
calls.push((input.clone(), *hash));
87106
}
88107
}
89108
}
90109
}
91110

92111
impl<In> Default for MultiConstraint<In> {
112+
#[inline]
93113
fn default() -> Self {
94114
Self { calls: RefCell::new(vec![]) }
95115
}
96116
}
97117

118+
/// Check for a constraint violation.
119+
#[inline]
120+
#[track_caller]
121+
fn check(left: u128, right: u128) {
122+
if left != right {
123+
panic!(
124+
"comemo: found conflicting constraints. \
125+
is this tracked function pure?"
126+
)
127+
}
128+
}
129+
98130
/// Produce a 128-bit hash of a value.
131+
#[inline]
99132
pub fn hash<T: Hash>(value: &T) -> u128 {
100133
let mut state = SipHasher::new();
101134
value.hash(&mut state);

0 commit comments

Comments
 (0)