diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -46,47 +46,3 @@ serde_json::from_str(&response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); } - -/// Test that a message to an offline device gets pushed to dynamodb -/// then recalled once a device connects -#[tokio::test] -async fn persist_messages() { - let device_info = create_device(None).await; - - // Send request for keyserver to refresh keys (identity service) - let mut tunnelbroker_client = - TunnelbrokerServiceClient::connect("http://localhost:50051") - .await - .unwrap(); - - let refresh_request = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - - let payload = serde_json::to_string(&refresh_request).unwrap(); - let request = MessageToDevice { - device_id: device_info.device_id.to_string(), - payload, - }; - let grpc_message = tonic::Request::new(request); - tunnelbroker_client - .send_message_to_device(grpc_message) - .await - .unwrap(); - - // Wait one second to ensure that message had time to persist - use std::{thread, time}; - let ten_millis = time::Duration::from_millis(50); - thread::sleep(ten_millis); - - let mut socket = create_socket(&device_info).await; - // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); - }; -} diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs copy from services/commtest/tests/tunnelbroker_integration_tests.rs copy to services/commtest/tests/tunnelbroker_persist_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -2,55 +2,23 @@ tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::{ + MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, +}; use commtest::tunnelbroker::socket::create_socket; -use futures_util::StreamExt; +use futures_util::{SinkExt, StreamExt}; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; -use tunnelbroker_messages::RefreshKeyRequest; +use tokio_tungstenite::tungstenite::Message; +use tunnelbroker_messages::{ + MessageToDevice as WebsocketMessageToDevice, RefreshKeyRequest, +}; -#[tokio::test] -async fn send_refresh_request() { - // Create session as a keyserver - let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; - - // Send request for keyserver to refresh keys (identity service) - let mut tunnelbroker_client = - TunnelbrokerServiceClient::connect("http://localhost:50051") - .await - .unwrap(); - - let refresh_request = RefreshKeyRequest { - device_id: device_info.device_id.clone(), - number_of_keys: 5, - }; - - let payload = serde_json::to_string(&refresh_request).unwrap(); - let request = MessageToDevice { - device_id: device_info.device_id.clone(), - payload, - }; - let grpc_message = tonic::Request::new(request); - - tunnelbroker_client - .send_message_to_device(grpc_message) - .await - .unwrap(); - - // Have keyserver receive any websocket messages - let response = socket.next().await.unwrap().unwrap(); - - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); -} - -/// Test that a message to an offline device gets pushed to dynamodb +/// Tests that a message to an offline device gets pushed to dynamodb /// then recalled once a device connects + #[tokio::test] -async fn persist_messages() { +async fn persist_grpc_messages() { let device_info = create_device(None).await; // Send request for keyserver to refresh keys (identity service) @@ -75,7 +43,7 @@ .await .unwrap(); - // Wait one second to ensure that message had time to persist + // Wait 50ms to ensure that message had time to persist use std::{thread, time}; let ten_millis = time::Duration::from_millis(50); thread::sleep(ten_millis); @@ -90,3 +58,39 @@ assert_eq!(serialized_response, refresh_request); }; } + +#[tokio::test] +async fn persist_websocket_messages() { + let sender = create_device(Some(MOCK_CLIENT_KEYS_1.clone())).await; + let receiver = create_device(Some(MOCK_CLIENT_KEYS_2.clone())).await; + + // Send message to not connected client + let payload = "persisted message"; + let request = WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: payload.to_string(), + }; + + let serialized_request = serde_json::to_string(&request) + .expect("Failed to serialize message to device"); + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Text(serialized_request)) + .await + .expect("Failed to send message"); + + // Wait 50ms to ensure that message had time to persist + use std::{thread, time}; + let ten_millis = time::Duration::from_millis(500); + thread::sleep(ten_millis); + + // Connect receiver + let mut receiver_socket = create_socket(&receiver).await; + + // Receive message + if let Some(Ok(response)) = receiver_socket.next().await { + let received_payload = response.to_text().unwrap(); + assert_eq!(payload, received_payload); + }; +} diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -49,7 +49,7 @@ "", &message.device_id, BasicPublishOptions::default(), - &message.payload.as_bytes(), + message.payload.as_bytes(), BasicProperties::default(), ) .await diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,11 +1,16 @@ +use aws_sdk_dynamodb::error::SdkError; +use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; -use lapin::options::{BasicConsumeOptions, QueueDeclareOptions}; +use lapin::options::{ + BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, +}; use lapin::types::FieldTable; +use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; @@ -27,6 +32,7 @@ tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, + amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } @@ -41,6 +47,7 @@ AmqpError(lapin::Error), InternalError, UnauthorizedDevice, + PersistenceError(SdkError), } pub fn consume_error(result: Result) { @@ -140,6 +147,7 @@ tx, db_client, device_info, + amqp_channel: amqp_channel.clone(), amqp_consumer, }) } @@ -148,7 +156,43 @@ &self, msg: Message, ) -> Result<(), SessionError> { - debug!("Received frame: {:?}", msg); + let text_msg = match msg { + Message::Text(payload) => payload, + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + }; + + let serialized_message = serde_json::from_str::(&text_msg)?; + + match serialized_message { + Messages::MessageToDevice(message_to_device) => { + debug!("Received message for {}", message_to_device.device_id); + self + .db_client + .persist_message( + message_to_device.device_id.as_str(), + message_to_device.payload.as_str(), + ) + .await?; + + self + .amqp_channel + .basic_publish( + "", + &message_to_device.device_id, + BasicPublishOptions::default(), + message_to_device.payload.as_bytes(), + BasicProperties::default(), + ) + .await?; + } + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + } Ok(()) }