diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index ca3b163bd..b24c12d64 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,64 +1,78 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; pub const WS_SESSION_CLOSE_AMQP_MSG: &str = "SessionClose"; pub const ENV_APNS_CONFIG: &str = "APNS_CONFIG"; pub const ENV_FCM_CONFIG: &str = "FCM_CONFIG"; pub const ENV_WEB_PUSH_CONFIG: &str = "WEB_PUSH_CONFIG"; pub const ENV_WNS_CONFIG: &str = "WNS_CONFIG"; pub const COMM_SERVICES_USE_JSON_LOGS: &str = "COMM_SERVICES_USE_JSON_LOGS"; pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub const FCM_ACCESS_TOKEN_GENERATION_THRESHOLD: u64 = 5 * 60; pub const PUSH_SERVICE_REQUEST_TIMEOUT: Duration = Duration::from_secs(8); pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } // This table holds a device token associated with a device. // // - (primary key) = (deviceID: Partition Key) // - deviceID: The public key of a device's olm identity key. // - deviceToken: Token to push services uploaded by device. // - tokenInvalid: Information is token is invalid. pub mod device_tokens { pub const TABLE_NAME: &str = "tunnelbroker-device-tokens"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const DEVICE_TOKEN: &str = "deviceToken"; pub const TOKEN_INVALID: &str = "tokenInvalid"; pub const PLATFORM: &str = "platform"; pub const DEVICE_TOKEN_INDEX_NAME: &str = "deviceToken-index"; } } + +// Log Error Types + +pub mod error_types { + pub const AMQP_ERROR: &str = "AMQP Error"; + pub const DDB_ERROR: &str = "DDB Error"; + pub const FCM_ERROR: &str = "FCM Error"; + pub const APNS_ERROR: &str = "APNs Error"; + pub const WEB_PUSH_ERROR: &str = "Web Push Error"; + pub const WNS_ERROR: &str = "WNS Error"; + pub const IDENTITY_ERROR: &str = "Identity Error"; + pub const WEBSOCKET_ERROR: &str = "Websocket Error"; + pub const SERVER_ERROR: &str = "Server Error"; +} diff --git a/services/tunnelbroker/src/database/message.rs b/services/tunnelbroker/src/database/message.rs index bae00611b..dfea56d70 100644 --- a/services/tunnelbroker/src/database/message.rs +++ b/services/tunnelbroker/src/database/message.rs @@ -1,40 +1,45 @@ use comm_lib::database::{AttributeExtractor, AttributeMap, DBItemError}; use tunnelbroker_messages::MessageToDevice; use crate::constants::dynamodb::undelivered_messages::{ DEVICE_ID, MESSAGE_ID, PAYLOAD, }; +use crate::constants::error_types; #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum MessageErrors { SerializationError, } impl From for MessageErrors { fn from(err: DBItemError) -> Self { - tracing::error!("Failed to extract MessageToDevice attribute: {:?}", err); + tracing::error!( + errorType = error_types::DDB_ERROR, + "Failed to extract MessageToDevice attribute: {:?}", + err + ); MessageErrors::SerializationError } } pub trait MessageToDeviceExt { fn from_hashmap( hashmap: AttributeMap, ) -> Result; } impl MessageToDeviceExt for MessageToDevice { fn from_hashmap( mut hashmap: AttributeMap, ) -> Result { let device_id: String = hashmap.take_attr(DEVICE_ID)?; let message_id: String = hashmap.take_attr(MESSAGE_ID)?; let payload: String = hashmap.take_attr(PAYLOAD)?; Ok(MessageToDevice { device_id, message_id, payload, }) } } diff --git a/services/tunnelbroker/src/database/mod.rs b/services/tunnelbroker/src/database/mod.rs index dd78d45ca..efc0fe8b5 100644 --- a/services/tunnelbroker/src/database/mod.rs +++ b/services/tunnelbroker/src/database/mod.rs @@ -1,316 +1,329 @@ use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::delete_item::{ DeleteItemError, DeleteItemOutput, }; use comm_lib::aws::ddb::operation::put_item::PutItemError; use comm_lib::aws::ddb::operation::query::QueryError; use comm_lib::aws::ddb::types::AttributeValue; use comm_lib::aws::{AwsConfig, DynamoDBClient}; use comm_lib::database::{ AttributeExtractor, AttributeMap, DBItemAttributeError, DBItemError, Error, }; use std::collections::HashMap; use std::sync::Arc; use tracing::{debug, error, warn}; use crate::constants::dynamodb::{device_tokens, undelivered_messages}; +use crate::constants::error_types; pub mod message; pub mod message_id; use crate::database::message_id::MessageID; pub use message::*; use std::str::FromStr; use tunnelbroker_messages::Platform; #[derive(Clone)] pub struct DatabaseClient { client: Arc, } pub fn handle_ddb_error(db_error: SdkError) -> tonic::Status { match db_error { SdkError::TimeoutError(_) | SdkError::ServiceError(_) => { tonic::Status::unavailable("please retry") } e => { - error!("Encountered an unexpected error: {}", e); + error!( + errorType = error_types::DDB_ERROR, + "Encountered an unexpected error: {}", e + ); tonic::Status::failed_precondition("unexpected error") } } } pub struct DeviceTokenEntry { pub device_token: String, pub token_invalid: bool, pub platform: Option, } impl DatabaseClient { pub fn new(aws_config: &AwsConfig) -> Self { let client = DynamoDBClient::new(aws_config); DatabaseClient { client: Arc::new(client), } } pub async fn persist_message( &self, device_id: &str, payload: &str, client_message_id: &str, ) -> Result> { let message_id: String = MessageID::new(client_message_id.to_string()).into(); let device_av = AttributeValue::S(device_id.to_string()); let payload_av = AttributeValue::S(payload.to_string()); let message_id_av = AttributeValue::S(message_id.clone()); let request = self .client .put_item() .table_name(undelivered_messages::TABLE_NAME) .item(undelivered_messages::PARTITION_KEY, device_av) .item(undelivered_messages::SORT_KEY, message_id_av) .item(undelivered_messages::PAYLOAD, payload_av); debug!("Persisting message to device: {}", &device_id); request.send().await?; Ok(message_id) } pub async fn retrieve_messages( &self, device_id: &str, ) -> Result, SdkError> { debug!("Retrieving messages for device: {}", device_id); let response = self .client .query() .table_name(undelivered_messages::TABLE_NAME) .key_condition_expression(format!( "{} = :u", undelivered_messages::PARTITION_KEY )) .expression_attribute_values( ":u", AttributeValue::S(device_id.to_string()), ) .consistent_read(true) .send() .await?; debug!("Retrieved {} messages for {}", response.count, device_id); match response.items { None => Ok(Vec::new()), Some(items) => Ok(items.to_vec()), } } pub async fn delete_message( &self, device_id: &str, message_id: &str, ) -> Result> { debug!("Deleting message for device: {}", device_id); let key = HashMap::from([ ( undelivered_messages::PARTITION_KEY.to_string(), AttributeValue::S(device_id.to_string()), ), ( undelivered_messages::SORT_KEY.to_string(), AttributeValue::S(message_id.to_string()), ), ]); self .client .delete_item() .table_name(undelivered_messages::TABLE_NAME) .set_key(Some(key)) .send() .await } pub async fn remove_device_token( &self, device_id: &str, ) -> Result<(), Error> { debug!("Removing device token for device: {}", &device_id); let device_av = AttributeValue::S(device_id.to_string()); self .client .delete_item() .table_name(device_tokens::TABLE_NAME) .key(device_tokens::PARTITION_KEY, device_av) .send() .await .map_err(|e| { - error!("DynamoDB client failed to remove device token: {:?}", e); + error!( + errorType = error_types::DDB_ERROR, + "DynamoDB client failed to remove device token: {:?}", e + ); Error::AwsSdk(e.into()) })?; Ok(()) } pub async fn get_device_token( &self, device_id: &str, ) -> Result, Error> { let get_response = self .client .get_item() .table_name(device_tokens::TABLE_NAME) .key( device_tokens::PARTITION_KEY, AttributeValue::S(device_id.into()), ) .send() .await .map_err(|e| { - error!("DynamoDB client failed to get device token"); + error!( + errorType = error_types::DDB_ERROR, + "DynamoDB client failed to get device token" + ); Error::AwsSdk(e.into()) })?; let Some(mut item) = get_response.item else { return Ok(None); }; let device_token: String = item.take_attr(device_tokens::DEVICE_TOKEN)?; let token_invalid: Option = item.take_attr(device_tokens::TOKEN_INVALID)?; let platform = if let Some(platform_str) = item.take_attr::>(device_tokens::PLATFORM)? { Some(Platform::from_str(&platform_str).map_err(|_| { DBItemError::new( device_tokens::TOKEN_INVALID.to_string(), platform_str.clone().into(), DBItemAttributeError::InvalidValue, ) })?) } else { None }; Ok(Some(DeviceTokenEntry { device_token, token_invalid: token_invalid.unwrap_or(false), platform, })) } pub async fn set_device_token( &self, device_id: &str, device_token: &str, platform: Option, ) -> Result<(), Error> { debug!("Setting device token for device: {}", &device_id); let query_response = self .client .query() .table_name(device_tokens::TABLE_NAME) .index_name(device_tokens::DEVICE_TOKEN_INDEX_NAME) .key_condition_expression("#device_token = :token") .expression_attribute_names("#device_token", device_tokens::DEVICE_TOKEN) .expression_attribute_values( ":token", AttributeValue::S(device_token.to_string()), ) .send() .await .map_err(|e| { error!( - "DynamoDB client failed to find existing device token {:?}", - e + errorType = error_types::DDB_ERROR, + "DynamoDB client failed to find existing device token {:?}", e ); Error::AwsSdk(e.into()) })?; if let Some(existing_tokens) = query_response.items { if existing_tokens.len() > 1 { warn!("Found the same token for multiple devices!"); debug!("Duplicated token is: {device_token}. Removing..."); } else if !existing_tokens.is_empty() { debug!( "Device token {device_token} already exists. It will be replaced..." ); } for mut item in existing_tokens { let found_device_id = item.take_attr::(device_tokens::DEVICE_ID)?; // PutItem will replace token with `device_id` key anyway. if found_device_id != device_id { self.remove_device_token(&found_device_id).await?; } } } let mut put_item_input = self .client .put_item() .table_name(device_tokens::TABLE_NAME) .item( device_tokens::PARTITION_KEY, AttributeValue::S(device_id.to_string()), ) .item( device_tokens::DEVICE_TOKEN, AttributeValue::S(device_token.to_string()), ); if let Some(platform_atr) = platform { put_item_input = put_item_input.item( device_tokens::PLATFORM, AttributeValue::S(platform_atr.to_string()), ); } put_item_input.send().await.map_err(|e| { - error!("DynamoDB client failed to set device token {:?}", e); + error!( + errorType = error_types::DDB_ERROR, + "DynamoDB client failed to set device token {:?}", e + ); Error::AwsSdk(e.into()) })?; Ok(()) } pub async fn mark_device_token_as_invalid( &self, device_id: &str, ) -> Result<(), Error> { let update_expression = format!("SET {0} = :val", device_tokens::TOKEN_INVALID); self .client .update_item() .table_name(device_tokens::TABLE_NAME) .key( device_tokens::DEVICE_ID, AttributeValue::S(device_id.to_string()), ) .update_expression(update_expression) .expression_attribute_values(":val", AttributeValue::Bool(true)) .send() .await .map_err(|e| { error!( - "DynamoDB client failed to mark device token as invalid {:?}", - e + errorType = error_types::DDB_ERROR, + "DynamoDB client failed to mark device token as invalid {:?}", e ); Error::AwsSdk(e.into()) })?; Ok(()) } } diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs index b5fec5296..f8a1e5689 100644 --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -1,148 +1,166 @@ pub mod amqp; pub mod config; pub mod constants; pub mod database; pub mod error; pub mod grpc; pub mod identity; pub mod notifs; pub mod websockets; use crate::notifs::apns::APNsClient; use crate::notifs::fcm::FCMClient; use crate::notifs::web_push::WebPushClient; use crate::notifs::wns::WNSClient; use crate::notifs::NotifClient; use anyhow::{anyhow, Result}; use config::CONFIG; -use constants::COMM_SERVICES_USE_JSON_LOGS; +use constants::{error_types, COMM_SERVICES_USE_JSON_LOGS}; use std::env; use tracing::{self, error, info, Level}; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<()> { let use_json_logs: bool = env::var(COMM_SERVICES_USE_JSON_LOGS) .unwrap_or("false".to_string()) .parse() .unwrap_or_default(); let filter = EnvFilter::builder() .with_default_directive(Level::INFO.into()) .with_env_var(constants::LOG_LEVEL_ENV_VAR) .from_env_lossy(); if use_json_logs { let subscriber = tracing_subscriber::fmt() .json() .with_env_filter(filter) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Unable to configure tracing"); } else { let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish(); tracing::subscriber::set_global_default(subscriber) .expect("Unable to configure tracing"); } config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; let db_client = database::DatabaseClient::new(&aws_config); let amqp_connection = amqp::connect().await; let apns_config = CONFIG.apns_config.clone(); let apns = match apns_config { Some(config) => match APNsClient::new(&config) { Ok(apns_client) => { info!("APNs client created successfully"); Some(apns_client) } Err(err) => { - error!("Error creating APNs client: {}", err); + error!( + errorType = error_types::APNS_ERROR, + "Error creating APNs client: {}", err + ); None } }, None => { - error!("APNs config is missing"); + error!( + errorType = error_types::APNS_ERROR, + "APNs config is missing" + ); None } }; let fcm_config = CONFIG.fcm_config.clone(); let fcm = match fcm_config { Some(config) => match FCMClient::new(&config) { Ok(fcm_client) => { info!("FCM client created successfully"); Some(fcm_client) } Err(err) => { - error!("Error creating FCM client: {}", err); + error!( + errorType = error_types::FCM_ERROR, + "Error creating FCM client: {}", err + ); None } }, None => { - error!("FCM config is missing"); + error!(errorType = error_types::FCM_ERROR, "FCM config is missing"); None } }; let web_push_config = CONFIG.web_push_config.clone(); let web_push = match web_push_config { Some(config) => match WebPushClient::new(&config) { Ok(web_client) => { info!("Web Push client created successfully"); Some(web_client) } Err(err) => { - error!("Error creating Web Push client: {}", err); + error!( + errorType = error_types::WEB_PUSH_ERROR, + "Error creating Web Push client: {}", err + ); None } }, None => { - error!("Web Push config is missing"); + error!( + errorType = error_types::WEB_PUSH_ERROR, + "Web Push config is missing" + ); None } }; let wns_config = CONFIG.wns_config.clone(); let wns = match wns_config { Some(config) => match WNSClient::new(&config) { Ok(wns_client) => { info!("WNS client created successfully"); Some(wns_client) } Err(err) => { - error!("Error creating WNS client: {}", err); + error!( + errorType = error_types::WNS_ERROR, + "Error creating WNS client: {}", err + ); None } }, None => { - error!("WNS config is missing"); + error!(errorType = error_types::WNS_ERROR, "WNS config is missing"); None } }; let notif_client = NotifClient { apns, fcm, web_push, wns, }; let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); let websocket_server = websockets::run_server( db_client.clone(), &amqp_connection, notif_client.clone(), ); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, Ok(_) = websocket_server => { Ok(()) }, else => { - tracing::error!("A grpc or websocket server crashed."); + tracing::error!(errorType = error_types::SERVER_ERROR, "A grpc or websocket server crashed."); Err(anyhow!("A grpc or websocket server crashed.")) } } } diff --git a/services/tunnelbroker/src/notifs/wns/mod.rs b/services/tunnelbroker/src/notifs/wns/mod.rs index 30fdf3e14..f971a091b 100644 --- a/services/tunnelbroker/src/notifs/wns/mod.rs +++ b/services/tunnelbroker/src/notifs/wns/mod.rs @@ -1,146 +1,151 @@ -use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; +use crate::constants::{error_types, PUSH_SERVICE_REQUEST_TIMEOUT}; use crate::notifs::wns::config::WNSConfig; use error::WNSTokenError; use reqwest::StatusCode; use response::WNSErrorResponse; use std::{ sync::{Arc, RwLock}, time::{Duration, SystemTime}, }; pub mod config; pub mod error; pub mod response; #[derive(Debug, Clone)] pub struct WNSAccessToken { token: String, expires: SystemTime, } #[derive(Debug, Clone)] pub struct WNSNotif { pub device_token: String, pub payload: String, } #[derive(Clone)] pub struct WNSClient { http_client: reqwest::Client, config: WNSConfig, access_token: Arc>>, } impl WNSClient { pub fn new(config: &WNSConfig) -> Result { let http_client = reqwest::Client::builder() .timeout(PUSH_SERVICE_REQUEST_TIMEOUT) .build()?; Ok(WNSClient { http_client, config: config.clone(), access_token: Arc::new(RwLock::new(None)), }) } pub async fn send(&self, notif: WNSNotif) -> Result<(), error::Error> { let wns_access_token = self.get_wns_token().await?; let url = notif.device_token; // Send the notification let response = self .http_client .post(&url) .header("Content-Type", "application/octet-stream") .header("X-WNS-Type", "wns/raw") .bearer_auth(wns_access_token) .body(notif.payload) .send() .await?; match response.status() { StatusCode::OK => { tracing::debug!("Successfully sent WNS notif to {}", &url); Ok(()) } error_status => { let body = response .text() .await .unwrap_or_else(|error| format!("Error occurred: {}", error)); tracing::error!( "Failed sending WNS notification to: {}. Status: {}. Body: {}", &url, error_status, body ); let wns_error = WNSErrorResponse::from_status(error_status, body); Err(error::Error::WNSNotification(wns_error)) } } } pub async fn get_wns_token(&self) -> Result { const EXPIRY_WINDOW: Duration = Duration::from_secs(10); { let read_guard = self .access_token .read() .map_err(|_| error::Error::ReadLock)?; if let Some(ref token) = *read_guard { if token.expires >= SystemTime::now() - EXPIRY_WINDOW { return Ok(token.token.clone()); } } } let params = [ ("grant_type", "client_credentials"), ("client_id", &self.config.app_id), ("client_secret", &self.config.secret), ("scope", "https://wns.windows.com/.default"), ]; let url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let response = self.http_client.post(&url).form(¶ms).send().await?; if !response.status().is_success() { let status = response.status().to_string(); let body = response .text() .await .unwrap_or_else(|_| String::from("")); - tracing::error!(status, "Failure when getting the WNS token: {}", body); + tracing::error!( + errorType = error_types::WNS_ERROR, + status, + "Failure when getting the WNS token: {}", + body + ); return Err(error::Error::WNSToken(WNSTokenError::Unknown(status))); } let response_json: serde_json::Value = response.json().await?; let token = response_json["access_token"] .as_str() .ok_or(error::Error::WNSToken(WNSTokenError::TokenNotFound))? .to_string(); let expires_in = response_json["expires_in"] .as_u64() .ok_or(error::Error::WNSToken(WNSTokenError::ExpiryNotFound))?; let expires = SystemTime::now() + Duration::from_secs(expires_in); { let mut write_guard = self .access_token .write() .map_err(|_| error::Error::WriteLock)?; *write_guard = Some(WNSAccessToken { token: token.clone(), expires, }); } Ok(token) } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 5618e42a2..32dcfc8ca 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,778 +1,796 @@ use crate::constants::{ - CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, - RMQ_CONSUMER_TAG, + error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, + MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, }; use crate::notifs::fcm::response::FCMErrorResponse; use crate::notifs::wns::response::WNSErrorResponse; use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use notifs::fcm::error::Error::FCMError as NotifsFCMError; use notifs::web_push::error::Error::WebPush as NotifsWebPushError; use notifs::wns::error::Error::WNSNotification as NotifsWNSError; use reqwest::Url; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::bad_device_token::BadDeviceToken; use tunnelbroker_messages::Platform; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, DeviceToTunnelbrokerMessage, Heartbeat, MessageToDevice, MessageToDeviceRequest, MessageToTunnelbroker, }; use web_push::WebPushError; use crate::notifs::apns::response::ErrorReason; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::notifs::apns::headers::NotificationHeaders; use crate::notifs::apns::APNsNotif; use crate::notifs::fcm::firebase_message::{ AndroidConfig, AndroidMessagePriority, FCMMessage, }; use crate::notifs::web_push::WebPushNotif; use crate::notifs::wns::WNSNotif; use crate::notifs::{apns, NotifClient, NotifClientType}; use crate::{identity, notifs}; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, pub is_authenticated: bool, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, notif_client: NotifClient, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), DatabaseError(comm_lib::database::Error), MissingAPNsClient, MissingFCMClient, MissingWebPushClient, MissingWNSClient, MissingDeviceToken, InvalidDeviceToken, InvalidNotifProvider, InvalidDeviceTokenUpload, } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { DeviceToTunnelbrokerMessage::ConnectionInitializationMessage( mut session_info, ) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), is_authenticated: true, }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { - error!("Failed to complete request to identity service: {:?}", e); + error!( + errorType = error_types::IDENTITY_ERROR, + "Failed to complete request to identity service: {:?}", e + ); return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } DeviceToTunnelbrokerMessage::AnonymousInitializationMessage( session_info, ) => { debug!( "Starting unauthenticated session with device: {}", &session_info.device_id ); let device_info = DeviceInfo { device_id: session_info.device_id, device_type: session_info.device_type, device_app_version: session_info.device_app_version, device_os: session_info.device_os, is_authenticated: false, notify_token: None, }; Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage) } } } async fn publish_persisted_messages( db_client: &DatabaseClient, amqp_channel: &lapin::Channel, device_info: &DeviceInfo, ) -> Result<(), SessionError> { let messages = db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { - error!("Error while retrieving messages: {}", e); + error!( + errorType = error_types::DDB_ERROR, + "Error while retrieving messages: {}", e + ); Vec::new() }); for message in messages { let message_to_device = MessageToDevice::from_hashmap(message)?; let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; } debug!("Flushed messages for device: {}", &device_info.device_id); Ok(()) } pub async fn initialize_amqp( db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; let mut args = FieldTable::default(); args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) .await?; publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok((device_info, amqp_consumer)) } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, notif_client: NotifClient, ) -> Self { Self { tx, db_client, device_info, amqp_channel, amqp_consumer, notif_client, } } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let message_to_device = MessageToDevice { device_id: message_request.device_id.clone(), payload: message_request.payload.clone(), message_id: message_id.clone(), }; let serialized_message = serde_json::to_string(&message_to_device)?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_message_to_tunnelbroker( &self, message_to_tunnelbroker: &MessageToTunnelbroker, ) -> Result<(), SessionError> { match message_to_tunnelbroker { MessageToTunnelbroker::SetDeviceToken(token) => { self .db_client .set_device_token( &self.device_info.device_id, &token.device_token, None, ) .await?; } MessageToTunnelbroker::SetDeviceTokenWithPlatform( token_with_platform, ) => { if matches!(token_with_platform.platform, Platform::Windows) { Url::parse(&token_with_platform.device_token) .ok() .filter(|url| { url .domain() .is_some_and(|domain| domain.ends_with("notify.windows.com")) }) .ok_or_else(|| { debug!( device_token = &token_with_platform.device_token, device_id = &self.device_info.device_id, "Invalid Windows device token" ); SessionError::InvalidDeviceTokenUpload })?; } self .db_client .set_device_token( &self.device_info.device_id, &token_with_platform.device_token, Some(token_with_platform.platform.clone()), ) .await?; } } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return Some(MessageSentStatus::SerializationError(msg)); }; match serialized_message { DeviceToTunnelbrokerMessage::Heartbeat(Heartbeat {}) => { trace!("Received heartbeat from: {}", self.device_info.device_id); None } DeviceToTunnelbrokerMessage::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &message_id) .await { - error!("Failed to delete message: {}:", e); + error!( + errorType = error_types::DDB_ERROR, + "Failed to delete message: {}:", e + ); } } None } DeviceToTunnelbrokerMessage::MessageToDeviceRequest(message_request) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } DeviceToTunnelbrokerMessage::MessageToTunnelbrokerRequest( message_request, ) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for Tunnelbroker"); let Ok(message_to_tunnelbroker) = serde_json::from_str(&message_request.payload) else { return Some(MessageSentStatus::SerializationError( message_request.payload, )); }; let result = self .handle_message_to_tunnelbroker(&message_to_tunnelbroker) .await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } DeviceToTunnelbrokerMessage::APNsNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received APNs notif for {}", notif.device_id); let Ok(headers) = serde_json::from_str::(¬if.headers) else { return Some(MessageSentStatus::SerializationError(notif.headers)); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::APNs) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let apns_notif = APNsNotif { device_token: device_token.clone(), headers, payload: notif.payload, }; if let Some(apns) = self.notif_client.apns.clone() { let response = apns.send(apns_notif).await; if let Err(apns::error::Error::ResponseError(body)) = &response { if matches!( body.reason, ErrorReason::BadDeviceToken | ErrorReason::Unregistered | ErrorReason::ExpiredToken ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( - "Error invalidating device token {}: {:?}", - device_token, e + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e ); }; } } return Some( self .get_message_to_device_status(¬if.client_message_id, response), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingAPNsClient), )) } DeviceToTunnelbrokerMessage::FCMNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received FCM notif for {}", notif.device_id); let Some(priority) = AndroidMessagePriority::from_str(¬if.priority) else { return Some(MessageSentStatus::SerializationError(notif.priority)); }; let Ok(data) = serde_json::from_str(¬if.data) else { return Some(MessageSentStatus::SerializationError(notif.data)); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::FCM) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let fcm_message = FCMMessage { data, token: device_token.to_string(), android: AndroidConfig { priority }, }; if let Some(fcm) = self.notif_client.fcm.clone() { let result = fcm.send(fcm_message).await; if let Err(NotifsFCMError(fcm_error)) = &result { if matches!( fcm_error, FCMErrorResponse::Unregistered | FCMErrorResponse::InvalidArgument(_) ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( - "Error invalidating device token {}: {:?}", - device_token, e + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e ); }; } } return Some( self.get_message_to_device_status(¬if.client_message_id, result), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingFCMClient), )) } DeviceToTunnelbrokerMessage::WebPushNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send web push notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received WebPush notif for {}", notif.device_id); let Some(web_push_client) = self.notif_client.web_push.clone() else { return Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingWebPushClient), )); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::WebPush) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let web_push_notif = WebPushNotif { device_token: device_token.clone(), payload: notif.payload, }; let result = web_push_client.send(web_push_notif).await; if let Err(NotifsWebPushError(web_push_error)) = &result { if matches!( web_push_error, WebPushError::EndpointNotValid | WebPushError::EndpointNotFound ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( - "Error invalidating device token {}: {:?}", - device_token, e + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e ); }; } } Some( self.get_message_to_device_status(¬if.client_message_id, result), ) } DeviceToTunnelbrokerMessage::WNSNotif(notif) => { if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send WNS notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received WNS notif for {}", notif.device_id); let Some(wns_client) = self.notif_client.wns.clone() else { return Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingWNSClient), )); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::WNS) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let wns_notif = WNSNotif { device_token: device_token.clone(), payload: notif.payload, }; let result = wns_client.send(wns_notif).await; if let Err(NotifsWNSError(err)) = &result { if matches!(err, WNSErrorResponse::NotFound | WNSErrorResponse::Gone) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( - "Error invalidating device token {}: {:?}", - device_token, e + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e ); }; } } Some( self.get_message_to_device_status(¬if.client_message_id, result), ) } _ => { error!("Client sent invalid message type"); Some(MessageSentStatus::InvalidRequest) } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { - error!("Failed to send message to device: {}", e); + error!( + errorType = error_types::AMQP_ERROR, + "Failed to send message to device: {}", e + ); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { - error!("Failed to cancel consumer: {}", e); + error!( + errorType = error_types::AMQP_ERROR, + "Failed to cancel consumer: {}", e + ); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { - error!("Failed to delete queue: {}", e); + error!( + errorType = error_types::AMQP_ERROR, + "Failed to delete queue: {}", e + ); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), E>, ) -> MessageSentStatus where E: std::error::Error, { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } async fn get_device_token( &self, device_id: String, client: NotifClientType, ) -> Result { let db_token = self .db_client .get_device_token(&device_id) .await .map_err(SessionError::DatabaseError)?; match db_token { Some(token) => { if let Some(platform) = token.platform { if !client.supported_platform(platform) { return Err(SessionError::InvalidNotifProvider); } } if token.token_invalid { Err(SessionError::InvalidDeviceToken) } else { Ok(token.device_token) } } None => Err(SessionError::MissingDeviceToken), } } async fn invalidate_device_token( &self, device_id: String, invalidated_token: String, ) -> Result<(), SessionError> { let bad_device_token_message = BadDeviceToken { invalidated_token }; let payload = serde_json::to_string(&bad_device_token_message)?; let message_request = MessageToDeviceRequest { client_message_id: uuid::Uuid::new_v4().to_string(), device_id: device_id.to_string(), payload, }; self.handle_message_to_device(&message_request).await?; self .db_client .mark_device_token_as_invalid(&device_id) .await .map_err(SessionError::DatabaseError)?; Ok(()) } }