diff --git a/services/commtest/Cargo.lock b/services/commtest/Cargo.lock --- a/services/commtest/Cargo.lock +++ b/services/commtest/Cargo.lock @@ -698,6 +698,7 @@ "tonic-build 0.8.4", "tunnelbroker_messages", "url", + "uuid", ] [[package]] @@ -2672,6 +2673,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/services/commtest/Cargo.toml b/services/commtest/Cargo.toml --- a/services/commtest/Cargo.toml +++ b/services/commtest/Cargo.toml @@ -30,6 +30,7 @@ reqwest = { version = "0.11", features = ["json", "multipart", "stream"] } serde = "1.0" comm-services-lib = { path = "../comm-services-lib" } +uuid = { version = "1.2", features = ["v4"] } [build-dependencies] tonic-build = "0.8" diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,9 +1,12 @@ use crate::identity::device::DeviceInfo; -use futures_util::SinkExt; +use futures_util::{SinkExt, StreamExt}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; -use tunnelbroker_messages::{ConnectionInitializationMessage, DeviceTypes}; +use tunnelbroker_messages::{ + ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, + MessageToDevice, MessageToDeviceRequest, SendConfirmation, +}; pub async fn create_socket( device_info: &DeviceInfo, @@ -32,3 +35,31 @@ socket } + +pub async fn send_message( + socket: &mut WebSocketStream>, + message: MessageToDevice, +) -> Result> { + let client_message_id = uuid::Uuid::new_v4().to_string(); + let request = MessageToDeviceRequest { + client_message_id: client_message_id.clone(), + device_id: message.device_id, + payload: message.payload, + }; + + let serialized_request = serde_json::to_string(&request)?; + + socket.send(Message::Text(serialized_request)).await?; + + if let Some(Ok(response)) = socket.next().await { + let confirmation: SendConfirmation = + serde_json::from_str(response.to_text().unwrap())?; + if confirmation + .client_message_ids + .contains(&MessageSentStatus::Success(client_message_id.clone())) + { + return Ok(client_message_id); + } + } + Err("Failed to confirm sent message".into()) +} diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -6,15 +6,16 @@ use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; -use commtest::tunnelbroker::socket::create_socket; -use futures_util::{SinkExt, StreamExt}; +use commtest::tunnelbroker::socket::{create_socket, send_message}; +use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; -use tokio_tungstenite::tungstenite::Message; -use tunnelbroker_messages::{MessageToDeviceRequest, RefreshKeyRequest}; +use tunnelbroker_messages::{ + MessageToDevice as WebSocketMessageToDevice, RefreshKeyRequest, +}; #[tokio::test] async fn send_refresh_request() { @@ -61,39 +62,24 @@ let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let messages = vec![ - MessageToDeviceRequest { - client_message_id: "5".to_string(), + WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "first message".to_string(), }, - MessageToDeviceRequest { - client_message_id: "2".to_string(), + WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second message".to_string(), }, - MessageToDeviceRequest { - client_message_id: "7".to_string(), + WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "third message".to_string(), }, ]; - let serialized_messages: Vec<_> = messages - .iter() - .map(|message| { - serde_json::to_string(message) - .expect("Failed to serialize message to device") - }) - .map(Message::text) - .collect(); - - let (mut sender_socket, _) = create_socket(&sender).await.split(); + let mut sender_socket = create_socket(&sender).await; - for msg in serialized_messages.clone() { - sender_socket - .send(msg) - .await - .expect("Failed to send the message over WebSocket"); + for msg in messages.clone() { + send_message(&mut sender_socket, msg).await.unwrap(); } // Wait a specified duration to ensure that message had time to persist diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -1,19 +1,19 @@ mod proto { tonic::include_proto!("tunnelbroker"); } - use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; -use commtest::tunnelbroker::socket::create_socket; -use futures_util::{SinkExt, StreamExt}; +use commtest::tunnelbroker::socket::{create_socket, send_message}; +use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; -use tokio_tungstenite::tungstenite::Message; -use tunnelbroker_messages::{MessageToDeviceRequest, RefreshKeyRequest}; +use tunnelbroker_messages::{ + MessageToDevice as WebSocketMessageToDevice, RefreshKeyRequest, +}; /// Tests that a message to an offline device gets pushed to dynamodb /// then recalled once a device connects @@ -64,21 +64,15 @@ let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client - let payload = "persisted message"; - let request = MessageToDeviceRequest { - client_message_id: "mockID".to_string(), + let mut sender_socket = create_socket(&sender).await; + + let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), - payload: payload.to_string(), + payload: "persisted message".to_string(), }; - - let serialized_request = serde_json::to_string(&request) - .expect("Failed to serialize message to device"); - - let mut sender_socket = create_socket(&sender).await; - sender_socket - .send(Message::Text(serialized_request)) + send_message(&mut sender_socket, request.clone()) .await - .expect("Failed to send message"); + .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; @@ -89,6 +83,6 @@ // Receive message if let Some(Ok(response)) = receiver_socket.next().await { let received_payload = response.to_text().unwrap(); - assert_eq!(payload, received_payload); + assert_eq!(request.payload, received_payload); }; } diff --git a/shared/tunnelbroker_messages/src/messages/message_to_device.rs b/shared/tunnelbroker_messages/src/messages/message_to_device.rs --- a/shared/tunnelbroker_messages/src/messages/message_to_device.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Debug)] +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct MessageToDevice { #[serde(rename = "deviceID")]