Skip to content

derive Joined to implement JoinedValue and BufferMapLayout #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1137c08
derive joined
koonpeng Feb 5, 2025
9bc7cfe
remove use of unwrap
koonpeng Feb 5, 2025
83d6595
use full path
koonpeng Feb 5, 2025
5dbfa63
cleanup
koonpeng Feb 5, 2025
68681e4
derive BufferKeyMap
koonpeng Feb 6, 2025
1deada0
comments
koonpeng Feb 6, 2025
b612a05
Merge remote-tracking branch 'origin/buffer_map' into koonpeng/derive…
koonpeng Feb 11, 2025
af2caf7
support generics
koonpeng Feb 12, 2025
cff7138
add `select_buffers` to avoid the need to directly reference generate…
koonpeng Feb 12, 2025
bb88c8d
remove support for customizing buffers, select_buffers allows any buf…
koonpeng Feb 12, 2025
957f17d
revert changes to AnyBuffer
koonpeng Feb 12, 2025
dde0e6d
add helper attribute to customized generated buffer struct ident
koonpeng Feb 13, 2025
01af23f
remove builder argument in select_buffers; add test for select_buffer…
koonpeng Feb 13, 2025
25ec4f2
wip buffer_type helper but without auto downcasting, doesnt work due …
koonpeng Feb 13, 2025
943ea33
check for any
koonpeng Feb 13, 2025
37ee9b9
allow buffer_downcast to downcast back to the original Buffer<T>
koonpeng Feb 14, 2025
a00f09f
move test_select_buffers_json
koonpeng Feb 14, 2025
e61c3ff
put unused code into its own mod; to_phantom_data uses fn(...) to all…
koonpeng Feb 14, 2025
5299c25
rename helper attributes
koonpeng Feb 17, 2025
81687f9
Merge remote-tracking branch 'origin/buffer_map' into koonpeng/derive…
koonpeng Feb 17, 2025
b110e87
Test for generics in buffers, and fix Clone/Copy semantics
mxgrey Feb 17, 2025
31b3927
Fix style
mxgrey Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 82 additions & 23 deletions macros/src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_quote, Ident, ItemStruct, Type};
use syn::{parse_quote, Field, Generics, Ident, ItemStruct, Type, TypePath};

use crate::Result;

pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream> {
let struct_ident = &input_struct.ident;
let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl();
let (field_ident, field_type) = get_fields_map(&input_struct.fields)?;
let BufferStructConfig {
struct_name: buffer_struct_ident,
} = BufferStructConfig::from_data_struct(&input_struct);
let StructConfig {
buffer_struct_name: buffer_struct_ident,
} = StructConfig::from_data_struct(&input_struct);
let buffer_struct_vis = &input_struct.vis;

let (field_ident, _, field_config) = get_fields_map(&input_struct.fields)?;
let buffer_type: Vec<&Type> = field_config
.iter()
.map(|config| &config.buffer_type)
.collect();

let buffer_struct: ItemStruct = parse_quote! {
#[derive(Clone)]
#[allow(non_camel_case_types)]
#buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause {
#(
#buffer_struct_vis #field_ident: ::bevy_impulse::Buffer<#field_type>,
#buffer_struct_vis #field_ident: #buffer_type,
)*
}
};
Expand All @@ -35,14 +40,13 @@ pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream

impl #impl_generics #struct_ident #ty_generics #where_clause {
fn select_buffers(
builder: &mut ::bevy_impulse::Builder,
#(
#field_ident: impl ::bevy_impulse::Bufferable<BufferType = ::bevy_impulse::Buffer<#field_type>>,
#field_ident: #buffer_type,
)*
) -> #buffer_struct_ident #ty_generics {
#buffer_struct_ident {
#(
#field_ident: #field_ident.into_buffer(builder),
#field_ident,
)*
}
}
Expand All @@ -56,14 +60,30 @@ pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream
Ok(gen.into())
}

struct BufferStructConfig {
struct_name: Ident,
/// Converts a list of generics to a [`PhantomData`] TypePath.
/// e.g. `::std::marker::PhantomData<(T,)>`
// Currently unused but could be used in the future
fn _to_phantom_data(generics: &Generics) -> TypePath {
let lifetimes: Vec<Type> = generics
.lifetimes()
.map(|lt| {
let lt = &lt.lifetime;
let ty: Type = parse_quote! { & #lt () };
ty
})
.collect();
let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect();
parse_quote! { ::std::marker::PhantomData<(#(#lifetimes,)* #(#ty_params,)*)> }
}

struct StructConfig {
buffer_struct_name: Ident,
}

impl BufferStructConfig {
impl StructConfig {
fn from_data_struct(data_struct: &ItemStruct) -> Self {
let mut config = Self {
struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident),
buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident),
};

let attr = data_struct
Expand All @@ -73,8 +93,8 @@ impl BufferStructConfig {

if let Some(attr) = attr {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("struct_name") {
config.struct_name = meta.value()?.parse()?;
if meta.path.is_ident("buffer_struct_name") {
config.buffer_struct_name = meta.value()?.parse()?;
}
Ok(())
})
Expand All @@ -86,20 +106,53 @@ impl BufferStructConfig {
}
}

fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>)> {
struct FieldConfig {
buffer_type: Type,
}

impl FieldConfig {
fn from_field(field: &Field) -> Self {
let ty = &field.ty;
let mut config = Self {
buffer_type: parse_quote! { ::bevy_impulse::Buffer<#ty> },
};

let attr = field
.attrs
.iter()
.find(|attr| attr.path().is_ident("buffers"));

if let Some(attr) = attr {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("buffer_type") {
config.buffer_type = meta.value()?.parse()?;
}
Ok(())
})
// panic if attribute is malformed, this will result in a compile error which is intended.
.unwrap();
}

config
}
}

fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>, Vec<FieldConfig>)> {
match fields {
syn::Fields::Named(data) => {
let mut idents = Vec::new();
let mut types = Vec::new();
let mut configs = Vec::new();
for field in &data.named {
let ident = field
.ident
.as_ref()
.ok_or("expected named fields".to_string())?;
idents.push(ident);
types.push(&field.ty);
configs.push(FieldConfig::from_field(field));
}
Ok((idents, types))
Ok((idents, types, configs))
}
_ => return Err("expected named fields".to_string()),
}
Expand All @@ -114,7 +167,11 @@ fn impl_buffer_map_layout(
) -> Result<proc_macro2::TokenStream> {
let struct_ident = &buffer_struct.ident;
let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl();
let (field_ident, field_type) = get_fields_map(&item_struct.fields)?;
let (field_ident, _, field_config) = get_fields_map(&item_struct.fields)?;
let buffer_type: Vec<&Type> = field_config
.iter()
.map(|config| &config.buffer_type)
.collect();
let map_key: Vec<String> = field_ident.iter().map(|v| v.to_string()).collect();

Ok(quote! {
Expand All @@ -129,16 +186,18 @@ fn impl_buffer_map_layout(
fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result<Self, ::bevy_impulse::IncompatibleLayout> {
let mut compatibility = ::bevy_impulse::IncompatibleLayout::default();
#(
let #field_ident = if let Ok(buffer) = compatibility.require_message_type::<#field_type>(#map_key, buffers) {
let #field_ident = if let Ok(buffer) = compatibility.require_buffer_type::<#buffer_type>(#map_key, buffers) {
buffer
} else {
return Err(compatibility);
};
)*

Ok(Self {#(
#field_ident,
)*})
Ok(Self {
#(
#field_ident,
)*
})
}
}
}
Expand All @@ -155,7 +214,7 @@ fn impl_joined(
let struct_ident = &joined_struct.ident;
let item_struct_ident = &item_struct.ident;
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
let (field_ident, _) = get_fields_map(&item_struct.fields)?;
let (field_ident, _, _) = get_fields_map(&item_struct.fields)?;

Ok(quote! {
impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause {
Expand Down
4 changes: 4 additions & 0 deletions src/buffer/any_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ impl AnyBuffer {
.ok()
.map(|x| *x)
}

pub fn as_any_buffer(&self) -> Self {
self.clone().into()
}
}

impl<T: 'static + Send + Sync + Any> From<Buffer<T>> for AnyBuffer {
Expand Down
87 changes: 81 additions & 6 deletions src/buffer/buffer_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,15 @@ impl Accessed for BufferMap {
mod tests {
use crate::{prelude::*, testing::*, BufferMap};

#[derive(Clone, JoinedValue)]
#[derive(JoinedValue)]
struct TestJoinedValue<T: Send + Sync + 'static + Clone> {
integer: i64,
float: f64,
string: String,
generic: T,
#[buffers(buffer_type = AnyBuffer)]
#[allow(unused)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to pass a real value through this AnyBuffer and test that it comes out correctly on the other side instead of leaving it unused.

any: AnyMessageBox,
}

#[test]
Expand All @@ -332,26 +335,32 @@ mod tests {
let buffer_f64 = builder.create_buffer(BufferSettings::default());
let buffer_string = builder.create_buffer(BufferSettings::default());
let buffer_generic = builder.create_buffer(BufferSettings::default());
let buffer_any = builder.create_buffer(BufferSettings::default());

let mut buffers = BufferMap::default();
buffers.insert("integer", buffer_i64);
buffers.insert("float", buffer_f64);
buffers.insert("string", buffer_string);
buffers.insert("generic", buffer_generic);
buffers.insert("any", buffer_any);

scope.input.chain(builder).fork_unzip((
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
|chain: Chain<_>| chain.connect(buffer_generic.input_slot()),
|chain: Chain<_>| chain.connect(buffer_any.input_slot()),
));

builder.try_join(&buffers).unwrap().connect(scope.terminate);
});

let mut promise = context.command(|commands| {
commands
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
.request(
(5_i64, 3.14_f64, "hello".to_string(), "world", ()),
workflow,
)
.take_response()
});

Expand All @@ -373,33 +382,99 @@ mod tests {
let buffer_f64 = builder.create_buffer(BufferSettings::default());
let buffer_string = builder.create_buffer(BufferSettings::default());
let buffer_generic = builder.create_buffer(BufferSettings::default());
let buffer_any = builder.create_buffer::<()>(BufferSettings::default());

scope.input.chain(builder).fork_unzip((
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
|chain: Chain<_>| chain.connect(buffer_generic.input_slot()),
|chain: Chain<_>| chain.connect(buffer_any.input_slot()),
));

let buffers = TestJoinedValue::select_buffers(
builder,
buffer_i64,
buffer_f64,
buffer_string,
buffer_generic,
buffer_any.into(),
);

builder.join(buffers).connect(scope.terminate);
});

let mut promise = context.command(|commands| {
commands
.request(
(5_i64, 3.14_f64, "hello".to_string(), "world", ()),
workflow,
)
.take_response()
});

context.run_with_conditions(&mut promise, Duration::from_secs(2));
let value: TestJoinedValue<&'static str> = promise.take().available().unwrap();
assert_eq!(value.integer, 5);
assert_eq!(value.float, 3.14);
assert_eq!(value.string, "hello");
assert_eq!(value.generic, "world");
assert!(context.no_unhandled_errors());
}

#[test]
fn test_select_buffers_json() {
let mut context = TestingContext::minimal_plugins();

let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer_i64 =
JsonBuffer::from(builder.create_buffer::<i64>(BufferSettings::default()));
let buffer_f64 =
JsonBuffer::from(builder.create_buffer::<f64>(BufferSettings::default()));
let buffer_string =
JsonBuffer::from(builder.create_buffer::<String>(BufferSettings::default()));
let buffer_generic =
JsonBuffer::from(builder.create_buffer::<String>(BufferSettings::default()));
let buffer_any =
JsonBuffer::from(builder.create_buffer::<()>(BufferSettings::default()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to move this test into the json_buffer.rs module since JsonBuffer requires the diagram feature, but buffer_map.rs doesn't.


let buffers = TestJoinedValue::select_buffers(
buffer_i64.downcast_for_message().unwrap(),
buffer_f64.downcast_for_message().unwrap(),
buffer_string.downcast_for_message().unwrap(),
buffer_generic.downcast_for_message().unwrap(),
buffer_any.downcast_buffer().unwrap(),
);

scope.input.chain(builder).fork_unzip((
|chain: Chain<_>| chain.connect(buffers.integer.input_slot()),
|chain: Chain<_>| chain.connect(buffers.float.input_slot()),
|chain: Chain<_>| chain.connect(buffers.string.input_slot()),
|chain: Chain<_>| chain.connect(buffers.generic.input_slot()),
|chain: Chain<_>| {
chain.connect(buffers.any.downcast_for_message().unwrap().input_slot())
},
));

builder.join(buffers).connect(scope.terminate);
});

let mut promise = context.command(|commands| {
commands
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
.request(
(
5_i64,
3.14_f64,
"hello".to_string(),
"world".to_string(),
(),
),
workflow,
)
.take_response()
});

context.run_with_conditions(&mut promise, Duration::from_secs(2));
let value: TestJoinedValue<&'static str> = promise.take().available().unwrap();
let value: TestJoinedValue<String> = promise.take().available().unwrap();
assert_eq!(value.integer, 5);
assert_eq!(value.float, 3.14);
assert_eq!(value.string, "hello");
Expand All @@ -408,7 +483,7 @@ mod tests {
}

#[derive(Clone, JoinedValue)]
#[buffers(struct_name = FooBuffers)]
#[buffers(buffer_struct_name = FooBuffers)]
struct TestDeriveWithConfig {}

#[test]
Expand Down
Loading