diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index 7251938bb..5b9b7c58d 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,61 +1,63 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; pub const WS_SESSION_CLOSE_AMQP_MSG: &str = "SessionClose"; pub const ENV_APNS_CONFIG: &str = "APNS_CONFIG"; pub const ENV_FCM_CONFIG: &str = "FCM_CONFIG"; pub const ENV_WEB_PUSH_CONFIG: &str = "WEB_PUSH_CONFIG"; pub const ENV_WNS_CONFIG: &str = "WNS_CONFIG"; pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub const FCM_ACCESS_TOKEN_GENERATION_THRESHOLD: u64 = 5 * 60; +pub const PUSH_SERVICE_REQUEST_TIMEOUT: Duration = Duration::from_secs(8); + pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } // This table holds a device token associated with a device. // // - (primary key) = (deviceID: Partition Key) // - deviceID: The public key of a device's olm identity key. // - deviceToken: Token to push services uploaded by device. // - tokenInvalid: Information is token is invalid. pub mod device_tokens { pub const TABLE_NAME: &str = "tunnelbroker-device-tokens"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const DEVICE_TOKEN: &str = "deviceToken"; pub const TOKEN_INVALID: &str = "tokenInvalid"; pub const PLATFORM: &str = "platform"; pub const DEVICE_TOKEN_INDEX_NAME: &str = "deviceToken-index"; } } diff --git a/services/tunnelbroker/src/notifs/apns/mod.rs b/services/tunnelbroker/src/notifs/apns/mod.rs index ccf39b6bb..e4580cafa 100644 --- a/services/tunnelbroker/src/notifs/apns/mod.rs +++ b/services/tunnelbroker/src/notifs/apns/mod.rs @@ -1,137 +1,139 @@ +use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; use crate::notifs::apns::config::APNsConfig; use crate::notifs::apns::error::Error::ResponseError; use crate::notifs::apns::headers::{NotificationHeaders, PushType}; use crate::notifs::apns::response::ErrorBody; use crate::notifs::apns::token::APNsToken; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use std::time::Duration; use tracing::debug; pub mod config; pub mod error; pub(crate) mod headers; pub mod response; pub mod token; #[derive(Clone)] pub struct APNsClient { http2_client: reqwest::Client, token: APNsToken, is_prod: bool, } #[derive(Serialize, Deserialize)] pub struct APNsNotif { pub device_token: String, pub headers: NotificationHeaders, pub payload: String, } impl APNsClient { pub fn new(config: &APNsConfig) -> Result { let token_ttl = Duration::from_secs(60 * 55); let token = APNsToken::new(config, token_ttl)?; let http2_client = reqwest::Client::builder() .http2_prior_knowledge() .http2_keep_alive_interval(Some(Duration::from_secs(5))) .http2_keep_alive_while_idle(true) + .timeout(PUSH_SERVICE_REQUEST_TIMEOUT) .build()?; Ok(APNsClient { http2_client, token, is_prod: config.production, }) } async fn build_headers( &self, notif_headers: NotificationHeaders, ) -> Result { let mut headers = HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); let bearer = self.token.get_bearer().await?; let token = format!("bearer {bearer}"); headers.insert(AUTHORIZATION, HeaderValue::from_str(&token)?); if let Some(apns_topic) = ¬if_headers.apns_topic { headers.insert("apns-topic", HeaderValue::from_str(apns_topic)?); } if let Some(apns_id) = ¬if_headers.apns_id { headers.insert("apns-id", HeaderValue::from_str(apns_id)?); } if let Some(push_type) = ¬if_headers.apns_push_type { let push_type_str = match push_type { PushType::Alert => "alert", PushType::Background => "background", PushType::Location => "location", PushType::Voip => "voip", PushType::Complication => "complication", PushType::FileProvider => "fileprovider", PushType::Mdm => "mdm", PushType::LiveActivity => "live", PushType::PushToTalk => "pushtotalk", }; headers.insert("apns-push-type", HeaderValue::from_static(push_type_str)); } if let Some(expiration) = notif_headers.apns_expiration { headers.insert("apns-expiration", HeaderValue::from(expiration)); } if let Some(priority) = notif_headers.apns_priority { headers.insert("apns-priority", HeaderValue::from(priority)); } if let Some(collapse_id) = ¬if_headers.apns_collapse_id { headers.insert("apns-collapse-id", HeaderValue::from_str(collapse_id)?); } Ok(headers) } fn get_endpoint(&self) -> &'static str { if self.is_prod { return "api.push.apple.com"; } "api.development.push.apple.com" } pub async fn send(&self, notif: APNsNotif) -> Result<(), error::Error> { debug!("Sending APNs notif to {}", notif.device_token); let headers = self.build_headers(notif.headers.clone()).await?; let url = format!( "https://{}/3/device/{}", self.get_endpoint(), notif.device_token ); let response = self .http2_client .post(url) .headers(headers.clone()) .body(notif.payload) .send() .await?; match response.status() { StatusCode::OK => Ok(()), _ => { let error_body: ErrorBody = response.json().await?; Err(ResponseError(error_body)) } } } } diff --git a/services/tunnelbroker/src/notifs/fcm/mod.rs b/services/tunnelbroker/src/notifs/fcm/mod.rs index 421d37eb9..416e4481d 100644 --- a/services/tunnelbroker/src/notifs/fcm/mod.rs +++ b/services/tunnelbroker/src/notifs/fcm/mod.rs @@ -1,88 +1,91 @@ +use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; use crate::notifs::fcm::config::FCMConfig; use crate::notifs::fcm::error::Error::FCMError; use crate::notifs::fcm::firebase_message::{FCMMessage, FCMMessageWrapper}; use crate::notifs::fcm::response::FCMErrorResponse; use crate::notifs::fcm::token::FCMToken; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::StatusCode; use std::time::Duration; use tracing::{debug, error}; pub mod config; pub mod error; pub mod firebase_message; pub mod response; mod token; #[derive(Clone)] pub struct FCMClient { http_client: reqwest::Client, config: FCMConfig, token: FCMToken, } impl FCMClient { pub fn new(config: &FCMConfig) -> Result { - let http_client = reqwest::Client::builder().build()?; + let http_client = reqwest::Client::builder() + .timeout(PUSH_SERVICE_REQUEST_TIMEOUT) + .build()?; // Token must be a short-lived token (60 minutes) and in a reasonable // timeframe. let token_ttl = Duration::from_secs(60 * 55); let token = FCMToken::new(&config.clone(), token_ttl)?; Ok(FCMClient { http_client, config: config.clone(), token, }) } pub async fn send(&self, message: FCMMessage) -> Result<(), error::Error> { let token = message.token.clone(); debug!("Sending FCM notif to {}", token); let mut headers = HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); let bearer = self.token.get_auth_bearer().await?; headers.insert(AUTHORIZATION, HeaderValue::from_str(&bearer)?); let url = format!( "https://fcm.googleapis.com/v1/projects/{}/messages:send", self.config.project_id ); let msg_wrapper = FCMMessageWrapper { message }; let payload = serde_json::to_string(&msg_wrapper).unwrap(); let response = self .http_client .post(&url) .headers(headers) .body(payload) .send() .await?; match response.status() { StatusCode::OK => { debug!("Successfully sent FCM notif to {}", token); Ok(()) } error_status => { let body = response .text() .await .unwrap_or_else(|error| format!("Error occurred: {}", error)); error!( "Failed sending FCM notification to: {}. Status: {}. Body: {}", token, error_status, body ); let fcm_error = FCMErrorResponse::from_status(error_status, body); Err(FCMError(fcm_error)) } } } } diff --git a/services/tunnelbroker/src/notifs/web_push/mod.rs b/services/tunnelbroker/src/notifs/web_push/mod.rs index af2df4be4..21d4b5ef6 100644 --- a/services/tunnelbroker/src/notifs/web_push/mod.rs +++ b/services/tunnelbroker/src/notifs/web_push/mod.rs @@ -1,60 +1,67 @@ +use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; use serde::{Deserialize, Serialize}; use web_push::{ ContentEncoding, HyperWebPushClient, SubscriptionInfo, VapidSignatureBuilder, WebPushMessageBuilder, }; use web_push::{PartialVapidSignatureBuilder, WebPushClient as _}; use crate::notifs::web_push::config::WebPushConfig; pub mod config; pub mod error; #[derive(Serialize, Deserialize)] pub struct WebPushNotif { /// Device token for web is a JSON-encoded [`SubscriptionInfo`]. pub device_token: String, pub payload: String, } #[derive(Clone)] pub struct WebPushClient { _config: WebPushConfig, inner_client: HyperWebPushClient, signature_builder: PartialVapidSignatureBuilder, } impl WebPushClient { pub fn new(config: &WebPushConfig) -> Result { let inner_client = HyperWebPushClient::new(); let signature_builder = VapidSignatureBuilder::from_base64_no_sub( &config.private_key, web_push::URL_SAFE_NO_PAD, )?; Ok(WebPushClient { _config: config.clone(), inner_client, signature_builder, }) } pub async fn send(&self, notif: WebPushNotif) -> Result<(), error::Error> { let subscription = serde_json::from_str::(¬if.device_token)?; let vapid_signature = self .signature_builder .clone() .add_sub_info(&subscription) .build()?; let mut builder = WebPushMessageBuilder::new(&subscription); builder.set_payload(ContentEncoding::Aes128Gcm, notif.payload.as_bytes()); builder.set_vapid_signature(vapid_signature); let message = builder.build()?; - self.inner_client.send(message).await?; + let response_future = self.inner_client.send(message); + + tokio::time::timeout(PUSH_SERVICE_REQUEST_TIMEOUT, response_future) + .await + .map_err(|err| { + error::Error::WebPush(web_push::WebPushError::Other(err.to_string())) + })??; Ok(()) } } diff --git a/services/tunnelbroker/src/notifs/wns/mod.rs b/services/tunnelbroker/src/notifs/wns/mod.rs index dfa73e6e5..30fdf3e14 100644 --- a/services/tunnelbroker/src/notifs/wns/mod.rs +++ b/services/tunnelbroker/src/notifs/wns/mod.rs @@ -1,143 +1,146 @@ +use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; use crate::notifs::wns::config::WNSConfig; use error::WNSTokenError; use reqwest::StatusCode; use response::WNSErrorResponse; use std::{ sync::{Arc, RwLock}, time::{Duration, SystemTime}, }; pub mod config; pub mod error; pub mod response; #[derive(Debug, Clone)] pub struct WNSAccessToken { token: String, expires: SystemTime, } #[derive(Debug, Clone)] pub struct WNSNotif { pub device_token: String, pub payload: String, } #[derive(Clone)] pub struct WNSClient { http_client: reqwest::Client, config: WNSConfig, access_token: Arc>>, } impl WNSClient { pub fn new(config: &WNSConfig) -> Result { - let http_client = reqwest::Client::builder().build()?; + let http_client = reqwest::Client::builder() + .timeout(PUSH_SERVICE_REQUEST_TIMEOUT) + .build()?; Ok(WNSClient { http_client, config: config.clone(), access_token: Arc::new(RwLock::new(None)), }) } pub async fn send(&self, notif: WNSNotif) -> Result<(), error::Error> { let wns_access_token = self.get_wns_token().await?; let url = notif.device_token; // Send the notification let response = self .http_client .post(&url) .header("Content-Type", "application/octet-stream") .header("X-WNS-Type", "wns/raw") .bearer_auth(wns_access_token) .body(notif.payload) .send() .await?; match response.status() { StatusCode::OK => { tracing::debug!("Successfully sent WNS notif to {}", &url); Ok(()) } error_status => { let body = response .text() .await .unwrap_or_else(|error| format!("Error occurred: {}", error)); tracing::error!( "Failed sending WNS notification to: {}. Status: {}. Body: {}", &url, error_status, body ); let wns_error = WNSErrorResponse::from_status(error_status, body); Err(error::Error::WNSNotification(wns_error)) } } } pub async fn get_wns_token(&self) -> Result { const EXPIRY_WINDOW: Duration = Duration::from_secs(10); { let read_guard = self .access_token .read() .map_err(|_| error::Error::ReadLock)?; if let Some(ref token) = *read_guard { if token.expires >= SystemTime::now() - EXPIRY_WINDOW { return Ok(token.token.clone()); } } } let params = [ ("grant_type", "client_credentials"), ("client_id", &self.config.app_id), ("client_secret", &self.config.secret), ("scope", "https://wns.windows.com/.default"), ]; let url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let response = self.http_client.post(&url).form(¶ms).send().await?; if !response.status().is_success() { let status = response.status().to_string(); let body = response .text() .await .unwrap_or_else(|_| String::from("")); tracing::error!(status, "Failure when getting the WNS token: {}", body); return Err(error::Error::WNSToken(WNSTokenError::Unknown(status))); } let response_json: serde_json::Value = response.json().await?; let token = response_json["access_token"] .as_str() .ok_or(error::Error::WNSToken(WNSTokenError::TokenNotFound))? .to_string(); let expires_in = response_json["expires_in"] .as_u64() .ok_or(error::Error::WNSToken(WNSTokenError::ExpiryNotFound))?; let expires = SystemTime::now() + Duration::from_secs(expires_in); { let mut write_guard = self .access_token .write() .map_err(|_| error::Error::WriteLock)?; *write_guard = Some(WNSAccessToken { token: token.clone(), expires, }); } Ok(token) } }