diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs index ff70b5b73..de4e3620d 100644 --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -1,48 +1,111 @@ mod proto { tonic::include_proto!("tunnelbroker"); } + use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::{ + MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, +}; use commtest::tunnelbroker::socket::create_socket; -use futures_util::StreamExt; +use futures_util::{SinkExt, StreamExt}; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; -use tunnelbroker_messages::RefreshKeyRequest; +use std::time::Duration; +use tokio::time::sleep; +use tokio_tungstenite::tungstenite::Message; + +use tunnelbroker_messages::{ + MessageToDevice as WebsocketMessageToDevice, RefreshKeyRequest, +}; #[tokio::test] async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect("http://localhost:50051") .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.clone(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.clone(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Have keyserver receive any websocket messages let response = socket.next().await.unwrap().unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(&response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); } + +#[tokio::test] +async fn test_messages_order() { + let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; + + let messages = vec![ + WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: "first message".to_string(), + }, + WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: "second message".to_string(), + }, + WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: "third message".to_string(), + }, + ]; + + let serialized_messages: Vec<_> = messages + .iter() + .map(|message| { + serde_json::to_string(message) + .expect("Failed to serialize message to device") + }) + .map(Message::text) + .collect(); + + let (mut sender_socket, _) = create_socket(&sender).await.split(); + + for msg in serialized_messages.clone() { + sender_socket + .send(msg) + .await + .expect("Failed to send the message over WebSocket"); + } + + // Wait a specified duration to ensure that message had time to persist + sleep(Duration::from_millis(100)).await; + + let mut receiver_socket = create_socket(&receiver).await; + + for msg in messages { + if let Some(Ok(response)) = receiver_socket.next().await { + let received_payload = response.to_text().unwrap(); + assert_eq!(msg.payload, received_payload); + } else { + panic!("Unable to receive message"); + } + } +} diff --git a/services/terraform/modules/shared/dynamodb.tf b/services/terraform/modules/shared/dynamodb.tf index bb4da6158..0dec9539e 100644 --- a/services/terraform/modules/shared/dynamodb.tf +++ b/services/terraform/modules/shared/dynamodb.tf @@ -1,266 +1,266 @@ resource "aws_dynamodb_table" "backup-service-backup" { name = "backup-service-backup" hash_key = "userID" range_key = "backupID" billing_mode = "PAY_PER_REQUEST" attribute { name = "userID" type = "S" } attribute { name = "backupID" type = "S" } attribute { name = "created" type = "S" } global_secondary_index { name = "userID-created-index" hash_key = "userID" range_key = "created" projection_type = "INCLUDE" non_key_attributes = ["userKeys"] } } resource "aws_dynamodb_table" "backup-service-log" { name = "backup-service-log" hash_key = "backupID" range_key = "logID" billing_mode = "PAY_PER_REQUEST" attribute { name = "backupID" type = "S" } attribute { name = "logID" type = "S" } } resource "aws_dynamodb_table" "blob-service-blobs" { name = "blob-service-blobs" hash_key = "blob_hash" range_key = "holder" billing_mode = "PAY_PER_REQUEST" attribute { name = "blob_hash" type = "S" } attribute { name = "holder" type = "S" } attribute { name = "last_modified" type = "N" } attribute { name = "unchecked" type = "S" } global_secondary_index { name = "unchecked-index" hash_key = "unchecked" range_key = "last_modified" projection_type = "KEYS_ONLY" } } resource "aws_dynamodb_table" "tunnelbroker-undelivered-messages" { name = "tunnelbroker-undelivered-messages" hash_key = "deviceID" - range_key = "createdAt" + range_key = "messageID" billing_mode = "PAY_PER_REQUEST" attribute { name = "deviceID" type = "S" } attribute { - name = "createdAt" - type = "N" + name = "messageID" + type = "S" } } resource "aws_dynamodb_table" "identity-users" { name = "identity-users" hash_key = "userID" billing_mode = "PAY_PER_REQUEST" attribute { name = "userID" type = "S" } attribute { name = "username" type = "S" } # walletAddress not defined in prod dynamic "attribute" { # Create a dummy list to iterate over if is_dev is true for_each = var.is_dev ? [1] : [] content { name = "walletAddress" type = "S" } } global_secondary_index { name = "username-index" hash_key = "username" projection_type = "KEYS_ONLY" } # walletAddress not defined in prod dynamic "global_secondary_index" { # Create a dummy list to iterate over if is_dev is true for_each = var.is_dev ? [1] : [] content { name = "walletAddress-index" hash_key = "walletAddress" projection_type = "KEYS_ONLY" } } } # Identity users with opaque_ke 2.0 credentials resource "aws_dynamodb_table" "identity-users-opaque2" { # This table doesnt exist in prod count = var.is_dev ? 1 : 0 name = "identity-users-opaque2" hash_key = "userID" billing_mode = "PAY_PER_REQUEST" attribute { name = "userID" type = "S" } attribute { name = "username" type = "S" } attribute { name = "walletAddress" type = "S" } global_secondary_index { name = "username-index" hash_key = "username" projection_type = "KEYS_ONLY" } global_secondary_index { name = "walletAddress-index" hash_key = "walletAddress" projection_type = "KEYS_ONLY" } } resource "aws_dynamodb_table" "identity-tokens" { name = "identity-tokens" hash_key = "userID" range_key = "signingPublicKey" billing_mode = "PAY_PER_REQUEST" attribute { name = "userID" type = "S" } attribute { name = "signingPublicKey" type = "S" } } resource "aws_dynamodb_table" "identity-nonces" { name = "identity-nonces" hash_key = "nonce" billing_mode = "PAY_PER_REQUEST" attribute { name = "nonce" type = "S" } ttl { attribute_name = "expirationTimeUnix" enabled = true } } resource "aws_dynamodb_table" "identity-reserved-usernames" { name = "identity-reserved-usernames" hash_key = "username" billing_mode = "PAY_PER_REQUEST" attribute { name = "username" type = "S" } } resource "aws_dynamodb_table" "identity-one-time-keys" { name = "identity-one-time-keys" hash_key = "deviceID" range_key = "oneTimeKey" billing_mode = "PAY_PER_REQUEST" attribute { name = "deviceID" type = "S" } attribute { name = "oneTimeKey" type = "S" } } resource "aws_dynamodb_table" "feature-flags" { name = "feature-flags" hash_key = "platform" range_key = "feature" billing_mode = "PAY_PER_REQUEST" attribute { name = "platform" type = "S" } attribute { name = "feature" type = "S" } } resource "aws_dynamodb_table" "reports-service-reports" { name = "reports-service-reports" hash_key = "reportID" billing_mode = "PAY_PER_REQUEST" attribute { name = "reportID" type = "S" } } diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index eac78e0da..1414b1d10 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,28 +1,31 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. - // - createdAt: UNIX timestamp of when the item was inserted. - // Timestamp is needed to order the messages correctly to the device. + // - messageID = [createdAt]#[clientMessageID] + // - createdAd: UNIX timestamp of when the item was inserted. + // Timestamp is needed to order the messages correctly to the device. + // Timestamp format is ISO 8601 to handle lexicographical sorting. + // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; - pub const CREATED_AT: &str = "createdAt"; - pub const SORT_KEY: &str = "createdAt"; + pub const MESSAGE_ID: &str = "messageID"; + pub const SORT_KEY: &str = "messageID"; } } diff --git a/services/tunnelbroker/src/database/message.rs b/services/tunnelbroker/src/database/message.rs index 73a395f7a..34a4daadc 100644 --- a/services/tunnelbroker/src/database/message.rs +++ b/services/tunnelbroker/src/database/message.rs @@ -1,50 +1,50 @@ use std::collections::HashMap; use aws_sdk_dynamodb::types::AttributeValue; use crate::constants::dynamodb::undelivered_messages::{ - CREATED_AT, DEVICE_ID, PAYLOAD, + DEVICE_ID, MESSAGE_ID, PAYLOAD, }; #[derive(Debug)] pub struct DeviceMessage { pub device_id: String, - pub created_at: String, + pub message_id: String, pub payload: String, } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum MessageErrors { SerializationError, } impl DeviceMessage { pub fn from_hashmap( hashmap: HashMap, ) -> Result { let device_id: String = hashmap .get(DEVICE_ID) .ok_or(MessageErrors::SerializationError)? .as_s() .map_err(|_| MessageErrors::SerializationError)? .to_string(); - let created_at: String = hashmap - .get(CREATED_AT) + let message_id: String = hashmap + .get(MESSAGE_ID) .ok_or(MessageErrors::SerializationError)? - .as_n() + .as_s() .map_err(|_| MessageErrors::SerializationError)? .to_string(); let payload: String = hashmap .get(PAYLOAD) .ok_or(MessageErrors::SerializationError)? .as_s() .map_err(|_| MessageErrors::SerializationError)? .to_string(); Ok(DeviceMessage { device_id, - created_at, + message_id, payload, }) } } diff --git a/services/tunnelbroker/src/database/message_id.rs b/services/tunnelbroker/src/database/message_id.rs index 86ec09f63..e8e34fe2c 100644 --- a/services/tunnelbroker/src/database/message_id.rs +++ b/services/tunnelbroker/src/database/message_id.rs @@ -1,130 +1,130 @@ use chrono::{DateTime, Utc}; #[derive(Debug, derive_more::Display, derive_more::Error)] -enum ParseMessageIdError { +pub enum ParseMessageIdError { InvalidTimestamp(chrono::ParseError), InvalidFormat, } #[derive(Debug)] -struct MessageID { +pub struct MessageID { timestamp: DateTime, client_message_id: String, } impl MessageID { pub fn new(client_message_id: String) -> Self { Self { timestamp: Utc::now(), client_message_id, } } } impl TryFrom for MessageID { type Error = ParseMessageIdError; fn try_from(value: String) -> Result { let parts: Vec<&str> = value.splitn(2, '#').collect(); if parts.len() != 2 { return Err(ParseMessageIdError::InvalidFormat); } let timestamp = DateTime::parse_from_rfc3339(parts[0]) .map_err(ParseMessageIdError::InvalidTimestamp)? .with_timezone(&Utc); let client_message_id = parts[1].to_string(); Ok(Self { timestamp, client_message_id, }) } } impl From for String { fn from(value: MessageID) -> Self { format!( "{}#{}", value.timestamp.to_rfc3339(), value.client_message_id ) } } #[cfg(test)] mod message_id_tests { use super::*; use std::convert::TryInto; #[test] fn test_into_string() { let message_id = MessageID::new("abc123".to_string()); let message_id_string: String = message_id.into(); assert!( message_id_string.contains("abc123"), "Expected 'abc123' in the resulting string, but not found" ); let parts: Vec<&str> = message_id_string.splitn(2, '#').collect(); assert_eq!( parts.len(), 2, "Expected the string to contain 2 parts separated by '#'" ); } #[test] fn test_try_from_string_valid() { let client_message_id = "abc123".to_string(); let timestamp = Utc::now().to_rfc3339(); let valid_string = format!("{}#{}", timestamp, client_message_id); let message_id_result: Result = valid_string.try_into(); assert!( message_id_result.is_ok(), "Expected Ok, but found {:?}", message_id_result ); let message_id = message_id_result.unwrap(); assert_eq!(message_id.client_message_id, client_message_id); } #[test] fn test_try_from_string_invalid_format() { let message_id = MessageID::new("abc123".to_string()); let message_id_str: String = message_id.into(); let converted_message_id: Result = message_id_str.try_into(); assert!( converted_message_id.is_ok(), "Expected Ok, but found {:?}", converted_message_id ); let message_id_after_conversion = converted_message_id.unwrap(); assert_eq!( message_id_after_conversion.client_message_id, "abc123".to_string() ); } #[test] fn test_conversion() { let client_message_id = "abc123".to_string(); let timestamp = Utc::now().to_rfc3339(); let valid_string = format!("{}#{}", timestamp, client_message_id); let message_id_result: Result = valid_string.try_into(); assert!( message_id_result.is_ok(), "Expected Ok, but found {:?}", message_id_result ); let message_id = message_id_result.unwrap(); assert_eq!(message_id.client_message_id, client_message_id); } } diff --git a/services/tunnelbroker/src/database/mod.rs b/services/tunnelbroker/src/database/mod.rs index 547f5ae59..ae7c98526 100644 --- a/services/tunnelbroker/src/database/mod.rs +++ b/services/tunnelbroker/src/database/mod.rs @@ -1,130 +1,128 @@ use aws_config::SdkConfig; use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::delete_item::{ DeleteItemError, DeleteItemOutput, }; -use aws_sdk_dynamodb::operation::put_item::{PutItemError, PutItemOutput}; +use aws_sdk_dynamodb::operation::put_item::PutItemError; use aws_sdk_dynamodb::operation::query::QueryError; use aws_sdk_dynamodb::{types::AttributeValue, Client}; use std::collections::HashMap; use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; use tracing::{debug, error}; use crate::constants::dynamodb::undelivered_messages::{ PARTITION_KEY, PAYLOAD, SORT_KEY, TABLE_NAME, }; pub mod message; pub mod message_id; +use crate::database::message_id::MessageID; pub use message::*; #[derive(Clone)] pub struct DatabaseClient { client: Arc, } -pub fn unix_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("System time is misconfigured") - .as_secs() -} - pub fn handle_ddb_error(db_error: SdkError) -> tonic::Status { match db_error { SdkError::TimeoutError(_) | SdkError::ServiceError(_) => { tonic::Status::unavailable("please retry") } e => { error!("Encountered an unexpected error: {}", e); tonic::Status::failed_precondition("unexpected error") } } } impl DatabaseClient { pub fn new(aws_config: &SdkConfig) -> Self { let client = Client::new(aws_config); DatabaseClient { client: Arc::new(client), } } pub async fn persist_message( &self, device_id: &str, payload: &str, - ) -> Result> { + client_message_id: &str, + ) -> Result> { + let message_id: String = + MessageID::new(client_message_id.to_string()).into(); + let device_av = AttributeValue::S(device_id.to_string()); let payload_av = AttributeValue::S(payload.to_string()); - let created_av = AttributeValue::N(unix_timestamp().to_string()); + let message_id_av = AttributeValue::S(message_id.clone()); let request = self .client .put_item() .table_name(TABLE_NAME) .item(PARTITION_KEY, device_av) - .item(SORT_KEY, created_av) + .item(SORT_KEY, message_id_av) .item(PAYLOAD, payload_av); debug!("Persisting message to device: {}", &device_id); - request.send().await + request.send().await?; + Ok(message_id) } pub async fn retrieve_messages( &self, device_id: &str, ) -> Result>, SdkError> { debug!("Retrieving messages for device: {}", device_id); let response = self .client .query() .table_name(TABLE_NAME) .key_condition_expression(format!("{} = :u", PARTITION_KEY)) .expression_attribute_values( ":u", AttributeValue::S(device_id.to_string()), ) .consistent_read(true) .send() .await?; debug!("Retrieved {} messages for {}", response.count, device_id); match response.items { None => Ok(Vec::new()), Some(items) => Ok(items.to_vec()), } } pub async fn delete_message( &self, device_id: &str, - created_at: &str, + message_id: &str, ) -> Result> { debug!("Deleting message for device: {}", device_id); let key = HashMap::from([ ( PARTITION_KEY.to_string(), AttributeValue::S(device_id.to_string()), ), ( SORT_KEY.to_string(), - AttributeValue::N(created_at.to_string()), + AttributeValue::S(message_id.to_string()), ), ]); self .client .delete_item() .table_name(TABLE_NAME) .set_key(Some(key)) .send() .await } } diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs index 4df46a6ca..8998eb697 100644 --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -1,86 +1,86 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use lapin::{options::BasicPublishOptions, BasicProperties}; use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; use proto::Empty; use tonic::transport::Server; use tracing::debug; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, amqp_channel: lapin::Channel, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { match error { lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => { tonic::Status::invalid_argument("Invalid argument") } _ => tonic::Status::internal("Internal Error"), } } #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Received message for {}", &message.device_id); self .client - .persist_message(&message.device_id, &message.payload) + .persist_message(&message.device_id, &message.payload, "message_id") .await .map_err(handle_ddb_error)?; self .amqp_channel .basic_publish( "", &message.device_id, BasicPublishOptions::default(), message.payload.as_bytes(), BasicProperties::default(), ) .await .map_err(handle_amqp_error)?; let response = tonic::Response::new(Empty {}); Ok(response) } } pub async fn run_server( client: DatabaseClient, ampq_connection: &lapin::Connection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); let amqp_channel = ampq_connection .create_channel() .await .expect("Unable to create amqp channel"); tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, amqp_channel, })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 65bf6f31e..0b4bc2e75 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,274 +1,275 @@ use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::database::{self, DatabaseClient, DeviceMessage}; use crate::error::Error; use crate::identity; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), } pub fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError.into()); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice.into()); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage.into()) } } } impl WebsocketSession { pub async fn from_frame( tx: SplitSink, Message>, db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result, Error> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage.into()); } }; // We don't currently have a use case to interact directly with the queue, // however, we need to declare a queue for a given device amqp_channel .queue_declare( &device_info.device_id, QueueDeclareOptions::default(), FieldTable::default(), ) .await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, "tunnelbroker", BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok(WebsocketSession { tx, db_client, device_info, amqp_channel: amqp_channel.clone(), amqp_consumer, }) } pub async fn handle_websocket_frame_from_device( &self, msg: Message, ) -> Result<(), SessionError> { let text_msg = match msg { Message::Text(payload) => payload, _ => { error!("Client sent invalid message type"); return Err(SessionError::InvalidMessage); } }; let serialized_message = serde_json::from_str::(&text_msg)?; match serialized_message { Messages::MessageToDevice(message_to_device) => { debug!("Received message for {}", message_to_device.device_id); self .db_client .persist_message( message_to_device.device_id.as_str(), message_to_device.payload.as_str(), + "message_id", ) .await?; self .amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), message_to_device.payload.as_bytes(), BasicProperties::default(), ) .await?; } _ => { error!("Client sent invalid message type"); return Err(SessionError::InvalidMessage); } } Ok(()) } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn deliver_persisted_messages( &mut self, ) -> Result<(), SessionError> { // Check for persisted messages let messages = self .db_client .retrieve_messages(&self.device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let device_message = DeviceMessage::from_hashmap(message)?; self.send_message_to_device(device_message.payload).await; if let Err(e) = self .db_client - .delete_message(&self.device_info.device_id, &device_message.created_at) + .delete_message(&self.device_info.device_id, &device_message.message_id) .await { error!("Failed to delete message: {}:", e); } } debug!( "Flushed messages for device: {}", &self.device_info.device_id ); Ok(()) } pub async fn send_message_to_device(&mut self, incoming_payload: String) { if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } }