diff --git a/services/tunnelbroker/src/database/message.rs b/services/tunnelbroker/src/database/message.rs --- a/services/tunnelbroker/src/database/message.rs +++ b/services/tunnelbroker/src/database/message.rs @@ -1,27 +1,27 @@ use std::collections::HashMap; use aws_sdk_dynamodb::types::AttributeValue; +use tunnelbroker_messages::MessageToDevice; use crate::constants::dynamodb::undelivered_messages::{ DEVICE_ID, MESSAGE_ID, PAYLOAD, }; -#[derive(Debug)] -pub struct DeviceMessage { - pub device_id: String, - pub message_id: String, - pub payload: String, -} - #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum MessageErrors { SerializationError, } -impl DeviceMessage { - pub fn from_hashmap( +pub trait MessageToDeviceExt { + fn from_hashmap( + hashmap: HashMap, + ) -> Result; +} + +impl MessageToDeviceExt for MessageToDevice { + fn from_hashmap( hashmap: HashMap, - ) -> Result { + ) -> Result { let device_id: String = hashmap .get(DEVICE_ID) .ok_or(MessageErrors::SerializationError)? @@ -41,7 +41,7 @@ .map_err(|_| MessageErrors::SerializationError)? .to_string(); - Ok(DeviceMessage { + Ok(MessageToDevice { device_id, message_id, payload, diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -9,6 +9,7 @@ use proto::Empty; use tonic::transport::Server; use tracing::debug; +use tunnelbroker_messages::MessageToDevice; use crate::constants::CLIENT_RMQ_MSG_PRIORITY; use crate::database::{handle_ddb_error, DatabaseClient}; @@ -40,19 +41,28 @@ let client_message_id = uuid::Uuid::new_v4().to_string(); - self + let message_id = self .client .persist_message(&message.device_id, &message.payload, &client_message_id) .await .map_err(handle_ddb_error)?; + let message_to_device = MessageToDevice { + device_id: message.device_id.clone(), + payload: message.payload, + message_id, + }; + + let serialized_message = serde_json::to_string(&message_to_device) + .map_err(|_| tonic::Status::invalid_argument("Invalid argument"))?; + self .amqp_channel .basic_publish( "", &message.device_id, BasicPublishOptions::default(), - message.payload.as_bytes(), + serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -16,7 +16,9 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; -use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; +use tunnelbroker_messages::{ + MessageSentStatus, MessageToDevice, MessageToDeviceRequestStatus, +}; type BoxedError = Box; @@ -169,7 +171,8 @@ tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { - session.send_message_to_device(Message::Text(message.to_string())).await; + let message_to_device = serde_json::from_str::(message).unwrap(); + session.send_message_to_device(Message::Text(message_to_device.payload)).await; } else { error!("Invalid payload"); } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -22,10 +22,10 @@ use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - MessageToDeviceRequest, Messages, + MessageToDevice, MessageToDeviceRequest, Messages, }; -use crate::database::{self, DatabaseClient, DeviceMessage}; +use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; pub struct DeviceInfo { @@ -123,20 +123,22 @@ }); for message in messages { - let device_message = DeviceMessage::from_hashmap(message)?; + let message_to_device = MessageToDevice::from_hashmap(message)?; + + let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", - &device_message.device_id, + &message_to_device.device_id, BasicPublishOptions::default(), - device_message.payload.as_bytes(), + serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; if let Err(e) = db_client - .delete_message(&device_info.device_id, &device_message.message_id) + .delete_message(&device_info.device_id, &message_to_device.message_id) .await { error!("Failed to delete message: {}:", e); @@ -207,13 +209,21 @@ ) .await?; + let message_to_device = MessageToDevice { + device_id: message_request.device_id.clone(), + payload: message_request.payload.clone(), + message_id: message_id.clone(), + }; + + let serialized_message = serde_json::to_string(&message_to_device)?; + let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), - message_request.payload.as_bytes(), + serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await;