Skip to content

Commit 457274f

Browse files
committed
Tracked methods with arguments
1 parent 2f3b75e commit 457274f

File tree

12 files changed

+462
-129
lines changed

12 files changed

+462
-129
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ edition = "2021"
66
[dependencies]
77
comemo-macros = { path = "./macros" }
88
siphasher = "0.3"
9+
10+
[dev-dependencies]
11+
unscanny = "0.1"

macros/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ mod memoize;
1313
mod track;
1414

1515
use proc_macro::TokenStream;
16-
use quote::{quote, quote_spanned};
17-
use syn::spanned::Spanned;
16+
use quote::quote;
1817
use syn::{parse_quote, Error, Result};
1918

2019
/// Memoize a pure function.

macros/src/memoize.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,35 @@ pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1212
}
1313
};
1414

15-
let name = match &*typed.pat {
16-
syn::Pat::Ident(ident) => ident,
17-
_ => bail!(typed.pat, "only simple identifiers are supported"),
15+
let name = match typed.pat.as_ref() {
16+
syn::Pat::Ident(syn::PatIdent {
17+
by_ref: None,
18+
mutability: None,
19+
ident,
20+
subpat: None,
21+
..
22+
}) => ident,
23+
pat => bail!(pat, "only simple identifiers are supported"),
1824
};
1925

26+
let ty = typed.ty.as_ref();
2027
args.push(name);
21-
types.push(typed.ty.as_ref());
28+
types.push(ty);
2229
}
2330

2431
// Construct a tuple from all arguments.
2532
let arg_tuple = quote! { (#(#args,)*) };
2633

2734
// Construct assertions that the arguments fulfill the necessary bounds.
28-
let bounds = args.iter().zip(&types).map(|(arg, ty)| {
29-
quote_spanned! {
30-
arg.span() => ::comemo::internal::assert_hashable_or_trackable::<#ty>();
35+
let bounds = types.iter().map(|ty| {
36+
quote! {
37+
::comemo::internal::assert_hashable_or_trackable::<#ty>();
3138
}
3239
});
3340

3441
// Construct the inner closure.
3542
let body = &func.block;
36-
let inner = quote! { |#arg_tuple| #body };
43+
let closure = quote! { |#arg_tuple| #body };
3744

3845
// Adjust the function's body.
3946
let name = func.sig.ident.to_string();
@@ -42,7 +49,7 @@ pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
4249
::comemo::internal::cached(
4350
#name,
4451
::comemo::internal::Args(#arg_tuple),
45-
#inner,
52+
#closure,
4653
)
4754
} };
4855

macros/src/track.rs

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,68 @@ pub fn expand(block: syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
1010
methods.push(method(&item)?);
1111
}
1212

13-
let tracked_fields = methods.iter().map(|method| {
14-
let name = &method.sig.ident;
15-
let ty = match &method.sig.output {
16-
syn::ReturnType::Default => unreachable!(),
17-
syn::ReturnType::Type(_, ty) => ty.as_ref(),
18-
};
19-
quote! { #name: ::comemo::internal::HashConstraint<#ty>, }
13+
let tracked_valids = methods.iter().map(|method| {
14+
let name = &method.name;
15+
let args = &method.args;
16+
if args.is_empty() {
17+
quote! { constraint.#name.valid(&self.#name()) }
18+
} else {
19+
quote! {
20+
constraint.#name
21+
.valid(|(#(#args,)*)| self.#name(#(#args.clone(),)*))
22+
}
23+
}
2024
});
2125

2226
let tracked_methods = methods.iter().map(|method| {
23-
let name = &method.sig.ident;
24-
let mut method = (*method).clone();
25-
if matches!(method.vis, syn::Visibility::Inherited) {
26-
method.vis = parse_quote! { pub(super) };
27+
let mut wrapper = method.item.clone();
28+
if matches!(wrapper.vis, syn::Visibility::Inherited) {
29+
wrapper.vis = parse_quote! { pub(super) };
2730
}
28-
method.block = parse_quote! { {
29-
let (inner, constraint) = ::comemo::internal::to_parts(self.0);
30-
let output = inner.#name();
31+
32+
let name = &method.name;
33+
let args = &method.args;
34+
let set = if args.is_empty() {
35+
quote! { constraint.#name.set(&output) }
36+
} else {
37+
quote! { constraint.#name.set((#(#args,)*), &output) }
38+
};
39+
40+
// Construct assertions that the arguments fulfill the necessary bounds.
41+
let bounds = method.types.iter().map(|ty| {
42+
quote! {
43+
::comemo::internal::assert_clone_and_partial_eq::<#ty>();
44+
}
45+
});
46+
47+
wrapper.block = parse_quote! { {
48+
#(#bounds;)*
49+
let (value, constraint) = ::comemo::internal::to_parts(self.0);
50+
let output = value.#name(#(#args.clone(),)*);
3151
if let Some(constraint) = &constraint {
32-
constraint.#name.set(&output);
52+
#set;
3353
}
3454
output
3555
} };
36-
method
56+
57+
wrapper
3758
});
3859

39-
let tracked_valids = methods.iter().map(|method| {
40-
let name = &method.sig.ident;
41-
quote! {
42-
constraint.#name.valid(&self.#name())
60+
let tracked_fields = methods.iter().map(|method| {
61+
let name = &method.name;
62+
let types = &method.types;
63+
if types.is_empty() {
64+
quote! { #name: ::comemo::internal::HashConstraint, }
65+
} else {
66+
quote! { #name: ::comemo::internal::FuncConstraint<(#(#types,)*)>, }
4367
}
4468
});
4569

70+
let join_calls = methods.iter().map(|method| {
71+
let name = &method.name;
72+
quote! { self.#name.join(&inner.#name); }
73+
});
74+
4675
let track_impl = quote! {
4776
use super::*;
4877

@@ -77,6 +106,12 @@ pub fn expand(block: syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
77106
pub struct Constraint {
78107
#(#tracked_fields)*
79108
}
109+
110+
impl ::comemo::internal::Join for Constraint {
111+
fn join(&self, inner: &Self) {
112+
#(#join_calls)*
113+
}
114+
}
80115
};
81116

82117
Ok(quote! {
@@ -85,8 +120,15 @@ pub fn expand(block: syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
85120
})
86121
}
87122

123+
struct Method {
124+
item: syn::ImplItemMethod,
125+
name: syn::Ident,
126+
args: Vec<syn::Ident>,
127+
types: Vec<syn::Type>,
128+
}
129+
88130
/// Extract and validate a method.
89-
fn method(item: &syn::ImplItem) -> Result<&syn::ImplItemMethod> {
131+
fn method(item: &syn::ImplItem) -> Result<Method> {
90132
let method = match item {
91133
syn::ImplItem::Method(method) => method,
92134
_ => bail!(item, "only methods are supported"),
@@ -98,6 +140,27 @@ fn method(item: &syn::ImplItem) -> Result<&syn::ImplItemMethod> {
98140
_ => bail!(method.vis, "only private and public methods are supported"),
99141
}
100142

143+
if let Some(unsafety) = method.sig.unsafety {
144+
bail!(unsafety, "unsafe methods are not supported");
145+
}
146+
147+
if let Some(asyncness) = method.sig.asyncness {
148+
bail!(asyncness, "async methods are not supported");
149+
}
150+
151+
if let Some(constness) = method.sig.constness {
152+
bail!(constness, "const methods are not supported");
153+
}
154+
155+
for param in method.sig.generics.params.iter() {
156+
match param {
157+
syn::GenericParam::Const(_) | syn::GenericParam::Type(_) => {
158+
bail!(param, "method must not be generic")
159+
}
160+
syn::GenericParam::Lifetime(_) => {}
161+
}
162+
}
163+
101164
let mut inputs = method.sig.inputs.iter();
102165
let receiver = match inputs.next() {
103166
Some(syn::FnArg::Receiver(recv)) => recv,
@@ -108,20 +171,46 @@ fn method(item: &syn::ImplItem) -> Result<&syn::ImplItemMethod> {
108171
bail!(receiver, "must take self by shared reference");
109172
}
110173

111-
if inputs.next().is_some() {
112-
bail!(
113-
method.sig,
114-
"currently, only methods without extra arguments are supported"
115-
);
174+
let mut args = vec![];
175+
let mut types = vec![];
176+
for input in inputs {
177+
let typed = match input {
178+
syn::FnArg::Typed(typed) => typed,
179+
syn::FnArg::Receiver(_) => continue,
180+
};
181+
182+
let name = match typed.pat.as_ref() {
183+
syn::Pat::Ident(syn::PatIdent {
184+
by_ref: None,
185+
mutability: None,
186+
ident,
187+
subpat: None,
188+
..
189+
}) => ident.clone(),
190+
pat => bail!(pat, "only simple identifiers are supported"),
191+
};
192+
193+
let ty = (*typed.ty).clone();
194+
match ty {
195+
syn::Type::ImplTrait(_) => bail!(ty, "method must not be generic"),
196+
_ => {}
197+
}
198+
199+
args.push(name);
200+
types.push(ty);
116201
}
117202

118-
let output = &method.sig.output;
119-
match output {
203+
match method.sig.output {
120204
syn::ReturnType::Default => {
121205
bail!(method.sig, "method must have a return type")
122206
}
123207
syn::ReturnType::Type(..) => {}
124208
}
125209

126-
Ok(method)
210+
Ok(Method {
211+
item: method.clone(),
212+
name: method.sig.ident.clone(),
213+
args,
214+
types,
215+
})
127216
}

src/cache.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use std::any::{Any, TypeId};
2-
use std::cell::RefCell;
2+
use std::cell::{Cell, RefCell};
33
use std::fmt::Debug;
44
use std::hash::Hash;
55

66
use siphasher::sip128::{Hasher128, SipHasher};
77

8+
use crate::constraint::Join;
89
use crate::input::Input;
910
use crate::internal::Family;
1011

@@ -26,41 +27,51 @@ where
2627
let mut hit = true;
2728
let output = CACHE.with(|cache| {
2829
cache.lookup::<In, Out>(hash, &input).unwrap_or_else(|| {
30+
DEPTH.with(|v| v.set(v.get() + 1));
2931
let constraint = In::Constraint::default();
30-
let value = func(input.track(&constraint));
31-
let constrained = Constrained { value: value.clone(), constraint };
32-
cache.insert::<In, Out>(hash, constrained);
32+
let (tracked, outer) = input.retrack(&constraint);
33+
let output = func(tracked);
34+
outer.join(&constraint);
35+
cache.insert::<In, Out>(hash, Constrained {
36+
output: output.clone(),
37+
constraint,
38+
});
3339
hit = false;
34-
value
40+
DEPTH.with(|v| v.set(v.get() - 1));
41+
output
3542
})
3643
});
3744

45+
let depth = DEPTH.with(|v| v.get());
3846
let label = if hit { "[hit]" } else { "[miss]" };
39-
eprintln!("{name:<9} {label:<7} {output:?}");
47+
eprintln!("{depth} {name:<9} {label:<7} {output:?}");
4048

4149
output
4250
}
4351

4452
thread_local! {
4553
/// The global, dynamic cache shared by all memoized functions.
46-
pub static CACHE: Cache = Cache::default();
54+
static CACHE: Cache = Cache::default();
55+
56+
/// The current depth of the memoized call stack.
57+
static DEPTH: Cell<usize> = Cell::new(0);
4758
}
4859

4960
/// An untyped cache.
5061
#[derive(Default)]
51-
pub struct Cache {
62+
struct Cache {
5263
map: RefCell<Vec<Entry>>,
5364
}
5465

5566
/// An entry in the cache.
5667
struct Entry {
5768
hash: u128,
58-
output: Box<dyn Any>,
69+
constrained: Box<dyn Any>,
5970
}
6071

6172
/// A value with a constraint.
6273
struct Constrained<T, C> {
63-
value: T,
74+
output: T,
6475
constraint: C,
6576
}
6677

@@ -77,21 +88,21 @@ impl Cache {
7788
.filter(|entry| entry.hash == hash)
7889
.map(|entry| {
7990
entry
80-
.output
91+
.constrained
8192
.downcast_ref::<Constrained<Out, In::Constraint>>()
8293
.expect("comemo: a hash collision occurred")
8394
})
8495
.find(|output| input.valid(&output.constraint))
85-
.map(|output| output.value.clone())
96+
.map(|output| output.output.clone())
8697
}
8798

8899
/// Insert an entry into the cache.
89-
fn insert<In, Out>(&self, hash: u128, output: Constrained<Out, In::Constraint>)
100+
fn insert<In, Out>(&self, hash: u128, constrained: Constrained<Out, In::Constraint>)
90101
where
91102
In: Input,
92103
Out: 'static,
93104
{
94-
let entry = Entry { hash, output: Box::new(output) };
105+
let entry = Entry { hash, constrained: Box::new(constrained) };
95106
self.map.borrow_mut().push(entry);
96107
}
97108
}

0 commit comments

Comments
 (0)