Skip to content

Commit 7f25460

Browse files
committed
Basic fully macro-generated memoization
1 parent 0aa5c1b commit 7f25460

File tree

7 files changed

+298
-204
lines changed

7 files changed

+298
-204
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ edition = "2021"
66
[dependencies]
77
comemo-macros = { path = "./macros" }
88
siphasher = "0.3"
9-
once_cell = "1"

examples/image.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ fn describe(image: Tracked<Image>) -> &'static str {
3131
}
3232
}
3333

34+
#[comemo::track]
35+
impl Image {
36+
fn width(&self) -> u32 {
37+
self.width
38+
}
39+
40+
fn height(&self) -> u32 {
41+
self.height
42+
}
43+
}
44+
3445
/// A raster image.
3546
struct Image {
3647
width: u32,
@@ -53,14 +64,3 @@ impl Image {
5364
// Resize the actual image ...
5465
}
5566
}
56-
57-
#[comemo::track]
58-
impl Image {
59-
fn width(&self) -> u32 {
60-
self.width
61-
}
62-
63-
fn height(&self) -> u32 {
64-
self.height
65-
}
66-
}

macros/src/memoize.rs

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use super::*;
22

3-
/// Make a type trackable.
3+
/// Memoize a function.
44
pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
55
let name = func.sig.ident.to_string();
66

77
let mut args = vec![];
8+
let mut types = vec![];
89
for input in &func.sig.inputs {
910
let typed = match input {
1011
syn::FnArg::Typed(typed) => typed,
@@ -19,41 +20,96 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1920
};
2021

2122
args.push(ident);
23+
types.push(typed.ty.as_ref());
2224
}
2325

26+
let ret = match &func.sig.output {
27+
syn::ReturnType::Default => {
28+
bail!(func.sig, "function must have a return type")
29+
}
30+
syn::ReturnType::Type(.., ty) => ty.as_ref(),
31+
};
32+
2433
let mut inner = func.clone();
2534
inner.sig.ident = syn::Ident::new("inner", Span::call_site());
2635

27-
let cts = args.iter().map(|arg| {
28-
quote! {
29-
Validate::constraint(&#arg)
36+
if args.len() != 1 {
37+
bail!(func, "expected exactly one argument");
38+
}
39+
40+
let arg = args[0];
41+
let ty = types[0];
42+
let track = match ty {
43+
syn::Type::Path(path) => {
44+
let segs = &path.path.segments;
45+
if segs.len() != 1 {
46+
bail!(ty, "expected exactly one path segment")
47+
}
48+
let args = match &segs[0].arguments {
49+
syn::PathArguments::AngleBracketed(args) => &args.args,
50+
_ => bail!(ty, "expected `Tracked<_>` type"),
51+
};
52+
if args.len() != 1 {
53+
bail!(args, "expected exactly one generic argument")
54+
}
55+
match &args[0] {
56+
syn::GenericArgument::Type(ty) => ty,
57+
ty => bail!(ty, "expected type argument"),
58+
}
3059
}
31-
});
60+
_ => bail!(ty, "expected type of the form `Tracked<_>`"),
61+
};
3262

33-
let mut outer = func.clone();
34-
outer.block = parse_quote! { { #inner {
35-
use std::sync::atomic::{AtomicUsize, Ordering};
36-
use comemo::Validate;
63+
let trackable = quote! {
64+
<#track as ::comemo::internal::Trackable<'static>>
65+
};
66+
67+
let body = quote! {
68+
type Cache = ::core::cell::RefCell<
69+
::std::vec::Vec<(#trackable::Tracker, #ret)>
70+
>;
3771

38-
static NR: AtomicUsize = AtomicUsize::new(1);
39-
let nr = NR.fetch_add(1, Ordering::SeqCst);
40-
let cts = (#(#cts,)*);
72+
thread_local! {
73+
static CACHE: Cache = Default::default();
74+
}
4175

42-
println!("{:?}", cts);
76+
let mut hit = true;
77+
let output = CACHE.with(|cache| {
78+
cache
79+
.borrow()
80+
.iter()
81+
.find(|(tracker, _)| {
82+
let (#arg, _) = ::comemo::internal::to_parts(#arg);
83+
#trackable::valid(#arg, tracker)
84+
})
85+
.map(|&(_, output)| output)
86+
});
4387

44-
let mut hit = false;
45-
let result = inner(#(#args),*);
88+
let output = output.unwrap_or_else(|| {
89+
let tracker = ::core::default::Default::default();
90+
let (#arg, _) = ::comemo::internal::to_parts(#arg);
91+
let #arg = ::comemo::internal::from_parts(#arg, Some(&tracker));
92+
let output = inner(#arg);
93+
CACHE.with(|cache| cache.borrow_mut().push((tracker, output)));
94+
hit = false;
95+
output
96+
});
4697

4798
println!(
48-
"{} {} {} {}",
99+
"{} {} {}",
49100
#name,
50-
nr,
51101
if hit { "[hit]: " } else { "[miss]:" },
52-
result,
102+
output,
53103
);
54104

55-
result
56-
} } };
105+
output
106+
};
107+
108+
let mut outer = func.clone();
109+
outer.block = parse_quote! { {
110+
#inner
111+
{ #body }
112+
} };
57113

58114
Ok(quote! { #outer })
59115
}

macros/src/track.rs

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,76 @@ pub fn expand(block: &syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
1010
methods.push(method(&item)?);
1111
}
1212

13+
let tracked_fields = methods.iter().map(|method| {
14+
let name = &method.sig.ident;
15+
let ty = match &method.sig.output {
16+
syn::ReturnType::Default => unreachable!(),
17+
syn::ReturnType::Type(_, ty) => ty.as_ref(),
18+
};
19+
quote! { #name: ::comemo::internal::AccessTracker<#ty>, }
20+
});
21+
1322
let tracked_methods = methods.iter().map(|method| {
14-
let mut method = (*method).clone();
1523
let name = &method.sig.ident;
24+
let mut method = (*method).clone();
25+
if matches!(method.vis, syn::Visibility::Inherited) {
26+
method.vis = parse_quote! { pub(super) };
27+
}
1628
method.block = parse_quote! { {
17-
let output = self.0.#name();
18-
let slot = &mut self.constraint().#name;
19-
let ct = Validate::constraint(&output);
20-
if slot.is_none() {
21-
assert_eq!(*slot, Some(ct), "comemo: method is not pure");
29+
let (inner, tracker) = ::comemo::internal::to_parts(self.0);
30+
let output = inner.#name();
31+
if let Some(tracker) = &tracker {
32+
tracker.#name.track(&output);
2233
}
23-
*slot = Some(ct);
2434
output
2535
} };
2636
method
2737
});
2838

29-
let tracked_fields = methods.iter().map(|method| {
39+
let tracked_valids = methods.iter().map(|method| {
3040
let name = &method.sig.ident;
31-
let ty = match &method.sig.output {
32-
syn::ReturnType::Default => unreachable!(),
33-
syn::ReturnType::Type(_, ty) => ty.as_ref(),
34-
};
35-
quote! { #name: Option<<#ty as Validate>::Constraint>, }
41+
quote! {
42+
tracker.#name.valid(&self.#name())
43+
}
3644
});
3745

3846
let track_impl = quote! {
39-
use comemo::Validate;
47+
use super::*;
4048

41-
struct Surface<'a>(&'a #ty);
42-
43-
impl Surface<'_> {
44-
#(#tracked_methods)*
49+
impl<'a> ::comemo::Track<'a> for #ty {}
50+
impl<'a> ::comemo::internal::Trackable<'a> for #ty {
51+
type Tracker = Tracker;
52+
type Surface = Surface<'a>;
4553

46-
fn constraint(&self) -> &mut Constraint {
47-
todo!()
54+
fn valid(&self, tracker: &Self::Tracker) -> bool {
55+
#(#tracked_valids)&&*
4856
}
49-
}
5057

51-
impl<'a> From<&'a #ty> for Surface<'a> {
52-
fn from(val: &'a #ty) -> Self {
53-
Self(val)
58+
fn surface<'s>(tracked: &'s Tracked<'a, #ty>) -> &'s Self::Surface
59+
where
60+
Self: Track<'a>,
61+
{
62+
// Safety: Surface is repr(transparent).
63+
unsafe { &*(tracked as *const _ as *const Self::Surface) }
5464
}
5565
}
5666

57-
#[derive(Debug, Default)]
58-
struct Constraint {
59-
#(#tracked_fields)*
67+
#[repr(transparent)]
68+
pub struct Surface<'a>(Tracked<'a, #ty>);
69+
70+
impl Surface<'_> {
71+
#(#tracked_methods)*
6072
}
6173

62-
impl<'a> comemo::Track<'a> for #ty {
63-
type Surface = Surface<'a>;
64-
type Constraint = Constraint;
74+
#[derive(Default)]
75+
pub struct Tracker {
76+
#(#tracked_fields)*
6577
}
6678
};
6779

6880
Ok(quote! {
6981
#block
70-
const _: () = { #track_impl };
82+
const _: () = { mod private { #track_impl } };
7183
})
7284
}
7385

@@ -78,6 +90,12 @@ fn method(item: &syn::ImplItem) -> Result<&syn::ImplItemMethod> {
7890
_ => bail!(item, "only methods are supported"),
7991
};
8092

93+
match method.vis {
94+
syn::Visibility::Inherited => {}
95+
syn::Visibility::Public(_) => {}
96+
_ => bail!(method.vis, "only private and public methods are supported"),
97+
}
98+
8199
let mut inputs = method.sig.inputs.iter();
82100
let receiver = match inputs.next() {
83101
Some(syn::FnArg::Receiver(recv)) => recv,

src/internal.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use std::cell::Cell;
2+
use std::hash::Hash;
3+
use std::marker::PhantomData;
4+
use std::num::NonZeroU128;
5+
6+
use siphasher::sip128::{Hasher128, SipHasher};
7+
8+
use super::{Track, Tracked};
9+
10+
/// Destructure a tracker into its parts.
11+
pub fn to_parts<'a, T>(tracked: Tracked<'a, T>) -> (&'a T, Option<&'a T::Tracker>)
12+
where
13+
T: Track<'a>,
14+
{
15+
(tracked.inner, tracked.tracker)
16+
}
17+
18+
/// Create a tracker from its parts.
19+
pub fn from_parts<'a, T>(inner: &'a T, tracker: Option<&'a T::Tracker>) -> Tracked<'a, T>
20+
where
21+
T: Track<'a>,
22+
{
23+
Tracked { inner, tracker }
24+
}
25+
26+
/// Non-exposed parts of the `Track` trait.
27+
pub trait Trackable<'a>: Sized + 'a {
28+
/// Keeps track of accesses to the value.
29+
type Tracker: Default;
30+
31+
/// The tracked API surface of this type.
32+
type Surface;
33+
34+
/// Whether an instance fulfills the given tracker's constraints.
35+
fn valid(&self, tracker: &Self::Tracker) -> bool;
36+
37+
/// Cast a reference from `Tracked` to this type's surface.
38+
fn surface<'s>(tracked: &'s Tracked<'a, Self>) -> &'s Self::Surface
39+
where
40+
Self: Track<'a>;
41+
}
42+
43+
/// Tracks accesses to a value.
44+
#[derive(Default)]
45+
pub struct AccessTracker<T: Hash>(Cell<Option<NonZeroU128>>, PhantomData<T>);
46+
47+
impl<T: Hash> AccessTracker<T> {
48+
pub fn track(&self, value: &T) {
49+
self.0.set(Some(siphash(value)));
50+
}
51+
52+
pub fn valid(&self, value: &T) -> bool {
53+
self.0.get().map_or(true, |v| v == siphash(value))
54+
}
55+
}
56+
57+
/// Produce a non zero 128-bit hash of a value.
58+
fn siphash<T: Hash>(value: &T) -> NonZeroU128 {
59+
let mut state = SipHasher::new();
60+
value.hash(&mut state);
61+
state
62+
.finish128()
63+
.as_u128()
64+
.try_into()
65+
.unwrap_or(NonZeroU128::new(u128::MAX).unwrap())
66+
}

0 commit comments

Comments
 (0)