diff --git a/services/tunnelbroker/src/amqp.rs b/services/tunnelbroker/src/amqp.rs --- a/services/tunnelbroker/src/amqp.rs +++ b/services/tunnelbroker/src/amqp.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock}; use std::time::Duration; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::constants::{error_types, NUM_AMQP_CHANNELS}; use crate::CONFIG; @@ -139,6 +139,7 @@ #[derive(Clone)] pub struct AmqpConnection { inner: Arc>, + is_connecting: Arc, } impl AmqpConnection { @@ -147,8 +148,59 @@ let inner = Arc::new(RwLock::new(conn)); info!("Connected to AMQP endpoint: {}", &CONFIG.amqp_uri); - Ok(Self { inner }) + let is_connecting = AtomicBool::new(false); + Ok(Self { + inner, + is_connecting: Arc::new(is_connecting), + }) + } + + pub async fn channel( + &self, + id_hash: impl std::hash::Hash, + ) -> Result { + if !self.is_connected() { + warn!("AMQP disconnected while retrieving channel"); + self.reset_conn().await?; + } + self.inner.read().unwrap().get_channel(id_hash) + } + + async fn reset_conn(&self) -> Result<(), lapin::Error> { + if let Ok(false) = self.is_connecting.compare_exchange( + false, + true, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + debug!("Resetting connection..."); + let new_conn = ConnectionInner::new().await?; + let mut inner = self.inner.write().unwrap(); + if !inner.is_connected() { + *inner = new_conn; + info!("AMQP Connection restored."); + } + + self.is_connecting.store(false, Ordering::Relaxed); + } else { + debug!("Already resetting on other thread"); + while self.is_connecting() { + tokio::time::sleep(Duration::from_millis(10)).await; + } + if !self.is_connected() { + // other thread failed to reset + let state = self.inner.read().unwrap().raw().status().state(); + warn!("Other thread failed to reset. State: {:?}", state); + return Err(lapin::Error::InvalidConnectionState(state)); + } + } + Ok(()) + } + + fn is_connecting(&self) -> bool { + self.is_connecting.load(Ordering::Relaxed) } + fn is_connected(&self) -> bool { self.inner.read().unwrap().is_connected() }