|
1 | 1 | use std::any::{Any, TypeId};
|
2 |
| -use std::cell::{Cell, RefCell}; |
| 2 | +use std::cell::RefCell; |
| 3 | +use std::collections::HashMap; |
3 | 4 | use std::fmt::Debug;
|
4 |
| -use std::hash::Hash; |
5 | 5 |
|
6 | 6 | use siphasher::sip128::{Hasher128, SipHasher};
|
7 | 7 |
|
8 | 8 | use crate::constraint::Join;
|
9 | 9 | use crate::input::Input;
|
10 | 10 | use crate::internal::Family;
|
11 | 11 |
|
| 12 | +thread_local! { |
| 13 | + /// The global, dynamic cache shared by all memoized functions. |
| 14 | + static CACHE: RefCell<Cache> = RefCell::new(Cache::default()); |
| 15 | +} |
| 16 | + |
12 | 17 | /// Execute a function or use a cached result for it.
|
13 |
| -pub fn memoized<In, Out, F>(name: &'static str, unique: TypeId, input: In, func: F) -> Out |
| 18 | +pub fn memoized<In, Out, F>(name: &'static str, id: TypeId, input: In, func: F) -> Out |
14 | 19 | where
|
15 | 20 | In: Input,
|
16 | 21 | Out: Debug + Clone + 'static,
|
17 | 22 | F: for<'f> FnOnce(<In::Tracked as Family<'f>>::Out) -> Out,
|
18 | 23 | {
|
19 |
| - // Compute the hash of the input's key part. |
20 |
| - let hash = { |
21 |
| - let mut state = SipHasher::new(); |
22 |
| - unique.hash(&mut state); |
23 |
| - input.key(&mut state); |
24 |
| - state.finish128().as_u128() |
25 |
| - }; |
26 |
| - |
27 |
| - let mut hit = true; |
28 |
| - let output = CACHE.with(|cache| { |
29 |
| - cache.lookup::<In, Out>(hash, &input).unwrap_or_else(|| { |
30 |
| - DEPTH.with(|v| v.set(v.get() + 1)); |
31 |
| - let constraint = In::Constraint::default(); |
32 |
| - let (tracked, outer) = input.retrack(&constraint); |
33 |
| - let output = func(tracked); |
34 |
| - outer.join(&constraint); |
35 |
| - cache.insert::<In, Out>(hash, Constrained { |
36 |
| - output: output.clone(), |
37 |
| - constraint, |
38 |
| - }); |
39 |
| - hit = false; |
40 |
| - DEPTH.with(|v| v.set(v.get() - 1)); |
41 |
| - output |
42 |
| - }) |
43 |
| - }); |
| 24 | + CACHE.with(|cache| { |
| 25 | + // Compute the hash of the input's key part. |
| 26 | + let key = { |
| 27 | + let mut state = SipHasher::new(); |
| 28 | + input.key(&mut state); |
| 29 | + let hash = state.finish128().as_u128(); |
| 30 | + (id, hash) |
| 31 | + }; |
| 32 | + |
| 33 | + let mut hit = true; |
| 34 | + let mut borrowed = cache.borrow_mut(); |
44 | 35 |
|
45 |
| - let depth = DEPTH.with(|v| v.get()); |
46 |
| - let label = if hit { "[hit]" } else { "[miss]" }; |
47 |
| - eprintln!("{depth} {name:<12} {label:<7} {output:?}"); |
| 36 | + // Check whether there is a cached entry. |
| 37 | + let output = match borrowed.lookup::<In, Out>(key, &input) { |
| 38 | + Some(output) => output, |
| 39 | + None => { |
| 40 | + hit = false; |
| 41 | + borrowed.depth += 1; |
| 42 | + drop(borrowed); |
48 | 43 |
|
49 |
| - output |
| 44 | + // Point all tracked parts of the input to these constraints. |
| 45 | + let constraint = In::Constraint::default(); |
| 46 | + let (tracked, outer) = input.retrack(&constraint); |
| 47 | + |
| 48 | + // Execute the function with the new constraints hooked in. |
| 49 | + let output = func(tracked); |
| 50 | + |
| 51 | + // Add the new constraints to the previous outer ones. |
| 52 | + outer.join(&constraint); |
| 53 | + |
| 54 | + // Insert the result into the cache. |
| 55 | + borrowed = cache.borrow_mut(); |
| 56 | + borrowed.insert::<In, Out>(key, constraint, output.clone()); |
| 57 | + borrowed.depth -= 1; |
| 58 | + |
| 59 | + output |
| 60 | + } |
| 61 | + }; |
| 62 | + |
| 63 | + // Print details. |
| 64 | + let depth = borrowed.depth; |
| 65 | + let label = if hit { "[hit]" } else { "[miss]" }; |
| 66 | + eprintln!("{depth} {name:<12} {label:<7} {output:?}"); |
| 67 | + |
| 68 | + output |
| 69 | + }) |
50 | 70 | }
|
51 | 71 |
|
52 |
| -thread_local! { |
53 |
| - /// The global, dynamic cache shared by all memoized functions. |
54 |
| - static CACHE: Cache = Cache::default(); |
| 72 | +/// Configure the caching behaviour. |
| 73 | +pub fn config(config: Config) { |
| 74 | + CACHE.with(|cache| cache.borrow_mut().config = config); |
| 75 | +} |
55 | 76 |
|
56 |
| - /// The current depth of the memoized call stack. |
57 |
| - static DEPTH: Cell<usize> = Cell::new(0); |
| 77 | +/// Configuration for caching behaviour. |
| 78 | +pub struct Config { |
| 79 | + max_age: u32, |
58 | 80 | }
|
59 | 81 |
|
60 |
| -/// An untyped cache. |
61 |
| -#[derive(Default)] |
62 |
| -struct Cache { |
63 |
| - map: RefCell<Vec<Entry>>, |
| 82 | +impl Config { |
| 83 | + /// The maximum number of evictions an entry can survive without having been |
| 84 | + /// used in between. |
| 85 | + pub fn max_age(mut self, age: u32) -> Self { |
| 86 | + self.max_age = age; |
| 87 | + self |
| 88 | + } |
64 | 89 | }
|
65 | 90 |
|
66 |
| -/// An entry in the cache. |
67 |
| -struct Entry { |
68 |
| - hash: u128, |
69 |
| - constrained: Box<dyn Any>, |
| 91 | +impl Default for Config { |
| 92 | + fn default() -> Self { |
| 93 | + Self { max_age: 5 } |
| 94 | + } |
70 | 95 | }
|
71 | 96 |
|
72 |
| -/// A value with a constraint. |
73 |
| -struct Constrained<T, C> { |
74 |
| - output: T, |
75 |
| - constraint: C, |
| 97 | +/// Evict cache entries that haven't been used in a while. |
| 98 | +pub fn evict() { |
| 99 | + CACHE.with(|cache| { |
| 100 | + let mut cache = cache.borrow_mut(); |
| 101 | + let max = cache.config.max_age; |
| 102 | + cache.map.retain(|_, entries| { |
| 103 | + entries.retain_mut(|entry| { |
| 104 | + entry.age += 1; |
| 105 | + entry.age <= max |
| 106 | + }); |
| 107 | + !entries.is_empty() |
| 108 | + }); |
| 109 | + }); |
| 110 | +} |
| 111 | + |
| 112 | +/// The global cache. |
| 113 | +#[derive(Default)] |
| 114 | +struct Cache { |
| 115 | + /// Maps from function IDs + hashes to memoized results. |
| 116 | + map: HashMap<(TypeId, u128), Vec<Entry>>, |
| 117 | + /// The current depth of the memoized call stack. |
| 118 | + depth: usize, |
| 119 | + /// The current configuration. |
| 120 | + config: Config, |
76 | 121 | }
|
77 | 122 |
|
78 | 123 | impl Cache {
|
79 | 124 | /// Look for a matching entry in the cache.
|
80 |
| - fn lookup<In, Out>(&self, hash: u128, input: &In) -> Option<Out> |
| 125 | + fn lookup<In, Out>(&mut self, key: (TypeId, u128), input: &In) -> Option<Out> |
81 | 126 | where
|
82 | 127 | In: Input,
|
83 | 128 | Out: Clone + 'static,
|
84 | 129 | {
|
85 | 130 | self.map
|
86 |
| - .borrow() |
87 |
| - .iter() |
88 |
| - .filter(|entry| entry.hash == hash) |
89 |
| - .map(|entry| { |
90 |
| - entry |
91 |
| - .constrained |
92 |
| - .downcast_ref::<Constrained<Out, In::Constraint>>() |
93 |
| - .expect("comemo: a hash collision occurred") |
94 |
| - }) |
95 |
| - .find(|output| input.valid(&output.constraint)) |
96 |
| - .map(|output| output.output.clone()) |
| 131 | + .get_mut(&key)? |
| 132 | + .iter_mut() |
| 133 | + .find_map(|entry| entry.lookup::<In, Out>(input)) |
97 | 134 | }
|
98 | 135 |
|
99 | 136 | /// Insert an entry into the cache.
|
100 |
| - fn insert<In, Out>(&self, hash: u128, constrained: Constrained<Out, In::Constraint>) |
| 137 | + fn insert<In, Out>( |
| 138 | + &mut self, |
| 139 | + key: (TypeId, u128), |
| 140 | + constraint: In::Constraint, |
| 141 | + output: Out, |
| 142 | + ) where |
| 143 | + In: Input, |
| 144 | + Out: 'static, |
| 145 | + { |
| 146 | + self.map |
| 147 | + .entry(key) |
| 148 | + .or_default() |
| 149 | + .push(Entry::new::<In, Out>(constraint, output)); |
| 150 | + } |
| 151 | +} |
| 152 | + |
| 153 | +/// A memoized result. |
| 154 | +struct Entry { |
| 155 | + /// The memoized function's constrained output. |
| 156 | + /// |
| 157 | + /// This is of type `Constrained<In::Constraint, Out>`. |
| 158 | + constrained: Box<dyn Any>, |
| 159 | + /// How many evictions have passed since the entry has been last used. |
| 160 | + age: u32, |
| 161 | +} |
| 162 | + |
| 163 | +/// A value with a constraint. |
| 164 | +struct Constrained<C, T> { |
| 165 | + /// The constraint which must be fulfilled for the output to be used. |
| 166 | + constraint: C, |
| 167 | + /// The memoized function's output. |
| 168 | + output: T, |
| 169 | +} |
| 170 | + |
| 171 | +impl Entry { |
| 172 | + /// Create a new entry. |
| 173 | + fn new<In, Out>(constraint: In::Constraint, output: Out) -> Self |
101 | 174 | where
|
102 | 175 | In: Input,
|
103 | 176 | Out: 'static,
|
104 | 177 | {
|
105 |
| - let entry = Entry { hash, constrained: Box::new(constrained) }; |
106 |
| - self.map.borrow_mut().push(entry); |
| 178 | + Self { |
| 179 | + constrained: Box::new(Constrained { constraint, output }), |
| 180 | + age: 0, |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + /// Return the entry's output if it is valid for the given input. |
| 185 | + fn lookup<In, Out>(&mut self, input: &In) -> Option<Out> |
| 186 | + where |
| 187 | + In: Input, |
| 188 | + Out: Clone + 'static, |
| 189 | + { |
| 190 | + let Constrained::<In::Constraint, Out> { constraint, output } = |
| 191 | + self.constrained.downcast_ref().expect("wrong entry type"); |
| 192 | + input.valid(constraint).then(|| { |
| 193 | + self.age = 0; |
| 194 | + output.clone() |
| 195 | + }) |
107 | 196 | }
|
108 | 197 | }
|
0 commit comments