Skip to content

Commit 25ec4f2

Browse files
committed
wip buffer_type helper but without auto downcasting, doesnt work due to AnyBuffer limitations
Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>
1 parent 01af23f commit 25ec4f2

File tree

3 files changed

+124
-33
lines changed

3 files changed

+124
-33
lines changed

macros/src/buffer.rs

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
use proc_macro2::TokenStream;
22
use quote::{format_ident, quote};
3-
use syn::{parse_quote, Ident, ItemStruct, Type};
3+
use syn::{parse_quote, Field, Generics, Ident, ItemStruct, Type, TypePath};
44

55
use crate::Result;
66

77
pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream> {
88
let struct_ident = &input_struct.ident;
99
let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl();
10-
let (field_ident, field_type) = get_fields_map(&input_struct.fields)?;
11-
let BufferStructConfig {
12-
struct_name: buffer_struct_ident,
13-
} = BufferStructConfig::from_data_struct(&input_struct);
10+
let StructConfig {
11+
buffer_struct_name: buffer_struct_ident,
12+
} = StructConfig::from_data_struct(&input_struct);
1413
let buffer_struct_vis = &input_struct.vis;
1514

15+
let (field_ident, _, field_config) = get_fields_map(&input_struct.fields)?;
16+
let buffer_type: Vec<&Type> = field_config
17+
.iter()
18+
.map(|config| &config.buffer_type)
19+
.collect();
20+
1621
let buffer_struct: ItemStruct = parse_quote! {
1722
#[derive(Clone)]
1823
#[allow(non_camel_case_types)]
1924
#buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause {
2025
#(
21-
#buffer_struct_vis #field_ident: ::bevy_impulse::Buffer<#field_type>,
26+
#buffer_struct_vis #field_ident: #buffer_type,
2227
)*
2328
}
2429
};
@@ -36,7 +41,7 @@ pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream
3641
impl #impl_generics #struct_ident #ty_generics #where_clause {
3742
fn select_buffers(
3843
#(
39-
#field_ident: ::bevy_impulse::Buffer<#field_type>,
44+
#field_ident: #buffer_type,
4045
)*
4146
) -> #buffer_struct_ident #ty_generics {
4247
#buffer_struct_ident {
@@ -55,14 +60,30 @@ pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream
5560
Ok(gen.into())
5661
}
5762

58-
struct BufferStructConfig {
59-
struct_name: Ident,
63+
/// Converts a list of generics to a [`PhantomData`] TypePath.
64+
/// e.g. `::std::marker::PhantomData<(T,)>`
65+
// Currently unused but could be used in the future
66+
fn _to_phantom_data(generics: &Generics) -> TypePath {
67+
let lifetimes: Vec<Type> = generics
68+
.lifetimes()
69+
.map(|lt| {
70+
let lt = &lt.lifetime;
71+
let ty: Type = parse_quote! { & #lt () };
72+
ty
73+
})
74+
.collect();
75+
let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect();
76+
parse_quote! { ::std::marker::PhantomData<(#(#lifetimes,)* #(#ty_params,)*)> }
77+
}
78+
79+
struct StructConfig {
80+
buffer_struct_name: Ident,
6081
}
6182

62-
impl BufferStructConfig {
83+
impl StructConfig {
6384
fn from_data_struct(data_struct: &ItemStruct) -> Self {
6485
let mut config = Self {
65-
struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident),
86+
buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident),
6687
};
6788

6889
let attr = data_struct
@@ -72,8 +93,39 @@ impl BufferStructConfig {
7293

7394
if let Some(attr) = attr {
7495
attr.parse_nested_meta(|meta| {
75-
if meta.path.is_ident("struct_name") {
76-
config.struct_name = meta.value()?.parse()?;
96+
if meta.path.is_ident("buffer_struct_name") {
97+
config.buffer_struct_name = meta.value()?.parse()?;
98+
}
99+
Ok(())
100+
})
101+
// panic if attribute is malformed, this will result in a compile error which is intended.
102+
.unwrap();
103+
}
104+
105+
config
106+
}
107+
}
108+
109+
struct FieldConfig {
110+
buffer_type: Type,
111+
}
112+
113+
impl FieldConfig {
114+
fn from_field(field: &Field) -> Self {
115+
let ty = &field.ty;
116+
let mut config = Self {
117+
buffer_type: parse_quote! { ::bevy_impulse::Buffer<#ty> },
118+
};
119+
120+
let attr = field
121+
.attrs
122+
.iter()
123+
.find(|attr| attr.path().is_ident("buffers"));
124+
125+
if let Some(attr) = attr {
126+
attr.parse_nested_meta(|meta| {
127+
if meta.path.is_ident("buffer_type") {
128+
config.buffer_type = meta.value()?.parse()?;
77129
}
78130
Ok(())
79131
})
@@ -85,20 +137,22 @@ impl BufferStructConfig {
85137
}
86138
}
87139

88-
fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>)> {
140+
fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>, Vec<FieldConfig>)> {
89141
match fields {
90142
syn::Fields::Named(data) => {
91143
let mut idents = Vec::new();
92144
let mut types = Vec::new();
145+
let mut configs = Vec::new();
93146
for field in &data.named {
94147
let ident = field
95148
.ident
96149
.as_ref()
97150
.ok_or("expected named fields".to_string())?;
98151
idents.push(ident);
99152
types.push(&field.ty);
153+
configs.push(FieldConfig::from_field(field));
100154
}
101-
Ok((idents, types))
155+
Ok((idents, types, configs))
102156
}
103157
_ => return Err("expected named fields".to_string()),
104158
}
@@ -113,7 +167,11 @@ fn impl_buffer_map_layout(
113167
) -> Result<proc_macro2::TokenStream> {
114168
let struct_ident = &buffer_struct.ident;
115169
let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl();
116-
let (field_ident, field_type) = get_fields_map(&item_struct.fields)?;
170+
let (field_ident, _, field_config) = get_fields_map(&item_struct.fields)?;
171+
let buffer_type: Vec<&Type> = field_config
172+
.iter()
173+
.map(|config| &config.buffer_type)
174+
.collect();
117175
let map_key: Vec<String> = field_ident.iter().map(|v| v.to_string()).collect();
118176

119177
Ok(quote! {
@@ -128,16 +186,18 @@ fn impl_buffer_map_layout(
128186
fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result<Self, ::bevy_impulse::IncompatibleLayout> {
129187
let mut compatibility = ::bevy_impulse::IncompatibleLayout::default();
130188
#(
131-
let #field_ident = if let Ok(buffer) = compatibility.require_message_type::<#field_type>(#map_key, buffers) {
189+
let #field_ident = if let Ok(buffer) = compatibility.require_buffer_type::<#buffer_type>(#map_key, buffers) {
132190
buffer
133191
} else {
134192
return Err(compatibility);
135193
};
136194
)*
137195

138-
Ok(Self {#(
139-
#field_ident,
140-
)*})
196+
Ok(Self {
197+
#(
198+
#field_ident,
199+
)*
200+
})
141201
}
142202
}
143203
}
@@ -154,7 +214,7 @@ fn impl_joined(
154214
let struct_ident = &joined_struct.ident;
155215
let item_struct_ident = &item_struct.ident;
156216
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
157-
let (field_ident, _) = get_fields_map(&item_struct.fields)?;
217+
let (field_ident, _, _) = get_fields_map(&item_struct.fields)?;
158218

159219
Ok(quote! {
160220
impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause {

src/buffer/any_buffer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ impl AnyBuffer {
121121
.ok()
122122
.map(|x| *x)
123123
}
124+
125+
pub fn as_any_buffer(&self) -> Self {
126+
self.clone().into()
127+
}
124128
}
125129

126130
impl<T: 'static + Send + Sync + Any> From<Buffer<T>> for AnyBuffer {

src/buffer/buffer_map.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,15 @@ impl Accessed for BufferMap {
315315
mod tests {
316316
use crate::{prelude::*, testing::*, BufferMap};
317317

318-
#[derive(Clone, JoinedValue)]
318+
#[derive(JoinedValue)]
319319
struct TestJoinedValue<T: Send + Sync + 'static + Clone> {
320320
integer: i64,
321321
float: f64,
322322
string: String,
323323
generic: T,
324+
#[buffers(buffer_type = AnyBuffer)]
325+
#[allow(unused)]
326+
any: AnyMessageBox,
324327
}
325328

326329
#[test]
@@ -332,26 +335,32 @@ mod tests {
332335
let buffer_f64 = builder.create_buffer(BufferSettings::default());
333336
let buffer_string = builder.create_buffer(BufferSettings::default());
334337
let buffer_generic = builder.create_buffer(BufferSettings::default());
338+
let buffer_any = builder.create_buffer(BufferSettings::default());
335339

336340
let mut buffers = BufferMap::default();
337341
buffers.insert("integer", buffer_i64);
338342
buffers.insert("float", buffer_f64);
339343
buffers.insert("string", buffer_string);
340344
buffers.insert("generic", buffer_generic);
345+
buffers.insert("any", buffer_any);
341346

342347
scope.input.chain(builder).fork_unzip((
343348
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
344349
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
345350
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
346351
|chain: Chain<_>| chain.connect(buffer_generic.input_slot()),
352+
|chain: Chain<_>| chain.connect(buffer_any.input_slot()),
347353
));
348354

349355
builder.try_join(&buffers).unwrap().connect(scope.terminate);
350356
});
351357

352358
let mut promise = context.command(|commands| {
353359
commands
354-
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
360+
.request(
361+
(5_i64, 3.14_f64, "hello".to_string(), "world", ()),
362+
workflow,
363+
)
355364
.take_response()
356365
});
357366

@@ -373,27 +382,33 @@ mod tests {
373382
let buffer_f64 = builder.create_buffer(BufferSettings::default());
374383
let buffer_string = builder.create_buffer(BufferSettings::default());
375384
let buffer_generic = builder.create_buffer(BufferSettings::default());
385+
let buffer_any = builder.create_buffer::<()>(BufferSettings::default());
386+
387+
scope.input.chain(builder).fork_unzip((
388+
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
389+
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
390+
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
391+
|chain: Chain<_>| chain.connect(buffer_generic.input_slot()),
392+
|chain: Chain<_>| chain.connect(buffer_any.input_slot()),
393+
));
376394

377395
let buffers = TestJoinedValue::select_buffers(
378396
buffer_i64,
379397
buffer_f64,
380398
buffer_string,
381399
buffer_generic,
400+
buffer_any.into(),
382401
);
383402

384-
scope.input.chain(builder).fork_unzip((
385-
|chain: Chain<_>| chain.connect(buffers.integer.input_slot()),
386-
|chain: Chain<_>| chain.connect(buffers.float.input_slot()),
387-
|chain: Chain<_>| chain.connect(buffers.string.input_slot()),
388-
|chain: Chain<_>| chain.connect(buffers.generic.input_slot()),
389-
));
390-
391403
builder.join(buffers).connect(scope.terminate);
392404
});
393405

394406
let mut promise = context.command(|commands| {
395407
commands
396-
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
408+
.request(
409+
(5_i64, 3.14_f64, "hello".to_string(), "world", ()),
410+
workflow,
411+
)
397412
.take_response()
398413
});
399414

@@ -419,19 +434,25 @@ mod tests {
419434
JsonBuffer::from(builder.create_buffer::<String>(BufferSettings::default()));
420435
let buffer_generic =
421436
JsonBuffer::from(builder.create_buffer::<String>(BufferSettings::default()));
437+
let buffer_any =
438+
JsonBuffer::from(builder.create_buffer::<()>(BufferSettings::default()));
422439

423440
let buffers = TestJoinedValue::select_buffers(
424441
buffer_i64.downcast_for_message().unwrap(),
425442
buffer_f64.downcast_for_message().unwrap(),
426443
buffer_string.downcast_for_message().unwrap(),
427444
buffer_generic.downcast_for_message().unwrap(),
445+
buffer_any.downcast_buffer().unwrap(),
428446
);
429447

430448
scope.input.chain(builder).fork_unzip((
431449
|chain: Chain<_>| chain.connect(buffers.integer.input_slot()),
432450
|chain: Chain<_>| chain.connect(buffers.float.input_slot()),
433451
|chain: Chain<_>| chain.connect(buffers.string.input_slot()),
434452
|chain: Chain<_>| chain.connect(buffers.generic.input_slot()),
453+
|chain: Chain<_>| {
454+
chain.connect(buffers.any.downcast_for_message().unwrap().input_slot())
455+
},
435456
));
436457

437458
builder.join(buffers).connect(scope.terminate);
@@ -440,7 +461,13 @@ mod tests {
440461
let mut promise = context.command(|commands| {
441462
commands
442463
.request(
443-
(5_i64, 3.14_f64, "hello".to_string(), "world".to_string()),
464+
(
465+
5_i64,
466+
3.14_f64,
467+
"hello".to_string(),
468+
"world".to_string(),
469+
(),
470+
),
444471
workflow,
445472
)
446473
.take_response()
@@ -456,7 +483,7 @@ mod tests {
456483
}
457484

458485
#[derive(Clone, JoinedValue)]
459-
#[buffers(struct_name = FooBuffers)]
486+
#[buffers(buffer_struct_name = FooBuffers)]
460487
struct TestDeriveWithConfig {}
461488

462489
#[test]

0 commit comments

Comments
 (0)