diff --git a/services/tunnelbroker/src/notifs/apns/mod.rs b/services/tunnelbroker/src/notifs/apns/mod.rs --- a/services/tunnelbroker/src/notifs/apns/mod.rs +++ b/services/tunnelbroker/src/notifs/apns/mod.rs @@ -12,7 +12,7 @@ pub mod config; pub mod error; pub(crate) mod headers; -mod response; +pub mod response; pub mod token; #[derive(Clone)] diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -19,6 +19,7 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; +use tunnelbroker_messages::bad_device_token::BadDeviceToken; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, @@ -26,8 +27,11 @@ MessageToDeviceRequest, MessageToTunnelbroker, }; +use crate::notifs::apns::response::ErrorReason; + use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; +use crate::notifs::apns::error::Error; use crate::notifs::apns::headers::NotificationHeaders; use crate::notifs::apns::APNsNotif; use crate::notifs::fcm::firebase_message::{ @@ -381,24 +385,40 @@ return Some(MessageSentStatus::SerializationError(notif.headers)); }; - let device_token = match self.get_device_token(notif.device_id).await { - Ok(token) => token, - Err(e) => { - return Some( - self - .get_message_to_device_status(¬if.client_message_id, Err(e)), - ) - } - }; + let device_token = + match self.get_device_token(notif.device_id.clone()).await { + Ok(token) => token, + Err(e) => { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(e), + )) + } + }; let apns_notif = APNsNotif { - device_token, + device_token: device_token.clone(), headers, payload: notif.payload, }; if let Some(apns) = self.notif_client.apns.clone() { let response = apns.send(apns_notif).await; + if let Err(Error::ResponseError(body)) = &response { + if matches!( + body.reason, + ErrorReason::BadDeviceToken + | ErrorReason::Unregistered + | ErrorReason::ExpiredToken + ) { + if let Err(e) = self + .invalidate_device_token(notif.device_id, device_token) + .await + { + error!("Error invalidating device token: {:?}", e); + }; + } + } return Some( self .get_message_to_device_status(¬if.client_message_id, response), @@ -583,4 +603,28 @@ None => Err(SessionError::MissingDeviceToken), } } + + async fn invalidate_device_token( + &self, + device_id: String, + invalidated_token: String, + ) -> Result<(), SessionError> { + let bad_device_token_message = BadDeviceToken { invalidated_token }; + let payload = serde_json::to_string(&bad_device_token_message)?; + let message_request = MessageToDeviceRequest { + client_message_id: uuid::Uuid::new_v4().to_string(), + device_id: device_id.to_string(), + payload, + }; + + self.handle_message_to_device(&message_request).await?; + + self + .db_client + .mark_device_token_as_invalid(&device_id) + .await + .map_err(SessionError::DatabaseError)?; + + Ok(()) + } }