Skip to content

Commit ca7a226

Browse files
committed
Draft an example implementation of JoinedValue
Signed-off-by: Michael X. Grey <mxgrey@intrinsic.ai>
1 parent 5dcae45 commit ca7a226

File tree

4 files changed

+319
-22
lines changed

4 files changed

+319
-22
lines changed

src/buffer/any_buffer.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ pub struct AnyBuffer {
4949
pub(crate) interface: &'static (dyn AnyBufferAccessInterface + Send + Sync)
5050
}
5151

52+
impl AnyBuffer {
53+
/// The buffer ID for this key.
54+
pub fn id(&self) -> Entity {
55+
self.source
56+
}
57+
58+
/// Get the type ID of the messages that this buffer supports.
59+
pub fn message_type_id(&self) -> TypeId {
60+
self.interface.message_type_id()
61+
}
62+
}
63+
5264
impl std::fmt::Debug for AnyBuffer {
5365
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5466
f.debug_struct("AnyBuffer")

src/buffer/buffer_map.rs

Lines changed: 288 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,35 @@ use smallvec::SmallVec;
2828
use bevy_ecs::prelude::{Entity, World};
2929

3030
use crate::{
31-
AnyBuffer, OperationError, OperationResult, OperationRoster, Buffered, Gate,
32-
Joined, Accessed, BufferKeyBuilder, AnyBufferKey,
31+
AnyBuffer, AddOperation, Chain, OperationError, OperationResult, OperationRoster, Buffered, Gate,
32+
Join, Joined, Accessed, BufferKeyBuilder, AnyBufferKey, Builder, Output, UnusedTarget, GateState,
33+
add_listener_to_source,
3334
};
3435

35-
#[derive(Clone)]
36+
#[derive(Clone, Default)]
3637
pub struct BufferMap {
3738
inner: HashMap<Cow<'static, str>, AnyBuffer>,
3839
}
3940

41+
impl BufferMap {
42+
/// Insert a named buffer into the map.
43+
pub fn insert(
44+
&mut self,
45+
name: Cow<'static, str>,
46+
buffer: impl Into<AnyBuffer>,
47+
) {
48+
self.inner.insert(name, buffer.into());
49+
}
50+
51+
/// Get one of the buffers from the map by its name.
52+
pub fn get(&self, name: &str) -> Option<&AnyBuffer> {
53+
self.inner.get(name)
54+
}
55+
}
56+
4057
/// This error is used when the buffers provided for an input are not compatible
4158
/// with the layout.
42-
#[derive(ThisError, Debug, Clone)]
59+
#[derive(ThisError, Debug, Clone, Default)]
4360
#[error("the incoming buffer map is incompatible with the layout")]
4461
pub struct IncompatibleLayout {
4562
/// Names of buffers that were missing from the incoming buffer map.
@@ -48,6 +65,38 @@ pub struct IncompatibleLayout {
4865
pub incompatible_buffers: Vec<BufferIncompatibility>,
4966
}
5067

68+
impl IncompatibleLayout {
69+
/// Check whether a named buffer is compatible with a specific type.
70+
pub fn require_buffer<T: 'static>(
71+
&mut self,
72+
expected_name: &str,
73+
buffers: &BufferMap,
74+
) {
75+
if let Some((name, buffer)) = buffers.inner.get_key_value(expected_name) {
76+
if buffer.message_type_id() != TypeId::of::<T>() {
77+
self.incompatible_buffers.push(BufferIncompatibility {
78+
name: name.clone(),
79+
expected: TypeId::of::<T>(),
80+
received: buffer.message_type_id(),
81+
});
82+
}
83+
} else {
84+
self.missing_buffers.push(Cow::Owned(expected_name.to_owned()));
85+
}
86+
}
87+
88+
/// Convert the instance into a result. If any field has content in it, then
89+
/// this will produce an [`Err`]. Otherwise if no incompatibilities are
90+
/// present, this will produce an [`Ok`].
91+
pub fn into_result(self) -> Result<(), IncompatibleLayout> {
92+
if self.missing_buffers.is_empty() && self.incompatible_buffers.is_empty() {
93+
Ok(())
94+
} else {
95+
Err(self)
96+
}
97+
}
98+
}
99+
51100
/// Difference between the expected and received types of a named buffer.
52101
#[derive(Debug, Clone)]
53102
pub struct BufferIncompatibility {
@@ -69,41 +118,82 @@ pub trait BufferMapLayout: Sized {
69118
world: &World,
70119
) -> Result<usize, OperationError>;
71120

121+
fn ensure_active_session(
122+
buffers: &BufferMap,
123+
session: Entity,
124+
world: &mut World,
125+
) -> OperationResult;
126+
72127
fn add_listener(
73128
buffers: &BufferMap,
74129
listener: Entity,
75130
world: &mut World,
76-
) -> OperationResult;
131+
) -> OperationResult {
132+
for buffer in buffers.inner.values() {
133+
add_listener_to_source(buffer.id(), listener, world)?;
134+
}
135+
Ok(())
136+
}
77137

78138
fn gate_action(
79139
buffers: &BufferMap,
80140
session: Entity,
81141
action: Gate,
82142
world: &mut World,
83143
roster: &mut OperationRoster,
84-
) -> OperationResult;
85-
86-
fn as_input(buffers: &BufferMap) -> SmallVec<[Entity; 8]>;
144+
) -> OperationResult {
145+
for buffer in buffers.inner.values() {
146+
GateState::apply(buffer.id(), session, action, world, roster)?;
147+
}
148+
Ok(())
149+
}
87150

88-
fn ensure_active_session(
89-
buffers: &BufferMap,
90-
session: Entity,
91-
world: &mut World,
92-
) -> OperationResult;
151+
fn as_input(buffers: &BufferMap) -> SmallVec<[Entity; 8]> {
152+
let mut inputs = SmallVec::new();
153+
for buffer in buffers.inner.values() {
154+
inputs.push(buffer.id());
155+
}
156+
inputs
157+
}
93158
}
94159

95-
pub trait JoinedValue: BufferMapLayout {
96-
fn buffered_count(
97-
buffers: &BufferMap,
98-
session: Entity,
99-
world: &World,
100-
) -> Result<usize, OperationError>;
160+
pub trait JoinedValue: 'static + BufferMapLayout + Send + Sync{
161+
/// This associated type must represent a buffer map layout that is
162+
/// guaranteed to be compatible for this JoinedValue. Failure to implement
163+
/// this trait accordingly will result in panics.
164+
type Buffers: Into<BufferMap>;
101165

102166
fn pull(
103167
buffers: &BufferMap,
104168
session: Entity,
105-
world: &World,
169+
world: &mut World,
106170
) -> Result<Self, OperationError>;
171+
172+
fn join_into<'w, 's, 'a, 'b>(
173+
buffers: Self::Buffers,
174+
builder: &'b mut Builder<'w, 's, 'a>,
175+
) -> Chain<'w, 's, 'a, 'b, Self> {
176+
Self::try_join_into(buffers.into(), builder).unwrap()
177+
}
178+
179+
fn try_join_into<'w, 's, 'a, 'b>(
180+
buffers: BufferMap,
181+
builder: &'b mut Builder<'w, 's, 'a>,
182+
) -> Result<Chain<'w, 's, 'a, 'b, Self>, IncompatibleLayout> {
183+
let scope = builder.scope();
184+
let buffers = BufferedMap::<Self>::new(buffers)?;
185+
buffers.verify_scope(scope);
186+
187+
let join = builder.commands.spawn(()).id();
188+
let target = builder.commands.spawn(UnusedTarget).id();
189+
builder.commands.add(AddOperation::new(
190+
Some(scope),
191+
join,
192+
Join::new(buffers, target),
193+
));
194+
195+
Ok(Output::new(scope, target).chain(builder))
196+
}
107197
}
108198

109199
/// Trait to describe a layout of buffer keys
@@ -126,6 +216,13 @@ struct BufferedMap<K> {
126216
_ignore: std::marker::PhantomData<fn(K)>,
127217
}
128218

219+
impl<K: BufferMapLayout> BufferedMap<K> {
220+
fn new(map: BufferMap) -> Result<Self, IncompatibleLayout> {
221+
K::is_compatible(&map)?;
222+
Ok(Self { map, _ignore: Default::default() })
223+
}
224+
}
225+
129226
impl<K> Clone for BufferedMap<K> {
130227
fn clone(&self) -> Self {
131228
Self { map: self.map.clone(), _ignore: Default::default() }
@@ -274,3 +371,174 @@ impl BufferMapLayout for AnyBufferKeyMap {
274371
Ok(())
275372
}
276373
}
374+
375+
#[cfg(test)]
376+
mod tests {
377+
use std::borrow::Cow;
378+
379+
use crate::{
380+
prelude::*,
381+
testing::*,
382+
OperationResult, OperationError, OrBroken, InspectBuffer, ManageBuffer, BufferMap,
383+
};
384+
385+
use bevy_ecs::prelude::World;
386+
387+
#[derive(Clone)]
388+
struct TestJoinedValue {
389+
integer: i64,
390+
float: f64,
391+
string: String,
392+
}
393+
394+
impl BufferMapLayout for TestJoinedValue {
395+
fn is_compatible(buffers: &BufferMap) -> Result<(), IncompatibleLayout> {
396+
let mut compatibility = IncompatibleLayout::default();
397+
compatibility.require_buffer::<i64>("integer", buffers);
398+
compatibility.require_buffer::<f64>("float", buffers);
399+
compatibility.require_buffer::<String>("string", buffers);
400+
compatibility.into_result()
401+
}
402+
403+
fn buffered_count(
404+
buffers: &BufferMap,
405+
session: Entity,
406+
world: &World,
407+
) -> Result<usize, OperationError> {
408+
let integer_count = world
409+
.get_entity(buffers.get("integer").unwrap().id())
410+
.or_broken()?
411+
.buffered_count::<i64>(session)?;
412+
413+
let float_count = world
414+
.get_entity(buffers.get("float").unwrap().id())
415+
.or_broken()?
416+
.buffered_count::<f64>(session)?;
417+
418+
let string_count = world
419+
.get_entity(buffers.get("string").unwrap().id())
420+
.or_broken()?
421+
.buffered_count::<String>(session)?;
422+
423+
Ok(
424+
[
425+
integer_count,
426+
float_count,
427+
string_count,
428+
]
429+
.iter()
430+
.min()
431+
.copied()
432+
.unwrap_or(0)
433+
)
434+
}
435+
436+
fn ensure_active_session(
437+
buffers: &BufferMap,
438+
session: Entity,
439+
world: &mut World,
440+
) -> OperationResult {
441+
world
442+
.get_entity_mut(buffers.get("integer").unwrap().id())
443+
.or_broken()?
444+
.ensure_session::<i64>(session)?;
445+
446+
world
447+
.get_entity_mut(buffers.get("float").unwrap().id())
448+
.or_broken()?
449+
.ensure_session::<f64>(session)?;
450+
451+
world
452+
.get_entity_mut(buffers.get("string").unwrap().id())
453+
.or_broken()?
454+
.ensure_session::<String>(session)?;
455+
456+
Ok(())
457+
}
458+
}
459+
460+
impl JoinedValue for TestJoinedValue {
461+
type Buffers = TestJoinedValueBuffers;
462+
463+
fn pull(
464+
buffers: &BufferMap,
465+
session: Entity,
466+
world: &mut World,
467+
) -> Result<Self, OperationError> {
468+
let integer = world
469+
.get_entity_mut(buffers.get("integer").unwrap().id())
470+
.or_broken()?
471+
.pull_from_buffer::<i64>(session)?;
472+
473+
let float = world
474+
.get_entity_mut(buffers.get("float").unwrap().id())
475+
.or_broken()?
476+
.pull_from_buffer::<f64>(session)?;
477+
478+
let string = world
479+
.get_entity_mut(buffers.get("string").unwrap().id())
480+
.or_broken()?
481+
.pull_from_buffer::<String>(session)?;
482+
483+
Ok(Self { integer, float, string })
484+
}
485+
}
486+
487+
struct TestJoinedValueBuffers {
488+
integer: Buffer<i64>,
489+
float: Buffer<f64>,
490+
string: Buffer<String>,
491+
}
492+
493+
impl From<TestJoinedValueBuffers> for BufferMap {
494+
fn from(value: TestJoinedValueBuffers) -> Self {
495+
let mut buffers = BufferMap::default();
496+
buffers.insert(Cow::Borrowed("integer"), value.integer);
497+
buffers.insert(Cow::Borrowed("float"), value.float);
498+
buffers.insert(Cow::Borrowed("string"), value.string);
499+
buffers
500+
}
501+
}
502+
503+
#[test]
504+
fn test_joined_value() {
505+
let mut context = TestingContext::minimal_plugins();
506+
507+
let workflow = context.spawn_io_workflow(|scope, builder| {
508+
let buffer_i64 = builder.create_buffer(BufferSettings::default());
509+
let buffer_f64 = builder.create_buffer(BufferSettings::default());
510+
let buffer_string = builder.create_buffer(BufferSettings::default());
511+
512+
let mut buffers = BufferMap::default();
513+
buffers.insert(Cow::Borrowed("integer"), buffer_i64);
514+
buffers.insert(Cow::Borrowed("float"), buffer_f64);
515+
buffers.insert(Cow::Borrowed("string"), buffer_string);
516+
517+
scope
518+
.input
519+
.chain(builder)
520+
.fork_unzip((
521+
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
522+
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
523+
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
524+
));
525+
526+
builder.try_join_into(buffers).unwrap().connect(scope.terminate);
527+
});
528+
529+
let mut promise = context.command(
530+
|commands| commands.request(
531+
(5_i64, 3.14_f64, "hello".to_string()),
532+
workflow,
533+
)
534+
.take_response()
535+
);
536+
537+
context.run_with_conditions(&mut promise, Duration::from_secs(2));
538+
let value: TestJoinedValue = promise.take().available().unwrap();
539+
assert_eq!(value.integer, 5);
540+
assert_eq!(value.float, 3.14);
541+
assert_eq!(value.string, "hello");
542+
assert!(context.no_unhandled_errors());
543+
}
544+
}

0 commit comments

Comments
 (0)