Skip to content

Commit eadbcf2

Browse files
committed
Support reference arguments in tracked methods
1 parent 457274f commit eadbcf2

File tree

10 files changed

+358
-189
lines changed

10 files changed

+358
-189
lines changed

macros/src/lib.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
extern crate proc_macro;
22

3+
/// Return an error at the given item.
34
macro_rules! bail {
45
($item:expr, $fmt:literal $($tts:tt)*) => {
56
return Err(Error::new_spanned(
@@ -12,24 +13,26 @@ macro_rules! bail {
1213
mod memoize;
1314
mod track;
1415

15-
use proc_macro::TokenStream;
16-
use quote::quote;
16+
use proc_macro::TokenStream as BoundaryStream;
17+
use proc_macro2::TokenStream;
18+
use quote::{quote, quote_spanned};
19+
use syn::spanned::Spanned;
1720
use syn::{parse_quote, Error, Result};
1821

1922
/// Memoize a pure function.
2023
#[proc_macro_attribute]
21-
pub fn memoize(_: TokenStream, stream: TokenStream) -> TokenStream {
24+
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
2225
let func = syn::parse_macro_input!(stream as syn::ItemFn);
23-
memoize::expand(func)
26+
memoize::expand(&func)
2427
.unwrap_or_else(|err| err.to_compile_error())
2528
.into()
2629
}
2730

2831
/// Make a type trackable.
2932
#[proc_macro_attribute]
30-
pub fn track(_: TokenStream, stream: TokenStream) -> TokenStream {
33+
pub fn track(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
3134
let block = syn::parse_macro_input!(stream as syn::ItemImpl);
32-
track::expand(block)
35+
track::expand(&block)
3336
.unwrap_or_else(|err| err.to_compile_error())
3437
.into()
3538
}

macros/src/memoize.rs

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
11
use super::*;
22

33
/// Memoize a function.
4-
pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
4+
pub fn expand(item: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
5+
// Preprocess and validate the function.
6+
let function = prepare(&item)?;
7+
8+
// Rewrite the function's body to memoize it.
9+
process(&function)
10+
}
11+
12+
/// Details about a function that should be memoized.
13+
struct Function {
14+
item: syn::ItemFn,
15+
name: syn::Ident,
16+
args: Vec<syn::Ident>,
17+
types: Vec<syn::Type>,
18+
output: syn::Type,
19+
}
20+
21+
/// Preprocess and validate a function.
22+
fn prepare(function: &syn::ItemFn) -> Result<Function> {
523
let mut args = vec![];
624
let mut types = vec![];
7-
for input in &func.sig.inputs {
25+
26+
for input in &function.sig.inputs {
827
let typed = match input {
928
syn::FnArg::Typed(typed) => typed,
1029
syn::FnArg::Receiver(_) => {
11-
bail!(input, "methods are not supported")
30+
bail!(function, "methods are not supported")
1231
}
1332
};
1433

@@ -19,32 +38,53 @@ pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
1938
ident,
2039
subpat: None,
2140
..
22-
}) => ident,
41+
}) => ident.clone(),
2342
pat => bail!(pat, "only simple identifiers are supported"),
2443
};
2544

26-
let ty = typed.ty.as_ref();
45+
let ty = typed.ty.as_ref().clone();
2746
args.push(name);
2847
types.push(ty);
2948
}
3049

50+
let output = match &function.sig.output {
51+
syn::ReturnType::Default => {
52+
bail!(function.sig, "function must have a return type")
53+
}
54+
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
55+
};
56+
57+
Ok(Function {
58+
item: function.clone(),
59+
name: function.sig.ident.clone(),
60+
args,
61+
types,
62+
output,
63+
})
64+
}
65+
66+
/// Rewrite a function's body to memoize it.
67+
fn process(function: &Function) -> Result<TokenStream> {
3168
// Construct a tuple from all arguments.
69+
let args = &function.args;
3270
let arg_tuple = quote! { (#(#args,)*) };
3371

3472
// Construct assertions that the arguments fulfill the necessary bounds.
35-
let bounds = types.iter().map(|ty| {
73+
let bounds = function.types.iter().map(|ty| {
3674
quote! {
3775
::comemo::internal::assert_hashable_or_trackable::<#ty>();
3876
}
3977
});
4078

4179
// Construct the inner closure.
42-
let body = &func.block;
43-
let closure = quote! { |#arg_tuple| #body };
80+
let output = &function.output;
81+
let body = &function.item.block;
82+
let closure = quote! { |#arg_tuple| -> #output #body };
4483

4584
// Adjust the function's body.
46-
let name = func.sig.ident.to_string();
47-
func.block = parse_quote! { {
85+
let mut wrapped = function.item.clone();
86+
let name = function.name.to_string();
87+
wrapped.block = parse_quote! { {
4888
#(#bounds;)*
4989
::comemo::internal::cached(
5090
#name,
@@ -53,5 +93,5 @@ pub fn expand(mut func: syn::ItemFn) -> Result<proc_macro2::TokenStream> {
5393
)
5494
} };
5595

56-
Ok(quote! { #func })
96+
Ok(quote! { #wrapped })
5797
}

0 commit comments

Comments
 (0)