diff --git a/services/tunnelbroker/src/amqp.rs b/services/tunnelbroker/src/amqp.rs --- a/services/tunnelbroker/src/amqp.rs +++ b/services/tunnelbroker/src/amqp.rs @@ -1,10 +1,13 @@ use comm_lib::database::batch_operations::ExponentialBackoffConfig; use lapin::{uri::AMQPUri, Connection, ConnectionProperties}; use once_cell::sync::Lazy; +use std::hash::Hasher; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, RwLock}; use std::time::Duration; -use tracing::info; +use tracing::{debug, error, info, warn}; -use crate::constants::error_types; +use crate::constants::{error_types, NUM_AMQP_CHANNELS}; use crate::CONFIG; static AMQP_URI: Lazy = Lazy::new(|| { @@ -24,45 +27,187 @@ amqp_uri }); -pub async fn connect() -> Connection { +async fn create_connection() -> Result { let options = ConnectionProperties::default() .with_executor(tokio_executor_trait::Tokio::current()) .with_reactor(tokio_reactor_trait::Tokio); let retry_config = ExponentialBackoffConfig { - max_attempts: 5, + max_attempts: 8, base_duration: Duration::from_millis(500), ..Default::default() }; let mut retry_counter = retry_config.new_counter(); tracing::debug!("Attempting to connect to AMQP..."); - let conn_result = loop { + loop { let amqp_uri = Lazy::force(&AMQP_URI).clone(); match lapin::Connection::connect_uri(amqp_uri, options.clone()).await { - Ok(conn) => break Ok(conn), + Ok(conn) => return Ok(conn), Err(err) => { let attempt = retry_counter.attempt(); tracing::warn!(attempt, "AMQP connection attempt failed: {err}."); if retry_counter.sleep_and_retry().await.is_err() { tracing::error!("Unable to connect to AMQP: {err}"); - break Err(err); + return Err(err); } } } - }; + } +} + +/// Inner connection that is a direct wrapper over lapin::Connection +/// This should be instantiated only once +struct ConnectionInner { + conn: lapin::Connection, + // channel pool + channels: [lapin::Channel; NUM_AMQP_CHANNELS], +} + +impl ConnectionInner { + async fn new() -> Result { + let conn = create_connection().await?; + conn.on_error(|err| { + // TODO: we should filter out some IOErrors here to avoid spamming alerts + error!(errorType = error_types::AMQP_ERROR, "Lapin error: {err:?}"); + }); + + debug!("Creating channels..."); + let mut channels = Vec::with_capacity(NUM_AMQP_CHANNELS); + for idx in 0..NUM_AMQP_CHANNELS { + let channel = conn.create_channel().await?; + tracing::trace!("Creating channel ID={} at index={}", channel.id(), idx); + channels.push(channel); + } + + Ok(Self { + conn, + channels: channels + .try_into() + .expect("Channels vec size doesn't match array size"), + }) + } + + pub fn get_channel( + &self, + id_hash: impl std::hash::Hash, + ) -> Result { + // We have channel pool and want to distribute them between connected + // devices. Round robin would work too, but by using "hash modulo N" + // we make sure the same device will always use the same channel. + // Generally this shouldn't matter, but helps avoiding potential issues + // with the same queue name being declared by different channels, + // in case of reconnection. + let mut hasher = std::hash::DefaultHasher::new(); + id_hash.hash(&mut hasher); + let channel_idx: usize = hasher.finish() as usize % NUM_AMQP_CHANNELS; + + let channel = self.channels[channel_idx].clone(); + let channel_id = channel.id(); + tracing::trace!(channel_id, channel_idx, "Retrieving AMQP Channel"); + Ok(channel) + } + + fn is_connected(&self) -> bool { + self.conn.status().connected() + } + + fn raw(&self) -> &lapin::Connection { + &self.conn + } +} + +/// Thread safe connection wrapper that is Clone + Send + Sync +/// and can be shared wherever needed +#[derive(Clone)] +pub struct AmqpConnection { + inner: Arc>, + is_connecting: Arc, +} - let conn = conn_result.expect("Unable to connect to AMQP. Exiting."); - conn.on_error(|error| { - tracing::error!( - errorType = error_types::AMQP_ERROR, - "Lapin error: {error:?}" - ); - }); +impl AmqpConnection { + pub async fn connect() -> Result { + let is_connecting = AtomicBool::new(true); + let conn = ConnectionInner::new().await?; + let inner = Arc::new(RwLock::new(conn)); + is_connecting.store(false, Ordering::Relaxed); + info!("Connected to AMQP endpoint: {}", &CONFIG.amqp_uri); + Ok(Self { + inner, + is_connecting: Arc::new(is_connecting), + }) + } + + pub async fn channel( + &self, + id_hash: impl std::hash::Hash, + ) -> Result { + if !self.is_connected() { + warn!("AMQP disconnected while retrieving channel"); + self.reset_conn().await?; + } + self.inner.read().unwrap().get_channel(id_hash) + } + + async fn reset_conn(&self) -> Result<(), lapin::Error> { + if let Ok(false) = self.is_connecting.compare_exchange( + false, + true, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + debug!("Resetting connection..."); + let new_conn = ConnectionInner::new().await?; + let mut inner = self.inner.write().unwrap(); + if !inner.is_connected() { + *inner = new_conn; + info!("AMQP Connection restored."); + } + + self.is_connecting.store(false, Ordering::Relaxed); + } else { + debug!("Already resetting on other thread"); + while self.is_connecting() { + tokio::time::sleep(Duration::from_millis(10)).await; + } + if !self.is_connected() { + // other thread failed to reset + let state = self.inner.read().unwrap().raw().status().state(); + warn!("Other thread failed to reset. State: {:?}", state); + return Err(lapin::Error::InvalidConnectionState(state)); + } + } + Ok(()) + } + + /// Triggers reconnecting in background, without awaiting + pub fn trigger_reconnect(&self) { + if !self.is_connected() && !self.is_connecting() { + let this = self.clone(); + tokio::spawn(async move { + if let Err(err) = this.reset_conn().await { + tracing::warn!("AMQP background reconnect failed: {:?}", err); + } + }); + } + } + + fn is_connecting(&self) -> bool { + self.is_connecting.load(Ordering::Relaxed) + } + + fn is_connected(&self) -> bool { + self.inner.read().unwrap().is_connected() + } +} - info!("Connected to AMQP endpoint: {}", &CONFIG.amqp_uri); - conn +pub fn is_connection_error(err: &lapin::Error) -> bool { + matches!( + err, + lapin::Error::InvalidChannelState(_) + | lapin::Error::InvalidConnectionState(_) + ) } fn from_env(var_name: &str) -> Option { diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -7,6 +7,7 @@ pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); +pub const NUM_AMQP_CHANNELS: usize = 8; pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -11,13 +11,14 @@ use tracing::debug; use tunnelbroker_messages::MessageToDevice; +use crate::amqp::AmqpConnection; use crate::constants::{CLIENT_RMQ_MSG_PRIORITY, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp: AmqpConnection, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { @@ -57,7 +58,10 @@ .map_err(|_| tonic::Status::invalid_argument("Invalid argument"))?; self - .amqp_channel + .amqp + .channel("grpc") + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, @@ -81,7 +85,10 @@ debug!("Connection close request for device {}", &message.device_id); self - .amqp_channel + .amqp + .channel("grpc") + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, @@ -122,24 +129,19 @@ pub async fn run_server( client: DatabaseClient, - ampq_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); - let amqp_channel = ampq_connection - .create_channel() - .await - .expect("Unable to create amqp channel"); - tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, - amqp_channel, + amqp: amqp_connection.clone(), })) .serve(addr) .await diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -48,7 +48,9 @@ config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; let db_client = database::DatabaseClient::new(&aws_config); - let amqp_connection = amqp::connect().await; + let amqp_connection = amqp::AmqpConnection::connect() + .await + .expect("Failed to create AMQP connection"); let apns_config = CONFIG.apns_config.clone(); diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,9 +1,10 @@ pub mod session; +use crate::amqp::AmqpConnection; use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; -use crate::websockets::session::{initialize_amqp, SessionError}; +use crate::websockets::session::{handle_first_ws_frame, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; @@ -18,7 +19,7 @@ use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; use tunnelbroker_messages::{ ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, Heartbeat, MessageSentStatus, @@ -39,7 +40,7 @@ /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, - channel: lapin::Channel, + amqp: AmqpConnection, db_client: DatabaseClient, notif_client: NotifClient, } @@ -62,7 +63,7 @@ fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); - let channel = self.channel.clone(); + let amqp = self.amqp.clone(); let notif_client = self.notif_client.clone(); let future = async move { @@ -72,7 +73,7 @@ // Spawn a task to handle the websocket connection. tokio::spawn(async move { - accept_connection(websocket, addr, db_client, channel, notif_client) + accept_connection(websocket, addr, db_client, amqp, notif_client) .await; }); @@ -101,7 +102,7 @@ pub async fn run_server( db_client: DatabaseClient, - amqp_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") @@ -115,15 +116,12 @@ http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { - let channel = amqp_connection - .create_channel() - .await - .expect("Failed to create AMQP channel"); + let amqp = amqp_connection.clone(); let connection = http .serve_connection( stream, WebsocketService { - channel, + amqp, db_client: db_client.clone(), addr, notif_client: notif_client.clone(), @@ -169,7 +167,7 @@ hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp_connection: AmqpConnection, notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); @@ -194,7 +192,7 @@ outgoing, first_msg, db_client, - amqp_channel, + amqp_connection, notif_client, ) .await @@ -231,16 +229,31 @@ loop { trace!("Polling for messages from: {}", addr); tokio::select! { - Some(Ok(delivery)) = session.next_amqp_message() => { - if let Ok(message) = std::str::from_utf8(&delivery.data) { - if message == WS_SESSION_CLOSE_AMQP_MSG { - debug!("Connection to {} closed by server.", addr); + Some(delivery_result) = session.next_amqp_message() => { + match delivery_result { + Ok(delivery) => { + if let Ok(message) = std::str::from_utf8(&delivery.data) { + if message == WS_SESSION_CLOSE_AMQP_MSG { + debug!("Connection to {} closed by server.", addr); + break; + } else { + session.send_message_to_device(Message::Text(message.to_string())).await; + } + } else { + error!("Invalid payload"); + } + }, + Err(ref err) if crate::amqp::is_connection_error(err) => { + if let Err(e) = session.reset_failed_amqp().await { + warn!("Connection to {} closed due to failed AMQP restoration: {:?}", addr, e); + break; + } + continue; + } + Err(err) => { + warn!("Connection to {} closed due to AMQP error: {:?}", addr, err); break; - } else { - session.send_message_to_device(Message::Text(message.to_string())).await; } - } else { - error!("Invalid payload"); } }, device_message = incoming.next() => { @@ -316,21 +329,14 @@ outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp: AmqpConnection, notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { - let initialized_session = - initialize_amqp(db_client.clone(), frame, &amqp_channel).await; + let device_info = match handle_first_ws_frame(frame).await { + Ok(info) => info, + Err(e) => return Err((e, outgoing)), + }; - match initialized_session { - Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( - outgoing, - db_client, - device_info, - amqp_channel, - amqp_consumer, - notif_client, - )), - Err(e) => Err((e, outgoing)), - } + WebsocketSession::new(outgoing, db_client, device_info, amqp, notif_client) + .await } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,3 +1,4 @@ +use crate::amqp::{is_connection_error, AmqpConnection}; use crate::constants::{ error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, @@ -24,7 +25,7 @@ use reqwest::Url; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; use tunnelbroker_messages::bad_device_token::BadDeviceToken; use tunnelbroker_messages::Platform; use tunnelbroker_messages::{ @@ -48,6 +49,7 @@ use crate::notifs::{apns, NotifClient, NotifClientType}; use crate::{identity, notifs}; +#[derive(Clone)] pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, @@ -61,6 +63,7 @@ tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, + amqp: AmqpConnection, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, @@ -200,11 +203,9 @@ Ok(()) } -pub async fn initialize_amqp( - db_client: DatabaseClient, +pub async fn handle_first_ws_frame( frame: Message, - amqp_channel: &lapin::Channel, -) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { +) -> Result { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? @@ -214,42 +215,93 @@ return Err(SessionError::InvalidMessage); } }; - - let mut args = FieldTable::default(); - args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); - amqp_channel - .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) - .await?; - - publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; - - let amqp_consumer = amqp_channel - .basic_consume( - &device_info.device_id, - RMQ_CONSUMER_TAG, - BasicConsumeOptions::default(), - FieldTable::default(), - ) - .await?; - Ok((device_info, amqp_consumer)) + Ok(device_info) } impl WebsocketSession { - pub fn new( + pub async fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, - amqp_channel: lapin::Channel, - amqp_consumer: lapin::Consumer, + amqp: AmqpConnection, notif_client: NotifClient, - ) -> Self { - Self { + ) -> Result> { + let (amqp_channel, amqp_consumer) = + match Self::init_amqp(&device_info, &db_client, &amqp).await { + Ok(consumer) => consumer, + Err(err) => return Err((err, tx)), + }; + + Ok(Self { tx, db_client, device_info, + amqp, amqp_channel, amqp_consumer, notif_client, + }) + } + + async fn init_amqp( + device_info: &DeviceInfo, + db_client: &DatabaseClient, + amqp: &AmqpConnection, + ) -> Result<(lapin::Channel, lapin::Consumer), SessionError> { + let amqp_channel = amqp.channel(&device_info.device_id).await?; + debug!( + "Got AMQP Channel Id={} for device '{}'", + amqp_channel.id(), + device_info.device_id + ); + + let mut args = FieldTable::default(); + args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); + amqp_channel + .queue_declare( + &device_info.device_id, + QueueDeclareOptions::default(), + args, + ) + .await?; + + publish_persisted_messages(db_client, &amqp_channel, device_info).await?; + + // cancel previous consumer. If not done, Rabbit yells that + // "trying to reuse tag" and closes channels. + if let Err(e) = amqp_channel + .basic_cancel(RMQ_CONSUMER_TAG, BasicCancelOptions::default()) + .await + { + warn!( + errorType = error_types::AMQP_ERROR, + "Failed to cancel previous consumer: {}", e + ); } + + let amqp_consumer = amqp_channel + .basic_consume( + &device_info.device_id, + RMQ_CONSUMER_TAG, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await?; + Ok((amqp_channel, amqp_consumer)) + } + + pub async fn reset_failed_amqp(&mut self) -> Result<(), SessionError> { + debug!( + "Resetting failed amqp for session with {}", + &self.device_info.device_id + ); + + let (amqp_channel, amqp_consumer) = + Self::init_amqp(&self.device_info, &self.db_client, &self.amqp).await?; + + self.amqp_channel = amqp_channel; + self.amqp_consumer = amqp_consumer; + + Ok(()) } pub async fn handle_message_to_device( @@ -704,10 +756,15 @@ ) .await { - error!( - errorType = error_types::AMQP_ERROR, - "Failed to cancel consumer: {}", e - ); + if is_connection_error(&e) { + warn!("AMQP connection dead when closing WS session."); + self.amqp.trigger_reconnect(); + } else { + error!( + errorType = error_types::AMQP_ERROR, + "Failed to cancel consumer: {}", e + ); + } } if let Err(e) = self @@ -718,10 +775,15 @@ ) .await { - error!( - errorType = error_types::AMQP_ERROR, - "Failed to delete queue: {}", e - ); + if is_connection_error(&e) { + warn!("AMQP connection dead when closing WS session."); + self.amqp.trigger_reconnect(); + } else { + error!( + errorType = error_types::AMQP_ERROR, + "Failed to delete queue: {}", e + ); + } } } diff --git a/shared/tunnelbroker_messages/src/messages/session.rs b/shared/tunnelbroker_messages/src/messages/session.rs --- a/shared/tunnelbroker_messages/src/messages/session.rs +++ b/shared/tunnelbroker_messages/src/messages/session.rs @@ -18,7 +18,7 @@ /// messages to device /// - Tunnelbroker then polls for incoming messages from device -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "camelCase")] pub enum DeviceTypes { Mobile,