diff --git a/services/commtest/tests/tunnelbroker_integration_test.rs b/services/commtest/tests/tunnelbroker_integration_test.rs --- a/services/commtest/tests/tunnelbroker_integration_test.rs +++ b/services/commtest/tests/tunnelbroker_integration_test.rs @@ -110,11 +110,11 @@ .expect("Failed to send message"); // Have keyserver receive any websocket messages - let response = socket.next().await.unwrap().unwrap(); - - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); + if let Some(Ok(response)) = socket.next().await { + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + assert_eq!(serialized_response, refresh_request); + }; } diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock --- a/services/tunnelbroker/Cargo.lock +++ b/services/tunnelbroker/Cargo.lock @@ -664,6 +664,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation" version = "0.9.3" @@ -756,6 +762,19 @@ "parking_lot_core", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.6" @@ -2491,6 +2510,7 @@ "base64 0.20.0", "clap", "dashmap", + "derive_more", "env_logger", "fcm", "futures", diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml --- a/services/tunnelbroker/Cargo.toml +++ b/services/tunnelbroker/Cargo.toml @@ -34,6 +34,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } tunnelbroker_messages = { path = "../../shared/tunnelbroker_messages" } +derive_more = "0.99.17" [build-dependencies] tonic-build = "0.8" diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -2,18 +2,12 @@ use crate::database::DatabaseClient; use crate::CONFIG; -use futures_util::stream::SplitSink; -use futures_util::SinkExt; -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::StreamExt; use std::net::SocketAddr; use std::{env, io::Error}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::WebSocketStream; -use tracing::{debug, error, info}; - -use crate::ACTIVE_CONNECTIONS; +use tracing::{debug, info}; pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") @@ -48,49 +42,34 @@ } }; - let (mut outgoing, incoming) = ws_stream.split(); + let (outgoing, mut incoming) = ws_stream.split(); // Create channel for messages to be passed to this connection let (tx, mut rx) = mpsc::unbounded_channel::(); - let session = session::WebsocketSession::new(tx.clone(), db_client.clone()); - let handle_incoming = incoming.try_for_each(|msg| async { - debug!("Received message from {}", addr); - match msg { - Message::Text(text) => { - match session.handle_message_from_device(&text).await { - Ok(_) => { - debug!("Successfully handled message: {}", text) - } - Err(e) => { - error!("Failed to process message: {}", e); - } - }; - } - _ => { - error!("Invalid message was received"); - } - } - Ok(()) - }); + let mut session = session::WebsocketSession::new(outgoing, db_client.clone()); - debug!("Polling for messages from: {}", addr); // Poll for messages either being sent to the device (rx) - // or messages being received from the device (handle_incoming) - tokio::select! { - Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; }, - Ok(_) = handle_incoming => { debug!("Received message from websocket") }, - else => { - info!("Connection with {} closed.", addr); - ACTIVE_CONNECTIONS.remove("test"); + // or messages being received from the device (incoming) + loop { + debug!("Polling for messages from: {}", addr); + tokio::select! { + Some(message) = rx.recv() => { session.send_message_to_device(message).await; }, + device_message = incoming.next() => { + match device_message { + Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await, + _ => { + debug!("Connection to {} closed remotely.", addr); + break; + } + } + }, + else => { + debug!("Unhealthy connection for: {}", addr); + break; + }, } } -} -async fn handle_message_from_service( - incoming_payload: String, - outgoing: &mut SplitSink, Message>, -) { - if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await { - error!("Failed to send message to device: {}", e); - } + info!("Unregistering connection to: {}", addr); + session.close().await } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,58 +1,124 @@ -use tracing::debug; -use tunnelbroker_messages::Messages; +use derive_more; +use futures_util::stream::SplitSink; +use futures_util::SinkExt; +use tokio::{net::TcpStream, sync::mpsc::UnboundedSender}; +use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; +use tracing::{debug, error}; +use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::{ constants::dynamodb::undelivered_messages::CREATED_AT, database::DatabaseClient, ACTIVE_CONNECTIONS, }; +pub struct DeviceInfo { + pub device_id: String, + pub notify_token: Option, + pub device_type: DeviceTypes, + pub device_app_version: Option, + pub device_os: Option, +} + pub struct WebsocketSession { - tx: tokio::sync::mpsc::UnboundedSender, + tx: SplitSink, Message>, db_client: DatabaseClient, + device_info: Option, +} + +#[derive(Debug, derive_more::Display, derive_more::From)] +pub enum SessionError { + InvalidMessage, + SerializationError(serde_json::Error), } +fn consume_error(result: Result) { + if let Err(e) = result { + error!("{}", e) + } +} impl WebsocketSession { pub fn new( - tx: tokio::sync::mpsc::UnboundedSender, + tx: SplitSink, Message>, db_client: DatabaseClient, ) -> WebsocketSession { - WebsocketSession { tx, db_client } + WebsocketSession { + tx, + db_client, + device_info: None, + } + } + + pub async fn handle_websocket_frame_from_device( + &mut self, + frame: Message, + tx: UnboundedSender, + ) { + debug!("Received message from device: {}", frame); + let result = match frame { + Message::Text(payload) => { + self.handle_message_from_device(&payload, tx).await + } + Message::Close(_) => { + self.close().await; + Ok(()) + } + _ => Err(SessionError::InvalidMessage), + }; + consume_error(result); } pub async fn handle_message_from_device( - &self, + &mut self, message: &str, - ) -> Result<(), serde_json::Error> { - match serde_json::from_str::(message)? { - Messages::SessionRequest(session_info) => { + tx: UnboundedSender, + ) -> Result<(), SessionError> { + let serialized_message = serde_json::from_str::(message)?; + + match serialized_message { + Messages::SessionRequest(mut session_info) => { // TODO: Authenticate device using auth token + + // Check if session request was already sent + if self.device_info.is_some() { + return Err(SessionError::InvalidMessage); + } + + let device_info = DeviceInfo { + device_id: session_info.device_id.clone(), + notify_token: session_info.notify_token.take(), + device_type: session_info.device_type, + device_app_version: session_info.device_app_version.take(), + device_os: session_info.device_os.take(), + }; + // Check for persisted messages let messages = self .db_client - .retrieve_messages(&session_info.device_id) + .retrieve_messages(&device_info.device_id) .await - .expect("Failed to retreive messages"); + .unwrap_or_else(|e| { + error!("Error while retrieving messages: {}", e); + Vec::new() + }); - ACTIVE_CONNECTIONS - .insert(session_info.device_id.clone(), self.tx.clone()); + ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone()); for message in messages { let payload = message.get("payload").unwrap().as_s().unwrap().to_string(); - self - .tx - .send(payload) - .expect("Failed to send message to client"); let created_at = message.get(CREATED_AT).unwrap().as_n().unwrap().to_string(); + self.send_message_to_device(payload).await; self .db_client - .delete_message(&session_info.device_id, &created_at) + .delete_message(&device_info.device_id, &created_at) .await .expect("Failed to delete messages"); } debug!("Flushed messages for device: {}", &session_info.device_id); + + self.device_info = Some(device_info); } _ => { debug!("Received invalid request"); @@ -61,4 +127,21 @@ Ok(()) } + + pub async fn send_message_to_device(&mut self, incoming_payload: String) { + if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { + error!("Failed to send message to device: {}", e); + } + } + + // Release websocket and remove from active connections + pub async fn close(&mut self) { + if let Some(device_info) = &self.device_info { + ACTIVE_CONNECTIONS.remove(&device_info.device_id); + } + + if let Err(e) = self.tx.close().await { + debug!("Failed to close session: {}", e); + } + } }