Skip to content

Commit a6bf08d

Browse files
committed
Implement more general buffer key downcasting
Signed-off-by: Michael X. Grey <greyxmike@gmail.com>
1 parent c419ab3 commit a6bf08d

File tree

4 files changed

+97
-28
lines changed

4 files changed

+97
-28
lines changed

src/buffer.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ impl<T> BufferKey<T> {
251251
self.tag.session
252252
}
253253

254-
pub(crate) fn tag(&self) -> &BufferKeyTag {
254+
pub fn tag(&self) -> &BufferKeyTag {
255255
&self.tag
256256
}
257257

@@ -286,11 +286,8 @@ impl<T> std::fmt::Debug for BufferKey<T> {
286286

287287
/// The identifying information for a buffer key. This does not indicate
288288
/// anything about the type of messages that the buffer can contain.
289-
///
290-
/// This struct will be internal to the crate until we decide to make
291-
/// [`BufferAccessLifecycle`] a public struct.
292289
#[derive(Clone)]
293-
pub(crate) struct BufferKeyTag {
290+
pub struct BufferKeyTag {
294291
pub buffer: Entity,
295292
pub session: Entity,
296293
pub accessor: Entity,

src/buffer/any_buffer.rs

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::{
2121
any::{Any, TypeId},
2222
collections::{hash_map::Entry, HashMap},
2323
ops::RangeBounds,
24-
sync::{Arc, Mutex, OnceLock},
24+
sync::{Mutex, OnceLock},
2525
};
2626

2727
use bevy_ecs::{
@@ -34,10 +34,10 @@ use thiserror::Error as ThisError;
3434
use smallvec::SmallVec;
3535

3636
use crate::{
37-
add_listener_to_source, Accessed, Buffer, BufferAccessLifecycle, BufferAccessMut,
38-
BufferAccessors, BufferError, BufferKey, BufferKeyTag, BufferLocation, BufferStorage,
39-
Bufferable, Buffered, Builder, DrainBuffer, Gate, GateState, InspectBuffer, Joined,
40-
ManageBuffer, NotifyBufferUpdate, OperationError, OperationResult, OperationRoster, OrBroken,
37+
add_listener_to_source, Accessed, Buffer, BufferAccessMut, BufferAccessors, BufferError,
38+
BufferKey, BufferKeyTag, BufferLocation, BufferStorage, Bufferable, Buffered, Builder,
39+
DrainBuffer, Gate, GateState, InspectBuffer, Joined, ManageBuffer, NotifyBufferUpdate,
40+
OperationError, OperationResult, OperationRoster, OrBroken,
4141
};
4242

4343
/// A [`Buffer`] whose message type has been anonymized. Joining with this buffer
@@ -143,18 +143,29 @@ pub struct AnyBufferKey {
143143
}
144144

145145
impl AnyBufferKey {
146-
/// Downcast this into a concrete [`BufferKey`] type.
147-
pub fn downcast<T: 'static>(&self) -> Option<BufferKey<T>> {
148-
if TypeId::of::<T>() == self.interface.message_type_id() {
146+
/// Downcast this into a concrete [`BufferKey`] for the specified message type.
147+
///
148+
/// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead.
149+
pub fn downcast_for_message<Message: 'static>(self) -> Option<BufferKey<Message>> {
150+
if TypeId::of::<Message>() == self.interface.message_type_id() {
149151
Some(BufferKey {
150-
tag: self.tag.clone(),
152+
tag: self.tag,
151153
_ignore: Default::default(),
152154
})
153155
} else {
154156
None
155157
}
156158
}
157159

160+
/// Downcast this into a different special buffer key representation, such
161+
/// as a `JsonBufferKey`.
162+
pub fn downcast_buffer_key<KeyType: 'static>(self) -> Option<KeyType> {
163+
self.interface.key_downcast(TypeId::of::<KeyType>())?(self.tag)
164+
.downcast::<KeyType>()
165+
.ok()
166+
.map(|x| *x)
167+
}
168+
158169
/// The buffer ID of this key.
159170
pub fn id(&self) -> Entity {
160171
self.tag.buffer
@@ -780,6 +791,10 @@ pub trait AnyBufferAccessInterface {
780791

781792
fn buffer_downcast(&self, buffer_type: TypeId) -> Option<BufferDowncastRef>;
782793

794+
fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox);
795+
796+
fn key_downcast(&self, key_type: TypeId) -> Option<KeyDowncastRef>;
797+
783798
fn pull(
784799
&self,
785800
entity_mut: &mut EntityWorldMut,
@@ -800,16 +815,23 @@ pub trait AnyBufferAccessInterface {
800815

801816
pub type BufferDowncastBox = Box<dyn Fn(BufferLocation) -> Box<dyn Any> + Send + Sync>;
802817
pub type BufferDowncastRef = &'static (dyn Fn(BufferLocation) -> Box<dyn Any> + Send + Sync);
818+
pub type KeyDowncastBox = Box<dyn Fn(BufferKeyTag) -> Box<dyn Any> + Send + Sync>;
819+
pub type KeyDowncastRef = &'static (dyn Fn(BufferKeyTag) -> Box<dyn Any> + Send + Sync);
803820

804821
struct AnyBufferAccessImpl<T> {
805822
buffer_downcasts: Mutex<HashMap<TypeId, BufferDowncastRef>>,
823+
key_downcasts: Mutex<HashMap<TypeId, KeyDowncastRef>>,
806824
_ignore: std::marker::PhantomData<fn(T)>,
807825
}
808826

809827
impl<T: 'static + Send + Sync> AnyBufferAccessImpl<T> {
810828
fn new() -> Self {
811829
let mut buffer_downcasts: HashMap<_, BufferDowncastRef> = HashMap::new();
812830

831+
// SAFETY: These leaks are okay because we will only ever instantiate
832+
// AnyBufferAccessImpl once per generic argument T, which puts a firm
833+
// ceiling on how many of these callbacks will get leaked.
834+
813835
// Automatically register a downcast into AnyBuffer
814836
buffer_downcasts.insert(
815837
TypeId::of::<AnyBuffer>(),
@@ -821,8 +843,22 @@ impl<T: 'static + Send + Sync> AnyBufferAccessImpl<T> {
821843
})),
822844
);
823845

846+
let mut key_downcasts: HashMap<_, KeyDowncastRef> = HashMap::new();
847+
848+
// Automatically register a downcast to AnyBufferKey
849+
key_downcasts.insert(
850+
TypeId::of::<AnyBufferKey>(),
851+
Box::leak(Box::new(|tag| -> Box<dyn Any> {
852+
Box::new(AnyBufferKey {
853+
tag,
854+
interface: AnyBuffer::interface_for::<T>(),
855+
})
856+
})),
857+
);
858+
824859
Self {
825860
buffer_downcasts: Mutex::new(buffer_downcasts),
861+
key_downcasts: Mutex::new(key_downcasts),
826862
_ignore: Default::default(),
827863
}
828864
}
@@ -862,6 +898,19 @@ impl<T: 'static + Send + Sync + Any> AnyBufferAccessInterface for AnyBufferAcces
862898
.copied()
863899
}
864900

901+
fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox) {
902+
let mut downcasts = self.key_downcasts.lock().unwrap();
903+
904+
if let Entry::Vacant(entry) = downcasts.entry(key_type) {
905+
// We should only leak this in to the register once per type
906+
entry.insert(Box::leak(f));
907+
}
908+
}
909+
910+
fn key_downcast(&self, key_type: TypeId) -> Option<KeyDowncastRef> {
911+
self.key_downcasts.lock().unwrap().get(&key_type).copied()
912+
}
913+
865914
fn pull(
866915
&self,
867916
entity_mut: &mut EntityWorldMut,

src/buffer/buffer_access_lifecycle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::{emit_disposal, ChannelItem, Disposal, OperationRoster};
2929
/// we would be needlessly doing a reachability check every time the key gets
3030
/// cloned.
3131
#[derive(Clone)]
32-
pub(crate) struct BufferAccessLifecycle {
32+
pub struct BufferAccessLifecycle {
3333
scope: Entity,
3434
accessor: Entity,
3535
session: Entity,

src/buffer/json_buffer.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::{
2121
any::TypeId,
2222
collections::HashMap,
2323
ops::RangeBounds,
24-
sync::{Arc, Mutex, OnceLock},
24+
sync::{Mutex, OnceLock},
2525
};
2626

2727
use bevy_ecs::{
@@ -143,20 +143,19 @@ pub struct JsonBufferKey {
143143

144144
impl JsonBufferKey {
145145
/// Downcast this into a concrete [`BufferKey`] for the specified message type.
146-
pub fn downcast_for_message<T: 'static>(&self) -> Option<BufferKey<T>> {
147-
if TypeId::of::<T>() == self.interface.any_access_interface().message_type_id() {
148-
Some(BufferKey {
149-
tag: self.tag.clone(),
150-
_ignore: Default::default(),
151-
})
152-
} else {
153-
None
154-
}
146+
///
147+
/// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead.
148+
pub fn downcast_for_message<T: 'static>(self) -> Option<BufferKey<T>> {
149+
self.as_any_buffer_key().downcast_for_message()
150+
}
151+
152+
pub fn downcast_buffer_key<KeyType: 'static>(self) -> Option<KeyType> {
153+
self.as_any_buffer_key().downcast_buffer_key()
155154
}
156155

157156
/// Cast this into an [`AnyBufferKey`]
158-
pub fn as_any_buffer_key(&self) -> AnyBufferKey {
159-
self.clone().into()
157+
pub fn as_any_buffer_key(self) -> AnyBufferKey {
158+
self.into()
160159
}
161160

162161
fn deep_clone(&self) -> Self {
@@ -816,6 +815,16 @@ impl<T: 'static + Send + Sync + Serialize + DeserializeOwned> JsonBufferAccessIm
816815
}),
817816
);
818817

818+
any_interface.register_key_downcast(
819+
TypeId::of::<JsonBufferKey>(),
820+
Box::new(|tag| {
821+
Box::new(JsonBufferKey {
822+
tag,
823+
interface: Self::get_interface(),
824+
})
825+
}),
826+
);
827+
819828
// Create and cache the json buffer access interface
820829
static INTERFACE_MAP: OnceLock<
821830
Mutex<HashMap<TypeId, &'static (dyn JsonBufferAccessInterface + Send + Sync)>>,
@@ -1319,7 +1328,21 @@ mod tests {
13191328
let _original_from_json: Buffer<TestMessage> =
13201329
json_buffer.downcast_for_message().unwrap();
13211330

1322-
builder.connect(scope.input, scope.terminate);
1331+
scope
1332+
.input
1333+
.chain(builder)
1334+
.with_access(buffer)
1335+
.map_block(|(data, key)| {
1336+
let any_key: AnyBufferKey = key.clone().into();
1337+
let json_key: JsonBufferKey = any_key.clone().downcast_buffer_key().unwrap();
1338+
let _original_from_any: BufferKey<TestMessage> =
1339+
any_key.downcast_for_message().unwrap();
1340+
let _original_from_json: BufferKey<TestMessage> =
1341+
json_key.downcast_for_message().unwrap();
1342+
1343+
data
1344+
})
1345+
.connect(scope.terminate);
13231346
});
13241347

13251348
let mut promise = context.command(|commands| commands.request(1, workflow).take_response());

0 commit comments

Comments
 (0)