diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index fdc65b1b3..3258a419c 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,308 +1,309 @@ use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{ - send_confirmation::Failure, send_confirmation::MessageSentStatus, - session::DeviceTypes, MessageToDeviceRequest, Messages, + message_to_device_request_status::Failure, + message_to_device_request_status::MessageSentStatus, session::DeviceTypes, + MessageToDeviceRequest, Messages, }; use crate::database::{self, DatabaseClient, DeviceMessage}; use crate::error::Error; use crate::identity; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), } pub fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError.into()); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice.into()); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage.into()) } } } impl WebsocketSession { pub async fn from_frame( tx: SplitSink, Message>, db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result, Error> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage.into()); } }; // We don't currently have a use case to interact directly with the queue, // however, we need to declare a queue for a given device amqp_channel .queue_declare( &device_info.device_id, QueueDeclareOptions::default(), FieldTable::default(), ) .await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, "tunnelbroker", BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok(WebsocketSession { tx, db_client, device_info, amqp_channel: amqp_channel.clone(), amqp_consumer, }) } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), message_request.payload.as_bytes(), BasicProperties::default(), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> MessageSentStatus { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return MessageSentStatus::SerializationError(msg); }; match serialized_message { Messages::MessageToDeviceRequest(message_request) => { debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; self.get_message_to_device_status( &message_request.client_message_id, result, ) } _ => { error!("Client sent invalid message type"); MessageSentStatus::InvalidRequest } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn deliver_persisted_messages( &mut self, ) -> Result<(), SessionError> { // Check for persisted messages let messages = self .db_client .retrieve_messages(&self.device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let device_message = DeviceMessage::from_hashmap(message)?; self .send_message_to_device(Message::Text(device_message.payload)) .await; if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &device_message.message_id) .await { error!("Failed to delete message: {}:", e); } } debug!( "Flushed messages for device: {}", &self.device_info.device_id ); Ok(()) } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), SessionError>, ) -> MessageSentStatus { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } } diff --git a/shared/tunnelbroker_messages/src/lib.rs b/shared/tunnelbroker_messages/src/lib.rs index 14d04d641..5a81ba56b 100644 --- a/shared/tunnelbroker_messages/src/lib.rs +++ b/shared/tunnelbroker_messages/src/lib.rs @@ -1,4 +1,4 @@ pub mod messages; +pub use message_to_device_request_status::*; pub use messages::*; -pub use send_confirmation::*; diff --git a/shared/tunnelbroker_messages/src/messages/keys.rs b/shared/tunnelbroker_messages/src/messages/keys.rs index 911ada168..c2f5ab779 100644 --- a/shared/tunnelbroker_messages/src/messages/keys.rs +++ b/shared/tunnelbroker_messages/src/messages/keys.rs @@ -1,29 +1,29 @@ -// Messages sent between Tunnelbroker and a device +//! Messages sent from Tunnelbroker to a device. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct RefreshKeyRequest { #[serde(rename = "deviceID")] pub device_id: String, pub number_of_keys: u32, } #[cfg(test)] mod key_tests { use super::*; #[test] fn test_refresh_deserialization() { let example_payload = r#"{ "type": "RefreshKeyRequest", "deviceID": "adfjEDFS", "numberOfKeys": 6 }"#; let request = serde_json::from_str::(example_payload).unwrap(); assert_eq!(request.number_of_keys, 6); } } diff --git a/shared/tunnelbroker_messages/src/messages/message_to_device.rs b/shared/tunnelbroker_messages/src/messages/message_to_device.rs index 312baddd0..31f39ab51 100644 --- a/shared/tunnelbroker_messages/src/messages/message_to_device.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device.rs @@ -1,30 +1,30 @@ -// Messages sent between Tunnelbroker and a device via WebSocket +//! Messages sent between Tunnelbroker and a device via WebSocket. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct MessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } #[cfg(test)] mod message_to_device_tests { use super::*; #[test] fn test_message_to_device_deserialization() { let example_payload = r#"{ "type": "MessageToDevice", "deviceID": "alice", "payload": "message from Bob" }"#; let request = serde_json::from_str::(example_payload).unwrap(); assert_eq!(request.device_id, "alice"); assert_eq!(request.payload, "message from Bob"); } } diff --git a/shared/tunnelbroker_messages/src/messages/message_to_device_request.rs b/shared/tunnelbroker_messages/src/messages/message_to_device_request.rs index 1138824e6..cc92e6729 100644 --- a/shared/tunnelbroker_messages/src/messages/message_to_device_request.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device_request.rs @@ -1,34 +1,34 @@ -// Message sent from WebSocket clients to Tunnelbroker +//! Message sent from Tunnelbroker to WebSocket. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct MessageToDeviceRequest { #[serde(rename = "clientMessageID")] pub client_message_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } #[cfg(test)] mod message_to_device_request_tests { use super::*; #[test] fn test_message_to_device_request_deserialization() { let example_payload = r#"{ "type": "MessageToDeviceRequest", "clientMessageID": "client123", "deviceID": "alice", "payload": "message from Bob" }"#; let request = serde_json::from_str::(example_payload).unwrap(); assert_eq!(request.client_message_id, "client123"); assert_eq!(request.device_id, "alice"); assert_eq!(request.payload, "message from Bob"); } } diff --git a/shared/tunnelbroker_messages/src/messages/send_confirmation.rs b/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs similarity index 95% rename from shared/tunnelbroker_messages/src/messages/send_confirmation.rs rename to shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs index 9e33cd180..bd77e3f40 100644 --- a/shared/tunnelbroker_messages/src/messages/send_confirmation.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs @@ -1,86 +1,86 @@ -// Message sent from Tunnelbroker to WebSocket clients to inform that message -// was processed, saved in DDB and will be delivered. +//! Message sent from Tunnelbroker to WebSocket clients to inform that message +//! was processed, saved in DDB, and will be delivered. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct Failure { pub id: String, pub error: String, } #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", content = "data")] pub enum MessageSentStatus { /// The message with the provided ID (String) has been processed /// by the Tunnelbroker and is queued for delivery. Success(String), /// 'Failure' contains information about the message ID /// along with the specific error message. Error(Failure), /// The request was invalid (e.g., Bytes instead of Text). /// In this case, the ID cannot be retrieved. InvalidRequest, /// The JSON could not be serialized, which is why the entire message is /// returned back. /// It becomes impossible to retrieve the message ID in such circumstances. SerializationError(String), } #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct MessageToDeviceRequestStatus { #[serde(rename = "clientMessageIDs")] pub client_message_ids: Vec, } #[cfg(test)] mod send_confirmation_tests { use super::*; #[test] fn test_send_confirmation_deserialization() { let example_payload = r#"{ "type": "MessageToDeviceRequestStatus", "clientMessageIDs": [ {"type": "Success", "data": "id123"}, {"type": "Success", "data": "id456"}, {"type": "Error", "data": {"id": "id789", "error": "Something went wrong"}}, {"type": "SerializationError", "data": "message"}, {"type": "InvalidRequest"} ] }"#; let request = serde_json::from_str::(example_payload) .unwrap(); let expected_client_message_ids = vec![ MessageSentStatus::Success("id123".to_string()), MessageSentStatus::Success("id456".to_string()), MessageSentStatus::Error(Failure { id: String::from("id789"), error: String::from("Something went wrong"), }), MessageSentStatus::SerializationError("message".to_string()), MessageSentStatus::InvalidRequest, ]; assert_eq!(request.client_message_ids, expected_client_message_ids); } #[test] fn test_send_confirmation_deserialization_empty_vec() { let example_payload = r#"{ "type": "MessageToDeviceRequestStatus", "clientMessageIDs": [] }"#; let request = serde_json::from_str::(example_payload) .unwrap(); let expected_client_message_ids: Vec = Vec::new(); assert_eq!(request.client_message_ids, expected_client_message_ids); } } diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index d18ef0963..27c3b4328 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,27 +1,28 @@ -// Messages sent between Tunnelbroker and a device +//! Messages sent between Tunnelbroker and a device. + pub mod keys; pub mod message_to_device; pub mod message_to_device_request; -pub mod send_confirmation; +pub mod message_to_device_request_status; pub mod session; pub use keys::*; pub use message_to_device::*; pub use message_to_device_request::*; -pub use send_confirmation::*; +pub use message_to_device_request_status::*; pub use session::*; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum Messages { RefreshKeysRequest(RefreshKeyRequest), ConnectionInitializationMessage(ConnectionInitializationMessage), // MessageToDeviceRequest must be placed before MessageToDevice. // This is due to serde's pattern matching behavior where it prioritizes // the first matching pattern it encounters. MessageToDeviceRequest(MessageToDeviceRequest), MessageToDevice(MessageToDevice), MessageToDeviceRequestStatus(MessageToDeviceRequestStatus), } diff --git a/shared/tunnelbroker_messages/src/messages/session.rs b/shared/tunnelbroker_messages/src/messages/session.rs index 0f99d6d69..e9e5e9519 100644 --- a/shared/tunnelbroker_messages/src/messages/session.rs +++ b/shared/tunnelbroker_messages/src/messages/session.rs @@ -1,73 +1,73 @@ -// Messages sent between Tunnelbroker and a device +//! The first message sent from WebSocket client to Tunnelbroker. use serde::{Deserialize, Serialize}; -/// The workflow when estabilishing a Tunnelbroker connection: +/// The workflow when establishing a Tunnelbroker connection: /// - Client sends ConnectionInitializationMessage /// - Tunnelbroker validates access_token with identity service /// - Tunnelbroker emits an AMQP message declaring that it has opened a new -/// connection with a given device, so that the respective tunnelbroker +/// connection with a given device, so that the respective Tunnelbroker /// instance can close the existing connection. /// - Tunnelbroker returns a session_id representing that the connection was /// accepted /// - Tunnelbroker will flush all messages related to device from RabbitMQ. /// This must be done first before flushing DynamoDB to prevent duplicated /// messages. /// - Tunnelbroker flushes all messages in DynamoDB /// - Tunnelbroker orders messages by creation date (oldest first), and sends /// messages to device /// - Tunnelbroker then polls for incoming messages from device #[derive(Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "camelCase")] pub enum DeviceTypes { Mobile, Web, Keyserver, } /// Message sent by a client to Tunnelbroker to initiate a websocket /// session. Tunnelbroker will then validate the access token with identity /// service before continuing with the request. #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct ConnectionInitializationMessage { #[serde(rename = "deviceID")] pub device_id: String, pub access_token: String, #[serde(rename = "userID")] pub user_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } #[derive(Serialize, Deserialize)] pub struct ConnectionInitializationResponse { pub session_id: String, } #[cfg(test)] mod session_tests { use super::*; #[test] fn test_session_deserialization() { let example_payload = r#"{ "type": "sessionRequest", "accessToken": "xkdeifjsld", "deviceID": "foo", "userID": "alice", "deviceType": "keyserver" }"#; let request = serde_json::from_str::(example_payload) .unwrap(); assert_eq!(request.device_id, "foo"); assert_eq!(request.access_token, "xkdeifjsld"); assert_eq!(request.device_os, None); assert_eq!(request.device_type, DeviceTypes::Keyserver); } }