Skip to content

Commit 3c06628

Browse files
committed
Tweak the different provider input structs
Signed-off-by: Michael X. Grey <grey@openrobotics.org>
1 parent 7d18d8a commit 3c06628

File tree

8 files changed

+87
-62
lines changed

8 files changed

+87
-62
lines changed

src/callback.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
use crate::{
19-
BlockingCallback, AsyncCallback, Channel, InnerChannel, ChannelQueue,
19+
BlockingCallback, AsyncCallback, Channel, ChannelQueue,
2020
OperationRoster, StreamPack, Input, Provider, ProvideOnce,
2121
AddOperation, OperateCallback, ManageInput, OperationError,
2222
OrBroken, OperateTask,
@@ -141,10 +141,11 @@ impl<'a> CallbackRequest<'a> {
141141
fn get_channel<Streams: StreamPack>(
142142
&mut self,
143143
session: Entity,
144-
) -> Result<Channel<Streams>, OperationError> {
144+
) -> Result<(Channel, Streams::Channel), OperationError> {
145145
let sender = self.world.get_resource_or_insert_with(|| ChannelQueue::new()).sender.clone();
146-
let channel = InnerChannel::new(self.source, session, sender);
147-
channel.into_specific(&self.world)
146+
let channel = Channel::new(self.source, session, sender);
147+
let streams = channel.for_streams::<Streams>(&self.world)?;
148+
Ok((channel, streams))
148149
}
149150
}
150151

@@ -222,14 +223,14 @@ where
222223
fn call(&mut self, mut input: CallbackRequest) -> Result<(), OperationError> {
223224
let Input { session, data: request } = input.get_request()?;
224225

225-
let channel = input.get_channel(session)?;
226+
let (channel, streams) = input.get_channel::<Streams>(session)?;
226227

227228
if !self.initialized {
228229
self.system.initialize(&mut input.world);
229230
}
230231

231232
let task = self.system.run(AsyncCallback {
232-
request, channel, source: input.source, session,
233+
request, streams, channel, source: input.source, session,
233234
}, &mut input.world);
234235
self.system.apply_deferred(&mut input.world);
235236

@@ -342,9 +343,9 @@ where
342343
fn as_callback(mut self) -> Callback<Self::Request, Self::Response, Self::Streams> {
343344
let callback = move |mut input: CallbackRequest| {
344345
let Input { session, data: request } = input.get_request::<Self::Request>()?;
345-
let channel = input.get_channel(session)?;
346+
let (channel, streams) = input.get_channel::<Streams>(session)?;
346347
let task = (self)(AsyncCallback {
347-
request, channel, source: input.source, session,
348+
request, streams, channel, source: input.source, session,
348349
});
349350
input.give_task(session, task)
350351
};

src/channel.rs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,11 @@ use crate::{
3030
};
3131

3232
#[derive(Clone)]
33-
pub struct Channel<Streams: StreamPack = ()> {
34-
/// Stream channels that will let you send stream information. This will
35-
/// usually be a [`StreamChannel`] or a (possibly nested) tuple of
36-
/// `StreamChannel`s, whichever matches the [`StreamPack`] description.
37-
pub streams: Streams::Channel,
33+
pub struct Channel {
3834
inner: Arc<InnerChannel>,
39-
_ignore: std::marker::PhantomData<Streams>,
4035
}
4136

42-
impl<Streams: StreamPack> Channel<Streams> {
37+
impl Channel {
4338
pub fn query<P: Provider>(&self, request: P::Request, provider: P) -> Promise<P::Response>
4439
where
4540
P::Request: 'static + Send + Sync,
@@ -70,6 +65,23 @@ impl<Streams: StreamPack> Channel<Streams> {
7065

7166
promise
7267
}
68+
69+
pub(crate) fn for_streams<Streams: StreamPack>(
70+
&self,
71+
world: &World,
72+
) -> Result<Streams::Channel, OperationError> {
73+
Ok(Streams::make_channel(&self.inner, world))
74+
}
75+
76+
pub(crate) fn new(
77+
source: Entity,
78+
session: Entity,
79+
sender: CbSender<ChannelItem>,
80+
) -> Self {
81+
Self {
82+
inner: Arc::new(InnerChannel { source, session, sender }),
83+
}
84+
}
7385
}
7486

7587
#[derive(Clone)]
@@ -87,23 +99,6 @@ impl InnerChannel {
8799
pub fn sender(&self) -> &CbSender<ChannelItem> {
88100
&self.sender
89101
}
90-
91-
pub(crate) fn into_specific<Streams: StreamPack>(
92-
self,
93-
world: &World,
94-
) -> Result<Channel<Streams>, OperationError> {
95-
let inner = Arc::new(self);
96-
let streams = Streams::make_channel(&inner, world);
97-
Ok(Channel { inner, streams, _ignore: Default::default() })
98-
}
99-
100-
pub(crate) fn new(
101-
source: Entity,
102-
session: Entity,
103-
sender: CbSender<ChannelItem>,
104-
) -> Self {
105-
InnerChannel { source, session, sender }
106-
}
107102
}
108103

109104
pub(crate) type ChannelItem = Box<dyn FnOnce(&mut World, &mut OperationRoster) + Send>;

src/impulse/map.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use std::future::Future;
2525
use crate::{
2626
Impulsive, OperationSetup, OperationRequest, SingleTargetStorage, StreamPack,
2727
InputBundle, OperationResult, OrBroken, Input, ManageInput,
28-
ChannelQueue, BlockingMap, AsyncMap, InnerChannel, OperateTask, ActiveTasksStorage,
28+
ChannelQueue, BlockingMap, AsyncMap, Channel, OperateTask, ActiveTasksStorage,
2929
CallBlockingMapOnce, CallAsyncMapOnce,
3030
};
3131

@@ -169,11 +169,11 @@ where
169169
let target = source_mut.get::<SingleTargetStorage>().or_broken()?.get();
170170
let f = source_mut.take::<AsyncMapOnceStorage<F>>().or_broken()?.f;
171171

172-
let channel = InnerChannel::new(source, session, sender.clone());
173-
let channel = channel.into_specific(&world)?;
172+
let channel = Channel::new(source, session, sender.clone());
173+
let streams = channel.for_streams::<Streams>(&world)?;
174174

175175
let task = AsyncComputeTaskPool::get().spawn(f.call(AsyncMap {
176-
request, channel, source, session,
176+
request, streams, channel, source, session,
177177
}));
178178

179179
let task_source = world.spawn(()).id();

src/lib.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ use bevy::prelude::{Entity, In};
105105
/// struct Precision(i32);
106106
///
107107
/// fn rounding_service(
108-
/// In(BlockingService{request, provider, ..}): InBlockingService<f64>,
108+
/// In(BlockingService{request, provider, ..}): BlockingServiceInput<f64>,
109109
/// service_precision: Query<&Precision>,
110110
/// global_precision: Res<Precision>,
111111
/// ) -> f64 {
@@ -140,10 +140,13 @@ pub type BlockingServiceInput<Request, Streams = ()> = In<BlockingService<Reques
140140
pub struct AsyncService<Request, Streams: StreamPack = ()> {
141141
/// The input data of the request
142142
pub request: Request,
143+
/// Stream channels that will let you send stream information. This will
144+
/// usually be a [`StreamChannel`] or a (possibly nested) tuple of
145+
/// `StreamChannel`s, whichever matches the [`StreamPack`] description.
146+
pub streams: Streams::Channel,
143147
/// The channel that allows querying and syncing with the world while the
144-
/// service runs asynchronously. Use the [`Channel::streams`] method to
145-
/// send stream output data from the service.
146-
pub channel: Channel<Streams>,
148+
/// service runs asynchronously.
149+
pub channel: Channel,
147150
/// The entity providing the service
148151
pub provider: Entity,
149152
/// The node in a workflow or impulse chain that asked for the service
@@ -185,10 +188,13 @@ pub type BlockingCallbackInput<Request, Streams = ()> = In<BlockingCallback<Requ
185188
pub struct AsyncCallback<Request, Streams: StreamPack = ()> {
186189
/// The input data of the request
187190
pub request: Request,
191+
/// Stream channels that will let you send stream information. This will
192+
/// usually be a [`StreamChannel`] or a (possibly nested) tuple of
193+
/// `StreamChannel`s, whichever matches the [`StreamPack`] description.
194+
pub streams: Streams::Channel,
188195
/// The channel that allows querying and syncing with the world while the
189-
/// service runs asynchronously. Use the [`Channel::streams`] method to
190-
/// send stream output data from the service.
191-
pub channel: Channel<Streams>,
196+
/// service runs asynchronously.
197+
pub channel: Channel,
192198
/// The node in a workflow or impulse chain that asked for the callback
193199
pub source: Entity,
194200
/// The unique session ID for the workflow
@@ -213,9 +219,6 @@ pub struct BlockingMap<Request, Streams: StreamPack = ()> {
213219
pub session: Entity,
214220
}
215221

216-
/// Use this to reduce the bracket noise when you need `In<`[`BlockingMap<R, S>`]`>`.
217-
pub type BlockingMapInput<Request, Streams = ()> = In<BlockingMap<Request, Streams>>;
218-
219222
/// Use AsyncMap to indicate that your function is an async map. A Map is not
220223
/// associated with any entity, and it cannot be a Bevy System. These
221224
/// restrictions allow them to be processed more efficiently.
@@ -226,15 +229,15 @@ pub type BlockingMapInput<Request, Streams = ()> = In<BlockingMap<Request, Strea
226229
pub struct AsyncMap<Request, Streams: StreamPack = ()> {
227230
/// The input data of the request
228231
pub request: Request,
232+
/// Stream channels that will let you send stream information. This will
233+
/// usually be a [`StreamChannel`] or a (possibly nested) tuple of
234+
/// `StreamChannel`s, whichever matches the [`StreamPack`] description.
235+
pub streams: Streams::Channel,
229236
/// The channel that allows querying and syncing with the world while the
230-
/// service runs asynchronously. Use the [`Channel::streams`] method to
231-
/// send stream output data from the service.
232-
pub channel: Channel<Streams>,
237+
/// service runs asynchronously.
238+
pub channel: Channel,
233239
/// The node in a workflow or impulse chain that asked for the callback
234240
pub source: Entity,
235241
/// The unique session ID for the workflow
236242
pub session: Entity,
237243
}
238-
239-
/// Use this to reduce bracket noise when you need `In<`[`AsyncMap<R, S>`]`>`.
240-
pub type AsyncMapInput<Request, Streams = ()> = In<AsyncMap<Request, Streams>>;

src/map.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ pub struct BlockingMapMarker;
9191

9292
impl<F, Request, Response, Streams> AsMap<(Request, Response, Streams, BlockingMapMarker)> for F
9393
where
94-
F: FnMut(BlockingMap<Request>) -> Response + 'static + Send + Sync,
94+
F: FnMut(BlockingMap<Request, Streams>) -> Response + 'static + Send + Sync,
9595
Request: 'static + Send + Sync,
9696
Response: 'static + Send + Sync,
97+
Streams: StreamPack,
9798
{
9899
type MapType = BlockingMapDef<MapDef<F>, Request, Response, Streams>;
99100
fn as_map(self) -> Self::MapType {

src/operation/operate_map.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use bevy::{
2323
use std::future::Future;
2424

2525
use crate::{
26-
BlockingMap, AsyncMap, Operation, ChannelQueue, InnerChannel,
26+
BlockingMap, AsyncMap, Operation, ChannelQueue, Channel,
2727
SingleTargetStorage, StreamPack, Input, ManageInput, OperationCleanup,
2828
CallBlockingMap, CallAsyncMap, SingleInputStorage, OperationResult,
2929
OrBroken, OperationSetup, OperationRequest, OperateTask, ActiveTasksStorage,
@@ -180,11 +180,11 @@ where
180180
let mut f = source_mut.get_mut::<AsyncMapStorage<F>>().or_broken()?
181181
.f.take().or_broken()?;
182182

183-
let channel = InnerChannel::new(source, session, sender.clone());
184-
let channel = channel.into_specific(&world)?;
183+
let channel = Channel::new(source, session, sender.clone());
184+
let streams = channel.for_streams::<Streams>(&world)?;
185185

186186
let task = AsyncComputeTaskPool::get().spawn(f.call(AsyncMap {
187-
request, channel, source, session,
187+
request, streams, channel, source, session,
188188
}));
189189
world.get_entity_mut(source).or_broken()?
190190
.get_mut::<AsyncMapStorage<F>>().or_broken()?

src/service/async_srv.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::{
1919
AsyncService, AsyncServiceInput, IntoService, ServiceTrait, ServiceBundle, ServiceRequest,
20-
InnerChannel, ChannelQueue, OperationRoster, Blocker,
20+
Channel, ChannelQueue, OperationRoster, Blocker,
2121
StreamPack, ServiceBuilder, ChooseAsyncServiceDelivery, OperationRequest,
2222
OperationError, OrBroken, ManageInput, Input, OperateTask,
2323
SingleTargetStorage, dispose_for_despawned_service,
@@ -220,8 +220,9 @@ where
220220
};
221221

222222
let sender = world.get_resource_or_insert_with(|| ChannelQueue::new()).sender.clone();
223-
let channel = InnerChannel::new(source, session, sender.clone()).into_specific(world)?;
224-
let job = service.run(AsyncService { request, channel, provider, source, session }, world);
223+
let channel = Channel::new(source, session, sender.clone());
224+
let streams = channel.for_streams::<Streams>(world)?;
225+
let job = service.run(AsyncService { request, streams, channel, provider, source, session }, world);
225226
service.apply_deferred(world);
226227

227228
if let Some(mut service_storage) = world.get_mut::<AsyncServiceStorage<Request, Streams, Task>>(provider) {

src/stream.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ mod tests {
754754
|In(input): AsyncServiceInput<u32, StreamOf<u32>>| {
755755
async move {
756756
for i in 0..input.request {
757-
input.channel.streams.send(StreamOf(i));
757+
input.streams.send(StreamOf(i));
758758
}
759759
return input.request;
760760
}
@@ -779,13 +779,37 @@ mod tests {
779779
|In(input): AsyncCallbackInput<u32, StreamOf<u32>>| {
780780
async move {
781781
for i in 0..input.request {
782-
input.channel.streams.send(StreamOf(i));
782+
input.streams.send(StreamOf(i));
783783
}
784784
return input.request;
785785
}
786786
}
787787
).as_callback();
788788

789789
test_counting_stream(count_async_callback, &mut context);
790+
791+
let count_blocking_map = (
792+
|input: BlockingMap<u32, StreamOf<u32>>| {
793+
for i in 0..input.request {
794+
input.streams.send(StreamOf(i));
795+
}
796+
return input.request;
797+
}
798+
).as_map();
799+
800+
test_counting_stream(count_blocking_map, &mut context);
801+
802+
let count_async_map = (
803+
|input: AsyncMap<u32, StreamOf<u32>>| {
804+
async move {
805+
for i in 0..input.request {
806+
input.streams.send(StreamOf(i));
807+
}
808+
return input.request;
809+
}
810+
}
811+
).as_map();
812+
813+
test_counting_stream(count_async_map, &mut context);
790814
}
791815
}

0 commit comments

Comments
 (0)