diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -144,12 +144,8 @@ } }; - let notif_client = NotifClient { - apns, - fcm, - web_push, - wns, - }; + let notif_client = + NotifClient::new(apns, fcm, web_push, wns, db_client.clone()); let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); let websocket_server = websockets::run_server( diff --git a/services/tunnelbroker/src/notifs/mod.rs b/services/tunnelbroker/src/notifs/mod.rs --- a/services/tunnelbroker/src/notifs/mod.rs +++ b/services/tunnelbroker/src/notifs/mod.rs @@ -1,8 +1,22 @@ -use crate::notifs::apns::APNsClient; +use crate::amqp_client::AmqpClient; +use crate::constants::error_types; +use crate::database::DatabaseClient; +use crate::notifs::apns::headers::NotificationHeaders; +use crate::notifs::apns::{APNsClient, APNsNotif}; +use crate::notifs::fcm::firebase_message::{ + AndroidConfig, AndroidMessagePriority, FCMMessage, +}; use crate::notifs::fcm::FCMClient; -use crate::notifs::web_push::WebPushClient; -use crate::notifs::wns::WNSClient; -use tunnelbroker_messages::Platform; +use crate::notifs::web_push::error::Error::WebPush as NotifsWebPushError; +use crate::notifs::web_push::{WebPushClient, WebPushNotif}; +use crate::notifs::wns::error::Error::WNSNotification as NotifsWNSError; +use crate::notifs::wns::{WNSClient, WNSNotif}; +use crate::websockets::session::SessionError; +use ::web_push::WebPushError; +use tracing::{debug, error}; +use tunnelbroker_messages::bad_device_token::BadDeviceToken; +use tunnelbroker_messages::MessageSentStatus; +use tunnelbroker_messages::{MessageToDeviceRequest, Platform}; pub mod apns; pub mod fcm; @@ -32,8 +46,336 @@ #[derive(Clone)] pub struct NotifClient { - pub(crate) apns: Option, - pub(crate) fcm: Option, - pub(crate) web_push: Option, - pub(crate) wns: Option, + apns: Option, + fcm: Option, + web_push: Option, + wns: Option, + db_client: DatabaseClient, +} + +impl NotifClient { + pub fn new( + apns: Option, + fcm: Option, + web_push: Option, + wns: Option, + db_client: DatabaseClient, + ) -> NotifClient { + NotifClient { + apns, + fcm, + web_push, + wns, + db_client, + } + } + + async fn invalidate_device_token( + &self, + device_id: String, + invalidated_token: String, + amqp_client: &mut AmqpClient, + ) -> Result<(), SessionError> { + let bad_device_token_message = BadDeviceToken { invalidated_token }; + let payload = serde_json::to_string(&bad_device_token_message)?; + let message_request = MessageToDeviceRequest { + client_message_id: uuid::Uuid::new_v4().to_string(), + device_id: device_id.to_string(), + payload, + }; + + amqp_client + .handle_message_to_device(&message_request) + .await?; + + self + .db_client + .mark_device_token_as_invalid(&device_id) + .await + .map_err(SessionError::DatabaseError)?; + + Ok(()) + } + + async fn get_device_token( + &self, + device_id: String, + client: NotifClientType, + ) -> Result { + let db_token = self + .db_client + .get_device_token(&device_id) + .await + .map_err(SessionError::DatabaseError)?; + + match db_token { + Some(token) => { + if let Some(platform) = token.platform { + if !client.supported_platform(platform) { + return Err(SessionError::InvalidNotifProvider); + } + } + if token.token_invalid { + Err(SessionError::InvalidDeviceToken) + } else { + Ok(token.device_token) + } + } + None => Err(SessionError::MissingDeviceToken), + } + } + + pub async fn send_apns_notif( + &self, + notif: tunnelbroker_messages::notif::APNsNotif, + amqp_client: &mut AmqpClient, + ) -> Option { + debug!("Received APNs notif for {}", notif.device_id); + + let Ok(headers) = + serde_json::from_str::(¬if.headers) + else { + return Some(MessageSentStatus::SerializationError(notif.headers)); + }; + + let device_token = match self + .get_device_token(notif.device_id.clone(), NotifClientType::APNs) + .await + { + Ok(token) => token, + Err(e) => { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(e), + )); + } + }; + + let apns_notif = APNsNotif { + device_token: device_token.clone(), + headers, + payload: notif.payload, + }; + + if let Some(apns) = self.apns.clone() { + let response = apns.send(apns_notif).await; + if let Err(apns::error::Error::ResponseError(body)) = &response { + if body.reason.should_invalidate_token() { + if let Err(e) = self + .invalidate_device_token( + notif.device_id, + device_token.clone(), + amqp_client, + ) + .await + { + error!( + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e + ); + }; + } + } + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + response, + )); + } + + Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(SessionError::MissingAPNsClient), + )) + } + + pub async fn send_fcm_notif( + &self, + notif: tunnelbroker_messages::notif::FCMNotif, + amqp_client: &mut AmqpClient, + ) -> Option { + debug!("Received FCM notif for {}", notif.device_id); + + let Some(priority) = AndroidMessagePriority::from_str(¬if.priority) + else { + return Some(MessageSentStatus::SerializationError(notif.priority)); + }; + + let Ok(data) = serde_json::from_str(¬if.data) else { + return Some(MessageSentStatus::SerializationError(notif.data)); + }; + + let device_token = match self + .get_device_token(notif.device_id.clone(), NotifClientType::FCM) + .await + { + Ok(token) => token, + Err(e) => { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(e), + )) + } + }; + + let fcm_message = FCMMessage { + data, + token: device_token.to_string(), + android: AndroidConfig { priority }, + }; + + if let Some(fcm) = self.fcm.clone() { + let result = fcm.send(fcm_message).await; + + if let Err(crate::notifs::fcm::error::Error::FCMError(fcm_error)) = + &result + { + if fcm_error.should_invalidate_token() { + if let Err(e) = self + .invalidate_device_token( + notif.device_id, + device_token.clone(), + amqp_client, + ) + .await + { + error!( + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e + ); + }; + } + } + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + result, + )); + } + + Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(SessionError::MissingFCMClient), + )) + } + + pub async fn send_web_notif( + &self, + notif: tunnelbroker_messages::notif::WebPushNotif, + amqp_client: &mut AmqpClient, + ) -> Option { + debug!("Received WebPush notif for {}", notif.device_id); + + let Some(web_push_client) = self.web_push.clone() else { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(SessionError::MissingWebPushClient), + )); + }; + + let device_token = match self + .get_device_token(notif.device_id.clone(), NotifClientType::WebPush) + .await + { + Ok(token) => token, + Err(e) => { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(e), + )) + } + }; + + let web_push_notif = WebPushNotif { + device_token: device_token.clone(), + payload: notif.payload, + }; + + let result = web_push_client.send(web_push_notif).await; + if let Err(NotifsWebPushError(web_push_error)) = &result { + if matches!( + web_push_error, + WebPushError::EndpointNotValid(_) | WebPushError::EndpointNotFound(_) + ) { + if let Err(e) = self + .invalidate_device_token( + notif.device_id, + device_token.clone(), + amqp_client, + ) + .await + { + error!( + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e + ); + }; + } else { + error!( + errorType = error_types::WEB_PUSH_ERROR, + "Failed sending Web Push notification to: {}. Error: {}", + device_token, + web_push_error + ); + } + } + Some(MessageSentStatus::from_result( + ¬if.client_message_id, + result, + )) + } + + pub async fn send_wns_notif( + &self, + notif: tunnelbroker_messages::notif::WNSNotif, + amqp_client: &mut AmqpClient, + ) -> Option { + debug!("Received WNS notif for {}", notif.device_id); + + let Some(wns_client) = self.wns.clone() else { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(SessionError::MissingWNSClient), + )); + }; + + let device_token = match self + .get_device_token(notif.device_id.clone(), NotifClientType::WNS) + .await + { + Ok(token) => token, + Err(e) => { + return Some(MessageSentStatus::from_result( + ¬if.client_message_id, + Err(e), + )) + } + }; + + let wns_notif = WNSNotif { + device_token: device_token.clone(), + payload: notif.payload, + }; + + let result = wns_client.send(wns_notif).await; + if let Err(NotifsWNSError(err)) = &result { + if err.should_invalidate_token() { + if let Err(e) = self + .invalidate_device_token( + notif.device_id, + device_token.clone(), + amqp_client, + ) + .await + { + error!( + errorType = error_types::DDB_ERROR, + "Error invalidating device token {}: {:?}", device_token, e + ); + }; + } + } + Some(MessageSentStatus::from_result( + ¬if.client_message_id, + result, + )) + } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -7,33 +7,22 @@ use futures_util::SinkExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; -use notifs::fcm::error::Error::FCMError as NotifsFCMError; -use notifs::web_push::error::Error::WebPush as NotifsWebPushError; -use notifs::wns::error::Error::WNSNotification as NotifsWNSError; + use reqwest::Url; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; -use tunnelbroker_messages::bad_device_token::BadDeviceToken; + use tunnelbroker_messages::{ message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - DeviceToTunnelbrokerMessage, Heartbeat, MessageToDeviceRequest, - MessageToTunnelbroker, + DeviceToTunnelbrokerMessage, Heartbeat, MessageToTunnelbroker, }; use tunnelbroker_messages::{DeviceToTunnelbrokerRequestStatus, Platform}; -use web_push::WebPushError; use crate::amqp_client::AmqpClient; use crate::database::{self, DatabaseClient}; -use crate::notifs::apns::headers::NotificationHeaders; -use crate::notifs::apns::APNsNotif; -use crate::notifs::fcm::firebase_message::{ - AndroidConfig, AndroidMessagePriority, FCMMessage, -}; -use crate::notifs::web_push::WebPushNotif; -use crate::notifs::wns::WNSNotif; -use crate::notifs::{apns, NotifClient, NotifClientType}; -use crate::{identity, notifs}; +use crate::identity; +use crate::notifs::NotifClient; #[derive(Clone)] pub struct DeviceInfo { @@ -354,58 +343,10 @@ ); return Some(MessageSentStatus::Unauthenticated); } - debug!("Received APNs notif for {}", notif.device_id); - - let Ok(headers) = - serde_json::from_str::(¬if.headers) - else { - return Some(MessageSentStatus::SerializationError(notif.headers)); - }; - - let device_token = match self - .get_device_token(notif.device_id.clone(), NotifClientType::APNs) + self + .notif_client + .send_apns_notif(notif, &mut self.amqp_client) .await - { - Ok(token) => token, - Err(e) => { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(e), - )); - } - }; - - let apns_notif = APNsNotif { - device_token: device_token.clone(), - headers, - payload: notif.payload, - }; - - if let Some(apns) = self.notif_client.apns.clone() { - let response = apns.send(apns_notif).await; - if let Err(apns::error::Error::ResponseError(body)) = &response { - if body.reason.should_invalidate_token() { - if let Err(e) = self - .invalidate_device_token(notif.device_id, device_token.clone()) - .await - { - error!( - errorType = error_types::DDB_ERROR, - "Error invalidating device token {}: {:?}", device_token, e - ); - }; - } - } - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - response, - )); - } - - Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(SessionError::MissingAPNsClient), - )) } DeviceToTunnelbrokerMessage::FCMNotif(notif) => { // unauthenticated clients cannot send notifs @@ -416,62 +357,10 @@ ); return Some(MessageSentStatus::Unauthenticated); } - debug!("Received FCM notif for {}", notif.device_id); - - let Some(priority) = AndroidMessagePriority::from_str(¬if.priority) - else { - return Some(MessageSentStatus::SerializationError(notif.priority)); - }; - - let Ok(data) = serde_json::from_str(¬if.data) else { - return Some(MessageSentStatus::SerializationError(notif.data)); - }; - - let device_token = match self - .get_device_token(notif.device_id.clone(), NotifClientType::FCM) + self + .notif_client + .send_fcm_notif(notif, &mut self.amqp_client) .await - { - Ok(token) => token, - Err(e) => { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(e), - )) - } - }; - - let fcm_message = FCMMessage { - data, - token: device_token.to_string(), - android: AndroidConfig { priority }, - }; - - if let Some(fcm) = self.notif_client.fcm.clone() { - let result = fcm.send(fcm_message).await; - - if let Err(NotifsFCMError(fcm_error)) = &result { - if fcm_error.should_invalidate_token() { - if let Err(e) = self - .invalidate_device_token(notif.device_id, device_token.clone()) - .await - { - error!( - errorType = error_types::DDB_ERROR, - "Error invalidating device token {}: {:?}", device_token, e - ); - }; - } - } - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - result, - )); - } - - Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(SessionError::MissingFCMClient), - )) } DeviceToTunnelbrokerMessage::WebPushNotif(notif) => { // unauthenticated clients cannot send notifs @@ -482,62 +371,10 @@ ); return Some(MessageSentStatus::Unauthenticated); } - debug!("Received WebPush notif for {}", notif.device_id); - - let Some(web_push_client) = self.notif_client.web_push.clone() else { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(SessionError::MissingWebPushClient), - )); - }; - - let device_token = match self - .get_device_token(notif.device_id.clone(), NotifClientType::WebPush) + self + .notif_client + .send_web_notif(notif, &mut self.amqp_client) .await - { - Ok(token) => token, - Err(e) => { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(e), - )) - } - }; - - let web_push_notif = WebPushNotif { - device_token: device_token.clone(), - payload: notif.payload, - }; - - let result = web_push_client.send(web_push_notif).await; - if let Err(NotifsWebPushError(web_push_error)) = &result { - if matches!( - web_push_error, - WebPushError::EndpointNotValid(_) - | WebPushError::EndpointNotFound(_) - ) { - if let Err(e) = self - .invalidate_device_token(notif.device_id, device_token.clone()) - .await - { - error!( - errorType = error_types::DDB_ERROR, - "Error invalidating device token {}: {:?}", device_token, e - ); - }; - } else { - error!( - errorType = error_types::WEB_PUSH_ERROR, - "Failed sending Web Push notification to: {}. Error: {}", - device_token, - web_push_error - ); - } - } - Some(MessageSentStatus::from_result( - ¬if.client_message_id, - result, - )) } DeviceToTunnelbrokerMessage::WNSNotif(notif) => { if !self.device_info.is_authenticated { @@ -547,51 +384,10 @@ ); return Some(MessageSentStatus::Unauthenticated); } - debug!("Received WNS notif for {}", notif.device_id); - - let Some(wns_client) = self.notif_client.wns.clone() else { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(SessionError::MissingWNSClient), - )); - }; - - let device_token = match self - .get_device_token(notif.device_id.clone(), NotifClientType::WNS) + self + .notif_client + .send_wns_notif(notif, &mut self.amqp_client) .await - { - Ok(token) => token, - Err(e) => { - return Some(MessageSentStatus::from_result( - ¬if.client_message_id, - Err(e), - )) - } - }; - - let wns_notif = WNSNotif { - device_token: device_token.clone(), - payload: notif.payload, - }; - - let result = wns_client.send(wns_notif).await; - if let Err(NotifsWNSError(err)) = &result { - if err.should_invalidate_token() { - if let Err(e) = self - .invalidate_device_token(notif.device_id, device_token.clone()) - .await - { - error!( - errorType = error_types::DDB_ERROR, - "Error invalidating device token {}: {:?}", device_token, e - ); - }; - } - } - Some(MessageSentStatus::from_result( - ¬if.client_message_id, - result, - )) } _ => { error!("Client sent invalid message type"); @@ -627,61 +423,6 @@ self.amqp_client.close_connection().await; } - - async fn get_device_token( - &self, - device_id: String, - client: NotifClientType, - ) -> Result { - let db_token = self - .db_client - .get_device_token(&device_id) - .await - .map_err(SessionError::DatabaseError)?; - - match db_token { - Some(token) => { - if let Some(platform) = token.platform { - if !client.supported_platform(platform) { - return Err(SessionError::InvalidNotifProvider); - } - } - if token.token_invalid { - Err(SessionError::InvalidDeviceToken) - } else { - Ok(token.device_token) - } - } - None => Err(SessionError::MissingDeviceToken), - } - } - - async fn invalidate_device_token( - &mut self, - device_id: String, - invalidated_token: String, - ) -> Result<(), SessionError> { - let bad_device_token_message = BadDeviceToken { invalidated_token }; - let payload = serde_json::to_string(&bad_device_token_message)?; - let message_request = MessageToDeviceRequest { - client_message_id: uuid::Uuid::new_v4().to_string(), - device_id: device_id.to_string(), - payload, - }; - - self - .amqp_client - .handle_message_to_device(&message_request) - .await?; - - self - .db_client - .mark_device_token_as_invalid(&device_id) - .await - .map_err(SessionError::DatabaseError)?; - - Ok(()) - } } fn should_ignore_error(err: &hyper_tungstenite::tungstenite::Error) -> bool {