Skip to content

Commit a587c69

Browse files
authored
Add option to combine all perception nodes into one (#196)
* Move perception nodes to RELIABLE QoS * Combine perception into a single node
1 parent b70b2f5 commit a587c69

14 files changed

+726
-486
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""
2+
This module contains the ADAFeedingPerceptionNode class, which is used as a component
3+
of all perception nodes in the ADA Feeding project. Specifically, by storing all
4+
functionality to get camera images (RGB and depth) and info in this node, it makes
5+
it easier to combine one or more perception functionalities into a single node,
6+
which will reduce the number of parallel subscriptions the ROS2 middleware has to
7+
manage.
8+
"""
9+
# Standard imports
10+
from collections import deque
11+
from functools import partial
12+
from threading import Lock, Thread
13+
from typing import Any, Optional, Type, Union
14+
15+
# Third-party imports
16+
import rclpy
17+
from rclpy.callback_groups import CallbackGroup, MutuallyExclusiveCallbackGroup
18+
from rclpy.executors import MultiThreadedExecutor
19+
from rclpy.node import Node
20+
from rclpy.qos import QoSProfile, ReliabilityPolicy
21+
from rclpy.subscription import Subscription
22+
23+
# Local imports
24+
25+
26+
class ADAFeedingPerceptionNode(Node):
27+
"""
28+
Class that contains all functionality to get camera images (RGB and depth) and
29+
camera info. This class is meant to consolidate the number of parallel subscriptions
30+
the ROS2 middleware has to manage.
31+
"""
32+
33+
def __init__(self, name: str):
34+
"""
35+
Constructor for the ADAFeedingPerceptionNode class.
36+
37+
Parameters
38+
----------
39+
name : str
40+
The name of the node.
41+
"""
42+
super().__init__(name)
43+
self.__msg_locks: dict[str, Lock] = {}
44+
self.__latest_msgs: dict[str, deque[Any]] = {}
45+
self.__subs: dict[str, Subscription] = {}
46+
47+
# pylint: disable=too-many-arguments
48+
# These are fine to mimic ROS2 API
49+
def add_subscription(
50+
self,
51+
msg_type: Type,
52+
topic: str,
53+
qos_profile: Union[QoSProfile, int] = QoSProfile(
54+
depth=1, reliability=ReliabilityPolicy.BEST_EFFORT
55+
),
56+
callback_group: Optional[CallbackGroup] = MutuallyExclusiveCallbackGroup(),
57+
num_msgs: int = 1,
58+
) -> None:
59+
"""
60+
Adds a subscription to this node.
61+
62+
Parameters
63+
----------
64+
msg_type : Type
65+
The type of message to subscribe to.
66+
topic : str
67+
The name of the topic to subscribe to.
68+
qos_profile : Union[QoSProfile, int], optional
69+
The quality of service profile to use for the subscription, by default
70+
QoSProfile(depth=1, reliability=ReliabilityPolicy.BEST_EFFORT).
71+
callback_group : Optional[CallbackGroup], optional
72+
The callback group to use for the subscription, by default
73+
MutuallyExclusiveCallbackGroup().
74+
num_msgs : int, optional
75+
The number of messages to store in the subscription, by default 1.
76+
"""
77+
if topic in self.__msg_locks:
78+
with self.__msg_locks[topic]:
79+
if num_msgs > self.__latest_msgs[topic].maxlen:
80+
# Grow the deque
81+
self.__latest_msgs[topic] = deque(
82+
self.__latest_msgs[topic], maxlen=num_msgs
83+
)
84+
return
85+
self.__msg_locks[topic] = Lock()
86+
self.__latest_msgs[topic] = deque(maxlen=num_msgs)
87+
self.__subs[topic] = self.create_subscription(
88+
msg_type=msg_type,
89+
topic=topic,
90+
callback=partial(self.__callback, topic=topic),
91+
qos_profile=qos_profile,
92+
callback_group=callback_group,
93+
)
94+
95+
def get_latest_msg(self, topic: str) -> Optional[Any]:
96+
"""
97+
Returns the latest message from the specified topic.
98+
99+
Parameters
100+
----------
101+
topic : str
102+
The name of the topic to get the latest message from.
103+
104+
Returns
105+
-------
106+
Optional[Any]
107+
The latest message from the specified topic, or None if no message has been
108+
received.
109+
"""
110+
if topic not in self.__msg_locks:
111+
self.get_logger().error(f"Topic '{topic}' not found.")
112+
return None
113+
with self.__msg_locks[topic]:
114+
if len(self.__latest_msgs[topic]) == 0:
115+
self.get_logger().error(f"No message received from topic '{topic}'.")
116+
return None
117+
return self.__latest_msgs[topic][-1]
118+
119+
def get_all_msgs(self, topic: str, copy: bool = True) -> Optional[deque[Any]]:
120+
"""
121+
Returns all messages from the specified topic.
122+
123+
Parameters
124+
----------
125+
topic : str
126+
The name of the topic to get the messages from.
127+
copy : bool, optional
128+
Whether to return a copy of the messages, by default True.
129+
130+
Returns
131+
-------
132+
Optional[deque[Any]]
133+
All messages from the specified topic, or None if no messages have been
134+
received.
135+
"""
136+
if topic not in self.__msg_locks:
137+
self.get_logger().error(f"Topic '{topic}' not found.")
138+
return None
139+
with self.__msg_locks[topic]:
140+
if len(self.__latest_msgs[topic]) == 0:
141+
self.get_logger().error(f"No message received from topic '{topic}'.")
142+
return None
143+
if copy:
144+
return deque(
145+
self.__latest_msgs[topic], maxlen=self.__latest_msgs[topic].maxlen
146+
)
147+
return self.__latest_msgs[topic]
148+
149+
def __callback(self, msg: Any, topic: str) -> None:
150+
"""
151+
Callback function for the subscription.
152+
153+
Parameters
154+
----------
155+
msg : Any
156+
The message received from the subscription.
157+
topic : str
158+
The name of the topic the message was received from.
159+
"""
160+
with self.__msg_locks[topic]:
161+
self.__latest_msgs[topic].append(msg)
162+
163+
164+
# pylint: disable=too-many-locals
165+
def main(args=None):
166+
"""
167+
Launch the ROS node and spin.
168+
"""
169+
# Import the necessary modules
170+
# pylint: disable=import-outside-toplevel
171+
from ada_feeding_perception.face_detection import FaceDetectionNode
172+
from ada_feeding_perception.food_on_fork_detection import FoodOnForkDetectionNode
173+
from ada_feeding_perception.segment_from_point import SegmentFromPointNode
174+
from ada_feeding_perception.table_detection import TableDetectionNode
175+
176+
rclpy.init(args=args)
177+
178+
node = ADAFeedingPerceptionNode("ada_feeding_perception")
179+
face_detection = FaceDetectionNode(node)
180+
food_on_fork_detection = FoodOnForkDetectionNode(node)
181+
segment_from_point = SegmentFromPointNode(node) # pylint: disable=unused-variable
182+
table_detection = TableDetectionNode(node)
183+
executor = MultiThreadedExecutor(num_threads=16)
184+
185+
# Spin in the background initially
186+
spin_thread = Thread(
187+
target=rclpy.spin,
188+
args=(node,),
189+
kwargs={"executor": executor},
190+
daemon=True,
191+
)
192+
spin_thread.start()
193+
194+
# Run the perception nodes
195+
def face_detection_run():
196+
try:
197+
face_detection.run()
198+
except KeyboardInterrupt:
199+
pass
200+
201+
def food_on_fork_detection_run():
202+
try:
203+
food_on_fork_detection.run()
204+
except KeyboardInterrupt:
205+
pass
206+
207+
def table_detection_run():
208+
try:
209+
table_detection.run()
210+
except KeyboardInterrupt:
211+
pass
212+
213+
face_detection_thread = Thread(target=face_detection_run, daemon=True)
214+
face_detection_thread.start()
215+
food_on_fork_detection_thread = Thread(
216+
target=food_on_fork_detection_run, daemon=True
217+
)
218+
food_on_fork_detection_thread.start()
219+
table_detection_thread = Thread(target=table_detection_run, daemon=True)
220+
table_detection_thread.start()
221+
222+
# Spin in the foreground
223+
spin_thread.join()
224+
225+
# Terminate this node
226+
node.destroy_node()
227+
rclpy.shutdown()
228+
229+
230+
if __name__ == "__main__":
231+
main()

0 commit comments

Comments
 (0)