Page MenuHomePhabricator

D9574.id32343.diff
No OneTemporary

D9574.id32343.diff

diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs
--- a/services/commtest/src/tunnelbroker/socket.rs
+++ b/services/commtest/src/tunnelbroker/socket.rs
@@ -7,7 +7,7 @@
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tunnelbroker_messages::{
ConnectionInitializationMessage, DeviceTypes, MessageSentStatus,
- MessageToDeviceRequest, MessageToDeviceRequestStatus,
+ MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus,
};
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
@@ -73,3 +73,16 @@
}
Err("Failed to confirm sent message".into())
}
+
+pub async fn receive_message(
+ socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
+) -> Result<String, Box<dyn std::error::Error>> {
+ if let Some(Ok(response)) = socket.next().await {
+ let message = response.to_text().expect("Failed to get response content");
+ let message_to_device = serde_json::from_str::<MessageToDevice>(message)
+ .expect("Failed to parse MessageToDevice from response");
+ return Ok(message_to_device.payload);
+ }
+
+ Err("Failed to receive message".into())
+}
diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs
--- a/services/commtest/tests/identity_tunnelbroker_tests.rs
+++ b/services/commtest/tests/identity_tunnelbroker_tests.rs
@@ -2,7 +2,7 @@
create_device, DEVICE_TYPE, PLACEHOLDER_CODE_VERSION,
};
use commtest::service_addr;
-use commtest::tunnelbroker::socket::create_socket;
+use commtest::tunnelbroker::socket::{create_socket, receive_message};
use futures_util::StreamExt;
use grpc_clients::identity::protos::authenticated::OutboundKeysForUserRequest;
use grpc_clients::identity::protos::client::UploadOneTimeKeysRequest;
@@ -92,18 +92,13 @@
let device_info = create_device(None).await;
let mut socket = create_socket(&device_info).await;
+ let response = receive_message(&mut socket).await.unwrap();
+ let serialized_response: RefreshKeyRequest =
+ serde_json::from_str(&response).unwrap();
- // Have keyserver receive any websocket messages
- if let Some(Ok(response)) = socket.next().await {
- // Check that message received by keyserver matches what identity server
- // issued
- let serialized_response: RefreshKeyRequest =
- serde_json::from_str(response.to_text().unwrap()).unwrap();
-
- let expected_response = RefreshKeyRequest {
- device_id: device_info.device_id.to_string(),
- number_of_keys: 5,
- };
- assert_eq!(serialized_response, expected_response);
+ let expected_response = RefreshKeyRequest {
+ device_id: device_info.device_id.to_string(),
+ number_of_keys: 5,
};
+ assert_eq!(serialized_response, expected_response);
}
diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs
--- a/services/commtest/tests/tunnelbroker_integration_tests.rs
+++ b/services/commtest/tests/tunnelbroker_integration_tests.rs
@@ -8,9 +8,9 @@
};
use commtest::service_addr;
use commtest::tunnelbroker::socket::{
- create_socket, send_message, WebSocketMessageToDevice,
+ create_socket, receive_message, send_message, WebSocketMessageToDevice,
};
-use futures_util::StreamExt;
+
use proto::tunnelbroker_service_client::TunnelbrokerServiceClient;
use proto::MessageToDevice;
use std::time::Duration;
@@ -48,12 +48,12 @@
.unwrap();
// Have keyserver receive any websocket messages
- let response = socket.next().await.unwrap().unwrap();
+ let response = receive_message(&mut socket).await.unwrap();
// Check that message received by keyserver matches what identity server
// issued
let serialized_response: RefreshKeyRequest =
- serde_json::from_str(response.to_text().unwrap()).unwrap();
+ serde_json::from_str(&response).unwrap();
assert_eq!(serialized_response, refresh_request);
}
@@ -89,11 +89,7 @@
let mut receiver_socket = create_socket(&receiver).await;
for msg in messages {
- if let Some(Ok(response)) = receiver_socket.next().await {
- let received_payload = response.to_text().unwrap();
- assert_eq!(msg.payload, received_payload);
- } else {
- panic!("Unable to receive message");
- }
+ let response = receive_message(&mut receiver_socket).await.unwrap();
+ assert_eq!(msg.payload, response);
}
}
diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs
--- a/services/commtest/tests/tunnelbroker_persist_tests.rs
+++ b/services/commtest/tests/tunnelbroker_persist_tests.rs
@@ -7,9 +7,8 @@
};
use commtest::service_addr;
use commtest::tunnelbroker::socket::{
- create_socket, send_message, WebSocketMessageToDevice,
+ create_socket, receive_message, send_message, WebSocketMessageToDevice,
};
-use futures_util::StreamExt;
use proto::tunnelbroker_service_client::TunnelbrokerServiceClient;
use proto::MessageToDevice;
use std::time::Duration;
@@ -50,13 +49,13 @@
let mut socket = create_socket(&device_info).await;
// Have keyserver receive any websocket messages
- if let Some(Ok(response)) = socket.next().await {
- // Check that message received by keyserver matches what identity server
- // issued
- let serialized_response: RefreshKeyRequest =
- serde_json::from_str(response.to_text().unwrap()).unwrap();
- assert_eq!(serialized_response, refresh_request);
- };
+ let response = receive_message(&mut socket).await.unwrap();
+
+ // Check that message received by keyserver matches what identity server
+ // issued
+ let serialized_response: RefreshKeyRequest =
+ serde_json::from_str(&response).unwrap();
+ assert_eq!(serialized_response, refresh_request);
}
#[tokio::test]
@@ -78,12 +77,7 @@
// Wait a specified duration to ensure that message had time to persist
sleep(Duration::from_millis(100)).await;
- // Connect receiver
let mut receiver_socket = create_socket(&receiver).await;
-
- // Receive message
- if let Some(Ok(response)) = receiver_socket.next().await {
- let received_payload = response.to_text().unwrap();
- assert_eq!(request.payload, received_payload);
- };
+ let response = receive_message(&mut receiver_socket).await.unwrap();
+ assert_eq!(request.payload, response);
}
diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
--- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
+++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
@@ -2,7 +2,7 @@
use commtest::identity::olm_account_infos::{
DEFAULT_CLIENT_KEYS, MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2,
};
-use commtest::tunnelbroker::socket::create_socket;
+use commtest::tunnelbroker::socket::{create_socket, receive_message};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
use tunnelbroker_messages::{
@@ -47,10 +47,8 @@
// Connect receiver to flush DDB and avoid polluting other tests
let mut receiver_socket = create_socket(&receiver).await;
- if let Some(Ok(response)) = receiver_socket.next().await {
- let received_payload = response.to_text().unwrap();
- assert_eq!(payload, received_payload);
- };
+ let receiver_response = receive_message(&mut receiver_socket).await.unwrap();
+ assert_eq!(payload, receiver_response);
}
#[tokio::test]
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -16,9 +16,7 @@
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tracing::{debug, error, info};
-use tunnelbroker_messages::{
- MessageSentStatus, MessageToDevice, MessageToDeviceRequestStatus,
-};
+use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus};
type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
@@ -171,8 +169,7 @@
tokio::select! {
Some(Ok(delivery)) = session.next_amqp_message() => {
if let Ok(message) = std::str::from_utf8(&delivery.data) {
- let message_to_device = serde_json::from_str::<MessageToDevice>(message).unwrap();
- session.send_message_to_device(Message::Text(message_to_device.payload)).await;
+ session.send_message_to_device(Message::Text(message.to_string())).await;
} else {
error!("Invalid payload");
}

File Metadata

Mime Type
text/plain
Expires
Sat, Nov 16, 5:20 PM (17 h, 34 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2499366
Default Alt Text
D9574.id32343.diff (8 KB)

Event Timeline