Skip to content

Commit a3a661a

Browse files
committed
implementing close()
1 parent e969217 commit a3a661a

File tree

4 files changed

+87
-39
lines changed

4 files changed

+87
-39
lines changed

examples/receive_super_stream.rs

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@ use rabbitmq_stream_client::error::StreamCreateError;
33
use rabbitmq_stream_client::types::{
44
ByteCapacity, OffsetSpecification, ResponseCode, SuperStreamConsumer,
55
};
6-
use std::sync::atomic::{AtomicU32, Ordering};
7-
use std::sync::Arc;
8-
use std::time::Duration;
9-
use tokio::task;
10-
use tokio::time::sleep;
116

127
#[tokio::main]
138
async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -41,34 +36,25 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
4136
.await
4237
.unwrap();
4338

44-
let received_messages = Arc::new(AtomicU32::new(0));
45-
46-
for mut consumer in super_stream_consumer.get_consumers().await.into_iter() {
47-
let received_messages_outer = received_messages.clone();
48-
49-
task::spawn(async move {
50-
let mut inner_received_messages = received_messages_outer.clone();
51-
while let Some(delivery) = consumer.next().await {
52-
let d = delivery.unwrap();
53-
println!(
54-
"Got message: {:#?} from stream: {} with offset: {}",
55-
d.message()
56-
.data()
57-
.map(|data| String::from_utf8(data.to_vec()).unwrap()),
58-
d.stream(),
59-
d.offset(),
60-
);
61-
let value = inner_received_messages.fetch_add(1, Ordering::Relaxed);
62-
if value == message_count {
63-
let handle = consumer.handle();
64-
_ = handle.close().await;
65-
break;
66-
}
67-
}
68-
});
39+
let mut received_messages = 0;
40+
41+
while let delivery = super_stream_consumer.next().await.unwrap() {
42+
println!("inside while delivery loop");
43+
let d = delivery.unwrap();
44+
println!(
45+
"Got message: {:#?} from stream: {} with offset: {}",
46+
d.message()
47+
.data()
48+
.map(|data| String::from_utf8(data.to_vec()).unwrap()),
49+
d.stream(),
50+
d.offset()
51+
);
52+
53+
received_messages = received_messages + 1;
54+
if received_messages == 10 {
55+
break;
56+
}
6957
}
7058

71-
sleep(Duration::from_millis(20000)).await;
72-
7359
Ok(())
7460
}

src/consumer.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,16 @@ impl Stream for Consumer {
261261
}
262262

263263
/// Handler API for [`Consumer`]
264+
///
264265
pub struct ConsumerHandle(Arc<ConsumerInternal>);
265266

266267
impl ConsumerHandle {
267268
/// Close the [`Consumer`] associated to this handle
268269
pub async fn close(self) -> Result<(), ConsumerCloseError> {
270+
self.internal_close().await
271+
}
272+
273+
pub(crate) async fn internal_close(&self) -> Result<(), ConsumerCloseError> {
269274
match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
270275
Ok(false) => {
271276
let response = self.0.client.unsubscribe(self.0.subscription_id).await?;

src/superstream_consumer.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use crate::consumer::Delivery;
2-
use crate::error::ConsumerDeliveryError;
2+
use crate::error::{ConsumerCloseError, ConsumerDeliveryError};
33
use crate::superstream::DefaultSuperStreamMetadata;
4-
use crate::{error::ConsumerCreateError, Environment};
4+
use crate::{error::ConsumerCreateError, ConsumerHandle, Environment};
5+
use futures::task::AtomicWaker;
56
use futures::{Stream, StreamExt};
67
use rabbitmq_stream_protocol::commands::subscribe::OffsetSpecification;
78
use std::pin::Pin;
9+
use std::sync::atomic::AtomicBool;
10+
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
11+
use std::sync::Arc;
812
use std::task::{Context, Poll};
913
use tokio::sync::mpsc::{channel, Receiver};
1014
use tokio::task;
@@ -13,11 +17,14 @@ use tokio::task;
1317

1418
/// API for consuming RabbitMQ stream messages
1519
pub struct SuperStreamConsumer {
16-
internal: SuperStreamConsumerInternal,
20+
internal: Arc<SuperStreamConsumerInternal>,
21+
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
1722
}
1823

1924
struct SuperStreamConsumerInternal {
20-
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
25+
closed: Arc<AtomicBool>,
26+
handlers: Vec<ConsumerHandle>,
27+
waker: AtomicWaker,
2128
}
2229

2330
/// Builder for [`Consumer`]
@@ -44,6 +51,7 @@ impl SuperStreamConsumerBuilder {
4451
};
4552
let partitions = super_stream_metadata.partitions().await;
4653

54+
let mut handlers = Vec::<ConsumerHandle>::new();
4755
for partition in partitions.into_iter() {
4856
let tx_cloned = tx.clone();
4957
let mut consumer = self
@@ -54,17 +62,24 @@ impl SuperStreamConsumerBuilder {
5462
.await
5563
.unwrap();
5664

65+
handlers.push(consumer.handle());
66+
5767
task::spawn(async move {
5868
while let Some(d) = consumer.next().await {
5969
_ = tx_cloned.send(d).await;
6070
}
6171
});
6272
}
6373

64-
let super_stream_consumer_internal = SuperStreamConsumerInternal { receiver: rx };
74+
let super_stream_consumer_internal = SuperStreamConsumerInternal {
75+
closed: Arc::new(AtomicBool::new(false)),
76+
handlers,
77+
waker: AtomicWaker::new(),
78+
};
6579

6680
Ok(SuperStreamConsumer {
67-
internal: super_stream_consumer_internal,
81+
internal: Arc::new(super_stream_consumer_internal),
82+
receiver: rx,
6883
})
6984
}
7085

@@ -78,6 +93,46 @@ impl Stream for SuperStreamConsumer {
7893
type Item = Result<Delivery, ConsumerDeliveryError>;
7994

8095
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81-
Pin::new(&mut self.internal.receiver).poll_recv(cx)
96+
self.internal.waker.register(cx.waker());
97+
let poll = Pin::new(&mut self.receiver).poll_recv(cx);
98+
match (self.is_closed(), poll.is_ready()) {
99+
(true, false) => Poll::Ready(None),
100+
_ => poll,
101+
}
102+
}
103+
}
104+
105+
impl SuperStreamConsumer {
106+
/// Check if the consumer is closed
107+
pub fn is_closed(&self) -> bool {
108+
self.internal.is_closed()
109+
}
110+
111+
pub fn handle(&self) -> SuperStreamConsumerHandle {
112+
SuperStreamConsumerHandle(self.internal.clone())
113+
}
114+
}
115+
116+
impl SuperStreamConsumerInternal {
117+
fn is_closed(&self) -> bool {
118+
self.closed.load(Relaxed)
119+
}
120+
}
121+
122+
pub struct SuperStreamConsumerHandle(Arc<SuperStreamConsumerInternal>);
123+
124+
impl SuperStreamConsumerHandle {
125+
/// Close the [`Consumer`] associated to this handle
126+
pub async fn close(self) -> Result<(), ConsumerCloseError> {
127+
self.0.waker.wake();
128+
match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
129+
Ok(false) => {
130+
for handle in &self.0.handlers {
131+
handle.internal_close().await.unwrap();
132+
}
133+
Ok(())
134+
}
135+
_ => Err(ConsumerCloseError::AlreadyClosed),
136+
}
82137
}
83138
}

tests/integration/consumer_test.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ async fn super_stream_consumer_test() {
9898
}
9999

100100
let mut received_messages = 0;
101+
let handle = super_stream_consumer.handle();
101102

102103
println!("before looping");
103104
while let delivery = super_stream_consumer.next().await.unwrap() {
@@ -121,6 +122,7 @@ async fn super_stream_consumer_test() {
121122
assert!(received_messages == message_count);
122123

123124
super_stream_producer.close().await.unwrap();
125+
_ = handle.close().await;
124126
}
125127

126128
#[tokio::test(flavor = "multi_thread")]

0 commit comments

Comments
 (0)