Skip to content

Commit 2f3b75e

Browse files
committed
Improve diagnostics
1 parent f49439b commit 2f3b75e

File tree

9 files changed

+96
-86
lines changed

9 files changed

+96
-86
lines changed

macros/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ mod track;
1414

1515
use proc_macro::TokenStream;
1616
use quote::{quote, quote_spanned};
17-
use syn::{parse_quote, spanned::Spanned, Error, Result};
17+
use syn::spanned::Spanned;
18+
use syn::{parse_quote, Error, Result};
1819

1920
/// Memoize a pure function.
2021
#[proc_macro_attribute]
2122
pub fn memoize(_: TokenStream, stream: TokenStream) -> TokenStream {
2223
let func = syn::parse_macro_input!(stream as syn::ItemFn);
23-
memoize::expand(&func)
24+
memoize::expand(func)
2425
.unwrap_or_else(|err| err.to_compile_error())
2526
.into()
2627
}
@@ -29,7 +30,7 @@ pub fn memoize(_: TokenStream, stream: TokenStream) -> TokenStream {
2930
#[proc_macro_attribute]
3031
pub fn track(_: TokenStream, stream: TokenStream) -> TokenStream {
3132
let block = syn::parse_macro_input!(stream as syn::ItemImpl);
32-
track::expand(&block)
33+
track::expand(block)
3334
.unwrap_or_else(|err| err.to_compile_error())
3435
.into()
3536
}

macros/src/memoize.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::*;
22

33
/// Memoize a function.
4-
pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
4+
pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
55
let mut args = vec![];
66
let mut types = vec![];
77
for input in &func.sig.inputs {
@@ -21,30 +21,30 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
2121
types.push(typed.ty.as_ref());
2222
}
2323

24-
let mut inner = func.clone();
24+
// Construct a tuple from all arguments.
2525
let arg_tuple = quote! { (#(#args,)*) };
26-
let type_tuple = quote! { (#(#types,)*) };
27-
inner.sig.inputs = parse_quote! { #arg_tuple: #type_tuple };
2826

27+
// Construct assertions that the arguments fulfill the necessary bounds.
2928
let bounds = args.iter().zip(&types).map(|(arg, ty)| {
3029
quote_spanned! {
3130
arg.span() => ::comemo::internal::assert_hashable_or_trackable::<#ty>();
3231
}
3332
});
3433

35-
let mut outer = func.clone();
36-
let name = &func.sig.ident;
37-
outer.block = parse_quote! { {
38-
#inner
34+
// Construct the inner closure.
35+
let body = &func.block;
36+
let inner = quote! { |#arg_tuple| #body };
37+
38+
// Adjust the function's body.
39+
let name = func.sig.ident.to_string();
40+
func.block = parse_quote! { {
3941
#(#bounds;)*
40-
::comemo::internal::CACHE.with(|cache|
41-
cache.query(
42-
stringify!(#name),
43-
::comemo::internal::Args(#arg_tuple),
44-
#name,
45-
)
42+
::comemo::internal::cached(
43+
#name,
44+
::comemo::internal::Args(#arg_tuple),
45+
#inner,
4646
)
4747
} };
4848

49-
Ok(quote! { #outer })
49+
Ok(quote! { #func })
5050
}

macros/src/track.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::*;
22

33
/// Make a type trackable.
4-
pub fn expand(block: &syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
4+
pub fn expand(block: syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
55
let ty = &block.self_ty;
66

77
// Extract and validate the methods.

src/cache.rs

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,46 @@
1-
use std::any::Any;
1+
use std::any::{Any, TypeId};
22
use std::cell::RefCell;
33
use std::fmt::Debug;
4+
use std::hash::Hash;
45

56
use siphasher::sip128::{Hasher128, SipHasher};
67

78
use crate::input::Input;
89
use crate::internal::Family;
910

11+
/// Execute a function or use a cached result for it.
12+
pub fn cached<In, Out, F>(name: &str, input: In, func: F) -> Out
13+
where
14+
In: Input,
15+
Out: Debug + Clone + 'static,
16+
F: for<'f> Fn(<In::Tracked as Family<'f>>::Out) -> Out + 'static,
17+
{
18+
// Compute the hash of the input's key part.
19+
let hash = {
20+
let mut state = SipHasher::new();
21+
TypeId::of::<F>().hash(&mut state);
22+
input.key(&mut state);
23+
state.finish128().as_u128()
24+
};
25+
26+
let mut hit = true;
27+
let output = CACHE.with(|cache| {
28+
cache.lookup::<In, Out>(hash, &input).unwrap_or_else(|| {
29+
let constraint = In::Constraint::default();
30+
let value = func(input.track(&constraint));
31+
let constrained = Constrained { value: value.clone(), constraint };
32+
cache.insert::<In, Out>(hash, constrained);
33+
hit = false;
34+
value
35+
})
36+
});
37+
38+
let label = if hit { "[hit]" } else { "[miss]" };
39+
eprintln!("{name:<9} {label:<7} {output:?}");
40+
41+
output
42+
}
43+
1044
thread_local! {
1145
/// The global, dynamic cache shared by all memoized functions.
1246
pub static CACHE: Cache = Cache::default();
@@ -31,36 +65,6 @@ struct Constrained<T, C> {
3165
}
3266

3367
impl Cache {
34-
/// Execute `f` or use a cached result for it.
35-
pub fn query<In, Out, F>(&self, name: &str, input: In, func: F) -> Out
36-
where
37-
In: Input,
38-
Out: Debug + Clone + 'static,
39-
F: for<'f> Fn(<In::Tracked as Family<'f>>::Out) -> Out,
40-
{
41-
// Compute the hash of the input's key part.
42-
let hash = {
43-
let mut state = SipHasher::new();
44-
input.key(&mut state);
45-
state.finish128().as_u128()
46-
};
47-
48-
let mut hit = true;
49-
let output = self.lookup::<In, Out>(hash, &input).unwrap_or_else(|| {
50-
let constraint = In::Constraint::default();
51-
let value = func(input.track(&constraint));
52-
let constrained = Constrained { value: value.clone(), constraint };
53-
self.insert::<In, Out>(hash, constrained);
54-
hit = false;
55-
value
56-
});
57-
58-
let label = if hit { "[hit]" } else { "[miss]" };
59-
eprintln!("{name:<9} {label:<7} {output:?}");
60-
61-
output
62-
}
63-
6468
/// Look for a matching entry in the cache.
6569
fn lookup<In, Out>(&self, hash: u128, input: &In) -> Option<Out>
6670
where
@@ -75,7 +79,7 @@ impl Cache {
7579
entry
7680
.output
7781
.downcast_ref::<Constrained<Out, In::Constraint>>()
78-
.expect("comemo: hash collision")
82+
.expect("comemo: a hash collision occurred")
7983
})
8084
.find(|output| input.valid(&output.constraint))
8185
.map(|output| output.value.clone())

src/hash.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl<T: Hash> HashConstraint<T> {
2424
}
2525

2626
/// Produce a non zero 128-bit hash of a value.
27-
pub fn hash<T: Hash>(value: &T) -> u128 {
27+
fn hash<T: Hash>(value: &T) -> u128 {
2828
let mut state = SipHasher::new();
2929
value.hash(&mut state);
3030
state.finish128().as_u128()

src/input.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ use crate::track::{from_parts, to_parts, Track, Trackable, Tracked};
88
pub fn assert_hashable_or_trackable<T: Input>() {}
99

1010
/// An input to a cached function.
11+
///
12+
/// This is implemented for hashable types, `Tracked<_>` types and `Args<(...)>`
13+
/// types containing tuples up to length twelve.
1114
pub trait Input {
1215
/// Describes an instance of this input.
1316
type Constraint: Default + 'static;
@@ -81,7 +84,7 @@ impl<'a, T: Track> Input for Tracked<'a, T> {
8184
}
8285
}
8386

84-
/// 'f -> Tracked<'f, T> type constructor.
87+
/// Type constructor for `'f -> Tracked<'f, T>`.
8588
pub struct TrackedFamily<T>(PhantomData<T>);
8689

8790
impl<'f, T: Track + 'f> Family<'f> for TrackedFamily<T> {
@@ -91,17 +94,17 @@ impl<'f, T: Track + 'f> Family<'f> for TrackedFamily<T> {
9194
/// Wrapper for multiple inputs.
9295
pub struct Args<T>(pub T);
9396

94-
/// Lifetime to tuple of arguments type constructor.
97+
/// Type constructor that maps a lifetime to tuple of arguments.
9598
pub struct ArgsFamily<T>(PhantomData<T>);
9699

97100
macro_rules! args_input {
98-
($($idx:tt: $letter:ident),*) => {
101+
($($param:tt $idx:tt ),*) => {
99102
#[allow(unused_variables)]
100-
impl<$($letter: Input),*> Input for Args<($($letter,)*)> {
101-
type Constraint = ($($letter::Constraint,)*);
102-
type Tracked = ArgsFamily<($($letter,)*)>;
103+
impl<$($param: Input),*> Input for Args<($($param,)*)> {
104+
type Constraint = ($($param::Constraint,)*);
105+
type Tracked = ArgsFamily<($($param,)*)>;
103106

104-
fn key<H: Hasher>(&self, state: &mut H) {
107+
fn key<T: Hasher>(&self, state: &mut T) {
105108
$((self.0).$idx.key(state);)*
106109
}
107110

@@ -121,17 +124,22 @@ macro_rules! args_input {
121124
}
122125

123126
#[allow(unused_parens)]
124-
impl<'f, $($letter: Input),*> Family<'f> for ArgsFamily<($($letter,)*)> {
125-
type Out = ($(<$letter::Tracked as Family<'f>>::Out,)*);
127+
impl<'f, $($param: Input),*> Family<'f> for ArgsFamily<($($param,)*)> {
128+
type Out = ($(<$param::Tracked as Family<'f>>::Out,)*);
126129
}
127130
};
128131
}
129132

130133
args_input! {}
131-
args_input! { 0: A }
132-
args_input! { 0: A, 1: B }
133-
args_input! { 0: A, 1: B, 2: C }
134-
args_input! { 0: A, 1: B, 2: C, 3: D }
135-
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E }
136-
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E, 5: F }
137-
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G }
134+
args_input! { A 0 }
135+
args_input! { A 0, B 1 }
136+
args_input! { A 0, B 1, C 2 }
137+
args_input! { A 0, B 1, C 2, D 3 }
138+
args_input! { A 0, B 1, C 2, D 3, E 4 }
139+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5 }
140+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6 }
141+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7 }
142+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7, I 8 }
143+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7, I 8, J 9 }
144+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7, I 8, J 9, K 10 }
145+
args_input! { A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7, I 8, J 9, K 10, L 11 }

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ pub use comemo_macros::{memoize, track};
1111
/// These are implementation details. Do not rely on them!
1212
#[doc(hidden)]
1313
pub mod internal {
14-
pub use crate::cache::CACHE;
14+
pub use crate::cache::cached;
1515
pub use crate::hash::HashConstraint;
1616
pub use crate::input::{assert_hashable_or_trackable, Args};
17-
pub use crate::track::{from_parts, to_parts, Trackable};
17+
pub use crate::track::{to_parts, Trackable};
1818

1919
/// Helper trait for lifetime type families.
2020
pub trait Family<'a> {

src/main.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,18 @@ fn main() {
2929

3030
/// Format the image's size humanly readable.
3131
fn describe(image: Tracked<Image>) -> &'static str {
32-
fn describe((image,): (Tracked<Image>,)) -> &'static str {
33-
if image.width() > 50 || image.height() > 50 {
34-
"The image is big!"
35-
} else {
36-
"The image is small!"
37-
}
38-
}
39-
40-
::comemo::internal::CACHE.with(|cache| {
41-
cache.query(
42-
stringify!(describe),
43-
::comemo::internal::Args((image,)),
44-
describe,
45-
)
46-
})
32+
::comemo::internal::assert_hashable_or_trackable::<Tracked<Image>>();
33+
::comemo::internal::cached(
34+
"describe",
35+
::comemo::internal::Args((image,)),
36+
|(image,)| {
37+
if image.width() > 50 || image.height() > 50 {
38+
"The image is big!"
39+
} else {
40+
"The image is small!"
41+
}
42+
},
43+
)
4744
}
4845

4946
const _: () = {

tests/image.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn test_image() {
1111

1212
describe(image.track()); // [Miss] Width and height changed.
1313
select(image.track(), "width"); // [Miss] First call.
14-
select(image.track(), "height"); // [Miss]
14+
select(image.track(), "height"); // [Miss] Different 2nd argument.
1515

1616
image.resize(80, 70);
1717
image.pixels.fill(255);

0 commit comments

Comments
 (0)