diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -7,7 +7,7 @@ use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, - MessageToDeviceRequest, MessageToDeviceRequestStatus, + MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] @@ -73,3 +73,15 @@ } Err("Failed to confirm sent message".into()) } + +pub async fn receive_message( + socket: &mut WebSocketStream>, +) -> Result> { + let Some(Ok(response)) = socket.next().await else { + return Err("Failed to receive message".into()); + }; + let message = response.to_text().expect("Failed to get response content"); + let message_to_device = serde_json::from_str::(message) + .expect("Failed to parse MessageToDevice from response"); + Ok(message_to_device.payload) +} diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs --- a/services/commtest/tests/identity_tunnelbroker_tests.rs +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -2,7 +2,7 @@ create_device, DEVICE_TYPE, PLACEHOLDER_CODE_VERSION, }; use commtest::service_addr; -use commtest::tunnelbroker::socket::create_socket; +use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::StreamExt; use grpc_clients::identity::protos::authenticated::OutboundKeysForUserRequest; use grpc_clients::identity::protos::client::UploadOneTimeKeysRequest; @@ -92,18 +92,13 @@ let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; + let response = receive_message(&mut socket).await.unwrap(); + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response).unwrap(); - // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); - - let expected_response = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - assert_eq!(serialized_response, expected_response); + let expected_response = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, }; + assert_eq!(serialized_response, expected_response); } diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -8,9 +8,9 @@ }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ - create_socket, send_message, WebSocketMessageToDevice, + create_socket, receive_message, send_message, WebSocketMessageToDevice, }; -use futures_util::StreamExt; + use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; @@ -48,12 +48,12 @@ .unwrap(); // Have keyserver receive any websocket messages - let response = socket.next().await.unwrap().unwrap(); + let response = receive_message(&mut socket).await.unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); + serde_json::from_str(&response).unwrap(); assert_eq!(serialized_response, refresh_request); } @@ -89,11 +89,7 @@ let mut receiver_socket = create_socket(&receiver).await; for msg in messages { - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(msg.payload, received_payload); - } else { - panic!("Unable to receive message"); - } + let response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(msg.payload, response); } } diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -7,9 +7,8 @@ }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ - create_socket, send_message, WebSocketMessageToDevice, + create_socket, receive_message, send_message, WebSocketMessageToDevice, }; -use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; @@ -50,13 +49,13 @@ let mut socket = create_socket(&device_info).await; // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); - }; + let response = receive_message(&mut socket).await.unwrap(); + + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response).unwrap(); + assert_eq!(serialized_response, refresh_request); } #[tokio::test] @@ -78,12 +77,7 @@ // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - // Connect receiver let mut receiver_socket = create_socket(&receiver).await; - - // Receive message - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(request.payload, received_payload); - }; + let response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(request.payload, response); } diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs --- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -2,7 +2,7 @@ use commtest::identity::olm_account_infos::{ DEFAULT_CLIENT_KEYS, MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; -use commtest::tunnelbroker::socket::create_socket; +use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite::Message; use tunnelbroker_messages::{ @@ -47,10 +47,8 @@ // Connect receiver to flush DDB and avoid polluting other tests let mut receiver_socket = create_socket(&receiver).await; - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(payload, received_payload); - }; + let receiver_response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(payload, receiver_response); } #[tokio::test] diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -16,9 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; -use tunnelbroker_messages::{ - MessageSentStatus, MessageToDevice, MessageToDeviceRequestStatus, -}; +use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; type BoxedError = Box; @@ -171,8 +169,7 @@ tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { - let message_to_device = serde_json::from_str::(message).unwrap(); - session.send_message_to_device(Message::Text(message_to_device.payload)).await; + session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); }