Skip to content

Commit 82aa013

Browse files
committed
Allow tracked or hashed argument
1 parent 1d78ec3 commit 82aa013

File tree

7 files changed

+139
-79
lines changed

7 files changed

+139
-79
lines changed

macros/src/memoize.rs

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@ use super::*;
22

33
/// Memoize a function.
44
pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
5-
let name = &func.sig.ident;
6-
75
let mut args = vec![];
8-
let mut types = vec![];
96
for input in &func.sig.inputs {
107
let typed = match input {
118
syn::FnArg::Typed(typed) => typed,
@@ -14,47 +11,26 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1411
}
1512
};
1613

17-
let ident = match &*typed.pat {
14+
let name = match &*typed.pat {
1815
syn::Pat::Ident(ident) => ident,
1916
_ => bail!(typed.pat, "only simple identifiers are supported"),
2017
};
2118

22-
args.push(ident);
23-
types.push(typed.ty.as_ref());
19+
args.push(name);
2420
}
2521

26-
if args.len() != 1 {
27-
bail!(func, "expected exactly one argument");
28-
}
29-
30-
let arg = args[0];
31-
let ty = types[0];
32-
let _inner = match ty {
33-
syn::Type::Path(path) => {
34-
let segs = &path.path.segments;
35-
if segs.len() != 1 {
36-
bail!(ty, "expected exactly one path segment")
37-
}
38-
let args = match &segs[0].arguments {
39-
syn::PathArguments::AngleBracketed(args) => &args.args,
40-
_ => bail!(ty, "expected `Tracked<_>` type"),
41-
};
42-
if args.len() != 1 {
43-
bail!(args, "expected exactly one generic argument")
44-
}
45-
match &args[0] {
46-
syn::GenericArgument::Type(ty) => ty,
47-
ty => bail!(ty, "expected type argument"),
48-
}
49-
}
50-
_ => bail!(ty, "expected type of the form `Tracked<_>`"),
51-
};
52-
5322
let mut outer = func.clone();
23+
let name = &func.sig.ident;
24+
let arg = &args[0];
25+
5426
outer.block = parse_quote! { {
5527
#func
5628
::comemo::internal::CACHE.with(|cache|
57-
cache.query(stringify!(#name), #name, #arg)
29+
cache.query(
30+
stringify!(#name),
31+
#arg,
32+
#name,
33+
)
5834
)
5935
} };
6036

macros/src/track.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub fn expand(block: &syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
5252
type Surface = SurfaceFamily;
5353

5454
fn valid(&self, constraint: &Self::Constraint) -> bool {
55-
#(#tracked_valids)&&*
55+
true #(&& #tracked_valids)*
5656
}
5757

5858
fn surface<'a, 'r>(tracked: &'r Tracked<'a, #ty>) -> &'r Surface<'a>

src/cache.rs

Lines changed: 115 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
use std::any::Any;
22
use std::cell::RefCell;
33
use std::fmt::Debug;
4+
use std::hash::{Hash, Hasher};
5+
use std::marker::PhantomData;
6+
7+
use siphasher::sip128::{Hasher128, SipHasher};
48

59
use crate::track::{from_parts, to_parts, Track, Trackable, Tracked};
610

@@ -12,32 +16,45 @@ thread_local! {
1216
/// An untyped cache.
1317
#[derive(Default)]
1418
pub struct Cache {
15-
map: RefCell<Vec<Box<dyn Any>>>,
19+
map: RefCell<Vec<Entry>>,
1620
}
1721

1822
/// An entry in the cache.
19-
struct Entry<C, R> {
23+
struct Entry {
24+
hash: u128,
25+
output: Box<dyn Any>,
26+
}
27+
28+
/// A value with a constraint.
29+
struct Constrained<T, C> {
30+
value: T,
2031
constraint: C,
21-
output: R,
2232
}
2333

2434
impl Cache {
2535
/// Execute `f` or use a cached result for it.
26-
pub fn query<F, T, R>(&self, name: &'static str, f: F, tracked: Tracked<T>) -> R
36+
pub fn query<In, Out, F>(&self, name: &str, input: In, func: F) -> Out
2737
where
28-
F: Fn(Tracked<T>) -> R,
29-
T: Track,
30-
R: Debug + Clone + 'static,
38+
In: Input,
39+
Out: Debug + Clone + 'static,
40+
F: for<'f> Fn(<In::Focus as Family<'f>>::Out) -> Out,
3141
{
42+
// Compute the hash of the input's key part.
43+
let hash = {
44+
let mut state = SipHasher::new();
45+
input.hash(&mut state);
46+
state.finish128().as_u128()
47+
};
48+
3249
let mut hit = true;
33-
let output = self.lookup::<T, R>(tracked).unwrap_or_else(|| {
34-
let constraint = T::Constraint::default();
35-
let (inner, _) = to_parts(tracked);
36-
let tracked = from_parts(inner, Some(&constraint));
37-
let output = f(tracked);
38-
self.insert::<T, R>(constraint, output.clone());
50+
let output = self.lookup::<In, Out>(hash, &input).unwrap_or_else(|| {
51+
let constraint = In::Constraint::default();
52+
let input = input.focus(&constraint);
53+
let value = func(input);
54+
let constrained = Constrained { value: value.clone(), constraint };
55+
self.insert::<In, Out>(hash, constrained);
3956
hit = false;
40-
output
57+
value
4158
});
4259

4360
let label = if hit { "[hit]" } else { "[miss]" };
@@ -47,27 +64,99 @@ impl Cache {
4764
}
4865

4966
/// Look for a matching entry in the cache.
50-
fn lookup<T, R>(&self, tracked: Tracked<T>) -> Option<R>
67+
fn lookup<In, Out>(&self, hash: u128, input: &In) -> Option<Out>
5168
where
52-
T: Track,
53-
R: Clone + 'static,
69+
In: Input,
70+
Out: Clone + 'static,
5471
{
55-
let (inner, _) = to_parts(tracked);
5672
self.map
5773
.borrow()
5874
.iter()
59-
.filter_map(|boxed| boxed.downcast_ref::<Entry<T::Constraint, R>>())
60-
.find(|entry| Trackable::valid(inner, &entry.constraint))
61-
.map(|entry| entry.output.clone())
75+
.filter(|entry| entry.hash == hash)
76+
.map(|entry| {
77+
entry
78+
.output
79+
.downcast_ref::<Constrained<Out, In::Constraint>>()
80+
.expect("comemo: hash collision")
81+
})
82+
.find(|output| input.valid(&output.constraint))
83+
.map(|output| output.value.clone())
6284
}
6385

6486
/// Insert an entry into the cache.
65-
fn insert<T, R>(&self, constraint: T::Constraint, output: R)
87+
fn insert<In, Out>(&self, hash: u128, output: Constrained<Out, In::Constraint>)
6688
where
67-
T: Track,
68-
R: 'static,
89+
In: Input,
90+
Out: 'static,
6991
{
70-
let entry = Entry { constraint, output };
71-
self.map.borrow_mut().push(Box::new(entry));
92+
let entry = Entry { hash, output: Box::new(output) };
93+
self.map.borrow_mut().push(entry);
7294
}
7395
}
96+
97+
pub trait Input {
98+
type Constraint: Default + 'static;
99+
type Focus: for<'f> Family<'f>;
100+
101+
fn hash<H: Hasher>(&self, state: &mut H);
102+
fn valid(&self, constraint: &Self::Constraint) -> bool;
103+
fn focus<'f>(
104+
self,
105+
constraint: &'f Self::Constraint,
106+
) -> <Self::Focus as Family<'f>>::Out
107+
where
108+
Self: 'f;
109+
}
110+
111+
impl<T: Hash> Input for T {
112+
type Constraint = ();
113+
type Focus = HashFamily<T>;
114+
115+
fn hash<H: Hasher>(&self, state: &mut H) {
116+
Hash::hash(self, state);
117+
}
118+
119+
fn valid(&self, _: &()) -> bool {
120+
true
121+
}
122+
123+
fn focus<'f>(self, _: &'f ()) -> Self
124+
where
125+
Self: 'f,
126+
{
127+
self
128+
}
129+
}
130+
131+
pub struct HashFamily<T>(PhantomData<T>);
132+
impl<T> Family<'_> for HashFamily<T> {
133+
type Out = T;
134+
}
135+
136+
impl<'a, T: Track> Input for Tracked<'a, T> {
137+
type Constraint = <T as Trackable>::Constraint;
138+
type Focus = TrackedFamily<T>;
139+
140+
fn hash<H: Hasher>(&self, _: &mut H) {}
141+
142+
fn valid(&self, constraint: &Self::Constraint) -> bool {
143+
Trackable::valid(to_parts(*self).0, constraint)
144+
}
145+
146+
fn focus<'f>(self, constraint: &'f Self::Constraint) -> Tracked<'f, T>
147+
where
148+
Self: 'f,
149+
{
150+
from_parts(to_parts(self).0, Some(constraint))
151+
}
152+
}
153+
154+
pub struct TrackedFamily<T>(PhantomData<T>);
155+
impl<'f, T: Track + 'f> Family<'f> for TrackedFamily<T> {
156+
type Out = Tracked<'f, T>;
157+
}
158+
159+
pub trait Family<'a> {
160+
/// The surface with lifetime.
161+
type Out;
162+
}

src/hash.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,31 @@
11
use std::cell::Cell;
22
use std::hash::Hash;
33
use std::marker::PhantomData;
4-
use std::num::NonZeroU128;
54

65
use siphasher::sip128::{Hasher128, SipHasher};
76

8-
/// Produce a non zero 128-bit hash of a value.
9-
pub fn siphash<T: Hash>(value: &T) -> NonZeroU128 {
10-
let mut state = SipHasher::new();
11-
value.hash(&mut state);
12-
state
13-
.finish128()
14-
.as_u128()
15-
.try_into()
16-
.unwrap_or(NonZeroU128::new(u128::MAX).unwrap())
17-
}
18-
197
/// Defines a constraint for a value through its hash.
208
#[derive(Default)]
219
pub struct HashConstraint<T: Hash> {
22-
cell: Cell<Option<NonZeroU128>>,
10+
cell: Cell<Option<u128>>,
2311
marker: PhantomData<T>,
2412
}
2513

2614
impl<T: Hash> HashConstraint<T> {
2715
/// Set the constraint for the value.
2816
pub fn set(&self, value: &T) {
29-
self.cell.set(Some(siphash(value)));
17+
self.cell.set(Some(hash(value)));
3018
}
3119

3220
/// Whether the value fulfills the constraint.
3321
pub fn valid(&self, value: &T) -> bool {
34-
self.cell.get().map_or(true, |v| v == siphash(value))
22+
self.cell.get().map_or(true, |v| v == hash(value))
3523
}
3624
}
25+
26+
/// Produce a non zero 128-bit hash of a value.
27+
pub fn hash<T: Hash>(value: &T) -> u128 {
28+
let mut state = SipHasher::new();
29+
value.hash(&mut state);
30+
state.finish128().as_u128()
31+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ pub use comemo_macros::{memoize, track};
1212
pub mod internal {
1313
pub use crate::cache::CACHE;
1414
pub use crate::hash::HashConstraint;
15-
pub use crate::track::{to_parts, Family, Trackable};
15+
pub use crate::track::{from_parts, to_parts, Family, Trackable};
1616
}

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fn describe(image: Tracked<Image>) -> &'static str {
3838
}
3939

4040
::comemo::internal::CACHE
41-
.with(|cache| cache.query(stringify!(describe), describe, image))
41+
.with(|cache| cache.query(stringify!(describe), image, describe))
4242
}
4343

4444
const _: () = {

src/track.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub trait Track: Trackable {
7575
}
7676

7777
/// Non-exposed parts of the `Track` trait.
78-
pub trait Trackable {
78+
pub trait Trackable: 'static {
7979
/// Describes an instance of type.
8080
type Constraint: Default + 'static;
8181

0 commit comments

Comments
 (0)