Skip to content

Commit f49439b

Browse files
committed
Support multiple arguments
1 parent 2418d70 commit f49439b

File tree

8 files changed

+152
-43
lines changed

8 files changed

+152
-43
lines changed

macros/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ mod memoize;
1313
mod track;
1414

1515
use proc_macro::TokenStream;
16-
use quote::quote;
17-
use syn::{parse_quote, Error, Result};
16+
use quote::{quote, quote_spanned};
17+
use syn::{parse_quote, spanned::Spanned, Error, Result};
1818

1919
/// Memoize a pure function.
2020
#[proc_macro_attribute]

macros/src/memoize.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use super::*;
33
/// Memoize a function.
44
pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
55
let mut args = vec![];
6+
let mut types = vec![];
67
for input in &func.sig.inputs {
78
let typed = match input {
89
syn::FnArg::Typed(typed) => typed,
@@ -17,18 +18,29 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1718
};
1819

1920
args.push(name);
21+
types.push(typed.ty.as_ref());
2022
}
2123

24+
let mut inner = func.clone();
25+
let arg_tuple = quote! { (#(#args,)*) };
26+
let type_tuple = quote! { (#(#types,)*) };
27+
inner.sig.inputs = parse_quote! { #arg_tuple: #type_tuple };
28+
29+
let bounds = args.iter().zip(&types).map(|(arg, ty)| {
30+
quote_spanned! {
31+
arg.span() => ::comemo::internal::assert_hashable_or_trackable::<#ty>();
32+
}
33+
});
34+
2235
let mut outer = func.clone();
2336
let name = &func.sig.ident;
24-
let arg = &args[0];
25-
2637
outer.block = parse_quote! { {
27-
#func
38+
#inner
39+
#(#bounds;)*
2840
::comemo::internal::CACHE.with(|cache|
2941
cache.query(
3042
stringify!(#name),
31-
#arg,
43+
::comemo::internal::Args(#arg_tuple),
3244
#name,
3345
)
3446
)

src/cache.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl Cache {
3636
where
3737
In: Input,
3838
Out: Debug + Clone + 'static,
39-
F: for<'f> Fn(<In::Hooked as Family<'f>>::Out) -> Out,
39+
F: for<'f> Fn(<In::Tracked as Family<'f>>::Out) -> Out,
4040
{
4141
// Compute the hash of the input's key part.
4242
let hash = {
@@ -48,8 +48,7 @@ impl Cache {
4848
let mut hit = true;
4949
let output = self.lookup::<In, Out>(hash, &input).unwrap_or_else(|| {
5050
let constraint = In::Constraint::default();
51-
let input = input.hook_up(&constraint);
52-
let value = func(input);
51+
let value = func(input.track(&constraint));
5352
let constrained = Constrained { value: value.clone(), constraint };
5453
self.insert::<In, Out>(hash, constrained);
5554
hit = false;

src/input.rs

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,38 @@ use std::marker::PhantomData;
44
use crate::internal::Family;
55
use crate::track::{from_parts, to_parts, Track, Trackable, Tracked};
66

7-
/// An argument to a cached function.
7+
/// Ensure a type is suitable as input.
8+
pub fn assert_hashable_or_trackable<T: Input>() {}
9+
10+
/// An input to a cached function.
811
pub trait Input {
9-
/// Describes an instance of the argument.
12+
/// Describes an instance of this input.
1013
type Constraint: Default + 'static;
1114

12-
/// The argument with constraints hooked in.
13-
type Hooked: for<'f> Family<'f>;
15+
/// The input with constraints hooked in.
16+
type Tracked: for<'f> Family<'f>;
1417

15-
/// Hash the argument if it is a _key_ argument
18+
/// Hash the _key_ parts of the input.
1619
fn key<H: Hasher>(&self, state: &mut H);
1720

18-
/// Validate the argument if it is a _tracked_ argument.
21+
/// Validate the _tracked_ parts of the input.
1922
fn valid(&self, constraint: &Self::Constraint) -> bool;
2023

21-
/// Hook up the given constraints if this is a _tracked_ argument.
22-
fn hook_up<'f>(
24+
/// Hook up the given constraint to the _tracked_ parts of the input.
25+
fn track<'f>(
2326
self,
2427
constraint: &'f Self::Constraint,
25-
) -> <Self::Hooked as Family<'f>>::Out
28+
) -> <Self::Tracked as Family<'f>>::Out
2629
where
2730
Self: 'f;
2831
}
2932

3033
impl<T: Hash> Input for T {
31-
/// No constraint for hashed arguments.
34+
/// No constraint for hashed inputs.
3235
type Constraint = ();
3336

3437
/// The hooked-up type is just `Self`.
35-
type Hooked = IdFamily<Self>;
38+
type Tracked = IdFamily<Self>;
3639

3740
fn key<H: Hasher>(&self, state: &mut H) {
3841
Hash::hash(self, state);
@@ -42,45 +45,93 @@ impl<T: Hash> Input for T {
4245
true
4346
}
4447

45-
fn hook_up<'f>(self, _: &'f ()) -> Self
48+
fn track<'f>(self, _: &'f ()) -> Self
4649
where
4750
Self: 'f,
4851
{
4952
self
5053
}
5154
}
5255

56+
/// Identity type constructor.
57+
pub struct IdFamily<T>(PhantomData<T>);
58+
59+
impl<T> Family<'_> for IdFamily<T> {
60+
type Out = T;
61+
}
62+
5363
impl<'a, T: Track> Input for Tracked<'a, T> {
54-
/// Forwarded constraint from Trackable implementation.
64+
/// Forward constraint from `Trackable` implementation.
5565
type Constraint = <T as Trackable>::Constraint;
5666

5767
/// The hooked-up type is `Tracked<'f, T>`.
58-
type Hooked = TrackedFamily<T>;
68+
type Tracked = TrackedFamily<T>;
5969

6070
fn key<H: Hasher>(&self, _: &mut H) {}
6171

6272
fn valid(&self, constraint: &Self::Constraint) -> bool {
6373
Trackable::valid(to_parts(*self).0, constraint)
6474
}
6575

66-
fn hook_up<'f>(self, constraint: &'f Self::Constraint) -> Tracked<'f, T>
76+
fn track<'f>(self, constraint: &'f Self::Constraint) -> Tracked<'f, T>
6777
where
6878
Self: 'f,
6979
{
7080
from_parts(to_parts(self).0, Some(constraint))
7181
}
7282
}
7383

74-
/// Identity type constructor.
75-
pub struct IdFamily<T>(PhantomData<T>);
76-
77-
impl<T> Family<'_> for IdFamily<T> {
78-
type Out = T;
79-
}
80-
8184
/// 'f -> Tracked<'f, T> type constructor.
8285
pub struct TrackedFamily<T>(PhantomData<T>);
8386

8487
impl<'f, T: Track + 'f> Family<'f> for TrackedFamily<T> {
8588
type Out = Tracked<'f, T>;
8689
}
90+
91+
/// Wrapper for multiple inputs.
92+
pub struct Args<T>(pub T);
93+
94+
/// Lifetime to tuple of arguments type constructor.
95+
pub struct ArgsFamily<T>(PhantomData<T>);
96+
97+
macro_rules! args_input {
98+
($($idx:tt: $letter:ident),*) => {
99+
#[allow(unused_variables)]
100+
impl<$($letter: Input),*> Input for Args<($($letter,)*)> {
101+
type Constraint = ($($letter::Constraint,)*);
102+
type Tracked = ArgsFamily<($($letter,)*)>;
103+
104+
fn key<H: Hasher>(&self, state: &mut H) {
105+
$((self.0).$idx.key(state);)*
106+
}
107+
108+
fn valid(&self, constraint: &Self::Constraint) -> bool {
109+
true $(&& (self.0).$idx.valid(&constraint.$idx))*
110+
}
111+
112+
fn track<'f>(
113+
self,
114+
constraint: &'f Self::Constraint,
115+
) -> <Self::Tracked as Family<'f>>::Out
116+
where
117+
Self: 'f,
118+
{
119+
($((self.0).$idx.track(&constraint.$idx),)*)
120+
}
121+
}
122+
123+
#[allow(unused_parens)]
124+
impl<'f, $($letter: Input),*> Family<'f> for ArgsFamily<($($letter,)*)> {
125+
type Out = ($(<$letter::Tracked as Family<'f>>::Out,)*);
126+
}
127+
};
128+
}
129+
130+
args_input! {}
131+
args_input! { 0: A }
132+
args_input! { 0: A, 1: B }
133+
args_input! { 0: A, 1: B, 2: C }
134+
args_input! { 0: A, 1: B, 2: C, 3: D }
135+
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E }
136+
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E, 5: F }
137+
args_input! { 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G }

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use comemo_macros::{memoize, track};
1313
pub mod internal {
1414
pub use crate::cache::CACHE;
1515
pub use crate::hash::HashConstraint;
16+
pub use crate::input::{assert_hashable_or_trackable, Args};
1617
pub use crate::track::{from_parts, to_parts, Trackable};
1718

1819
/// Helper trait for lifetime type families.

src/main.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,21 @@ fn main() {
2929

3030
/// Format the image's size humanly readable.
3131
fn describe(image: Tracked<Image>) -> &'static str {
32-
fn describe(image: Tracked<Image>) -> &'static str {
32+
fn describe((image,): (Tracked<Image>,)) -> &'static str {
3333
if image.width() > 50 || image.height() > 50 {
3434
"The image is big!"
3535
} else {
3636
"The image is small!"
3737
}
3838
}
3939

40-
::comemo::internal::CACHE
41-
.with(|cache| cache.query(stringify!(describe), image, describe))
40+
::comemo::internal::CACHE.with(|cache| {
41+
cache.query(
42+
stringify!(describe),
43+
::comemo::internal::Args((image,)),
44+
describe,
45+
)
46+
})
4247
}
4348

4449
const _: () = {

tests/image.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@ use comemo::{Track, Tracked};
44
fn test_image() {
55
let mut image = Image::new(20, 40);
66

7-
// [Miss] The cache is empty.
8-
describe(image.track());
9-
10-
// [Hit] Everything stayed the same.
11-
describe(image.track());
7+
describe(image.track()); // [Miss] The cache is empty.
8+
describe(image.track()); // [Hit] Everything stayed the same.
129

1310
image.resize(80, 30);
1411

15-
// [Miss] The image's width and height are different.
16-
describe(image.track());
12+
describe(image.track()); // [Miss] Width and height changed.
13+
select(image.track(), "width"); // [Miss] First call.
14+
select(image.track(), "height"); // [Miss]
1715

1816
image.resize(80, 70);
1917
image.pixels.fill(255);
2018

21-
// [Hit] The last call only read the width and it stayed the same.
22-
describe(image.track());
19+
describe(image.track()); // [Hit] Width is > 50 stayed the same.
20+
select(image.track(), "width"); // [Hit] Width stayed the same.
2321
}
2422

2523
/// Format the image's size humanly readable.
@@ -32,6 +30,16 @@ fn describe(image: Tracked<Image>) -> &'static str {
3230
}
3331
}
3432

33+
/// Select either width or height.
34+
#[comemo::memoize]
35+
fn select(image: Tracked<Image>, what: &str) -> u32 {
36+
match what {
37+
"width" => image.width(),
38+
"height" => image.height(),
39+
_ => panic!("there is nothing else!"),
40+
}
41+
}
42+
3543
#[comemo::track]
3644
impl Image {
3745
fn width(&self) -> u32 {

tests/simple.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#[test]
2+
fn test_simple() {
3+
empty(); // [Miss] The cache is empty.
4+
empty(); // [Hit] Always a hit from now on.
5+
empty(); // [Hit] Always a hit from now on.
6+
7+
double(2); // [Miss] The cache is empty.
8+
double(4); // [Miss] Different number.
9+
double(2); // [Hit] Same number as initially.
10+
11+
sum(2, 4); // [Miss] The cache is empty.
12+
sum(2, 3); // [Miss] Different numbers.
13+
sum(2, 3); // [Hit] Same numbers
14+
sum(4, 2); // [Miss] Different numbers.
15+
}
16+
17+
/// Build a string.
18+
#[comemo::memoize]
19+
fn empty() -> String {
20+
format!("The world is {}", "big")
21+
}
22+
23+
/// Double a number.
24+
#[comemo::memoize]
25+
fn double(x: u32) -> u32 {
26+
2 * x
27+
}
28+
29+
/// Compute the sum of two numbers.
30+
#[comemo::memoize]
31+
fn sum(a: u32, b: u32) -> u32 {
32+
a + b
33+
}

0 commit comments

Comments
 (0)