Skip to content

Commit cdd0bb9

Browse files
authored
Move AcquisitionReport work into asynch thread (#210)
* Move AcquisitionReport work into asynch thread * Remove unecessary comment
1 parent 03c5043 commit cdd0bb9

File tree

1 file changed

+55
-13
lines changed

1 file changed

+55
-13
lines changed

ada_feeding_action_select/ada_feeding_action_select/policy_service.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import copy
1414
import errno
1515
import os
16+
import threading
1617
import time
1718
from typing import Dict
1819
import uuid
@@ -292,12 +293,14 @@ def __init__(self):
292293

293294
# Create AcquisitionSelect cache
294295
# UUID -> {context, request, response}
296+
self.cache_lock = threading.Lock()
295297
self.cache = {}
296298

297299
# Init Checkpoints / Data Record
298300
self._init_checkpoints_record(context_cls, posthoc_cls)
299301

300302
# Start ROS services
303+
self.acquisition_report_threads = []
301304
self.ros_objs = []
302305
self.ros_objs.append(
303306
self.create_service(
@@ -339,11 +342,12 @@ def select_callback(
339342
response.probabilities = list(res[0])
340343
response.actions = list(res[1])
341344
select_id = str(uuid.uuid4())
342-
self.cache[select_id] = {
343-
"context": np.copy(context),
344-
"request": copy.deepcopy(request),
345-
"response": copy.deepcopy(response),
346-
}
345+
with self.cache_lock:
346+
self.cache[select_id] = {
347+
"context": np.copy(context),
348+
"request": copy.deepcopy(request),
349+
"response": copy.deepcopy(response),
350+
}
347351
response.id = select_id
348352

349353
if response.status != "Success":
@@ -365,13 +369,49 @@ def report_callback(
365369
f"AcquisitionReport Request with ID: '{request.id}' and loss '{request.loss}'"
366370
)
367371

368-
# Collect cached context
369-
if request.id not in self.cache:
370-
response.status = "id does not map to previous select call"
371-
self.get_logger().error(f"AcquistionReport: {response.status}")
372-
response.success = False
373-
return response
374-
cache = self.cache[request.id]
372+
# Remove any completed threads
373+
i = 0
374+
while i < len(self.acquisition_report_threads):
375+
if not self.acquisition_report_threads[i].is_alive():
376+
self.get_logger().info("Removing completed acquisition report thread")
377+
self.acquisition_report_threads.pop(i)
378+
else:
379+
i += 1
380+
381+
# Start the asynch thread
382+
request_copy = copy.deepcopy(request)
383+
response_copy = copy.deepcopy(response)
384+
thread = threading.Thread(
385+
target=self.report_callback_work, args=(request_copy, response_copy)
386+
)
387+
self.acquisition_report_threads.append(thread)
388+
self.get_logger().info("Starting new acquisition report thread")
389+
thread.start()
390+
391+
# Return success immediately
392+
response.status = "Success"
393+
response.success = True
394+
return response
395+
396+
# pylint: disable=too-many-statements
397+
# One over is fine for this function.
398+
def report_callback_work(
399+
self, request: AcquisitionReport.Request, response: AcquisitionReport.Response
400+
) -> AcquisitionReport.Response:
401+
"""
402+
Perform the work of updating the policy based on the acquisition. This is a workaround
403+
to the fact that either ROSLib or rosbridge (likely the latter) cannot process a service
404+
and action at the same time, so in practice the next motion waits until after the policy
405+
has been updated, which adds a few seconds of unnecessary latency.
406+
"""
407+
with self.cache_lock:
408+
# Collect cached context
409+
if request.id not in self.cache:
410+
response.status = "id does not map to previous select call"
411+
self.get_logger().error(f"AcquistionReport: {response.status}")
412+
response.success = False
413+
return response
414+
cache = copy.deepcopy(self.cache[request.id])
375415
context = cache["context"]
376416

377417
# Collect executed action
@@ -409,7 +449,9 @@ def report_callback(
409449

410450
# Report completed
411451
self.n_successful_reports += 1
412-
del self.cache[request.id]
452+
with self.cache_lock:
453+
if request.id in self.cache:
454+
del self.cache[request.id]
413455

414456
# Save checkpoint if requested
415457
if (

0 commit comments

Comments
 (0)