diff --git a/services/tunnelbroker/src/notifs/fcm/mod.rs b/services/tunnelbroker/src/notifs/fcm/mod.rs index b6f249d25..3060bdd54 100644 --- a/services/tunnelbroker/src/notifs/fcm/mod.rs +++ b/services/tunnelbroker/src/notifs/fcm/mod.rs @@ -1,95 +1,102 @@ use crate::constants::error_types; use crate::constants::PUSH_SERVICE_REQUEST_TIMEOUT; use crate::notifs::fcm::config::FCMConfig; use crate::notifs::fcm::error::Error::FCMError; use crate::notifs::fcm::firebase_message::{FCMMessage, FCMMessageWrapper}; use crate::notifs::fcm::response::FCMErrorResponse; use crate::notifs::fcm::token::FCMToken; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::StatusCode; use std::time::Duration; use tracing::{debug, error}; pub mod config; pub mod error; pub mod firebase_message; pub mod response; mod token; #[derive(Clone)] pub struct FCMClient { http_client: reqwest::Client, config: FCMConfig, token: FCMToken, } impl FCMClient { pub fn new(config: &FCMConfig) -> Result { let http_client = reqwest::Client::builder() .timeout(PUSH_SERVICE_REQUEST_TIMEOUT) .build()?; // Token must be a short-lived token (60 minutes) and in a reasonable // timeframe. let token_ttl = Duration::from_secs(60 * 55); let token = FCMToken::new(&config.clone(), token_ttl)?; Ok(FCMClient { http_client, config: config.clone(), token, }) } pub async fn send(&self, message: FCMMessage) -> Result<(), error::Error> { let token = message.token.clone(); debug!("Sending FCM notif to {}", token); - let mut headers = HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - let bearer = self.token.get_auth_bearer().await?; - headers.insert(AUTHORIZATION, HeaderValue::from_str(&bearer)?); - - let url = format!( - "https://fcm.googleapis.com/v1/projects/{}/messages:send", - self.config.project_id - ); - let msg_wrapper = FCMMessageWrapper { message }; let payload = serde_json::to_string(&msg_wrapper).unwrap(); + let mut is_retry = false; - let response = self - .http_client - .post(&url) - .headers(headers) - .body(payload) - .send() - .await?; + loop { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + let bearer = self.token.get_auth_bearer(is_retry).await?; + headers.insert(AUTHORIZATION, HeaderValue::from_str(&bearer)?); - match response.status() { - StatusCode::OK => { - debug!("Successfully sent FCM notif to {}", token); - Ok(()) - } - error_status => { - let body = response - .text() - .await - .unwrap_or_else(|error| format!("Error occurred: {}", error)); - error!( - errorType = error_types::FCM_ERROR, - "Failed sending FCM notification to: {}. Status: {}. Body: {}", - token, - error_status, - body - ); - let fcm_error = FCMErrorResponse::from_status(error_status, body); - Err(FCMError(fcm_error)) + let url = format!( + "https://fcm.googleapis.com/v1/projects/{}/messages:send", + self.config.project_id + ); + + let response = self + .http_client + .post(&url) + .headers(headers) + .body(payload.clone()) + .send() + .await?; + + match response.status() { + StatusCode::OK => { + debug!("Successfully sent FCM notif to {}", token); + return Ok(()); + } + StatusCode::UNAUTHORIZED if !is_retry => { + is_retry = true; + debug!("Retrying after first 401 to regenerate token."); + continue; + } + error_status => { + let body = response + .text() + .await + .unwrap_or_else(|error| format!("Error occurred: {}", error)); + error!( + errorType = error_types::FCM_ERROR, + "Failed sending FCM notification to: {}. Status: {}. Body: {}", + token, + error_status, + body + ); + let fcm_error = FCMErrorResponse::from_status(error_status, body); + return Err(FCMError(fcm_error)); + } } } } } diff --git a/services/tunnelbroker/src/notifs/fcm/token.rs b/services/tunnelbroker/src/notifs/fcm/token.rs index 61282fe54..e474f7c36 100644 --- a/services/tunnelbroker/src/notifs/fcm/token.rs +++ b/services/tunnelbroker/src/notifs/fcm/token.rs @@ -1,132 +1,135 @@ use crate::constants::FCM_ACCESS_TOKEN_GENERATION_THRESHOLD; use crate::notifs::fcm::config::FCMConfig; use crate::notifs::fcm::error::Error; use crate::notifs::fcm::error::Error::FCMTokenNotInitialized; use jsonwebtoken::{Algorithm, EncodingKey, Header}; use serde::Deserialize; use serde_json::json; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tracing::debug; #[derive(Debug, Clone, Deserialize)] struct FCMAccessTokenResponse { access_token: String, token_type: String, expires_in: u64, } #[derive(Debug, Clone, Deserialize)] struct FCMAccessToken { access_token: String, token_type: String, expiration_time: u64, } #[derive(Debug, Clone)] pub struct FCMToken { token: Arc>>, config: FCMConfig, validity_duration: Duration, } impl FCMToken { pub fn new(config: &FCMConfig, token_ttl: Duration) -> Result { Ok(FCMToken { token: Arc::new(RwLock::new(None)), config: config.clone(), validity_duration: token_ttl, }) } - pub async fn get_auth_bearer(&self) -> Result { - if self.fcm_token_needs_generation().await { + pub async fn get_auth_bearer( + &self, + force_regenerate: bool, + ) -> Result { + if force_regenerate || self.fcm_token_needs_generation().await { self.generate_fcm_token().await?; } let bearer = self.token.read().await; match &*bearer { Some(token) => Ok(format!("{} {}", token.token_type, token.access_token)), None => Err(FCMTokenNotInitialized), } } fn get_jwt_token(&self, created_at: u64) -> Result { let exp = created_at + self.validity_duration.as_secs(); let payload = json!({ // The email address of the service account. "iss": self.config.client_email, // A descriptor of the intended target of the assertion. "aud": self.config.token_uri, // The time the assertion was issued. "iat": created_at, // The expiration time of the assertion. // This value has a maximum of 1 hour after the issued time. "exp": exp, // A space-delimited list of the permissions that the application // requests. "scope": "https://www.googleapis.com/auth/firebase.messaging", }); debug!("Encoding JWT token for FCM, created at: {}", created_at); let header = Header::new(Algorithm::RS256); let encoding_key = EncodingKey::from_rsa_pem(self.config.private_key.as_bytes()).unwrap(); let token = jsonwebtoken::encode(&header, &payload, &encoding_key)?; Ok(token) } async fn get_fcm_access_token( &self, jwt_token: String, ) -> Result { let response = reqwest::Client::new() .post(self.config.token_uri.clone()) .form(&[ ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), ("assertion", &jwt_token), ]) .send() .await?; let access_token = response.json::().await?; Ok(access_token) } async fn fcm_token_needs_generation(&self) -> bool { let token = self.token.read().await; match &*token { None => true, Some(token) => { get_time() - FCM_ACCESS_TOKEN_GENERATION_THRESHOLD >= token.expiration_time } } } async fn generate_fcm_token(&self) -> Result<(), Error> { debug!("Generating FCM access token"); let mut token = self.token.write().await; let created_at = get_time(); let new_jwt_token = self.get_jwt_token(created_at)?; let access_token_response = self.get_fcm_access_token(new_jwt_token).await?; *token = Some(FCMAccessToken { access_token: access_token_response.access_token, token_type: access_token_response.token_type, expiration_time: created_at + access_token_response.expires_in, }); Ok(()) } } fn get_time() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() }