diff --git a/services/tunnelbroker/src/notifs/wns/error.rs b/services/tunnelbroker/src/notifs/wns/error.rs index 94b4763c7..911506bac 100644 --- a/services/tunnelbroker/src/notifs/wns/error.rs +++ b/services/tunnelbroker/src/notifs/wns/error.rs @@ -1,6 +1,15 @@ use derive_more::{Display, Error, From}; #[derive(Debug, From, Display, Error)] pub enum Error { Reqwest(reqwest::Error), + SerdeJson(serde_json::Error), + #[display(fmt = "Token not found in response")] + TokenNotFound, + #[display(fmt = "Expiry time not found in response")] + ExpiryNotFound, + #[display(fmt = "Failed to acquire read lock")] + ReadLock, + #[display(fmt = "Failed to acquire write lock")] + WriteLock, } diff --git a/services/tunnelbroker/src/notifs/wns/mod.rs b/services/tunnelbroker/src/notifs/wns/mod.rs index 856257dbb..22b4a5e83 100644 --- a/services/tunnelbroker/src/notifs/wns/mod.rs +++ b/services/tunnelbroker/src/notifs/wns/mod.rs @@ -1,20 +1,97 @@ use crate::notifs::wns::config::WNSConfig; +use std::{ + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; pub mod config; mod error; +#[derive(Debug, Clone)] +pub struct WNSAccessToken { + token: String, + expires: SystemTime, +} + #[derive(Clone)] pub struct WNSClient { http_client: reqwest::Client, config: WNSConfig, + access_token: Arc>>, } impl WNSClient { pub fn new(config: &WNSConfig) -> Result { let http_client = reqwest::Client::builder().build()?; Ok(WNSClient { http_client, config: config.clone(), + access_token: Arc::new(RwLock::new(None)), }) } + + pub async fn get_wns_token( + &mut self, + ) -> Result, error::Error> { + const EXPIRY_WINDOW: Duration = Duration::from_secs(10); + + { + let read_guard = self + .access_token + .read() + .map_err(|_| error::Error::ReadLock)?; + if let Some(ref token) = *read_guard { + if token.expires >= SystemTime::now() - EXPIRY_WINDOW { + return Ok(Some(token.token.clone())); + } + } + } + + let params = [ + ("grant_type", "client_credentials"), + ("client_id", &self.config.app_id), + ("client_secret", &self.config.secret), + ("scope", "https://wns.windows.com/.default"), + ]; + + let url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let response = self.http_client.post(&url).form(¶ms).send().await?; + + if !response.status().is_success() { + let status = response.status().to_string(); + let body = response + .text() + .await + .unwrap_or_else(|_| String::from("")); + tracing::error!(status, "Failure when getting the WNS token: {}", body); + return Ok(None); + } + + let response_json: serde_json::Value = response.json().await?; + let token = response_json["access_token"] + .as_str() + .ok_or(error::Error::TokenNotFound)? + .to_string(); + let expires_in = response_json["expires_in"] + .as_u64() + .ok_or(error::Error::ExpiryNotFound)?; + + let expires = SystemTime::now() + Duration::from_secs(expires_in); + + { + let mut write_guard = self + .access_token + .write() + .map_err(|_| error::Error::WriteLock)?; + *write_guard = Some(WNSAccessToken { + token: token.clone(), + expires, + }); + } + Ok(Some(token)) + } }