diff --git a/lib/types/tunnelbroker/notif-types.js b/lib/types/tunnelbroker/notif-types.js index 2ba570bd6..abae830ba 100644 --- a/lib/types/tunnelbroker/notif-types.js +++ b/lib/types/tunnelbroker/notif-types.js @@ -1,29 +1,37 @@ // @flow export type TunnelbrokerAPNsNotif = { +type: 'APNsNotif', +headers: string, +clientMessageID: string, +deviceID: string, +payload: string, }; export type TunnelbrokerFCMNotif = { +type: 'FCMNotif', +clientMessageID: string, +deviceID: string, +data: string, +priority: 'NORMAL' | 'HIGH', }; export type TunnelbrokerWebPushNotif = { +type: 'WebPushNotif', +clientMessageID: string, +deviceID: string, +payload: string, }; +export type TunnelbrokerWNSNotif = { + +type: 'WNSNotif', + +clientMessageID: string, + +deviceID: string, + +payload: string, +}; + export type TunnelbrokerNotif = | TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif - | TunnelbrokerWebPushNotif; + | TunnelbrokerWebPushNotif + | TunnelbrokerWNSNotif; diff --git a/services/tunnelbroker/src/notifs/mod.rs b/services/tunnelbroker/src/notifs/mod.rs index 5ccc59ac4..16801f102 100644 --- a/services/tunnelbroker/src/notifs/mod.rs +++ b/services/tunnelbroker/src/notifs/mod.rs @@ -1,39 +1,39 @@ use crate::notifs::apns::APNsClient; use crate::notifs::fcm::FCMClient; use crate::notifs::web_push::WebPushClient; use crate::notifs::wns::WNSClient; use tunnelbroker_messages::Platform; pub mod apns; pub mod fcm; pub mod web_push; pub mod wns; #[derive(PartialEq)] pub enum NotifClientType { APNs, FCM, WebPush, - WNs, + WNS, } impl NotifClientType { pub fn supported_platform(&self, platform: Platform) -> bool { match self { NotifClientType::APNs => { platform == Platform::IOS || platform == Platform::MacOS } NotifClientType::FCM => platform == Platform::Android, NotifClientType::WebPush => platform == Platform::Web, - NotifClientType::WNs => platform == Platform::Windows, + NotifClientType::WNS => platform == Platform::Windows, } } } #[derive(Clone)] pub struct NotifClient { pub(crate) apns: Option, pub(crate) fcm: Option, pub(crate) web_push: Option, pub(crate) wns: Option, } diff --git a/services/tunnelbroker/src/notifs/wns/error.rs b/services/tunnelbroker/src/notifs/wns/error.rs index 911506bac..7372e2e20 100644 --- a/services/tunnelbroker/src/notifs/wns/error.rs +++ b/services/tunnelbroker/src/notifs/wns/error.rs @@ -1,15 +1,29 @@ use derive_more::{Display, Error, From}; +use super::response::WNSErrorResponse; + #[derive(Debug, From, Display, Error)] pub enum Error { Reqwest(reqwest::Error), SerdeJson(serde_json::Error), - #[display(fmt = "Token not found in response")] - TokenNotFound, - #[display(fmt = "Expiry time not found in response")] - ExpiryNotFound, + #[display(fmt = "WNS Token Error: {}", _0)] + WNSToken(WNSTokenError), #[display(fmt = "Failed to acquire read lock")] ReadLock, #[display(fmt = "Failed to acquire write lock")] WriteLock, + #[display(fmt = "WNS Notification Error: {}", _0)] + WNSNotification(WNSErrorResponse), +} + +#[derive(Debug, From, Display)] +pub enum WNSTokenError { + #[display(fmt = "Token not found in response")] + TokenNotFound, + #[display(fmt = "Expiry time not found in response")] + ExpiryNotFound, + #[display(fmt = "Unknown Error: {}", _0)] + Unknown(String), } + +impl std::error::Error for WNSTokenError {} diff --git a/services/tunnelbroker/src/notifs/wns/mod.rs b/services/tunnelbroker/src/notifs/wns/mod.rs index 22b4a5e83..e6ded493b 100644 --- a/services/tunnelbroker/src/notifs/wns/mod.rs +++ b/services/tunnelbroker/src/notifs/wns/mod.rs @@ -1,97 +1,143 @@ use crate::notifs::wns::config::WNSConfig; +use error::WNSTokenError; +use reqwest::StatusCode; +use response::WNSErrorResponse; use std::{ sync::{Arc, RwLock}, time::{Duration, SystemTime}, }; pub mod config; mod error; +mod response; #[derive(Debug, Clone)] pub struct WNSAccessToken { token: String, expires: SystemTime, } +#[derive(Debug, Clone)] +pub struct WNSNotif { + pub device_token: String, + pub payload: String, +} + #[derive(Clone)] pub struct WNSClient { http_client: reqwest::Client, config: WNSConfig, access_token: Arc>>, } impl WNSClient { pub fn new(config: &WNSConfig) -> Result { let http_client = reqwest::Client::builder().build()?; Ok(WNSClient { http_client, config: config.clone(), access_token: Arc::new(RwLock::new(None)), }) } - pub async fn get_wns_token( - &mut self, - ) -> Result, error::Error> { + pub async fn send(&self, notif: WNSNotif) -> Result<(), error::Error> { + let wns_access_token = self.get_wns_token().await?; + + let url = notif.device_token; + + // Send the notification + let response = self + .http_client + .post(&url) + .header("Content-Type", "application/octet-stream") + .header("X-WNS-Type", "wns/raw") + .bearer_auth(wns_access_token) + .body(notif.payload) + .send() + .await?; + + match response.status() { + StatusCode::OK => { + tracing::debug!("Successfully sent WNS notif to {}", &url); + Ok(()) + } + error_status => { + let body = response + .text() + .await + .unwrap_or_else(|error| format!("Error occurred: {}", error)); + tracing::error!( + "Failed sending WNS notification to: {}. Status: {}. Body: {}", + &url, + error_status, + body + ); + let wns_error = WNSErrorResponse::from_status(error_status, body); + Err(error::Error::WNSNotification(wns_error)) + } + } + } + + pub async fn get_wns_token(&self) -> Result { const EXPIRY_WINDOW: Duration = Duration::from_secs(10); { let read_guard = self .access_token .read() .map_err(|_| error::Error::ReadLock)?; if let Some(ref token) = *read_guard { if token.expires >= SystemTime::now() - EXPIRY_WINDOW { - return Ok(Some(token.token.clone())); + return Ok(token.token.clone()); } } } let params = [ ("grant_type", "client_credentials"), ("client_id", &self.config.app_id), ("client_secret", &self.config.secret), ("scope", "https://wns.windows.com/.default"), ]; let url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let response = self.http_client.post(&url).form(¶ms).send().await?; if !response.status().is_success() { let status = response.status().to_string(); let body = response .text() .await .unwrap_or_else(|_| String::from("")); tracing::error!(status, "Failure when getting the WNS token: {}", body); - return Ok(None); + return Err(error::Error::WNSToken(WNSTokenError::Unknown(status))); } let response_json: serde_json::Value = response.json().await?; let token = response_json["access_token"] .as_str() - .ok_or(error::Error::TokenNotFound)? + .ok_or(error::Error::WNSToken(WNSTokenError::TokenNotFound))? .to_string(); let expires_in = response_json["expires_in"] .as_u64() - .ok_or(error::Error::ExpiryNotFound)?; + .ok_or(error::Error::WNSToken(WNSTokenError::ExpiryNotFound))?; let expires = SystemTime::now() + Duration::from_secs(expires_in); { let mut write_guard = self .access_token .write() .map_err(|_| error::Error::WriteLock)?; *write_guard = Some(WNSAccessToken { token: token.clone(), expires, }); } - Ok(Some(token)) + Ok(token) } } diff --git a/services/tunnelbroker/src/notifs/wns/response.rs b/services/tunnelbroker/src/notifs/wns/response.rs new file mode 100644 index 000000000..377b76010 --- /dev/null +++ b/services/tunnelbroker/src/notifs/wns/response.rs @@ -0,0 +1,77 @@ +use derive_more::{Display, Error}; +use reqwest::StatusCode; + +#[derive(PartialEq, Debug, Clone, Display, Error)] +pub struct InvalidArgumentError { + pub details: String, +} + +#[derive(PartialEq, Debug, Display, Error)] +pub enum WNSErrorResponse { + /// No more information is available about this error. + UnspecifiedError, + + /// HTTP error code = 400. + /// One or more headers were specified incorrectly or conflict with another + /// header. + BadRequest(InvalidArgumentError), + + /// HTTP error code = 401. + /// The cloud service did not present a valid authentication ticket. + Unauthorized, + + /// HTTP error code = 403. + /// The cloud service is not authorized to send a notification to this URI. + Forbidden, + + /// HTTP error code = 404. + /// The channel URI is not valid or is not recognized by WNS. + NotFound, + + /// HTTP error code = 405. + /// Invalid method (GET, CREATE); only POST (Windows or Windows Phone) or + /// DELETE (Windows Phone only) is allowed. + MethodNotAllowed, + + /// HTTP error code = 406. + /// The cloud service exceeded its throttle limit. + NotAcceptable, + + /// HTTP error code = 410. + /// The channel expired. + Gone, + + /// HTTP error code = 413. + /// The notification payload exceeds the 5000 byte size limit. + RequestEntityTooLarge, + + /// HTTP error code = 500. + /// An internal failure caused notification delivery to fail. + InternalServerError, + + /// HTTP error code = 503. + /// The server is currently unavailable. + ServiceUnavailable, +} + +impl WNSErrorResponse { + pub fn from_status(status: StatusCode, body: String) -> Self { + match status { + StatusCode::BAD_REQUEST => { + WNSErrorResponse::BadRequest(InvalidArgumentError { details: body }) + } + StatusCode::UNAUTHORIZED => WNSErrorResponse::Unauthorized, + StatusCode::FORBIDDEN => WNSErrorResponse::Forbidden, + StatusCode::NOT_FOUND => WNSErrorResponse::NotFound, + StatusCode::METHOD_NOT_ALLOWED => WNSErrorResponse::MethodNotAllowed, + StatusCode::NOT_ACCEPTABLE => WNSErrorResponse::NotAcceptable, + StatusCode::GONE => WNSErrorResponse::Gone, + StatusCode::PAYLOAD_TOO_LARGE => WNSErrorResponse::RequestEntityTooLarge, + StatusCode::INTERNAL_SERVER_ERROR => { + WNSErrorResponse::InternalServerError + } + StatusCode::SERVICE_UNAVAILABLE => WNSErrorResponse::ServiceUnavailable, + _ => WNSErrorResponse::UnspecifiedError, + } + } +} diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index fde7db500..4944f59a8 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,700 +1,742 @@ use crate::constants::{ CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, }; use crate::notifs::fcm::response::FCMErrorResponse; use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use notifs::fcm::error::Error::FCMError as NotifsFCMError; use notifs::web_push::error::Error::WebPush as NotifsWebPushError; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::bad_device_token::BadDeviceToken; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, DeviceToTunnelbrokerMessage, Heartbeat, MessageToDevice, MessageToDeviceRequest, MessageToTunnelbroker, }; use web_push::WebPushError; use crate::notifs::apns::response::ErrorReason; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::notifs::apns::headers::NotificationHeaders; use crate::notifs::apns::APNsNotif; use crate::notifs::fcm::firebase_message::{ AndroidConfig, AndroidMessagePriority, FCMMessage, }; use crate::notifs::web_push::WebPushNotif; +use crate::notifs::wns::WNSNotif; use crate::notifs::{apns, NotifClient, NotifClientType}; use crate::{identity, notifs}; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, pub is_authenticated: bool, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, notif_client: NotifClient, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), DatabaseError(comm_lib::database::Error), MissingAPNsClient, MissingFCMClient, MissingWebPushClient, + MissingWNSClient, MissingDeviceToken, InvalidDeviceToken, InvalidNotifProvider, } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { DeviceToTunnelbrokerMessage::ConnectionInitializationMessage( mut session_info, ) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), is_authenticated: true, }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } DeviceToTunnelbrokerMessage::AnonymousInitializationMessage( session_info, ) => { debug!( "Starting unauthenticated session with device: {}", &session_info.device_id ); let device_info = DeviceInfo { device_id: session_info.device_id, device_type: session_info.device_type, device_app_version: session_info.device_app_version, device_os: session_info.device_os, is_authenticated: false, notify_token: None, }; Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage) } } } async fn publish_persisted_messages( db_client: &DatabaseClient, amqp_channel: &lapin::Channel, device_info: &DeviceInfo, ) -> Result<(), SessionError> { let messages = db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let message_to_device = MessageToDevice::from_hashmap(message)?; let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; } debug!("Flushed messages for device: {}", &device_info.device_id); Ok(()) } pub async fn initialize_amqp( db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; let mut args = FieldTable::default(); args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) .await?; publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok((device_info, amqp_consumer)) } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, notif_client: NotifClient, ) -> Self { Self { tx, db_client, device_info, amqp_channel, amqp_consumer, notif_client, } } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let message_to_device = MessageToDevice { device_id: message_request.device_id.clone(), payload: message_request.payload.clone(), message_id: message_id.clone(), }; let serialized_message = serde_json::to_string(&message_to_device)?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_message_to_tunnelbroker( &self, message_to_tunnelbroker: &MessageToTunnelbroker, ) -> Result<(), SessionError> { match message_to_tunnelbroker { MessageToTunnelbroker::SetDeviceToken(token) => { self .db_client .set_device_token( &self.device_info.device_id, &token.device_token, None, ) .await?; } MessageToTunnelbroker::SetDeviceTokenWithPlatform( token_with_platform, ) => { self .db_client .set_device_token( &self.device_info.device_id, &token_with_platform.device_token, Some(token_with_platform.platform.clone()), ) .await?; } } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return Some(MessageSentStatus::SerializationError(msg)); }; match serialized_message { DeviceToTunnelbrokerMessage::Heartbeat(Heartbeat {}) => { trace!("Received heartbeat from: {}", self.device_info.device_id); None } DeviceToTunnelbrokerMessage::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &message_id) .await { error!("Failed to delete message: {}:", e); } } None } DeviceToTunnelbrokerMessage::MessageToDeviceRequest(message_request) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } DeviceToTunnelbrokerMessage::MessageToTunnelbrokerRequest( message_request, ) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for Tunnelbroker"); let Ok(message_to_tunnelbroker) = serde_json::from_str(&message_request.payload) else { return Some(MessageSentStatus::SerializationError( message_request.payload, )); }; let result = self .handle_message_to_tunnelbroker(&message_to_tunnelbroker) .await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } DeviceToTunnelbrokerMessage::APNsNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received APNs notif for {}", notif.device_id); let Ok(headers) = serde_json::from_str::(¬if.headers) else { return Some(MessageSentStatus::SerializationError(notif.headers)); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::APNs) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let apns_notif = APNsNotif { device_token: device_token.clone(), headers, payload: notif.payload, }; if let Some(apns) = self.notif_client.apns.clone() { let response = apns.send(apns_notif).await; if let Err(apns::error::Error::ResponseError(body)) = &response { if matches!( body.reason, ErrorReason::BadDeviceToken | ErrorReason::Unregistered | ErrorReason::ExpiredToken ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( "Error invalidating device token {}: {:?}", device_token, e ); }; } } return Some( self .get_message_to_device_status(¬if.client_message_id, response), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingAPNsClient), )) } DeviceToTunnelbrokerMessage::FCMNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received FCM notif for {}", notif.device_id); let Some(priority) = AndroidMessagePriority::from_str(¬if.priority) else { return Some(MessageSentStatus::SerializationError(notif.priority)); }; let Ok(data) = serde_json::from_str(¬if.data) else { return Some(MessageSentStatus::SerializationError(notif.data)); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::FCM) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let fcm_message = FCMMessage { data, token: device_token.to_string(), android: AndroidConfig { priority }, }; if let Some(fcm) = self.notif_client.fcm.clone() { let result = fcm.send(fcm_message).await; if let Err(NotifsFCMError(fcm_error)) = &result { if matches!( fcm_error, FCMErrorResponse::Unregistered | FCMErrorResponse::InvalidArgument(_) ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( "Error invalidating device token {}: {:?}", device_token, e ); }; } } return Some( self.get_message_to_device_status(¬if.client_message_id, result), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingFCMClient), )) } DeviceToTunnelbrokerMessage::WebPushNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send web push notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received WebPush notif for {}", notif.device_id); let Some(web_push_client) = self.notif_client.web_push.clone() else { return Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingWebPushClient), )); }; let device_token = match self .get_device_token(notif.device_id.clone(), NotifClientType::WebPush) .await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let web_push_notif = WebPushNotif { device_token: device_token.clone(), payload: notif.payload, }; let result = web_push_client.send(web_push_notif).await; if let Err(NotifsWebPushError(web_push_error)) = &result { if matches!( web_push_error, WebPushError::EndpointNotValid | WebPushError::EndpointNotFound ) { if let Err(e) = self .invalidate_device_token(notif.device_id, device_token.clone()) .await { error!( "Error invalidating device token {}: {:?}", device_token, e ); }; } } Some( self.get_message_to_device_status(¬if.client_message_id, result), ) } + DeviceToTunnelbrokerMessage::WNSNotif(notif) => { + if !self.device_info.is_authenticated { + debug!( + "Unauthenticated device {} tried to send WNS notif. Aborting.", + self.device_info.device_id + ); + return Some(MessageSentStatus::Unauthenticated); + } + debug!("Received WNS notif for {}", notif.device_id); + + let Some(wns_client) = self.notif_client.wns.clone() else { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::MissingWNSClient), + )); + }; + + let device_token = match self + .get_device_token(notif.device_id, NotifClientType::WNS) + .await + { + Ok(token) => token, + Err(e) => { + return Some( + self + .get_message_to_device_status(¬if.client_message_id, Err(e)), + ) + } + }; + + let wns_notif = WNSNotif { + device_token, + payload: notif.payload, + }; + + let result = wns_client.send(wns_notif).await; + Some( + self.get_message_to_device_status(¬if.client_message_id, result), + ) + } _ => { error!("Client sent invalid message type"); Some(MessageSentStatus::InvalidRequest) } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), E>, ) -> MessageSentStatus where E: std::error::Error, { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } async fn get_device_token( &self, device_id: String, client: NotifClientType, ) -> Result { let db_token = self .db_client .get_device_token(&device_id) .await .map_err(SessionError::DatabaseError)?; match db_token { Some(token) => { if let Some(platform) = token.platform { if !client.supported_platform(platform) { return Err(SessionError::InvalidNotifProvider); } } if token.token_invalid { Err(SessionError::InvalidDeviceToken) } else { Ok(token.device_token) } } None => Err(SessionError::MissingDeviceToken), } } async fn invalidate_device_token( &self, device_id: String, invalidated_token: String, ) -> Result<(), SessionError> { let bad_device_token_message = BadDeviceToken { invalidated_token }; let payload = serde_json::to_string(&bad_device_token_message)?; let message_request = MessageToDeviceRequest { client_message_id: uuid::Uuid::new_v4().to_string(), device_id: device_id.to_string(), payload, }; self.handle_message_to_device(&message_request).await?; self .db_client .mark_device_token_as_invalid(&device_id) .await .map_err(SessionError::DatabaseError)?; Ok(()) } } diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index 31ebc1128..e07b99e43 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,83 +1,84 @@ //! Messages sent between Tunnelbroker and a device. pub mod bad_device_token; pub mod device_list_updated; pub mod keys; pub mod message_receive_confirmation; pub mod message_to_device; pub mod message_to_device_request; pub mod message_to_device_request_status; pub mod message_to_tunnelbroker; pub mod message_to_tunnelbroker_request; pub mod notif; pub mod session; pub use device_list_updated::*; pub use keys::*; pub use message_receive_confirmation::*; pub use message_to_device::*; pub use message_to_device_request::*; pub use message_to_device_request_status::*; pub use message_to_tunnelbroker::*; pub use message_to_tunnelbroker_request::*; pub use session::*; pub use websocket_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, }; use crate::bad_device_token::BadDeviceToken; use crate::notif::*; use serde::{Deserialize, Serialize}; // This file defines types and validation for messages exchanged // with the Tunnelbroker. The definitions in this file should remain in sync // with the structures defined in the corresponding // JavaScript file at `lib/types/tunnelbroker/messages.js`. // If you edit the definitions in one file, // please make sure to update the corresponding definitions in the other. // Messages sent from Device to Tunnelbroker. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum DeviceToTunnelbrokerMessage { ConnectionInitializationMessage(ConnectionInitializationMessage), AnonymousInitializationMessage(AnonymousInitializationMessage), APNsNotif(APNsNotif), FCMNotif(FCMNotif), WebPushNotif(WebPushNotif), + WNSNotif(WNSNotif), MessageToDeviceRequest(MessageToDeviceRequest), MessageReceiveConfirmation(MessageReceiveConfirmation), MessageToTunnelbrokerRequest(MessageToTunnelbrokerRequest), Heartbeat(Heartbeat), } // Messages sent from Tunnelbroker to Device. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum TunnelbrokerToDeviceMessage { ConnectionInitializationResponse(ConnectionInitializationResponse), DeviceToTunnelbrokerRequestStatus(DeviceToTunnelbrokerRequestStatus), MessageToDevice(MessageToDevice), BadDeviceToken(BadDeviceToken), Heartbeat(Heartbeat), } // Messages sent from Services (e.g. Identity) to Device. // This type is sent to a Device as a payload of MessageToDevice. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ServiceToDeviceMessages { RefreshKeysRequest(RefreshKeyRequest), IdentityDeviceListUpdated(IdentityDeviceListUpdated), BadDeviceToken(BadDeviceToken), } // Messages sent from Device to Tunnelbroker which Tunnelbroker itself should handle. // This type is sent to a Tunnelbroker as a payload of MessageToTunnelbrokerRequest. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum MessageToTunnelbroker { SetDeviceTokenWithPlatform(SetDeviceTokenWithPlatform), SetDeviceToken(SetDeviceToken), } diff --git a/shared/tunnelbroker_messages/src/messages/notif.rs b/shared/tunnelbroker_messages/src/messages/notif.rs index 383159874..432a870b7 100644 --- a/shared/tunnelbroker_messages/src/messages/notif.rs +++ b/shared/tunnelbroker_messages/src/messages/notif.rs @@ -1,37 +1,48 @@ use serde::{Deserialize, Serialize}; use util_macros::TagAwareDeserialize; /// APNs notif built on client. #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct APNsNotif { pub headers: String, #[serde(rename = "clientMessageID")] pub client_message_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } /// FCM notif built on client. #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct FCMNotif { #[serde(rename = "clientMessageID")] pub client_message_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub data: String, pub priority: String, } /// WebPush notif built on client. #[derive(Serialize, Deserialize, TagAwareDeserialize, PartialEq, Debug)] #[serde(tag = "type", remote = "Self", rename_all = "camelCase")] pub struct WebPushNotif { #[serde(rename = "clientMessageID")] pub client_message_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } + +/// WNS notif built on client. +#[derive(Serialize, Deserialize, TagAwareDeserialize, PartialEq, Debug)] +#[serde(tag = "type", remote = "Self", rename_all = "camelCase")] +pub struct WNSNotif { + #[serde(rename = "clientMessageID")] + pub client_message_id: String, + #[serde(rename = "deviceID")] + pub device_id: String, + pub payload: String, +}