diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -56,8 +56,11 @@ let notif_client = NotifClient { apns }; let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); - let websocket_server = - websockets::run_server(db_client.clone(), &amqp_connection); + let websocket_server = websockets::run_server( + db_client.clone(), + &amqp_connection, + notif_client.clone(), + ); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, diff --git a/services/tunnelbroker/src/notifs/apns/mod.rs b/services/tunnelbroker/src/notifs/apns/mod.rs --- a/services/tunnelbroker/src/notifs/apns/mod.rs +++ b/services/tunnelbroker/src/notifs/apns/mod.rs @@ -12,7 +12,7 @@ pub mod config; pub mod error; -mod headers; +pub(crate) mod headers; mod response; pub mod token; diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -2,6 +2,7 @@ use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; use crate::database::DatabaseClient; +use crate::notifs::NotifClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; @@ -40,6 +41,7 @@ addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, + notif_client: NotifClient, } impl hyper::service::Service> for WebsocketService { @@ -61,6 +63,7 @@ let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); + let notif_client = self.notif_client.clone(); let future = async move { // Check if the request is a websocket upgrade request. @@ -69,7 +72,8 @@ // Spawn a task to handle the websocket connection. tokio::spawn(async move { - accept_connection(websocket, addr, db_client, channel).await; + accept_connection(websocket, addr, db_client, channel, notif_client) + .await; }); // Return the response so the spawned future can continue. @@ -98,6 +102,7 @@ pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, + notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); @@ -121,6 +126,7 @@ channel, db_client: db_client.clone(), addr, + notif_client: notif_client.clone(), }, ) .with_upgrades(); @@ -164,6 +170,7 @@ addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, + notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); @@ -183,7 +190,15 @@ // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { - match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { + match initiate_session( + outgoing, + first_msg, + db_client, + amqp_channel, + notif_client, + ) + .await + { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { @@ -297,6 +312,7 @@ frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, + notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; @@ -308,6 +324,7 @@ device_info, amqp_channel, amqp_consumer, + notif_client, )), Err(e) => Err((e, outgoing)), } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -28,6 +28,9 @@ use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; +use crate::notifs::apns::headers::NotificationHeaders; +use crate::notifs::apns::APNsNotif; +use crate::notifs::NotifClient; pub struct DeviceInfo { pub device_id: String, @@ -45,6 +48,7 @@ amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, + notif_client: NotifClient, } #[derive( @@ -59,6 +63,8 @@ UnauthorizedDevice, PersistenceError(SdkError), DatabaseError(comm_lib::database::Error), + MissingAPNsClient, + MissingDeviceToken, } // Parse a session request and retrieve the device information @@ -201,6 +207,7 @@ device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, + notif_client: NotifClient, ) -> Self { Self { tx, @@ -208,6 +215,7 @@ device_info, amqp_channel, amqp_consumer, + notif_client, } } @@ -340,6 +348,61 @@ result, )) } + Messages::APNsNotif(notif) => { + // unauthenticated clients cannot send notifs + if !self.device_info.is_authenticated { + debug!( + "Unauthenticated device {} tried to send text notif. Aborting.", + self.device_info.device_id + ); + return Some(MessageSentStatus::Unauthenticated); + } + debug!("Received APNs notif for {}", notif.device_id); + + let Ok(headers) = + serde_json::from_str::(¬if.headers) + else { + return Some(MessageSentStatus::SerializationError(notif.headers)); + }; + + let device_token = + match self.db_client.get_device_token(¬if.device_id).await { + Ok(db_token) => { + let Some(token) = db_token else { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::MissingDeviceToken), + )); + }; + token + } + Err(e) => { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::DatabaseError(e)), + )); + } + }; + + let apns_notif = APNsNotif { + device_token, + headers, + payload: notif.payload, + }; + + if let Some(apns) = self.notif_client.apns.clone() { + let response = apns.send(apns_notif).await; + return Some( + self + .get_message_to_device_status(¬if.client_message_id, response), + ); + } + + Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::MissingAPNsClient), + )) + } _ => { error!("Client sent invalid message type"); Some(MessageSentStatus::InvalidRequest) @@ -388,11 +451,14 @@ } } - pub fn get_message_to_device_status( + pub fn get_message_to_device_status( &mut self, client_message_id: &str, - result: Result<(), SessionError>, - ) -> MessageSentStatus { + result: Result<(), E>, + ) -> MessageSentStatus + where + E: std::error::Error, + { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure {