diff --git a/services/tunnelbroker/src/database/message.rs b/services/tunnelbroker/src/database/message.rs new file mode 100644 index 000000000..450c53a2d --- /dev/null +++ b/services/tunnelbroker/src/database/message.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; + +use aws_sdk_dynamodb::types::AttributeValue; + +use crate::constants::dynamodb::undelivered_messages::{ + CREATED_AT, DEVICE_ID, PAYLOAD, +}; + +#[derive(Debug)] +pub struct DeviceMessage { + pub device_id: String, + pub created_at: String, + pub payload: String, +} + +#[derive(Debug, derive_more::Display)] +pub enum MessageErrors { + SerializationError, +} + +impl DeviceMessage { + pub fn from_hashmap( + hashmap: HashMap, + ) -> Result { + let device_id: String = hashmap + .get(DEVICE_ID) + .ok_or(MessageErrors::SerializationError)? + .as_s() + .map_err(|_| MessageErrors::SerializationError)? + .to_string(); + let created_at: String = hashmap + .get(CREATED_AT) + .ok_or(MessageErrors::SerializationError)? + .as_n() + .map_err(|_| MessageErrors::SerializationError)? + .to_string(); + let payload: String = hashmap + .get(PAYLOAD) + .ok_or(MessageErrors::SerializationError)? + .as_s() + .map_err(|_| MessageErrors::SerializationError)? + .to_string(); + + Ok(DeviceMessage { + device_id, + created_at, + payload, + }) + } +} diff --git a/services/tunnelbroker/src/database.rs b/services/tunnelbroker/src/database/mod.rs similarity index 98% rename from services/tunnelbroker/src/database.rs rename to services/tunnelbroker/src/database/mod.rs index c870605d6..157b02809 100644 --- a/services/tunnelbroker/src/database.rs +++ b/services/tunnelbroker/src/database/mod.rs @@ -1,125 +1,128 @@ use aws_config::SdkConfig; use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::delete_item::{ DeleteItemError, DeleteItemOutput, }; use aws_sdk_dynamodb::operation::put_item::{PutItemError, PutItemOutput}; use aws_sdk_dynamodb::operation::query::QueryError; use aws_sdk_dynamodb::{types::AttributeValue, Client}; use std::collections::HashMap; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use tracing::{debug, error}; use crate::constants::dynamodb::undelivered_messages::{ PARTITION_KEY, PAYLOAD, SORT_KEY, TABLE_NAME, }; +pub mod message; +pub use message::*; + #[derive(Clone)] pub struct DatabaseClient { client: Arc, } pub fn unix_timestamp() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .expect("System time is misconfigured") .as_secs() } pub fn handle_ddb_error(db_error: SdkError) -> tonic::Status { match db_error { SdkError::TimeoutError(_) | SdkError::ServiceError(_) => { tonic::Status::unavailable("please retry") } e => { error!("Encountered an unexpected error: {}", e); tonic::Status::failed_precondition("unexpected error") } } } impl DatabaseClient { pub fn new(aws_config: &SdkConfig) -> Self { let client = Client::new(aws_config); DatabaseClient { client: Arc::new(client), } } pub async fn persist_message( &self, device_id: &str, payload: &str, ) -> Result> { let device_av = AttributeValue::S(device_id.to_string()); let payload_av = AttributeValue::S(payload.to_string()); let created_av = AttributeValue::N(unix_timestamp().to_string()); let request = self .client .put_item() .table_name(TABLE_NAME) .item(PARTITION_KEY, device_av) .item(SORT_KEY, created_av) .item(PAYLOAD, payload_av); debug!("Persisting message to device: {}", &device_id); request.send().await } pub async fn retrieve_messages( &self, device_id: &str, ) -> Result>, SdkError> { debug!("Retrieving messages for device: {}", device_id); let response = self .client .query() .table_name(TABLE_NAME) .key_condition_expression(format!("{} = :u", PARTITION_KEY)) .expression_attribute_values( ":u", AttributeValue::S(device_id.to_string()), ) .consistent_read(true) .send() .await?; debug!("Retrieved {} messages for {}", response.count, device_id); match response.items { None => Ok(Vec::new()), Some(items) => Ok(items.to_vec()), } } pub async fn delete_message( &self, device_id: &str, created_at: &str, ) -> Result> { debug!("Deleting message for device: {}", device_id); let key = HashMap::from([ ( PARTITION_KEY.to_string(), AttributeValue::S(device_id.to_string()), ), ( SORT_KEY.to_string(), AttributeValue::N(created_at.to_string()), ), ]); self .client .delete_item() .table_name(TABLE_NAME) .set_key(Some(key)) .send() .await } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index cc9ad5fa0..30f27d1e3 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,147 +1,147 @@ use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use tokio::{net::TcpStream, sync::mpsc::UnboundedSender}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use tracing::{debug, error}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::{ - constants::dynamodb::undelivered_messages::CREATED_AT, - database::DatabaseClient, ACTIVE_CONNECTIONS, + database::{self, DatabaseClient, DeviceMessage}, + ACTIVE_CONNECTIONS, }; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, device_info: Option, } #[derive(Debug, derive_more::Display, derive_more::From)] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), + MessageError(database::MessageErrors), } fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, ) -> WebsocketSession { WebsocketSession { tx, db_client, device_info: None, } } pub async fn handle_websocket_frame_from_device( &mut self, frame: Message, tx: UnboundedSender, ) { debug!("Received message from device: {}", frame); let result = match frame { Message::Text(payload) => { self.handle_message_from_device(&payload, tx).await } Message::Close(_) => { self.close().await; Ok(()) } _ => Err(SessionError::InvalidMessage), }; consume_error(result); } pub async fn handle_message_from_device( &mut self, message: &str, tx: UnboundedSender, ) -> Result<(), SessionError> { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::SessionRequest(mut session_info) => { // TODO: Authenticate device using auth token // Check if session request was already sent if self.device_info.is_some() { return Err(SessionError::InvalidMessage); } let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Check for persisted messages let messages = self .db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone()); for message in messages { - let payload = - message.get("payload").unwrap().as_s().unwrap().to_string(); - let created_at = - message.get(CREATED_AT).unwrap().as_n().unwrap().to_string(); - self.send_message_to_device(payload).await; - self + let device_message = DeviceMessage::from_hashmap(message)?; + self.send_message_to_device(device_message.payload).await; + if let Err(e) = self .db_client - .delete_message(&device_info.device_id, &created_at) + .delete_message(&device_info.device_id, &device_message.created_at) .await - .expect("Failed to delete messages"); + { + error!("Failed to delete message: {}:", e); + } } debug!("Flushed messages for device: {}", &session_info.device_id); self.device_info = Some(device_info); } _ => { debug!("Received invalid request"); } } Ok(()) } pub async fn send_message_to_device(&mut self, incoming_payload: String) { if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { error!("Failed to send message to device: {}", e); } } // Release websocket and remove from active connections pub async fn close(&mut self) { if let Some(device_info) = &self.device_info { ACTIVE_CONNECTIONS.remove(&device_info.device_id); } if let Err(e) = self.tx.close().await { debug!("Failed to close session: {}", e); } } }