diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs new file mode 100644 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -0,0 +1,95 @@ +use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::{ + DEFAULT_CLIENT_KEYS, MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, +}; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::tungstenite::Message; +use tunnelbroker_messages::{ + MessageSentStatus, MessageToDeviceRequest, SendConfirmation, +}; + +/// Tests of responses sent from Tunnelberoker to client +/// trying to send message to other device + +#[tokio::test] +async fn get_confirmation() { + let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; + + let client_message_id = "mockID".to_string(); + + // Send message to not connected client + let payload = "persisted message"; + let request = MessageToDeviceRequest { + client_message_id: client_message_id.clone(), + device_id: receiver.device_id.clone(), + payload: payload.to_string(), + }; + + let serialized_request = serde_json::to_string(&request) + .expect("Failed to serialize message to device"); + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Text(serialized_request)) + .await + .expect("Failed to send message"); + + if let Some(Ok(response)) = sender_socket.next().await { + let expected_response = SendConfirmation { + client_message_ids: vec![MessageSentStatus::Success(client_message_id)], + }; + let expected_payload = serde_json::to_string(&expected_response).unwrap(); + let received_payload = response.to_text().unwrap(); + assert_eq!(received_payload, expected_payload); + }; + + // Connect receiver to flush DDB and avoid polluting other tests + let mut receiver_socket = create_socket(&receiver).await; + if let Some(Ok(response)) = receiver_socket.next().await { + let received_payload = response.to_text().unwrap(); + assert_eq!(payload, received_payload); + }; +} + +#[tokio::test] +async fn get_serialization_error() { + let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; + let message = "some bad json".to_string(); + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Text(message.clone())) + .await + .expect("Failed to send message"); + + if let Some(Ok(response)) = sender_socket.next().await { + let expected_response = SendConfirmation { + client_message_ids: vec![MessageSentStatus::SerializationError(message)], + }; + let expected_payload = serde_json::to_string(&expected_response).unwrap(); + let received_payload = response.to_text().unwrap(); + assert_eq!(received_payload, expected_payload); + }; +} + +#[tokio::test] +async fn get_invalid_request_error() { + let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Binary(vec![])) + .await + .expect("Failed to send message"); + + if let Some(Ok(response)) = sender_socket.next().await { + let expected_response = SendConfirmation { + client_message_ids: vec![MessageSentStatus::InvalidRequest], + }; + let expected_payload = serde_json::to_string(&expected_response).unwrap(); + let received_payload = response.to_text().unwrap(); + assert_eq!(received_payload, expected_payload); + }; +} diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -16,6 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; +use tunnelbroker_messages::{MessageSentStatus, SendConfirmation}; type BoxedError = Box; @@ -194,10 +195,15 @@ session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { - session::consume_error(session.handle_websocket_frame_from_device(msg).await); + let confirmation = session.handle_websocket_frame_from_device(msg).await; + let response = serde_json::to_string(&confirmation).unwrap(); + session.send_message_to_device(Message::text(response)).await; } _ => { error!("Client sent invalid message type"); + let confirmation = SendConfirmation {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; + let response = serde_json::to_string(&confirmation).unwrap(); + session.send_message_to_device(Message::text(response)).await; } } }, diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -15,7 +15,10 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; -use tunnelbroker_messages::{session::DeviceTypes, Messages}; +use tunnelbroker_messages::{ + send_confirmation::Failure, send_confirmation::MessageSentStatus, + session::DeviceTypes, MessageToDeviceRequest, Messages, SendConfirmation, +}; use crate::database::{self, DatabaseClient, DeviceMessage}; use crate::error::Error; @@ -153,45 +156,68 @@ }) } - pub async fn handle_websocket_frame_from_device( + pub async fn handle_message_to_device( &self, - msg: String, + message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { - let serialized_message = serde_json::from_str::(&msg)?; + let message_id = self + .db_client + .persist_message( + message_request.device_id.as_str(), + message_request.payload.as_str(), + message_request.client_message_id.as_str(), + ) + .await?; + + let publish_result = self + .amqp_channel + .basic_publish( + "", + &message_request.device_id, + BasicPublishOptions::default(), + message_request.payload.as_bytes(), + BasicProperties::default(), + ) + .await; + + if let Err(publish_error) = publish_result { + self + .db_client + .delete_message(&self.device_info.device_id, &message_id) + .await + .expect("Error deleting message"); + return Err(SessionError::AmqpError(publish_error)); + } + Ok(()) + } + + pub async fn handle_websocket_frame_from_device( + &mut self, + msg: String, + ) -> SendConfirmation { + let serialized_message = match serde_json::from_str::(&msg) { + Ok(message) => message, + Err(_) => { + return SendConfirmation { + client_message_ids: vec![MessageSentStatus::SerializationError(msg)], + }; + } + }; match serialized_message { - Messages::MessageToDeviceRequest(message_to_device_request) => { - debug!( - "Received message for {}", - message_to_device_request.device_id - ); - self - .db_client - .persist_message( - message_to_device_request.device_id.as_str(), - message_to_device_request.payload.as_str(), - message_to_device_request.client_message_id.as_str(), - ) - .await?; + Messages::MessageToDeviceRequest(message_request) => { + debug!("Received message for {}", message_request.device_id); - self - .amqp_channel - .basic_publish( - "", - &message_to_device_request.device_id, - BasicPublishOptions::default(), - message_to_device_request.payload.as_bytes(), - BasicProperties::default(), - ) - .await?; + let result = self.handle_message_to_device(&message_request).await; + self.get_send_confirmation(&message_request.client_message_id, result) } _ => { error!("Client sent invalid message type"); - return Err(SessionError::InvalidMessage); + SendConfirmation { + client_message_ids: vec![MessageSentStatus::InvalidRequest], + } } } - - Ok(()) } pub async fn next_amqp_message( @@ -269,4 +295,21 @@ error!("Failed to delete queue: {}", e); } } + + pub fn get_send_confirmation( + &mut self, + client_message_id: &str, + result: Result<(), SessionError>, + ) -> SendConfirmation { + let status = match result { + Ok(_) => MessageSentStatus::Success(client_message_id.to_string()), + Err(err) => MessageSentStatus::Error(Failure { + id: client_message_id.to_string(), + error: err.to_string(), + }), + }; + SendConfirmation { + client_message_ids: vec![status], + } + } }