diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index 146f572bf..f4c47c1af 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,53 +1,54 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; pub const ENV_APNS_CONFIG: &str = "APNS_CONFIG"; pub const ENV_FCM_CONFIG: &str = "FCM_CONFIG"; pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; +pub const FCM_ACCESS_TOKEN_GENERATION_THRESHOLD: u64 = 5 * 60; pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } // This table holds a device token associated with a device. // // - (primary key) = (deviceID: Partition Key) // - deviceID: The public key of a device's olm identity key // - deviceToken: Token to push services uploaded by device. pub mod device_tokens { pub const TABLE_NAME: &str = "tunnelbroker-device-tokens"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const DEVICE_TOKEN: &str = "deviceToken"; pub const DEVICE_TOKEN_INDEX_NAME: &str = "deviceToken-index"; } } diff --git a/services/tunnelbroker/src/notifs/fcm/token.rs b/services/tunnelbroker/src/notifs/fcm/token.rs index 835ba6473..61282fe54 100644 --- a/services/tunnelbroker/src/notifs/fcm/token.rs +++ b/services/tunnelbroker/src/notifs/fcm/token.rs @@ -1,92 +1,132 @@ +use crate::constants::FCM_ACCESS_TOKEN_GENERATION_THRESHOLD; use crate::notifs::fcm::config::FCMConfig; use crate::notifs::fcm::error::Error; use crate::notifs::fcm::error::Error::FCMTokenNotInitialized; use jsonwebtoken::{Algorithm, EncodingKey, Header}; use serde::Deserialize; use serde_json::json; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tracing::debug; #[derive(Debug, Clone, Deserialize)] struct FCMAccessTokenResponse { access_token: String, token_type: String, expires_in: u64, } #[derive(Debug, Clone, Deserialize)] struct FCMAccessToken { access_token: String, token_type: String, expiration_time: u64, } #[derive(Debug, Clone)] pub struct FCMToken { token: Arc>>, config: FCMConfig, validity_duration: Duration, } impl FCMToken { pub fn new(config: &FCMConfig, token_ttl: Duration) -> Result { Ok(FCMToken { token: Arc::new(RwLock::new(None)), config: config.clone(), validity_duration: token_ttl, }) } pub async fn get_auth_bearer(&self) -> Result { + if self.fcm_token_needs_generation().await { + self.generate_fcm_token().await?; + } + let bearer = self.token.read().await; match &*bearer { Some(token) => Ok(format!("{} {}", token.token_type, token.access_token)), None => Err(FCMTokenNotInitialized), } } fn get_jwt_token(&self, created_at: u64) -> Result { let exp = created_at + self.validity_duration.as_secs(); let payload = json!({ // The email address of the service account. "iss": self.config.client_email, // A descriptor of the intended target of the assertion. "aud": self.config.token_uri, // The time the assertion was issued. "iat": created_at, // The expiration time of the assertion. // This value has a maximum of 1 hour after the issued time. "exp": exp, // A space-delimited list of the permissions that the application // requests. "scope": "https://www.googleapis.com/auth/firebase.messaging", }); debug!("Encoding JWT token for FCM, created at: {}", created_at); let header = Header::new(Algorithm::RS256); let encoding_key = EncodingKey::from_rsa_pem(self.config.private_key.as_bytes()).unwrap(); let token = jsonwebtoken::encode(&header, &payload, &encoding_key)?; Ok(token) } async fn get_fcm_access_token( &self, jwt_token: String, ) -> Result { let response = reqwest::Client::new() .post(self.config.token_uri.clone()) .form(&[ ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), ("assertion", &jwt_token), ]) .send() .await?; let access_token = response.json::().await?; Ok(access_token) } + + async fn fcm_token_needs_generation(&self) -> bool { + let token = self.token.read().await; + match &*token { + None => true, + Some(token) => { + get_time() - FCM_ACCESS_TOKEN_GENERATION_THRESHOLD + >= token.expiration_time + } + } + } + async fn generate_fcm_token(&self) -> Result<(), Error> { + debug!("Generating FCM access token"); + let mut token = self.token.write().await; + + let created_at = get_time(); + let new_jwt_token = self.get_jwt_token(created_at)?; + let access_token_response = + self.get_fcm_access_token(new_jwt_token).await?; + + *token = Some(FCMAccessToken { + access_token: access_token_response.access_token, + token_type: access_token_response.token_type, + expiration_time: created_at + access_token_response.expires_in, + }); + + Ok(()) + } +} + +fn get_time() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() }