10000 OnClosed callback for producer by allevo · Pull Request #293 · rabbitmq/rabbitmq-stream-rust-client · GitHub
[go: up one dir, main page]

Skip to content

OnClosed callback for producer #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/ha_producer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use rabbitmq_stream_client::error::{ProducerPublishError, StreamCreateError};
use rabbitmq_stream_client::types::{ByteCapacity, Message, ResponseCode};
use rabbitmq_stream_client::Environment;
use rabbitmq_stream_client::{
ConfirmationStatus, NoDedup, OnClosed, Producer, RabbitMQStreamResult,
};
use tokio::sync::RwLock;

struct MyHAProducer {
environment: Environment,
stream: String,
producer: RwLock<Producer<NoDedup>>,
}

#[async_trait::async_trait]
impl OnClosed for MyHAProducer {
async fn on_closed(&self, unconfirmed: Vec<Message>) {
let mut producer = self.producer.write().await;

let new_producer = self
.environment
.producer()
.build(&self.stream)
.await
.unwrap();

if let Err(e) = producer.batch_send_with_confirm(unconfirmed).await {
eprintln!("Error resending unconfirmed messages: {:?}", e);
}

*producer = new_producer;
}
}

impl MyHAProducer {
async fn new(environment: Environment, stream: &str) -> RabbitMQStreamResult<Self> {
ensure_stream_exists(&environment, stream).await?;

let producer = environment.producer().build(stream).await.unwrap();

Ok(Self {
environment,
stream: stream.to_string(),
producer: RwLock::new(producer),
})
}

async fn send_with_confirm(
&self,
message: Message,
) -> Result<ConfirmationStatus, ProducerPublishError> {
let producer = self.producer.read().await;
producer.send_with_confirm(message).await
}
}

async fn ensure_stream_exists(environment: &Environment, stream: &str) -> RabbitMQStreamResult<()> {
let create_response = environment
.stream_creator()
.max_length(ByteCapacity::GB(5))
.create(stream)
.await;

if let Err(e) = create_response {
if let StreamCreateError::Create { stream, status } = e {
match status {
// we can ignore this error because the stream already exists
ResponseCode::StreamAlreadyExists => {}
err => {
panic!("Error creating stream: {:?} {:?}", stream, err);
}
}
}
}

Ok(())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let environment = Environment::builder().build().await?;
let stream = "hello-rust-stream";

let producer = MyHAProducer::new(environment, stream).await?;

producer
.send_with_confirm(Message::builder().body("Hello, world!").build())
.await?;

/*
let number_of_messages = 1000000;
for i in 0..number_of_messages {
let msg = Message::builder()
.body(format!("stream message_{}", i))
.build();
producer.send_with_confirm(msg).await?;
}
producer.close().await?;
*/

Ok(())
}
32 changes: 16 additions & 16 deletions src/client/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,6 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
#[derive(Clone)]
pub(crate) struct Dispatcher<T>(DispatcherState<T>);

pub(crate) struct DispatcherState<T> {
requests: Arc<RequestsMap>,
correlation_id: Arc<AtomicU32>,
handler: Arc<RwLock<Option<T>>>,
}

impl<T> Clone for DispatcherState<T> {
fn clone(&self) -> Self {
DispatcherState {
requests: self.requests.clone(),
correlation_id: self.correlation_id.clone(),
handler: self.handler.clone(),
}
}
}

struct RequestsMap {
requests: DashMap<u32, Sender<Response>>,
closed: AtomicBool,
Expand Down Expand Up @@ -126,6 +110,22 @@ where
}
}

pub(crate) struct DispatcherState<T> {
requests: Arc<RequestsMap>,
correlation_id: Arc<AtomicU32>,
handler: Arc<RwLock<Option<T>>>,
}

impl<T> Clone for DispatcherState<T> {
fn clone(&self) -> Self {
DispatcherState {
requests: self.requests.clone(),
correlation_id: self.correlation_id.clone(),
handler: self.handler.clone(),
}
}
}

impl<T> DispatcherState<T>
where
T: MessageHandler,
Expand Down
6 changes: 5 additions & 1 deletion src/client/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl BaseMessage for Message {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ClientMessage {
publishing_id: u64,
message: Message,
Expand All @@ -39,6 +39,10 @@ impl ClientMessage {
pub fn filter_value_extract(&mut self, filter_value_extractor: impl Fn(&Message) -> String) {
self.filter_value = Some(filter_value_extractor(&self.message));
}

pub fn into_message(self) -> Message {
self.message
}
}

impl BaseMessage for ClientMessage {
Expand Down
143 changes: 97 additions & 46 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::ops::DerefMut;
use std::{
collections::HashMap,
io,
Expand Down Expand Up @@ -142,42 +141,7 @@
max_frame_size: u32,
last_heatbeat: Instant,
heartbeat_task: Option<task::TaskHandle>,
}

#[async_trait::async_trait]
impl MessageHandler for Client {
async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
match &item {
Some(Ok(response)) => match response.kind_ref() {
ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
_ => {
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();

tokio::task::spawn(async move { handler.handle_message(item).await });
}
}
},
Some(Err(err)) => {
trace!(?err);
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();

tokio::task::spawn(async move { handler.handle_message(item).await });
}
}
None => {
trace!("Closing client");
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();
tokio::task::spawn(async move { handler.handle_message(None).await });
}
}
}

Ok(())
}
last_received_message: Arc<RwLock<Instant>>,
}

/// Raw API for taking to RabbitMQ stream
Expand All @@ -201,8 +165,9 @@

let (sender, receiver) = Client::create_connection(&broker).await?;

let dispatcher = Dispatcher::new();
let last_received_message = Arc::new(RwLock::new(Instant::now()));

let dispatcher = Dispatcher::new();
let state = ClientState {
server_properties: HashMap::new(),
connection_properties: HashMap::new(),
Expand All @@ -211,6 +176,7 @@
max_frame_size: broker.max_frame_size,
last_heatbeat: Instant::now(),
heartbeat_task: None,
last_received_message: last_received_message.clone(),
};
let mut client = Client {
dispatcher,
Expand Down Expand Up @@ -519,6 +485,15 @@
self.filtering_supported
}

/// Don't use this method in production code.
pub async fn set_heartbeat(&self, heartbeat: u32) {
let mut state = self.state.write().await;
state.heartbeat = heartbeat;
// Eventually, this drops the previous heartbeat task
state.heartbeat_task =
self.start_hearbeat_task(heartbeat, state.last_received_message.clone());
}

async fn create_connection(
broker: &ClientOptions,
) -> Result<
Expand All @@ -536,6 +511,7 @@

Ok((tx, rx))
}

async fn initialize<T>(&mut self, receiver: ChannelReceiver<T>) -> Result<(), ClientError>
where
T: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
Expand All @@ -558,7 +534,10 @@
.await?;

// Start heartbeat task after connection is established
self.start_hearbeat_task(self.state.write().await.deref_mut());
let mut state = self.state.write().await;
state.heartbeat_task =
self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
drop(state);

Ok(())
}
Expand Down Expand Up @@ -697,7 +676,9 @@
);

if state.heartbeat_task.take().is_some() {
self.start_hearbeat_task(&mut state);
// Start heartbeat task after connection is established
state.heartbeat_task =
self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());

Check warning on line 681 in src/client/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/client/mod.rs#L679-L681

Added lines #L679 - L681 were not covered by tests
}

drop(state);
Expand All @@ -710,13 +691,22 @@
self.tune_notifier.notify_one();
}

fn start_hearbeat_task(&self, state: &mut ClientState) {
if state.heartbeat == 0 {
return;
fn start_hearbeat_task(
&self,
heartbeat: u32,
last_received_message: Arc<RwLock<Instant>>,
) -> Option<task::TaskHandle> {
if heartbeat == 0 {
return None;

Check warning on line 700 in src/client/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/client/mod.rs#L700

Added line #L700 was not covered by tests
}
let heartbeat_interval = (state.heartbeat / 2).max(1);
let heartbeat_interval = (heartbeat / 2).max(1);
let channel = self.channel.clone();
let heartbeat_task = tokio::spawn(async move {

let client = self.clone();

let heartbeat_task: task::TaskHandle = tokio::spawn(async move {
let timeout_threashold = u64::from(heartbeat * 4);

loop {
trace!("Sending heartbeat");
if channel
Expand All @@ -727,11 +717,25 @@
break;
}
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;

let now = Instant::now();
let last_message = last_received_message.read().await;
if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) {
warn!("Heartbeat timeout reached. Force closing connection.");
if !client.is_closed() {
if let Err(e) = client.close().await {
warn!("Error closing client: {}", e);

Check warning on line 727 in src/client/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/client/mod.rs#L727

Added line #L727 was not covered by tests
}
}

Check warning on line 729 in src/client/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/client/mod.rs#L729

Added line #L729 was not covered by tests
break;
}
}

warn!("Heartbeat task stopped. Force closing connection");
})
.into();
state.heartbeat_task = Some(heartbeat_task);

Some(heartbeat_task)
}

async fn handle_heart_beat_command(&self) {
Expand All @@ -751,3 +755,50 @@
.await
}
}

#[async_trait::async_trait]
impl MessageHandler for Client {
async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
match &item {
Some(Ok(response)) => {
// Update last received message time: needed for heartbeat task
{
let s = self.state.read().await;
let mut last_received_message = s.last_received_message.write().await;
*last_received_message = Instant::now();
drop(last_received_message);
drop(s);
}

match response.kind_ref() {
ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
_ => {
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();

tokio::task::spawn(async move { handler.handle_message(item).await });
}
}
}
}
Some(Err(err)) => {
trace!(?err);
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();

tokio::task::spawn(async move { handler.handle_message(item).await });
}

Check warning on line 791 in src/client/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/client/mod.rs#L785-L791

Added lines #L785 - L791 were not covered by tests
}
None => {
trace!("Closing client");
if let Some(handler) = self.state.read().await.handler.as_ref() {
let handler = handler.clone();
tokio::task::spawn(async move { handler.handle_message(None).await });
}
}
}

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ impl Environment {
data: PhantomData,
filter_value_extractor: None,
client_provided_name: String::from("rust-stream-producer"),
on_closed: None,
overwrite_heartbeat: None,
}
}

Expand Down
Loading
Loading
0