diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index 8614aa964..7bfffa493 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,66 +1,75 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, - MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, + MessageToDeviceRequest, MessageToDeviceRequestStatus, }; +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +#[serde(tag = "type", rename_all = "camelCase")] +pub struct WebSocketMessageToDevice { + #[serde(rename = "deviceID")] + pub device_id: String, + pub payload: String, +} + pub async fn create_socket( device_info: &DeviceInfo, ) -> WebSocketStream> { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); socket } pub async fn send_message( socket: &mut WebSocketStream>, - message: MessageToDevice, + message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { let confirmation: MessageToDeviceRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs index e24942177..4ef700fc4 100644 --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -1,99 +1,99 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; -use commtest::tunnelbroker::socket::{create_socket, send_message}; +use commtest::tunnelbroker::socket::{ + create_socket, send_message, WebSocketMessageToDevice, +}; use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; -use tunnelbroker_messages::{ - MessageToDevice as WebSocketMessageToDevice, RefreshKeyRequest, -}; +use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.clone(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.clone(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Have keyserver receive any websocket messages let response = socket.next().await.unwrap().unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); } #[tokio::test] async fn test_messages_order() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let messages = vec![ WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "first message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "third message".to_string(), }, ]; let mut sender_socket = create_socket(&sender).await; for msg in messages.clone() { send_message(&mut sender_socket, msg).await.unwrap(); } // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; let mut receiver_socket = create_socket(&receiver).await; for msg in messages { if let Some(Ok(response)) = receiver_socket.next().await { let received_payload = response.to_text().unwrap(); assert_eq!(msg.payload, received_payload); } else { panic!("Unable to receive message"); } } } diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs index b3ace3760..d13af10b1 100644 --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -1,89 +1,89 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; -use commtest::tunnelbroker::socket::{create_socket, send_message}; +use commtest::tunnelbroker::socket::{ + create_socket, send_message, WebSocketMessageToDevice, +}; use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; -use tunnelbroker_messages::{ - MessageToDevice as WebSocketMessageToDevice, RefreshKeyRequest, -}; +use tunnelbroker_messages::RefreshKeyRequest; /// Tests that a message to an offline device gets pushed to dynamodb /// then recalled once a device connects #[tokio::test] async fn persist_grpc_messages() { let device_info = create_device(None).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.to_string(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.to_string(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; let mut socket = create_socket(&device_info).await; // Have keyserver receive any websocket messages if let Some(Ok(response)) = socket.next().await { // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); }; } #[tokio::test] async fn persist_websocket_messages() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client let mut sender_socket = create_socket(&sender).await; let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "persisted message".to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; // Connect receiver let mut receiver_socket = create_socket(&receiver).await; // Receive message if let Some(Ok(response)) = receiver_socket.next().await { let received_payload = response.to_text().unwrap(); assert_eq!(request.payload, received_payload); }; } diff --git a/shared/tunnelbroker_messages/src/messages/message_to_device.rs b/shared/tunnelbroker_messages/src/messages/message_to_device.rs index 31f39ab51..cefe2483e 100644 --- a/shared/tunnelbroker_messages/src/messages/message_to_device.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device.rs @@ -1,30 +1,34 @@ //! Messages sent between Tunnelbroker and a device via WebSocket. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct MessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, + #[serde(rename = "messageID")] + pub message_id: String, } #[cfg(test)] mod message_to_device_tests { use super::*; #[test] fn test_message_to_device_deserialization() { let example_payload = r#"{ "type": "MessageToDevice", "deviceID": "alice", - "payload": "message from Bob" + "payload": "message from Bob", + "messageID": "id234" }"#; let request = serde_json::from_str::(example_payload).unwrap(); assert_eq!(request.device_id, "alice"); assert_eq!(request.payload, "message from Bob"); + assert_eq!(request.message_id, "id234"); } }