diff --git a/services/tunnelbroker/src/amqp.rs b/services/tunnelbroker/src/amqp_client/amqp.rs rename from services/tunnelbroker/src/amqp.rs rename to services/tunnelbroker/src/amqp_client/amqp.rs diff --git a/services/tunnelbroker/src/amqp_client/mod.rs b/services/tunnelbroker/src/amqp_client/mod.rs new file mode 100644 --- /dev/null +++ b/services/tunnelbroker/src/amqp_client/mod.rs @@ -0,0 +1,249 @@ +use crate::amqp_client::amqp::{is_connection_error, AmqpConnection}; +use crate::constants::{ + error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, + MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, +}; +use crate::database::{DatabaseClient, MessageToDeviceExt}; +use crate::websockets::session::{DeviceInfo, SessionError}; +use futures_util::StreamExt; +use lapin::message::Delivery; +use lapin::options::{ + BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, + QueueDeclareOptions, QueueDeleteOptions, +}; +use lapin::types::FieldTable; +use lapin::BasicProperties; +use tracing::{debug, error, warn}; +use tunnelbroker_messages::{MessageToDevice, MessageToDeviceRequest}; + +pub mod amqp; + +pub struct AmqpClient { + db_client: DatabaseClient, + device_info: DeviceInfo, + amqp: AmqpConnection, + amqp_channel: lapin::Channel, + amqp_consumer: lapin::Consumer, +} + +async fn publish_persisted_messages( + db_client: &DatabaseClient, + amqp_channel: &lapin::Channel, + device_info: &DeviceInfo, +) -> Result<(), SessionError> { + let messages = db_client + .retrieve_messages(&device_info.device_id) + .await + .unwrap_or_else(|e| { + error!( + errorType = error_types::DDB_ERROR, + "Error while retrieving messages: {}", e + ); + Vec::new() + }); + + for message in messages { + let message_to_device = MessageToDevice::from_hashmap(message)?; + + let serialized_message = serde_json::to_string(&message_to_device)?; + + amqp_channel + .basic_publish( + "", + &message_to_device.device_id, + BasicPublishOptions::default(), + serialized_message.as_bytes(), + BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), + ) + .await?; + } + + debug!("Flushed messages for device: {}", &device_info.device_id); + Ok(()) +} + +impl AmqpClient { + pub async fn new( + db_client: DatabaseClient, + device_info: DeviceInfo, + amqp: AmqpConnection, + ) -> Result { + let (amqp_channel, amqp_consumer) = + Self::init_amqp(&device_info, &db_client, &amqp).await?; + + Ok(Self { + db_client, + device_info, + amqp, + amqp_channel, + amqp_consumer, + }) + } + + async fn init_amqp( + device_info: &DeviceInfo, + db_client: &DatabaseClient, + amqp: &AmqpConnection, + ) -> Result<(lapin::Channel, lapin::Consumer), SessionError> { + let amqp_channel = amqp.new_channel().await?; + debug!( + "Got AMQP Channel Id={} for device '{}'", + amqp_channel.id(), + device_info.device_id + ); + + let mut args = FieldTable::default(); + args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); + amqp_channel + .queue_declare( + &device_info.device_id, + QueueDeclareOptions::default(), + args, + ) + .await?; + + publish_persisted_messages(db_client, &amqp_channel, device_info).await?; + + let amqp_consumer = amqp_channel + .basic_consume( + &device_info.device_id, + RMQ_CONSUMER_TAG, + BasicConsumeOptions { + no_ack: true, + ..Default::default() + }, + FieldTable::default(), + ) + .await?; + Ok((amqp_channel, amqp_consumer)) + } + + fn is_amqp_channel_dead(&self) -> bool { + !self.amqp_channel.status().connected() + } + + pub async fn reset_failed_amqp(&mut self) -> Result<(), SessionError> { + if self.amqp_channel.status().connected() + && self.amqp_consumer.state().is_active() + { + return Ok(()); + } + debug!( + "Resetting failed amqp for session with {}", + &self.device_info.device_id + ); + + let (amqp_channel, amqp_consumer) = + Self::init_amqp(&self.device_info, &self.db_client, &self.amqp).await?; + + self.amqp_channel = amqp_channel; + self.amqp_consumer = amqp_consumer; + + Ok(()) + } + + async fn publish_amqp_message_to_device( + &mut self, + device_id: &str, + payload: &[u8], + ) -> Result { + if self.is_amqp_channel_dead() { + self.reset_failed_amqp().await?; + } + let publish_result = self + .amqp_channel + .basic_publish( + "", + device_id, + BasicPublishOptions::default(), + payload, + BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), + ) + .await?; + Ok(publish_result) + } + + pub async fn next_amqp_message( + &mut self, + ) -> Option> { + self.amqp_consumer.next().await + } + + pub async fn close_connection(&mut self) { + if self.is_amqp_channel_dead() { + warn!("AMQP channel or connection dead when closing WS session."); + self.amqp.maybe_reconnect_in_background(); + return; + } + if let Err(e) = self + .amqp_channel + .basic_cancel( + self.amqp_consumer.tag().as_str(), + BasicCancelOptions::default(), + ) + .await + { + if !is_connection_error(&e) { + error!( + errorType = error_types::AMQP_ERROR, + "Failed to cancel consumer: {}", e + ); + } + } + + if let Err(e) = self + .amqp_channel + .queue_delete( + self.device_info.device_id.as_str(), + QueueDeleteOptions::default(), + ) + .await + { + if !is_connection_error(&e) { + error!( + errorType = error_types::AMQP_ERROR, + "Failed to delete queue: {}", e + ); + } + } + } + + pub async fn handle_message_to_device( + &mut self, + message_request: &MessageToDeviceRequest, + ) -> Result<(), SessionError> { + let message_id = self + .db_client + .persist_message( + &message_request.device_id, + &message_request.payload, + &message_request.client_message_id, + ) + .await?; + + let message_to_device = MessageToDevice { + device_id: message_request.device_id.clone(), + payload: message_request.payload.clone(), + message_id: message_id.clone(), + }; + + let serialized_message = serde_json::to_string(&message_to_device)?; + + let publish_result = self + .publish_amqp_message_to_device( + &message_request.device_id, + serialized_message.as_bytes(), + ) + .await; + + if let Err(amqp_session_error) = publish_result { + self + .db_client + .delete_message(&self.device_info.device_id, &message_id) + .await + .expect("Error deleting message"); + return Err(amqp_session_error); + } + Ok(()) + } +} diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -11,7 +11,7 @@ use tracing::debug; use tunnelbroker_messages::MessageToDevice; -use crate::amqp::{AmqpChannel, AmqpConnection}; +use crate::amqp_client::amqp::{AmqpChannel, AmqpConnection}; use crate::constants::{CLIENT_RMQ_MSG_PRIORITY, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -1,4 +1,4 @@ -pub mod amqp; +pub mod amqp_client; pub mod config; pub mod constants; pub mod database; @@ -13,6 +13,7 @@ use crate::notifs::web_push::WebPushClient; use crate::notifs::wns::WNSClient; use crate::notifs::NotifClient; +use amqp_client::amqp; use anyhow::{anyhow, Result}; use config::CONFIG; use constants::{error_types, COMM_SERVICES_USE_JSON_LOGS}; diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,6 +1,6 @@ pub mod session; -use crate::amqp::AmqpConnection; +use crate::amqp_client::amqp::AmqpConnection; use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,40 +1,30 @@ -use crate::amqp::{is_connection_error, AmqpConnection}; -use crate::constants::{ - error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, - MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, -}; +use crate::amqp_client::amqp::AmqpConnection; +use crate::constants::error_types; use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; -use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; -use lapin::options::{ - BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, - QueueDeclareOptions, QueueDeleteOptions, -}; -use lapin::types::FieldTable; -use lapin::BasicProperties; use notifs::fcm::error::Error::FCMError as NotifsFCMError; use notifs::web_push::error::Error::WebPush as NotifsWebPushError; use notifs::wns::error::Error::WNSNotification as NotifsWNSError; use reqwest::Url; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, error, info, trace}; use tunnelbroker_messages::bad_device_token::BadDeviceToken; use tunnelbroker_messages::{ - message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - DeviceToTunnelbrokerMessage, Heartbeat, MessageToDevice, - MessageToDeviceRequest, MessageToTunnelbroker, + DeviceToTunnelbrokerMessage, Heartbeat, MessageToDeviceRequest, + MessageToTunnelbroker, }; use tunnelbroker_messages::{DeviceToTunnelbrokerRequestStatus, Platform}; use web_push::WebPushError; -use crate::database::{self, DatabaseClient, MessageToDeviceExt}; +use crate::amqp_client::AmqpClient; +use crate::database::{self, DatabaseClient}; use crate::notifs::apns::headers::NotificationHeaders; use crate::notifs::apns::APNsNotif; use crate::notifs::fcm::firebase_message::{ @@ -45,6 +35,7 @@ use crate::notifs::{apns, NotifClient, NotifClientType}; use crate::{identity, notifs}; +#[derive(Clone)] pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, @@ -58,10 +49,8 @@ tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, - amqp: AmqpConnection, - amqp_channel: lapin::Channel, - // Stream of messages from AMQP endpoint - amqp_consumer: lapin::Consumer, + // Each websocket has an AMQP connection associated with a particular device + amqp_client: AmqpClient, notif_client: NotifClient, } @@ -162,42 +151,6 @@ } } -async fn publish_persisted_messages( - db_client: &DatabaseClient, - amqp_channel: &lapin::Channel, - device_info: &DeviceInfo, -) -> Result<(), SessionError> { - let messages = db_client - .retrieve_messages(&device_info.device_id) - .await - .unwrap_or_else(|e| { - error!( - errorType = error_types::DDB_ERROR, - "Error while retrieving messages: {}", e - ); - Vec::new() - }); - - for message in messages { - let message_to_device = MessageToDevice::from_hashmap(message)?; - - let serialized_message = serde_json::to_string(&message_to_device)?; - - amqp_channel - .basic_publish( - "", - &message_to_device.device_id, - BasicPublishOptions::default(), - serialized_message.as_bytes(), - BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), - ) - .await?; - } - - debug!("Flushed messages for device: {}", &device_info.device_id); - Ok(()) -} - pub async fn get_device_info_from_frame( frame: Message, ) -> Result { @@ -221,9 +174,10 @@ amqp: AmqpConnection, notif_client: NotifClient, ) -> Result> { - let (amqp_channel, amqp_consumer) = - match Self::init_amqp(&device_info, &db_client, &amqp).await { - Ok(consumer) => consumer, + let amqp_client = + match AmqpClient::new(db_client.clone(), device_info.clone(), amqp).await + { + Ok(client) => client, Err(err) => return Err((err, tx)), }; @@ -231,133 +185,13 @@ tx, db_client, device_info, - amqp, - amqp_channel, - amqp_consumer, + amqp_client, notif_client, }) } - async fn init_amqp( - device_info: &DeviceInfo, - db_client: &DatabaseClient, - amqp: &AmqpConnection, - ) -> Result<(lapin::Channel, lapin::Consumer), SessionError> { - let amqp_channel = amqp.new_channel().await?; - debug!( - "Got AMQP Channel Id={} for device '{}'", - amqp_channel.id(), - device_info.device_id - ); - - let mut args = FieldTable::default(); - args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); - amqp_channel - .queue_declare( - &device_info.device_id, - QueueDeclareOptions::default(), - args, - ) - .await?; - - publish_persisted_messages(db_client, &amqp_channel, device_info).await?; - - let amqp_consumer = amqp_channel - .basic_consume( - &device_info.device_id, - RMQ_CONSUMER_TAG, - BasicConsumeOptions { - no_ack: true, - ..Default::default() - }, - FieldTable::default(), - ) - .await?; - Ok((amqp_channel, amqp_consumer)) - } - - fn is_amqp_channel_dead(&self) -> bool { - !self.amqp_channel.status().connected() - } - - async fn publish_amqp_message_to_device( - &mut self, - device_id: &str, - payload: &[u8], - ) -> Result { - if self.is_amqp_channel_dead() { - self.reset_failed_amqp().await?; - } - let publish_result = self - .amqp_channel - .basic_publish( - "", - device_id, - BasicPublishOptions::default(), - payload, - BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), - ) - .await?; - Ok(publish_result) - } - pub async fn reset_failed_amqp(&mut self) -> Result<(), SessionError> { - if self.amqp_channel.status().connected() - && self.amqp_consumer.state().is_active() - { - return Ok(()); - } - debug!( - "Resetting failed amqp for session with {}", - &self.device_info.device_id - ); - - let (amqp_channel, amqp_consumer) = - Self::init_amqp(&self.device_info, &self.db_client, &self.amqp).await?; - - self.amqp_channel = amqp_channel; - self.amqp_consumer = amqp_consumer; - - Ok(()) - } - - pub async fn handle_message_to_device( - &mut self, - message_request: &MessageToDeviceRequest, - ) -> Result<(), SessionError> { - let message_id = self - .db_client - .persist_message( - &message_request.device_id, - &message_request.payload, - &message_request.client_message_id, - ) - .await?; - - let message_to_device = MessageToDevice { - device_id: message_request.device_id.clone(), - payload: message_request.payload.clone(), - message_id: message_id.clone(), - }; - - let serialized_message = serde_json::to_string(&message_to_device)?; - - let publish_result = self - .publish_amqp_message_to_device( - &message_request.device_id, - serialized_message.as_bytes(), - ) - .await; - - if let Err(amqp_session_error) = publish_result { - self - .db_client - .delete_message(&self.device_info.device_id, &message_id) - .await - .expect("Error deleting message"); - return Err(amqp_session_error); - } - Ok(()) + self.amqp_client.reset_failed_amqp().await } pub async fn handle_message_to_tunnelbroker( @@ -473,7 +307,10 @@ } debug!("Received message for {}", message_request.device_id); - let result = self.handle_message_to_device(&message_request).await; + let result = self + .amqp_client + .handle_message_to_device(&message_request) + .await; Some(MessageSentStatus::from_result( &message_request.client_message_id, result, @@ -766,7 +603,7 @@ pub async fn next_amqp_message( &mut self, ) -> Option> { - self.amqp_consumer.next().await + self.amqp_client.next_amqp_message().await } pub async fn send_message_to_device(&mut self, message: Message) { @@ -788,43 +625,7 @@ debug!("Failed to close WebSocket session: {}", e); } - if self.is_amqp_channel_dead() { - warn!("AMQP channel or connection dead when closing WS session."); - self.amqp.maybe_reconnect_in_background(); - return; - } - - if let Err(e) = self - .amqp_channel - .basic_cancel( - self.amqp_consumer.tag().as_str(), - BasicCancelOptions::default(), - ) - .await - { - if !is_connection_error(&e) { - error!( - errorType = error_types::AMQP_ERROR, - "Failed to cancel consumer: {}", e - ); - } - } - - if let Err(e) = self - .amqp_channel - .queue_delete( - self.device_info.device_id.as_str(), - QueueDeleteOptions::default(), - ) - .await - { - if !is_connection_error(&e) { - error!( - errorType = error_types::AMQP_ERROR, - "Failed to delete queue: {}", e - ); - } - } + self.amqp_client.close_connection().await; } async fn get_device_token( @@ -868,7 +669,10 @@ payload, }; - self.handle_message_to_device(&message_request).await?; + self + .amqp_client + .handle_message_to_device(&message_request) + .await?; self .db_client diff --git a/shared/tunnelbroker_messages/src/messages/session.rs b/shared/tunnelbroker_messages/src/messages/session.rs --- a/shared/tunnelbroker_messages/src/messages/session.rs +++ b/shared/tunnelbroker_messages/src/messages/session.rs @@ -18,7 +18,7 @@ /// messages to device /// - Tunnelbroker then polls for incoming messages from device -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] #[serde(rename_all = "camelCase")] pub enum DeviceTypes { Mobile,