diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -46,47 +46,3 @@ serde_json::from_str(&response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); } - -/// Test that a message to an offline device gets pushed to dynamodb -/// then recalled once a device connects -#[tokio::test] -async fn persist_messages() { - let device_info = create_device(None).await; - - // Send request for keyserver to refresh keys (identity service) - let mut tunnelbroker_client = - TunnelbrokerServiceClient::connect("http://localhost:50051") - .await - .unwrap(); - - let refresh_request = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - - let payload = serde_json::to_string(&refresh_request).unwrap(); - let request = MessageToDevice { - device_id: device_info.device_id.to_string(), - payload, - }; - let grpc_message = tonic::Request::new(request); - tunnelbroker_client - .send_message_to_device(grpc_message) - .await - .unwrap(); - - // Wait one second to ensure that message had time to persist - use std::{thread, time}; - let ten_millis = time::Duration::from_millis(50); - thread::sleep(ten_millis); - - let mut socket = create_socket(&device_info).await; - // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); - }; -} diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs new file mode 100644 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -0,0 +1,95 @@ +mod proto { + tonic::include_proto!("tunnelbroker"); +} + +use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::{ + MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, +}; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::{SinkExt, StreamExt}; +use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; +use proto::MessageToDevice; +use std::time::Duration; +use tokio::time::sleep; +use tokio_tungstenite::tungstenite::Message; +use tunnelbroker_messages::{ + MessageToDevice as WebsocketMessageToDevice, RefreshKeyRequest, +}; + +/// Tests that a message to an offline device gets pushed to dynamodb +/// then recalled once a device connects + +#[tokio::test] +async fn persist_grpc_messages() { + let device_info = create_device(None).await; + + // Send request for keyserver to refresh keys (identity service) + let mut tunnelbroker_client = + TunnelbrokerServiceClient::connect("http://localhost:50051") + .await + .unwrap(); + + let refresh_request = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, + }; + + let payload = serde_json::to_string(&refresh_request).unwrap(); + let request = MessageToDevice { + device_id: device_info.device_id.to_string(), + payload, + }; + let grpc_message = tonic::Request::new(request); + tunnelbroker_client + .send_message_to_device(grpc_message) + .await + .unwrap(); + + // Wait a specified duration to ensure that message had time to persist + sleep(Duration::from_millis(100)).await; + + let mut socket = create_socket(&device_info).await; + // Have keyserver receive any websocket messages + if let Some(Ok(response)) = socket.next().await { + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + assert_eq!(serialized_response, refresh_request); + }; +} + +#[tokio::test] +async fn persist_websocket_messages() { + let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; + + // Send message to not connected client + let payload = "persisted message"; + let request = WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: payload.to_string(), + }; + + let serialized_request = serde_json::to_string(&request) + .expect("Failed to serialize message to device"); + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Text(serialized_request)) + .await + .expect("Failed to send message"); + + // Wait a specified duration to ensure that message had time to persist + sleep(Duration::from_millis(100)).await; + + // Connect receiver + let mut receiver_socket = create_socket(&receiver).await; + + // Receive message + if let Some(Ok(response)) = receiver_socket.next().await { + let received_payload = response.to_text().unwrap(); + assert_eq!(payload, received_payload); + }; +} diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -49,7 +49,7 @@ "", &message.device_id, BasicPublishOptions::default(), - &message.payload.as_bytes(), + message.payload.as_bytes(), BasicProperties::default(), ) .await diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,11 +1,16 @@ +use aws_sdk_dynamodb::error::SdkError; +use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; -use lapin::options::{BasicConsumeOptions, QueueDeclareOptions}; +use lapin::options::{ + BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, +}; use lapin::types::FieldTable; +use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; @@ -27,6 +32,7 @@ tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, + amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } @@ -41,6 +47,7 @@ AmqpError(lapin::Error), InternalError, UnauthorizedDevice, + PersistenceError(SdkError), } pub fn consume_error(result: Result) { @@ -140,6 +147,7 @@ tx, db_client, device_info, + amqp_channel: amqp_channel.clone(), amqp_consumer, }) } @@ -148,7 +156,43 @@ &self, msg: Message, ) -> Result<(), SessionError> { - debug!("Received frame: {:?}", msg); + let text_msg = match msg { + Message::Text(payload) => payload, + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + }; + + let serialized_message = serde_json::from_str::(&text_msg)?; + + match serialized_message { + Messages::MessageToDevice(message_to_device) => { + debug!("Received message for {}", message_to_device.device_id); + self + .db_client + .persist_message( + message_to_device.device_id.as_str(), + message_to_device.payload.as_str(), + ) + .await?; + + self + .amqp_channel + .basic_publish( + "", + &message_to_device.device_id, + BasicPublishOptions::default(), + message_to_device.payload.as_bytes(), + BasicProperties::default(), + ) + .await?; + } + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + } Ok(()) }