Skip to content

Commit b20d4b7

Browse files
committed
Less lifetimes and global cache
1 parent 7f25460 commit b20d4b7

File tree

9 files changed

+254
-242
lines changed

9 files changed

+254
-242
lines changed

macros/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ mod memoize;
1313
mod track;
1414

1515
use proc_macro::TokenStream;
16-
use proc_macro2::Span;
1716
use quote::quote;
1817
use syn::{parse_quote, Error, Result};
1918

macros/src/memoize.rs

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::*;
22

33
/// Memoize a function.
44
pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
5-
let name = func.sig.ident.to_string();
5+
let name = &func.sig.ident;
66

77
let mut args = vec![];
88
let mut types = vec![];
@@ -23,21 +23,12 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
2323
types.push(typed.ty.as_ref());
2424
}
2525

26-
let ret = match &func.sig.output {
27-
syn::ReturnType::Default => {
28-
bail!(func.sig, "function must have a return type")
29-
}
30-
syn::ReturnType::Type(.., ty) => ty.as_ref(),
31-
};
32-
33-
let mut inner = func.clone();
34-
inner.sig.ident = syn::Ident::new("inner", Span::call_site());
35-
3626
if args.len() != 1 {
3727
bail!(func, "expected exactly one argument");
3828
}
3929

4030
let arg = args[0];
31+
/*
4132
let ty = types[0];
4233
let track = match ty {
4334
syn::Type::Path(path) => {
@@ -59,56 +50,14 @@ pub fn expand(func: &syn::ItemFn) -> Result<proc_macro2::TokenStream> {
5950
}
6051
_ => bail!(ty, "expected type of the form `Tracked<_>`"),
6152
};
62-
63-
let trackable = quote! {
64-
<#track as ::comemo::internal::Trackable<'static>>
65-
};
66-
67-
let body = quote! {
68-
type Cache = ::core::cell::RefCell<
69-
::std::vec::Vec<(#trackable::Tracker, #ret)>
70-
>;
71-
72-
thread_local! {
73-
static CACHE: Cache = Default::default();
74-
}
75-
76-
let mut hit = true;
77-
let output = CACHE.with(|cache| {
78-
cache
79-
.borrow()
80-
.iter()
81-
.find(|(tracker, _)| {
82-
let (#arg, _) = ::comemo::internal::to_parts(#arg);
83-
#trackable::valid(#arg, tracker)
84-
})
85-
.map(|&(_, output)| output)
86-
});
87-
88-
let output = output.unwrap_or_else(|| {
89-
let tracker = ::core::default::Default::default();
90-
let (#arg, _) = ::comemo::internal::to_parts(#arg);
91-
let #arg = ::comemo::internal::from_parts(#arg, Some(&tracker));
92-
let output = inner(#arg);
93-
CACHE.with(|cache| cache.borrow_mut().push((tracker, output)));
94-
hit = false;
95-
output
96-
});
97-
98-
println!(
99-
"{} {} {}",
100-
#name,
101-
if hit { "[hit]: " } else { "[miss]:" },
102-
output,
103-
);
104-
105-
output
106-
};
53+
*/
10754

10855
let mut outer = func.clone();
10956
outer.block = parse_quote! { {
110-
#inner
111-
{ #body }
57+
#func
58+
::comemo::internal::CACHE.with(|cache|
59+
cache.query(stringify!(#name), #name, #arg)
60+
)
11261
} };
11362

11463
Ok(quote! { #outer })

macros/src/track.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,29 @@ pub fn expand(block: &syn::ItemImpl) -> Result<proc_macro2::TokenStream> {
4646
let track_impl = quote! {
4747
use super::*;
4848

49-
impl<'a> ::comemo::Track<'a> for #ty {}
50-
impl<'a> ::comemo::internal::Trackable<'a> for #ty {
49+
impl ::comemo::Track for #ty {}
50+
impl ::comemo::internal::Trackable for #ty {
5151
type Tracker = Tracker;
52-
type Surface = Surface<'a>;
52+
type Surface = SurfaceFamily;
5353

5454
fn valid(&self, tracker: &Self::Tracker) -> bool {
5555
#(#tracked_valids)&&*
5656
}
5757

58-
fn surface<'s>(tracked: &'s Tracked<'a, #ty>) -> &'s Self::Surface
58+
fn surface<'a, 'r>(tracked: &'r Tracked<'a, #ty>) -> &'r Surface<'a>
5959
where
60-
Self: Track<'a>,
60+
Self: Track,
6161
{
6262
// Safety: Surface is repr(transparent).
63-
unsafe { &*(tracked as *const _ as *const Self::Surface) }
63+
unsafe { &*(tracked as *const _ as *const _) }
6464
}
6565
}
6666

67+
pub enum SurfaceFamily {}
68+
impl<'a> ::comemo::internal::Family<'a> for SurfaceFamily {
69+
type Out = Surface<'a>;
70+
}
71+
6772
#[repr(transparent)]
6873
pub struct Surface<'a>(Tracked<'a, #ty>);
6974

src/cache.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use std::any::Any;
2+
use std::cell::RefCell;
3+
use std::fmt::Debug;
4+
5+
use crate::track::{from_parts, to_parts, Track, Trackable, Tracked};
6+
7+
thread_local! {
8+
/// The global, dynamic cache shared by all memoized functions.
9+
pub static CACHE: Cache = Cache::default();
10+
}
11+
12+
/// An untyped cache.
13+
#[derive(Default)]
14+
pub struct Cache {
15+
map: RefCell<Vec<Box<dyn Any>>>,
16+
}
17+
18+
/// An entry in the cache.
19+
struct Entry<Tracker, R> {
20+
tracker: Tracker,
21+
output: R,
22+
}
23+
24+
impl Cache {
25+
/// Execute `f` or use a cached result for it.
26+
pub fn query<F, T, R>(&self, name: &'static str, f: F, tracked: Tracked<T>) -> R
27+
where
28+
F: Fn(Tracked<T>) -> R,
29+
T: Track,
30+
R: Debug + Clone + 'static,
31+
{
32+
let mut hit = true;
33+
let output = self.lookup(tracked).unwrap_or_else(|| {
34+
let tracker = T::Tracker::default();
35+
let (inner, _) = to_parts(tracked);
36+
let tracked = from_parts(inner, Some(&tracker));
37+
let output = f(tracked);
38+
self.insert::<T, R>(tracker, output.clone());
39+
hit = false;
40+
output
41+
});
42+
43+
let label = if hit { "[hit]: " } else { "[miss]:" };
44+
eprintln!("{name} {label} {output:?}");
45+
46+
output
47+
}
48+
49+
/// Look for a matching entry in the cache.
50+
fn lookup<T, R>(&self, tracked: Tracked<T>) -> Option<R>
51+
where
52+
T: Track,
53+
R: Clone + 'static,
54+
{
55+
let (inner, _) = to_parts(tracked);
56+
self.map
57+
.borrow()
58+
.iter()
59+
.filter_map(|boxed| boxed.downcast_ref::<Entry<T::Tracker, R>>())
60+
.find(|entry| Trackable::valid(inner, &entry.tracker))
61+
.map(|entry| entry.output.clone())
62+
}
63+
64+
/// Insert an entry into the cache.
65+
fn insert<T, R>(&self, tracker: T::Tracker, output: R)
66+
where
67+
T: Track,
68+
R: 'static,
69+
{
70+
let entry = Entry { tracker, output };
71+
self.map.borrow_mut().push(Box::new(entry));
72+
}
73+
}

src/hash.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use std::hash::Hash;
2+
3+
use siphasher::sip128::{Hasher128, SipHasher};
4+
use std::num::NonZeroU128;
5+
6+
/// Produce a non zero 128-bit hash of a value.
7+
pub fn siphash<T: Hash>(value: &T) -> NonZeroU128 {
8+
let mut state = SipHasher::new();
9+
value.hash(&mut state);
10+
state
11+
.finish128()
12+
.as_u128()
13+
.try_into()
14+
.unwrap_or(NonZeroU128::new(u128::MAX).unwrap())
15+
}

src/internal.rs

Lines changed: 0 additions & 66 deletions
This file was deleted.

src/lib.rs

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,15 @@
11
//! Tracked memoization.
22
3-
// These are implementation details. Do not rely on them!
4-
#[doc(hidden)]
5-
pub mod internal;
3+
mod cache;
4+
mod hash;
5+
mod track;
66

7+
pub use crate::track::{Track, Tracked};
78
pub use comemo_macros::{memoize, track};
89

9-
use std::ops::Deref;
10-
11-
/// A trackable type.
12-
///
13-
/// This is implemented by types that have an implementation block annoted with
14-
/// [`#[track]`](track).
15-
pub trait Track<'a>: internal::Trackable<'a> {
16-
/// Start tracking a value.
17-
fn track(&'a self) -> Tracked<'a, Self> {
18-
Tracked { inner: self, tracker: None }
19-
}
20-
}
21-
22-
/// Tracks accesses to a value.
23-
///
24-
/// Encapsulates a reference to a value and tracks all accesses to it.
25-
/// The only methods accessible on `Tracked<T>` are those defined in an impl
26-
/// block for `T` annotated with [`#[track]`](track).
27-
///
28-
/// ```
29-
/// use comemo::Track;
30-
///
31-
/// let image = Image::random(20, 40);
32-
/// let sentence = describe(image.track());
33-
/// println!("{sentence}");
34-
/// ```
35-
pub struct Tracked<'a, T>
36-
where
37-
T: Track<'a>,
38-
{
39-
/// A reference to the tracked value.
40-
inner: &'a T,
41-
/// A tracker which stores constraints for T. It is filled by the tracked
42-
/// methods on T's generated surface type.
43-
///
44-
/// Starts out as `None` and is set to a stack-stored tracker in the
45-
/// preamble of memoized functions.
46-
tracker: Option<&'a T::Tracker>,
47-
}
48-
49-
// The type `Tracked<T>` automatically dereferences to T's generated surface
50-
// type. This makes all tracked methods available, but leaves all other ones
51-
// unaccessible.
52-
impl<'a, T> Deref for Tracked<'a, T>
53-
where
54-
T: Track<'a>,
55-
{
56-
type Target = T::Surface;
57-
58-
fn deref(&self) -> &Self::Target {
59-
T::surface(self)
60-
}
61-
}
62-
63-
impl<'a, T> Copy for Tracked<'a, T> where T: Track<'a> {}
64-
65-
impl<'a, T> Clone for Tracked<'a, T>
66-
where
67-
T: Track<'a>,
68-
{
69-
fn clone(&self) -> Self {
70-
*self
71-
}
10+
/// These are implementation details. Do not rely on them!
11+
#[doc(hidden)]
12+
pub mod internal {
13+
pub use crate::cache::CACHE;
14+
pub use crate::track::{from_parts, to_parts, AccessTracker, Family, Trackable};
7215
}

0 commit comments

Comments
 (0)