diff --git a/examples/ha_producer.rs b/examples/ha_producer.rs new file mode 100644 index 00000000..bbcd29fb --- /dev/null +++ b/examples/ha_producer.rs @@ -0,0 +1,151 @@ +use core::panic; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::Duration; + +use rabbitmq_stream_client::error::{ProducerPublishError, StreamCreateError}; +use rabbitmq_stream_client::types::{ByteCapacity, Message, ResponseCode}; +use rabbitmq_stream_client::Environment; +use rabbitmq_stream_client::{ + ConfirmationStatus, NoDedup, OnClosed, Producer, RabbitMQStreamResult, +}; +use tokio::sync::Notify; +use tokio::sync::RwLock; +use tokio::time::sleep; +use tracing::info; + +struct MyHAProducerInner { + environment: Environment, + stream: String, + producer: RwLock>, + notify: Notify, + is_opened: AtomicBool, +} + +#[derive(Clone)] +struct MyHAProducer(Arc); + +#[async_trait::async_trait] +impl OnClosed for MyHAProducer { + async fn on_closed(&self, unconfirmed: Vec) { + info!("Producer is closed. Creating new one"); + + self.0 + .is_opened + .store(false, std::sync::atomic::Ordering::SeqCst); + + let mut producer = self.0.producer.write().await; + + let new_producer = self + .0 + .environment + .producer() + .build(&self.0.stream) + .await + .unwrap(); + + new_producer.set_on_closed(Box::new(self.clone())).await; + + if !unconfirmed.is_empty() { + info!("Resending {} unconfirmed messages.", unconfirmed.len()); + if let Err(e) = producer.batch_send_with_confirm(unconfirmed).await { + eprintln!("Error resending unconfirmed messages: {:?}", e); + } + } + + *producer = new_producer; + + self.0 + .is_opened + .store(true, std::sync::atomic::Ordering::SeqCst); + self.0.notify.notify_waiters(); + } +} + +impl MyHAProducer { + async fn new(environment: Environment, stream: &str) -> RabbitMQStreamResult { + ensure_stream_exists(&environment, stream).await?; + + let producer = environment.producer().build(stream).await.unwrap(); + + let inner = MyHAProducerInner { + environment, + stream: stream.to_string(), + producer: RwLock::new(producer), + notify: Notify::new(), + is_opened: AtomicBool::new(true), + }; + let s = Self(Arc::new(inner)); + + let p = s.0.producer.write().await; + p.set_on_closed(Box::new(s.clone())).await; + drop(p); + + Ok(s) + } + + async fn send_with_confirm( + &self, + message: Message, + ) -> Result { + if !self.0.is_opened.load(std::sync::atomic::Ordering::SeqCst) { + self.0.notify.notified().await; + } + + let producer = self.0.producer.read().await; + let err = producer.send_with_confirm(message.clone()).await; + + match err { + Ok(s) => Ok(s), + Err(e) => match e { + ProducerPublishError::Timeout | ProducerPublishError::Closed => { + Box::pin(self.send_with_confirm(message)).await + } + _ => return Err(e), + }, + } + } +} + +async fn ensure_stream_exists(environment: &Environment, stream: &str) -> RabbitMQStreamResult<()> { + let create_response = environment + .stream_creator() + .max_length(ByteCapacity::GB(5)) + .create(stream) + .await; + + if let Err(e) = create_response { + if let StreamCreateError::Create { stream, status } = e { + match status { + // we can ignore this error because the stream already exists + ResponseCode::StreamAlreadyExists => {} + err => { + panic!("Error creating stream: {:?} {:?}", stream, err); + } + } + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = tracing_subscriber::fmt::try_init(); + + let environment = Environment::builder().build().await?; + let stream = "hello-rust-stream"; + + let producer = MyHAProducer::new(environment, stream).await?; + + let number_of_messages = 1000000; + for i in 0..number_of_messages { + let msg = Message::builder() + .body(format!("stream message_{}", i)) + .build(); + producer.send_with_confirm(msg).await?; + sleep(Duration::from_millis(100)).await; + } + + Ok(()) +} diff --git a/src/client/dispatcher.rs b/src/client/dispatcher.rs index 1abcf43b..1f9fa6d5 100644 --- a/src/client/dispatcher.rs +++ b/src/client/dispatcher.rs @@ -19,22 +19,6 @@ use super::{channel::ChannelReceiver, handler::MessageHandler}; #[derive(Clone)] pub(crate) struct Dispatcher(DispatcherState); -pub(crate) struct DispatcherState { - requests: Arc, - correlation_id: Arc, - handler: Arc>>, -} - -impl Clone for DispatcherState { - fn clone(&self) -> Self { - DispatcherState { - requests: self.requests.clone(), - correlation_id: self.correlation_id.clone(), - handler: self.handler.clone(), - } - } -} - struct RequestsMap { requests: DashMap>, closed: AtomicBool, @@ -126,6 +110,22 @@ where } } +pub(crate) struct DispatcherState { + requests: Arc, + correlation_id: Arc, + handler: Arc>>, +} + +impl Clone for DispatcherState { + fn clone(&self) -> Self { + DispatcherState { + requests: self.requests.clone(), + correlation_id: self.correlation_id.clone(), + handler: self.handler.clone(), + } + } +} + impl DispatcherState where T: MessageHandler, diff --git a/src/client/message.rs b/src/client/message.rs index 7e38f2f0..6679020a 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -20,7 +20,7 @@ impl BaseMessage for Message { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ClientMessage { publishing_id: u64, message: Message, @@ -39,6 +39,10 @@ impl ClientMessage { pub fn filter_value_extract(&mut self, filter_value_extractor: impl Fn(&Message) -> String) { self.filter_value = Some(filter_value_extractor(&self.message)); } + + pub fn into_message(self) -> Message { + self.message + } } impl BaseMessage for ClientMessage { diff --git a/src/client/mod.rs b/src/client/mod.rs index 15c42242..c6d65c07 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,4 +1,3 @@ -use std::ops::DerefMut; use std::{ collections::HashMap, io, @@ -142,42 +141,7 @@ pub struct ClientState { max_frame_size: u32, last_heatbeat: Instant, heartbeat_task: Option, -} - -#[async_trait::async_trait] -impl MessageHandler for Client { - async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> { - match &item { - Some(Ok(response)) => match response.kind_ref() { - ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, - ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, - _ => { - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - - tokio::task::spawn(async move { handler.handle_message(item).await }); - } - } - }, - Some(Err(err)) => { - trace!(?err); - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - - tokio::task::spawn(async move { handler.handle_message(item).await }); - } - } - None => { - trace!("Closing client"); - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - tokio::task::spawn(async move { handler.handle_message(None).await }); - } - } - } - - Ok(()) - } + last_received_message: Arc>, } /// Raw API for taking to RabbitMQ stream @@ -201,8 +165,9 @@ impl Client { let (sender, receiver) = Client::create_connection(&broker).await?; - let dispatcher = Dispatcher::new(); + let last_received_message = Arc::new(RwLock::new(Instant::now())); + let dispatcher = Dispatcher::new(); let state = ClientState { server_properties: HashMap::new(), connection_properties: HashMap::new(), @@ -211,6 +176,7 @@ impl Client { max_frame_size: broker.max_frame_size, last_heatbeat: Instant::now(), heartbeat_task: None, + last_received_message: last_received_message.clone(), }; let mut client = Client { dispatcher, @@ -519,6 +485,15 @@ impl Client { self.filtering_supported } + /// Don't use this method in production code. + pub async fn set_heartbeat(&self, heartbeat: u32) { + let mut state = self.state.write().await; + state.heartbeat = heartbeat; + // Eventually, this drops the previous heartbeat task + state.heartbeat_task = + self.start_hearbeat_task(heartbeat, state.last_received_message.clone()); + } + async fn create_connection( broker: &ClientOptions, ) -> Result< @@ -536,6 +511,7 @@ impl Client { Ok((tx, rx)) } + async fn initialize(&mut self, receiver: ChannelReceiver) -> Result<(), ClientError> where T: Stream> + Unpin + Send, @@ -558,7 +534,10 @@ impl Client { .await?; // Start heartbeat task after connection is established - self.start_hearbeat_task(self.state.write().await.deref_mut()); + let mut state = self.state.write().await; + state.heartbeat_task = + self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone()); + drop(state); Ok(()) } @@ -697,7 +676,9 @@ impl Client { ); if state.heartbeat_task.take().is_some() { - self.start_hearbeat_task(&mut state); + // Start heartbeat task after connection is established + state.heartbeat_task = + self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone()); } drop(state); @@ -710,13 +691,22 @@ impl Client { self.tune_notifier.notify_one(); } - fn start_hearbeat_task(&self, state: &mut ClientState) { - if state.heartbeat == 0 { - return; + fn start_hearbeat_task( + &self, + heartbeat: u32, + last_received_message: Arc>, + ) -> Option { + if heartbeat == 0 { + return None; } - let heartbeat_interval = (state.heartbeat / 2).max(1); + let heartbeat_interval = (heartbeat / 2).max(1); let channel = self.channel.clone(); - let heartbeat_task = tokio::spawn(async move { + + let client = self.clone(); + + let heartbeat_task: task::TaskHandle = tokio::spawn(async move { + let timeout_threashold = u64::from(heartbeat * 4); + loop { trace!("Sending heartbeat"); if channel @@ -727,11 +717,25 @@ impl Client { break; } tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await; + + let now = Instant::now(); + let last_message = last_received_message.read().await; + if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) { + warn!("Heartbeat timeout reached. Force closing connection."); + if !client.is_closed() { + if let Err(e) = client.close().await { + warn!("Error closing client: {}", e); + } + } + break; + } } + warn!("Heartbeat task stopped. Force closing connection"); }) .into(); - state.heartbeat_task = Some(heartbeat_task); + + Some(heartbeat_task) } async fn handle_heart_beat_command(&self) { @@ -751,3 +755,50 @@ impl Client { .await } } + +#[async_trait::async_trait] +impl MessageHandler for Client { + async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> { + match &item { + Some(Ok(response)) => { + // Update last received message time: needed for heartbeat task + { + let s = self.state.read().await; + let mut last_received_message = s.last_received_message.write().await; + *last_received_message = Instant::now(); + drop(last_received_message); + drop(s); + } + + match response.kind_ref() { + ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, + ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, + _ => { + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + + tokio::task::spawn(async move { handler.handle_message(item).await }); + } + } + } + } + Some(Err(err)) => { + trace!(?err); + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + + tokio::task::spawn(async move { handler.handle_message(item).await }); + } + } + None => { + trace!("Closing client"); + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + tokio::task::spawn(async move { handler.handle_message(None).await }); + } + } + } + + Ok(()) + } +} diff --git a/src/environment.rs b/src/environment.rs index 9592dbcd..40e2793e 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -199,6 +199,8 @@ impl Environment { data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), + on_closed: None, + overwrite_heartbeat: None, } } diff --git a/src/lib.rs b/src/lib.rs index d6633b00..bcac483b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,9 @@ pub use crate::consumer::{ Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration, MessageContext, }; pub use crate::environment::{Environment, EnvironmentBuilder}; -pub use crate::producer::{Dedup, NoDedup, Producer, ProducerBuilder}; +pub use crate::producer::{ + ConfirmationStatus, Dedup, NoDedup, OnClosed, Producer, ProducerBuilder, +}; pub mod types { pub use crate::byte_capacity::ByteCapacity; diff --git a/src/producer.rs b/src/producer.rs index 9fa3b93e..2c40c034 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -11,10 +11,10 @@ use std::{ use dashmap::DashMap; use futures::{future::BoxFuture, FutureExt}; -use tokio::sync::mpsc; use tokio::sync::mpsc::channel; +use tokio::sync::{mpsc, RwLock}; use tokio::time::sleep; -use tracing::{error, info, trace}; +use tracing::{error, info, trace, warn}; use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind}; @@ -27,7 +27,7 @@ use crate::{ error::{ClientError, ProducerCloseError, ProducerCreateError, ProducerPublishError}, }; -type WaiterMap = Arc>; +type WaiterMap = Arc>; type FilterValueExtractor = Arc String + 'static + Send + Sync>; #[derive(Debug)] @@ -69,6 +69,7 @@ pub struct ProducerInternal { closed: Arc, sender: mpsc::Sender, filter_value_extractor: Option, + on_closed: Arc>>>, } impl Drop for ProducerInternal { @@ -116,6 +117,8 @@ pub struct ProducerBuilder { pub(crate) data: PhantomData, pub filter_value_extractor: Option, pub(crate) client_provided_name: String, + pub(crate) on_closed: Option>, + pub(crate) overwrite_heartbeat: Option, } #[derive(Clone)] @@ -136,6 +139,10 @@ impl ProducerBuilder { .create_producer_client(stream, self.client_provided_name.clone()) .await?; + if let Some(heartbeat) = self.overwrite_heartbeat { + client.set_heartbeat(heartbeat).await; + } + let mut publish_version = 1; if self.filter_value_extractor.is_some() { @@ -146,11 +153,14 @@ impl ProducerBuilder { } } + let on_closed = Arc::new(RwLock::new(self.on_closed)); + let waiting_confirmations: WaiterMap = Arc::new(DashMap::new()); let confirm_handler = ProducerConfirmHandler { waiting_confirmations: waiting_confirmations.clone(), metrics_collector, + on_closed: on_closed.clone(), }; client.set_handler(confirm_handler).await; @@ -183,6 +193,7 @@ impl ProducerBuilder { closed: Arc::new(AtomicBool::new(false)), sender, filter_value_extractor: self.filter_value_extractor, + on_closed, }; let internal_producer = Arc::new(producer); @@ -204,11 +215,22 @@ impl ProducerBuilder { } } + pub fn on_closed(mut self, on_closed: Box) -> ProducerBuilder { + self.on_closed = Some(on_closed); + self + } + pub fn batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; self } + /// Don't use this in production, it is only for testing purposes. + pub fn overwrite_heartbeat(mut self, heartbeat: u32) -> ProducerBuilder { + self.overwrite_heartbeat = Some(heartbeat); + self + } + pub fn client_provided_name(mut self, name: &str) -> Self { self.client_provided_name = String::from(name); self @@ -223,6 +245,8 @@ impl ProducerBuilder { data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), + on_closed: self.on_closed, + overwrite_heartbeat: None, } } @@ -266,6 +290,15 @@ fn schedule_batch_send( Ok(_) => {} Err(e) => { error!("Error publishing batch {:?}", e); + + // If the underlying error is a broken pipe, we can assume the connection is closed + // In fact, BorkenPipe is not recoverable, so we can exit the loop. + // This will close the receiver, so, the next time a send is called, it will return an error. + if matches!(e, ClientError::Io(e) if e.kind() == std::io::ErrorKind::BrokenPipe) + { + // If the error is a broken pipe, we can assume the connection is closed + break; + } } }; } @@ -445,11 +478,19 @@ impl Producer { } let waiter = OnceProducerMessageWaiter::waiter_with_cb(cb, message); - self.0 - .waiting_confirmations - .insert(publishing_id, ProducerMessageWaiter::Once(waiter)); + self.0.waiting_confirmations.insert( + publishing_id, + (msg.clone(), ProducerMessageWaiter::Once(waiter)), + ); if let Err(e) = self.0.sender.send(msg).await { + // `send` fails only when the receiver is closed, which means the TCP connection is broken. + // In this case, we forcefully close the producer and return an error. + // The current message will not be sent, but it is not lost: + // `on_closed` handler will be called, and it can resend the message if needed. + if let Err(err) = self.0.close().await { + error!(error = ?err, "Failed to close producer after send error"); + } return Err(ClientError::GenericError(Box::new(e)))?; } @@ -484,13 +525,18 @@ impl Producer { client_message.filter_value_extract(f.as_ref()) } + self.0.waiting_confirmations.insert( + publishing_id, + ( + client_message.clone(), + ProducerMessageWaiter::Shared(waiter.clone()), + ), + ); + // Queue the message for sending if let Err(e) = self.0.sender.send(client_message).await { return Err(ClientError::GenericError(Box::new(e)))?; } - self.0 - .waiting_confirmations - .insert(publishing_id, ProducerMessageWaiter::Shared(waiter.clone())); } Ok(()) @@ -503,11 +549,22 @@ impl Producer { pub async fn close(self) -> Result<(), ProducerCloseError> { self.0.close().await } + + pub async fn set_on_closed(&self, on_closed: Box) { + let mut on_closed_lock = self.0.on_closed.write().await; + *on_closed_lock = Some(on_closed); + } +} + +#[async_trait::async_trait] +pub trait OnClosed { + async fn on_closed(&self, unconfirmed: Vec); } struct ProducerConfirmHandler { waiting_confirmations: WaiterMap, metrics_collector: Arc, + on_closed: Arc>>>, } #[async_trait::async_trait] @@ -522,7 +579,8 @@ impl MessageHandler for ProducerConfirmHandler { for publishing_id in &confirm.publishing_ids { let id = *publishing_id; - let waiter = match self.waiting_confirmations.remove(publishing_id) { + let (_, waiter) = match self.waiting_confirmations.remove(publishing_id) + { Some((_, confirm_sender)) => confirm_sender, None => todo!(), }; @@ -559,7 +617,7 @@ impl MessageHandler for ProducerConfirmHandler { let code = err.error_code.clone(); let id = err.publishing_id; - let waiter = match self.waiting_confirmations.remove(&id) { + let (_, waiter) = match self.waiting_confirmations.remove(&id) { Some((_, confirm_sender)) => confirm_sender, None => todo!(), }; @@ -582,8 +640,23 @@ impl MessageHandler for ProducerConfirmHandler { // TODO clean all waiting for confirm } None => { - trace!("Connection closed"); - // TODO connection close clean all waiting + info!("Connection closed"); + let on_closed = self.on_closed.read().await; + if let Some(on_close) = &*on_closed { + let mut unconfirmed: Vec<(u64, Message)> = self + .waiting_confirmations + .iter() + .map(|entry| (*entry.key(), entry.value().0.clone().into_message())) + .collect(); + unconfirmed.sort_by_key(|(id, _)| *id); + + let unconfirmed: Vec = + unconfirmed.into_iter().map(|(_, msg)| msg).collect(); + + on_close.on_closed(unconfirmed).await; + } else { + warn!("No on_closed handler set, unconfirmed messages will be lost."); + } } } Ok(()) diff --git a/tests/consumer_test.rs b/tests/consumer_test.rs index 7fc07abe..4ec97ed1 100644 --- a/tests/consumer_test.rs +++ b/tests/consumer_test.rs @@ -8,10 +8,7 @@ use common::*; use fake::{Fake, Faker}; use futures::StreamExt; use rabbitmq_stream_client::{ - error::{ - ClientError, ConsumerCloseError, ConsumerDeliveryError, ConsumerStoreOffsetError, - ProducerCloseError, - }, + error::{ClientError, ConsumerCloseError, ConsumerDeliveryError, ConsumerStoreOffsetError}, types::{Delivery, Message, OffsetSpecification, SuperStreamConsumer}, Consumer, FilterConfiguration, NoDedup, Producer, }; diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 0e9c1858..99a464e3 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -3,12 +3,12 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use chrono::Utc; use fake::{Fake, Faker}; use futures::{lock::Mutex, StreamExt}; -use tokio::{sync::mpsc::channel, task::yield_now, time::sleep}; +use tokio::{sync::mpsc::channel, time::sleep}; use rabbitmq_stream_client::{ error::ClientError, types::{Message, OffsetSpecification, SimpleValue}, - Environment, + Environment, OnClosed, }; #[path = "./common.rs"] @@ -19,7 +19,6 @@ use common::*; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, }; -use tracing::span; use std::sync::atomic::{AtomicU32, Ordering}; use tokio::sync::Notify; @@ -784,3 +783,162 @@ async fn producer_drop() { let metrics = tokio::runtime::Handle::current().metrics(); assert_eq!(metrics.num_alive_tasks(), 0); } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_drop_connection_on_close() { + struct Foo { + notifier: Arc, + } + #[async_trait::async_trait] + impl OnClosed for Foo { + async fn on_closed(&self, _: Vec) { + self.notifier.notify_one(); + } + } + + let notifier = Arc::new(Notify::new()); + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .on_closed(Box::new(Foo { + notifier: notifier.clone(), + })) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + sleep(Duration::from_millis(500)).await; + + let connection = wait_for_named_connection(client_provided_name.clone()).await; + drop_connection(connection).await; + + notifier.notified().await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn producer_timeout() { + struct Foo { + notifier: Arc, + } + #[async_trait::async_trait] + impl OnClosed for Foo { + async fn on_closed(&self, _: Vec) { + self.notifier.notify_one(); + } + } + + let notifier = Arc::new(Notify::new()); + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .overwrite_heartbeat(1) + .on_closed(Box::new(Foo { + notifier: notifier.clone(), + })) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + sleep(Duration::from_millis(500)).await; + + let is_stopped = tokio::select! { + _ = notifier.notified() => true, + _ = sleep(Duration::from_secs(5)) => false, + }; + + assert!(is_stopped, "Producer did not stop after timeout"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn producer_got_back_unconfirmed_messages_on_close() { + struct Foo { + on_closed_sender: tokio::sync::mpsc::Sender>, + } + #[async_trait::async_trait] + impl OnClosed for Foo { + async fn on_closed(&self, unconfirmed: Vec) { + self.on_closed_sender.send(unconfirmed).await.unwrap(); + } + } + + let (on_closed_sender, mut on_closed_receiver) = tokio::sync::mpsc::channel(1); + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .on_closed(Box::new(Foo { on_closed_sender })) + .build(&env.stream) + .await + .unwrap(); + + let connection = wait_for_named_connection(client_provided_name).await; + + let (sender, receiver) = tokio::sync::oneshot::channel(); + + let join_handler = tokio::spawn(async move { + let mut sender = Some(sender); + for i in 0..100 { + if i == 2 { + if let Some(sender) = sender.take() { + // Simulate a delay to ensure the producer is closed before sending more messages + sender.send(()).unwrap(); + } + } + + let message = Message::builder().body(format!("{}", i)).build(); + if producer.send(message, |_| async {}).await.is_err() { + break; + } + } + }); + + receiver.await.unwrap(); + drop_connection(connection).await; + + // Wait for the above task ends + join_handler.await.unwrap(); + + let unconfirmed = on_closed_receiver.recv().await.unwrap(); + + // Some messages shouldn't be confirmed + assert!(!unconfirmed.is_empty()); + + // Check that the unconfirmed messages are in order + for couple in unconfirmed.windows(2) { + let first = couple[0].data().expect("First message should have data"); + let second = couple[1].data().expect("Second message should have data"); + let first_value = + String::from_utf8(first.to_vec()).expect("First message should be valid UTF-8"); + let second_value = + String::from_utf8(second.to_vec()).expect("Second message should be valid UTF-8"); + let first_number: u32 = first_value + .parse() + .expect("First message should be a number"); + let second_number: u32 = second_value + .parse() + .expect("Second message should be a number"); + + assert!(first_number < second_number, "Messages should be in order"); + } +}