diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -21,10 +21,12 @@ use notifs::fcm::error::Error::FCMError as NotifsFCMError; use notifs::web_push::error::Error::WebPush as NotifsWebPushError; use notifs::wns::error::Error::WNSNotification as NotifsWNSError; +use reqwest::Url; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::bad_device_token::BadDeviceToken; +use tunnelbroker_messages::Platform; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, @@ -304,6 +306,33 @@ MessageToTunnelbroker::SetDeviceTokenWithPlatform( token_with_platform, ) => { + if matches!(token_with_platform.platform, Platform::Windows) { + let Ok(parsed_url) = Url::parse(&token_with_platform.device_token) + else { + debug!( + device_token = &token_with_platform.device_token, + device_id = &self.device_info.device_id, + "Device token could not be parsed as a URL" + ); + return Err(SessionError::InvalidDeviceToken); + }; + let Some(domain) = parsed_url.domain() else { + debug!( + device_token = &token_with_platform.device_token, + device_id = &self.device_info.device_id, + "Domain missing from device token URL" + ); + return Err(SessionError::InvalidDeviceToken); + }; + if !domain.ends_with("notify.windows.com") { + debug!( + device_token = &token_with_platform.device_token, + device_id = &self.device_info.device_id, + "Device token URL contains invalid domain" + ); + return Err(SessionError::InvalidDeviceToken); + } + } self .db_client .set_device_token(