Skip to content

Commit c2e0232

Browse files
committed
Memoizable methods
1 parent 001a04f commit c2e0232

File tree

8 files changed

+227
-46
lines changed

8 files changed

+227
-46
lines changed

macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ mod track;
1515

1616
use proc_macro::TokenStream as BoundaryStream;
1717
use proc_macro2::TokenStream;
18-
use quote::{quote, quote_spanned};
18+
use quote::{quote, quote_spanned, ToTokens};
1919
use syn::spanned::Spanned;
2020
use syn::{parse_quote, Error, Result};
2121

macros/src/memoize.rs

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,67 @@ pub fn expand(item: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1313
struct Function {
1414
item: syn::ItemFn,
1515
name: syn::Ident,
16-
args: Vec<syn::Ident>,
17-
types: Vec<syn::Type>,
16+
args: Vec<Argument>,
1817
output: syn::Type,
1918
}
2019

20+
/// An argument to a memoized function.
21+
enum Argument {
22+
Ident(syn::Ident),
23+
Receiver(syn::Token![self]),
24+
}
25+
26+
impl ToTokens for Argument {
27+
fn to_tokens(&self, tokens: &mut TokenStream) {
28+
match self {
29+
Self::Ident(ident) => ident.to_tokens(tokens),
30+
Self::Receiver(token) => token.to_tokens(tokens),
31+
}
32+
}
33+
}
34+
2135
/// Preprocess and validate a function.
2236
fn prepare(function: &syn::ItemFn) -> Result<Function> {
2337
let mut args = vec![];
24-
let mut types = vec![];
2538

2639
for input in &function.sig.inputs {
27-
let typed = match input {
28-
syn::FnArg::Typed(typed) => typed,
29-
syn::FnArg::Receiver(_) => {
30-
bail!(function, "methods are not supported")
40+
match input {
41+
syn::FnArg::Receiver(recv) => {
42+
if recv.mutability.is_some() {
43+
bail!(recv, "memoized functions cannot have mutable parameters");
44+
}
45+
46+
args.push(Argument::Receiver(recv.self_token));
47+
}
48+
syn::FnArg::Typed(typed) => {
49+
let name = match typed.pat.as_ref() {
50+
syn::Pat::Ident(syn::PatIdent {
51+
by_ref: None,
52+
mutability: None,
53+
ident,
54+
subpat: None,
55+
..
56+
}) => ident.clone(),
57+
pat => bail!(pat, "only simple identifiers are supported"),
58+
};
59+
60+
let ty = typed.ty.as_ref().clone();
61+
match ty {
62+
syn::Type::Reference(syn::TypeReference {
63+
mutability: Some(_),
64+
..
65+
}) => {
66+
bail!(
67+
typed.ty,
68+
"memoized functions cannot have mutable parameters"
69+
)
70+
}
71+
_ => {}
72+
}
73+
74+
args.push(Argument::Ident(name));
3175
}
32-
};
33-
34-
let name = match typed.pat.as_ref() {
35-
syn::Pat::Ident(syn::PatIdent {
36-
by_ref: None,
37-
mutability: None,
38-
ident,
39-
subpat: None,
40-
..
41-
}) => ident.clone(),
42-
pat => bail!(pat, "only simple identifiers are supported"),
43-
};
44-
45-
let ty = typed.ty.as_ref().clone();
46-
args.push(name);
47-
types.push(ty);
76+
}
4877
}
4978

5079
let output = match &function.sig.output {
@@ -58,36 +87,49 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
5887
item: function.clone(),
5988
name: function.sig.ident.clone(),
6089
args,
61-
types,
6290
output,
6391
})
6492
}
6593

6694
/// Rewrite a function's body to memoize it.
6795
fn process(function: &Function) -> Result<TokenStream> {
96+
// Construct assertions that the arguments fulfill the necessary bounds.
97+
let bounds = function.args.iter().map(|arg| {
98+
quote_spanned! { function.item.span() =>
99+
::comemo::internal::assert_hashable_or_trackable(&#arg);
100+
}
101+
});
102+
68103
// Construct a tuple from all arguments.
69-
let args = &function.args;
104+
let args = function.args.iter().map(|arg| match arg {
105+
Argument::Ident(id) => id.to_token_stream(),
106+
Argument::Receiver(token) => quote! {
107+
::comemo::internal::hash(&#token)
108+
},
109+
});
70110
let arg_tuple = quote! { (#(#args,)*) };
71111

72-
// Construct assertions that the arguments fulfill the necessary bounds.
73-
let bounds = function.types.iter().map(|ty| {
74-
quote! {
75-
::comemo::internal::assert_hashable_or_trackable::<#ty>();
76-
}
112+
// Construct a tuple for all parameters.
113+
let params = function.args.iter().map(|arg| match arg {
114+
Argument::Ident(id) => id.to_token_stream(),
115+
Argument::Receiver(_) => quote! { _ },
77116
});
117+
let param_tuple = quote! { (#(#params,)*) };
78118

79119
// Construct the inner closure.
80120
let output = &function.output;
81121
let body = &function.item.block;
82-
let closure = quote! { |#arg_tuple| -> #output #body };
122+
let closure = quote! { |#param_tuple| -> #output #body };
83123

84124
// Adjust the function's body.
85125
let mut wrapped = function.item.clone();
86126
let name = function.name.to_string();
87127
wrapped.block = parse_quote! { {
128+
struct __ComemoUnique;
88129
#(#bounds;)*
89-
::comemo::internal::cached(
130+
::comemo::internal::memoized(
90131
#name,
132+
::core::any::TypeId::of::<__ComemoUnique>(),
91133
::comemo::internal::Args(#arg_tuple),
92134
#closure,
93135
)

src/cache.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ use crate::input::Input;
1010
use crate::internal::Family;
1111

1212
/// Execute a function or use a cached result for it.
13-
pub fn cached<In, Out, F>(name: &str, input: In, func: F) -> Out
13+
pub fn memoized<In, Out, F>(name: &'static str, unique: TypeId, input: In, func: F) -> Out
1414
where
1515
In: Input,
1616
Out: Debug + Clone + 'static,
17-
F: for<'f> Fn(<In::Tracked as Family<'f>>::Out) -> Out + 'static,
17+
F: for<'f> FnOnce(<In::Tracked as Family<'f>>::Out) -> Out,
1818
{
1919
// Compute the hash of the input's key part.
2020
let hash = {
2121
let mut state = SipHasher::new();
22-
TypeId::of::<F>().hash(&mut state);
22+
unique.hash(&mut state);
2323
input.key(&mut state);
2424
state.finish128().as_u128()
2525
};

src/input.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::internal::Family;
66
use crate::track::{from_parts, to_parts, Track, Trackable, Tracked};
77

88
/// Ensure a type is suitable as input.
9-
pub fn assert_hashable_or_trackable<In: Input>() {}
9+
pub fn assert_hashable_or_trackable<In: Input>(_: &In) {}
1010

1111
/// An input to a cached function.
1212
///

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub use comemo_macros::{memoize, track};
1313
/// These are implementation details. Do not rely on them!
1414
#[doc(hidden)]
1515
pub mod internal {
16-
pub use crate::cache::cached;
16+
pub use crate::cache::memoized;
1717
pub use crate::constraint::{hash, Join, MultiConstraint, SoloConstraint};
1818
pub use crate::input::{assert_hashable_or_trackable, Args};
1919
pub use crate::track::{to_parts, Trackable};

src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
use comemo::{Track, Tracked};
44

55
// TODO
6-
// - Bring over Prehashed
76
// - Tracked return value from tracked method
8-
// - Memoized methods
97
// - Reporting and evicting
108

119
fn main() {
@@ -32,9 +30,11 @@ fn main() {
3230

3331
/// Format the image's size humanly readable.
3432
fn describe(image: Tracked<Image>) -> &'static str {
35-
::comemo::internal::assert_hashable_or_trackable::<Tracked<Image>>();
36-
::comemo::internal::cached(
33+
struct __ComemoUnique;
34+
::comemo::internal::assert_hashable_or_trackable(&image);
35+
::comemo::internal::memoized(
3736
"describe",
37+
::core::any::TypeId::of::<__ComemoUnique>(),
3838
::comemo::internal::Args((image,)),
3939
|(image,)| {
4040
if image.width() > 50 || image.height() > 50 {

tests/kinds.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::hash::Hash;
2+
13
use comemo::Track;
24

35
#[test]
@@ -23,6 +25,11 @@ fn test_kinds() {
2325
unconditional(tracky); // [Miss] The combined length changed.
2426
conditional(tracky, "World"); // [Miss] "World" is now shorter.
2527
ignorant(tracky, "Ignorant"); // [Hit] Doesn't depend on `tester`.
28+
29+
Taker("Hello".into()).take(); // [Miss] Never called.
30+
Taker("Hello".into()).copy(); // [Miss] Never called.
31+
Taker("World".into()).take(); // [Miss] Different value.
32+
Taker("Hello".into()).take(); // [Hit] Same value.
2633
}
2734

2835
/// Always accesses data from both arguments.
@@ -37,14 +44,17 @@ fn unconditional(tester: Tracky) -> &'static str {
3744

3845
/// Accesses data from both arguments conditionally.
3946
#[comemo::memoize]
40-
fn conditional(tester: Tracky, name: &str) -> String {
41-
tester.double_ref(name).to_string()
47+
fn conditional<T>(tester: Tracky, name: T) -> String
48+
where
49+
T: AsRef<str> + Hash,
50+
{
51+
tester.double_ref(name.as_ref()).to_string()
4252
}
4353

4454
/// Accesses only data from the second argument.
4555
#[comemo::memoize]
46-
fn ignorant(tester: Tracky, name: &str) -> String {
47-
tester.arg_ref(name).to_string()
56+
fn ignorant(tester: Tracky, name: impl AsRef<str> + Hash) -> String {
57+
tester.arg_ref(name.as_ref()).to_string()
4858
}
4959

5060
/// Test with type alias.
@@ -81,3 +91,18 @@ impl Tester {
8191
/// A non-copy struct that is passed by value to a tracked method.
8292
#[derive(Clone, PartialEq)]
8393
struct Heavy(String);
94+
95+
#[derive(Hash)]
96+
struct Taker(String);
97+
98+
impl Taker {
99+
#[comemo::memoize]
100+
fn copy(&self) -> String {
101+
self.0.clone()
102+
}
103+
104+
#[comemo::memoize]
105+
fn take(self) -> String {
106+
self.0
107+
}
108+
}

tests/layout.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use comemo::{Prehashed, Track, Tracked};
2+
3+
#[test]
4+
fn test_layout() {
5+
let par = Paragraph(vec![
6+
TextRun {
7+
font: "Helvetica".into(),
8+
text: "HELLO ".into(),
9+
},
10+
TextRun {
11+
font: "Futura".into(),
12+
text: "WORLD!".into(),
13+
},
14+
]);
15+
16+
let mut fonts = Fonts::default();
17+
fonts.insert("Helvetica", Style::Normal, vec![110; 75398]);
18+
fonts.insert("Futura", Style::Italic, vec![55; 12453]);
19+
20+
// [Miss] The cache is empty.
21+
par.layout(fonts.track());
22+
fonts.insert("Verdana", Style::Normal, vec![99; 12554]);
23+
24+
// [Hit] Verdana isn't used.
25+
par.layout(fonts.track());
26+
fonts.insert("Helvetica", Style::Bold, vec![120; 98532]);
27+
28+
// [Miss] Helvetica changed.
29+
par.layout(fonts.track());
30+
}
31+
32+
/// A paragraph composed from text runs.
33+
#[derive(Debug, Hash)]
34+
struct Paragraph(Vec<TextRun>);
35+
36+
impl Paragraph {
37+
/// A memoized method.
38+
#[comemo::memoize]
39+
fn layout(&self, fonts: Tracked<Fonts>) -> String {
40+
let mut result = String::new();
41+
for run in &self.0 {
42+
let font = fonts.select(&run.font).unwrap();
43+
for c in run.text.chars() {
44+
result.push(font.map(c));
45+
}
46+
}
47+
result
48+
}
49+
}
50+
51+
/// A run of text with consistent font.
52+
#[derive(Debug, Hash)]
53+
struct TextRun {
54+
font: String,
55+
text: String,
56+
}
57+
58+
/// Holds all fonts.
59+
///
60+
/// As font data is large and costly to hash, we use the `Prehashed` wrapper.
61+
/// Otherwise, every call to `Fonts::select` would hash the returned font from
62+
/// scratch.
63+
#[derive(Default)]
64+
struct Fonts(Vec<Prehashed<Font>>);
65+
66+
impl Fonts {
67+
/// Insert a new with name and data.
68+
fn insert(&mut self, name: impl Into<String>, style: Style, data: Vec<u8>) {
69+
let name = name.into();
70+
self.0.retain(|font| font.name != name);
71+
self.0.push(Prehashed::new(Font { name, style, data }))
72+
}
73+
}
74+
75+
#[comemo::track]
76+
impl Fonts {
77+
/// Select a font by name.
78+
fn select(&self, name: &str) -> Option<&Prehashed<Font>> {
79+
self.0.iter().find(|font| font.name == name)
80+
}
81+
}
82+
83+
/// A large binary font.
84+
#[derive(Hash)]
85+
struct Font {
86+
name: String,
87+
data: Vec<u8>,
88+
style: Style,
89+
}
90+
91+
impl Font {
92+
/// Map a character.
93+
fn map(&self, c: char) -> char {
94+
let base = match self.style {
95+
Style::Normal => 0x41,
96+
Style::Bold => 0x1D400,
97+
Style::Italic => 0x1D434,
98+
};
99+
100+
if c.is_ascii_alphabetic() {
101+
std::char::from_u32(base + (c as u32 - 0x41)).unwrap()
102+
} else {
103+
c
104+
}
105+
}
106+
}
107+
108+
/// A font style.
109+
#[derive(Hash)]
110+
enum Style {
111+
Normal,
112+
Italic,
113+
Bold,
114+
}

0 commit comments

Comments
 (0)