Skip to content

Commit ac7467b

Browse files
committed
Fix unreachability detection for streams out of async nodes
Signed-off-by: Michael X. Grey <grey@openrobotics.org>
1 parent c3879b3 commit ac7467b

File tree

7 files changed

+106
-24
lines changed

7 files changed

+106
-24
lines changed

src/callback.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl<'a> CallbackRequest<'a> {
129129
Ok(())
130130
}
131131

132-
fn give_task<Task: Future + 'static + Send>(
132+
fn give_task<Task: Future + 'static + Send, Streams: StreamPack>(
133133
&mut self,
134134
session: Entity,
135135
task: Task,
@@ -140,7 +140,7 @@ impl<'a> CallbackRequest<'a> {
140140
let sender = self.world.get_resource_or_insert_with(|| ChannelQueue::new()).sender.clone();
141141
let task = AsyncComputeTaskPool::get().spawn(task);
142142
let task_id = self.world.spawn(()).id();
143-
OperateTask::new(task_id, session, self.source, self.target, task, None, sender)
143+
OperateTask::<_, Streams>::new(task_id, session, self.source, self.target, task, None, sender)
144144
.add(self.world, self.roster);
145145
Ok(())
146146
}
@@ -244,7 +244,7 @@ where
244244
}, &mut input.world);
245245
self.system.apply_deferred(&mut input.world);
246246

247-
input.give_task(session, task)
247+
input.give_task::<_, Streams>(session, task)
248248
}
249249
}
250250

@@ -361,7 +361,7 @@ where
361361
let task = (self)(AsyncCallback {
362362
request, streams, channel, source: input.source, session,
363363
});
364-
input.give_task(session, task)
364+
input.give_task::<_, Streams>(session, task)
365365
};
366366
Callback::new(MapCallback { callback })
367367
}

src/impulse/map.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ where
182182
}));
183183

184184
let task_source = world.spawn(()).id();
185-
OperateTask::new(task_source, session, source, target, task, None, sender)
186-
.add(world, roster);
185+
OperateTask::<_, Streams>::new(
186+
task_source, session, source, target, task, None, sender,
187+
).add(world, roster);
187188
Ok(())
188189
}
189190
}

src/operation/operate_map.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ where
199199
.f = Some(f);
200200

201201
let task_source = world.spawn(()).id();
202-
OperateTask::new(
202+
OperateTask::<_, Streams>::new(
203203
task_source, session, source, target, task, None, sender,
204204
).add(world, roster);
205205
Ok(())

src/operation/operate_task.rs

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::{
3939
OperationRoster, Blocker, ManageInput, ChannelQueue, UnhandledErrors,
4040
OperationSetup, OperationRequest, OperationResult, Operation, AddOperation,
4141
OrBroken, OperationCleanup, ChannelItem, OperationError, Broken, ScopeStorage,
42-
OperationReachability, ReachabilityResult, emit_disposal, Disposal,
42+
OperationReachability, ReachabilityResult, emit_disposal, Disposal, StreamPack,
4343
};
4444

4545
struct JobWaker {
@@ -70,7 +70,7 @@ impl WakeQueue {
7070
}
7171

7272
#[derive(Component)]
73-
pub(crate) struct OperateTask<Response: 'static + Send + Sync> {
73+
pub(crate) struct OperateTask<Response: 'static + Send + Sync, Streams: StreamPack> {
7474
source: Entity,
7575
session: Entity,
7676
node: Entity,
@@ -81,9 +81,10 @@ pub(crate) struct OperateTask<Response: 'static + Send + Sync> {
8181
disposal: Option<Disposal>,
8282
being_cleaned: bool,
8383
finished_normally: bool,
84+
_ignore: std::marker::PhantomData<Streams>,
8485
}
8586

86-
impl<Response: 'static + Send + Sync> OperateTask<Response> {
87+
impl<Response: 'static + Send + Sync, Streams: StreamPack> OperateTask<Response, Streams> {
8788
pub(crate) fn new(
8889
source: Entity,
8990
session: Entity,
@@ -104,6 +105,7 @@ impl<Response: 'static + Send + Sync> OperateTask<Response> {
104105
disposal: None,
105106
being_cleaned: false,
106107
finished_normally: false,
108+
_ignore: Default::default(),
107109
}
108110
}
109111

@@ -121,7 +123,11 @@ impl<Response: 'static + Send + Sync> OperateTask<Response> {
121123
}
122124
}
123125

124-
impl<Response: 'static + Send + Sync> Drop for OperateTask<Response> {
126+
impl<Response, Streams> Drop for OperateTask<Response, Streams>
127+
where
128+
Response: 'static + Send + Sync,
129+
Streams: StreamPack,
130+
{
125131
fn drop(&mut self) {
126132
if self.finished_normally {
127133
// The task finished normally so no special action needs to be taken
@@ -157,7 +163,11 @@ impl<Response: 'static + Send + Sync> Drop for OperateTask<Response> {
157163
}
158164
}
159165

160-
impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
166+
impl<Response, Streams> Operation for OperateTask<Response, Streams>
167+
where
168+
Response: 'static + Send + Sync,
169+
Streams: StreamPack,
170+
{
161171
fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult {
162172
let wake_queue = world.get_resource_or_insert_with(|| WakeQueue::new());
163173
let waker = Arc::new(JobWaker {
@@ -172,7 +182,7 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
172182
.insert((
173183
self,
174184
JobWakerStorage(waker),
175-
StopTask(stop_task::<Response>),
185+
StopTask(stop_task::<Response, Streams>),
176186
))
177187
.set_parent(node);
178188

@@ -192,7 +202,7 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
192202
let mut source_mut = world.get_entity_mut(source).or_not_ready()?;
193203
// If the task has been stopped / cancelled then OperateTask will have
194204
// been removed, even if it has not despawned yet.
195-
let mut operation = source_mut.get_mut::<OperateTask<Response>>().or_not_ready()?;
205+
let mut operation = source_mut.get_mut::<OperateTask<Response, Streams>>().or_not_ready()?;
196206
if operation.being_cleaned {
197207
// The operation is being cleaned up, so the task will not be
198208
// available and there will be nothing for us to do here. We should
@@ -228,13 +238,35 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
228238
// ChannelQueue has been processed so that any streams from this
229239
// task will be delivered before the final output.
230240
let r = world.entity_mut(target).defer_input(session, result, roster);
231-
world.get_mut::<OperateTask<Response>>(source).or_broken()?.finished_normally = true;
241+
world.get_mut::<OperateTask<Response, Streams>>(source).or_broken()?.finished_normally = true;
232242
cleanup_task::<Response>(session, source, node, unblock, being_cleaned, world, roster);
243+
244+
if Streams::has_streams() {
245+
if let Some(scope) = world.get::<ScopeStorage>(node) {
246+
// When an async task with any number of streams >= 1 is
247+
// finished, we should always do a disposal notification
248+
// to force a reachability check. Normally there are
249+
// specific events that prompt us to check reachability,
250+
// but if a reachability test occurred while the async
251+
// node was running and the reachability depends on a
252+
// stream which may or may not have been emitted, then
253+
// the reachability test may have concluded with a false
254+
// positive, and it needs to be rechecked now that the
255+
// async node has finished.
256+
//
257+
// TODO(@mxgrey): Make this more efficient, e.g. only
258+
// trigger this disposal if we detected that a
259+
// reachability test happened while this task was
260+
// running.
261+
roster.disposed(scope.get(), session);
262+
}
263+
}
264+
233265
r?;
234266
}
235267
Poll::Pending => {
236268
// Task is still running
237-
if let Some(mut operation) = world.get_mut::<OperateTask<Response>>(source) {
269+
if let Some(mut operation) = world.get_mut::<OperateTask<Response, Streams>>(source) {
238270
operation.task = Some(task);
239271
operation.blocker = unblock;
240272
world.entity_mut(source).insert(JobWakerStorage(waker));
@@ -249,7 +281,7 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
249281
|| ChannelQueue::default()
250282
).sender.clone();
251283

252-
let operation = OperateTask::new(
284+
let operation = OperateTask::<_, Streams>::new(
253285
source, session, node, target, task, unblock, sender,
254286
);
255287

@@ -267,7 +299,7 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
267299
let session = clean.session;
268300
let source = clean.source;
269301
let mut source_mut = clean.world.get_entity_mut(source).or_broken()?;
270-
let mut operation = source_mut.get_mut::<OperateTask<Response>>().or_broken()?;
302+
let mut operation = source_mut.get_mut::<OperateTask<Response, Streams>>().or_broken()?;
271303
operation.being_cleaned = true;
272304
let node = operation.node;
273305
let task = operation.task.take();
@@ -292,7 +324,7 @@ impl<Response: 'static + Send + Sync> Operation for OperateTask<Response> {
292324
fn is_reachable(reachability: OperationReachability) -> ReachabilityResult {
293325
let session = reachability.world
294326
.get_entity(reachability.source).or_broken()?
295-
.get::<OperateTask<Response>>().or_broken()?.session;
327+
.get::<OperateTask<Response, Streams>>().or_broken()?.session;
296328
Ok(session == reachability.session)
297329
}
298330
}
@@ -351,13 +383,13 @@ fn cleanup_task<Response>(
351383
#[derive(Component, Clone, Copy)]
352384
pub(crate) struct StopTask(pub(crate) fn(OperationRequest, Disposal) -> OperationResult);
353385

354-
fn stop_task<Response: 'static + Send + Sync>(
386+
fn stop_task<Response: 'static + Send + Sync, Streams: StreamPack>(
355387
OperationRequest { source, world, .. }: OperationRequest,
356388
disposal: Disposal,
357389
) -> OperationResult {
358390
let mut operation = world
359391
.get_entity_mut(source).or_broken()?
360-
.take::<OperateTask<Response>>().or_broken()?;
392+
.take::<OperateTask<Response, Streams>>().or_broken()?;
361393

362394
operation.disposal = Some(disposal);
363395
drop(operation);

src/service/async_srv.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,9 @@ where
238238

239239
let task = AsyncComputeTaskPool::get().spawn(job);
240240

241-
OperateTask::new(task_id, session, source, target, task, blocker, sender)
242-
.add(world, roster);
241+
OperateTask::<_, Streams>::new(
242+
task_id, session, source, target, task, blocker, sender,
243+
).add(world, roster);
243244
Ok(())
244245
}
245246

src/stream.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ pub trait StreamPack: 'static + Send + Sync {
291291
world: &mut World,
292292
roster: &mut OperationRoster,
293293
) -> OperationResult;
294+
295+
/// Are there actually any streams in the pack?
296+
fn has_streams() -> bool;
294297
}
295298

296299
impl<T: Stream> StreamPack for T {
@@ -396,6 +399,10 @@ impl<T: Stream> StreamPack for T {
396399

397400
Ok(())
398401
}
402+
403+
fn has_streams() -> bool {
404+
true
405+
}
399406
}
400407

401408
impl StreamPack for () {
@@ -470,6 +477,10 @@ impl StreamPack for () {
470477
) -> OperationResult {
471478
Ok(())
472479
}
480+
481+
fn has_streams() -> bool {
482+
false
483+
}
473484
}
474485

475486
impl<T1: StreamPack> StreamPack for (T1,) {
@@ -545,6 +556,10 @@ impl<T1: StreamPack> StreamPack for (T1,) {
545556
T1::process_buffer(buffer, source, session, unused, world, roster)?;
546557
Ok(())
547558
}
559+
560+
fn has_streams() -> bool {
561+
T1::has_streams()
562+
}
548563
}
549564

550565
impl<T1: StreamPack, T2: StreamPack> StreamPack for (T1, T2) {
@@ -635,6 +650,11 @@ impl<T1: StreamPack, T2: StreamPack> StreamPack for (T1, T2) {
635650
T2::process_buffer(buffer.1, source, session, unused, world, roster)?;
636651
Ok(())
637652
}
653+
654+
fn has_streams() -> bool {
655+
T1::has_streams()
656+
|| T2::has_streams()
657+
}
638658
}
639659

640660
impl<T1: StreamPack, T2: StreamPack, T3: StreamPack> StreamPack for (T1, T2, T3) {
@@ -733,6 +753,12 @@ impl<T1: StreamPack, T2: StreamPack, T3: StreamPack> StreamPack for (T1, T2, T3)
733753
T3::process_buffer(buffer.2, source, session, unused, world, roster)?;
734754
Ok(())
735755
}
756+
757+
fn has_streams() -> bool {
758+
T1::has_streams()
759+
|| T2::has_streams()
760+
|| T3::has_streams()
761+
}
736762
}
737763

738764
/// Used by [`ServiceDiscovery`](crate::ServiceDiscovery) to filter services

src/workflow.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,10 @@ mod tests {
382382
}
383383

384384
#[test]
385-
fn test_reachability_with_unused_streams() {
385+
fn test_stream_reachability() {
386386
let mut context = TestingContext::minimal_plugins();
387387

388+
// Test for streams from a blocking node
388389
let workflow = context.spawn_io_workflow(|scope, builder| {
389390
let stream_node = builder.create_map(|_: BlockingMap<(), StreamOf<u32>>| {
390391
// Do nothing. The purpose of this node is to just return without
@@ -405,5 +406,26 @@ mod tests {
405406
context.run_with_conditions(&mut promise, Duration::from_secs(2));
406407
assert!(promise.peek().is_cancelled());
407408
assert!(context.no_unhandled_errors());
409+
410+
// Test for streams from an async node
411+
let workflow = context.spawn_io_workflow(|scope, builder| {
412+
let stream_node = builder.create_map(|_: AsyncMap<(), StreamOf<u32>>| {
413+
async { /* Do nothing */}
414+
});
415+
416+
builder.connect(scope.input, stream_node.input);
417+
stream_node.streams.chain(builder)
418+
.inner()
419+
.map_block(|value| 2 * value)
420+
.connect(scope.terminate);
421+
});
422+
423+
let mut promise = context.command(|commands| {
424+
commands.request((), workflow).take_response()
425+
});
426+
427+
context.run_with_conditions(&mut promise, Duration::from_secs(2));
428+
assert!(promise.peek().is_cancelled());
429+
assert!(context.no_unhandled_errors());
408430
}
409431
}

0 commit comments

Comments
 (0)