diff --git a/services/tunnelbroker/src/notifs/mod.rs b/services/tunnelbroker/src/notifs/mod.rs --- a/services/tunnelbroker/src/notifs/mod.rs +++ b/services/tunnelbroker/src/notifs/mod.rs @@ -1,11 +1,33 @@ use crate::notifs::apns::APNsClient; use crate::notifs::fcm::FCMClient; use crate::notifs::web_push::WebPushClient; +use tunnelbroker_messages::Platform; pub mod apns; pub mod fcm; pub mod web_push; +#[derive(PartialEq)] +pub enum NotifClientType { + APNs, + FCM, + WebPush, + WNs, +} + +impl NotifClientType { + pub fn supported_platform(&self, platform: Platform) -> bool { + match self { + NotifClientType::APNs => { + platform == Platform::IOS || platform == Platform::MacOS + } + NotifClientType::FCM => platform == Platform::Android, + NotifClientType::WebPush => platform == Platform::Web, + NotifClientType::WNs => platform == Platform::Windows, + } + } +} + #[derive(Clone)] pub struct NotifClient { pub(crate) apns: Option, diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -38,7 +38,7 @@ AndroidConfig, AndroidMessagePriority, FCMMessage, }; use crate::notifs::web_push::WebPushNotif; -use crate::notifs::NotifClient; +use crate::notifs::{NotifClient, NotifClientType}; pub struct DeviceInfo { pub device_id: String, @@ -76,6 +76,7 @@ MissingWebPushClient, MissingDeviceToken, InvalidDeviceToken, + InvalidNotifProvider, } // Parse a session request and retrieve the device information @@ -401,16 +402,18 @@ return Some(MessageSentStatus::SerializationError(notif.headers)); }; - let device_token = - match self.get_device_token(notif.device_id.clone()).await { - Ok(token) => token, - Err(e) => { - return Some(self.get_message_to_device_status( - ¬if.client_message_id, - Err(e), - )) - } - }; + let device_token = match self + .get_device_token(notif.device_id.clone(), NotifClientType::APNs) + .await + { + Ok(token) => token, + Err(e) => { + return Some( + self + .get_message_to_device_status(¬if.client_message_id, Err(e)), + ) + } + }; let apns_notif = APNsNotif { device_token: device_token.clone(), @@ -466,7 +469,10 @@ return Some(MessageSentStatus::SerializationError(notif.data)); }; - let device_token = match self.get_device_token(notif.device_id).await { + let device_token = match self + .get_device_token(notif.device_id, NotifClientType::FCM) + .await + { Ok(token) => token, Err(e) => { return Some( @@ -513,7 +519,10 @@ )); }; - let device_token = match self.get_device_token(notif.device_id).await { + let device_token = match self + .get_device_token(notif.device_id, NotifClientType::WebPush) + .await + { Ok(token) => token, Err(e) => { return Some( @@ -601,6 +610,7 @@ async fn get_device_token( &self, device_id: String, + client: NotifClientType, ) -> Result { let db_token = self .db_client @@ -610,6 +620,11 @@ match db_token { Some(token) => { + if let Some(platform) = token.platform { + if !client.supported_platform(platform) { + return Err(SessionError::InvalidNotifProvider); + } + } if token.token_invalid { Err(SessionError::InvalidDeviceToken) } else {