Page MenuHomePhabricator

D9595.id32731.diff
No OneTemporary

D9595.id32731.diff

diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs
--- a/services/commtest/src/tunnelbroker/socket.rs
+++ b/services/commtest/src/tunnelbroker/socket.rs
@@ -6,7 +6,8 @@
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tunnelbroker_messages::{
- ConnectionInitializationMessage, DeviceTypes, MessageSentStatus,
+ ConnectionInitializationMessage, ConnectionInitializationResponse,
+ ConnectionInitializationStatus, DeviceTypes, MessageSentStatus,
MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus,
};
@@ -20,7 +21,10 @@
pub async fn create_socket(
device_info: &DeviceInfo,
-) -> WebSocketStream<MaybeTlsStream<TcpStream>> {
+) -> Result<
+ WebSocketStream<MaybeTlsStream<TcpStream>>,
+ Box<dyn std::error::Error>,
+> {
let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS)
.await
.expect("Can't connect");
@@ -43,7 +47,16 @@
.await
.expect("Failed to send message");
- socket
+ if let Some(Ok(response)) = socket.next().await {
+ let response: ConnectionInitializationResponse =
+ serde_json::from_str(response.to_text().unwrap())?;
+ return match response.status {
+ ConnectionInitializationStatus::Success => Ok(socket),
+ ConnectionInitializationStatus::Error(err) => Err(err.into()),
+ };
+ }
+
+ Err("Failed to get response from Tunnelbroker".into())
}
pub async fn send_message(
diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs
--- a/services/commtest/tests/identity_tunnelbroker_tests.rs
+++ b/services/commtest/tests/identity_tunnelbroker_tests.rs
@@ -10,23 +10,17 @@
use tunnelbroker_messages::RefreshKeyRequest;
#[tokio::test]
-#[should_panic]
async fn test_tunnelbroker_invalid_auth() {
let mut device_info = create_device(None).await;
device_info.access_token = "".to_string();
- let mut socket = create_socket(&device_info).await;
-
- socket
- .next()
- .await
- .expect("Failed to receive response")
- .expect("Failed to read the response");
+ let socket = create_socket(&device_info).await;
+ assert!(matches!(socket, Result::Err(_)))
}
#[tokio::test]
async fn test_tunnelbroker_valid_auth() {
let device_info = create_device(None).await;
- let mut socket = create_socket(&device_info).await;
+ let mut socket = create_socket(&device_info).await.unwrap();
socket
.next()
@@ -91,7 +85,7 @@
// Create session as a keyserver
let device_info = create_device(None).await;
- let mut socket = create_socket(&device_info).await;
+ let mut socket = create_socket(&device_info).await.unwrap();
for _ in 0..2 {
let response = receive_message(&mut socket).await.unwrap();
let serialized_response: RefreshKeyRequest =
diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs
--- a/services/commtest/tests/tunnelbroker_integration_tests.rs
+++ b/services/commtest/tests/tunnelbroker_integration_tests.rs
@@ -22,7 +22,7 @@
async fn send_refresh_request() {
// Create session as a keyserver
let device_info = create_device(None).await;
- let mut socket = create_socket(&device_info).await;
+ let mut socket = create_socket(&device_info).await.unwrap();
// Send request for keyserver to refresh keys (identity service)
let mut tunnelbroker_client =
@@ -77,7 +77,7 @@
},
];
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
for msg in messages.clone() {
send_message(&mut sender_socket, msg).await.unwrap();
@@ -86,7 +86,7 @@
// Wait a specified duration to ensure that message had time to persist
sleep(Duration::from_millis(100)).await;
- let mut receiver_socket = create_socket(&receiver).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
for msg in messages {
let response = receive_message(&mut receiver_socket).await.unwrap();
diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs
--- a/services/commtest/tests/tunnelbroker_persist_tests.rs
+++ b/services/commtest/tests/tunnelbroker_persist_tests.rs
@@ -47,7 +47,7 @@
// Wait a specified duration to ensure that message had time to persist
sleep(Duration::from_millis(100)).await;
- let mut socket = create_socket(&device_info).await;
+ let mut socket = create_socket(&device_info).await.unwrap();
// Have keyserver receive any websocket messages
let response = receive_message(&mut socket).await.unwrap();
@@ -64,7 +64,7 @@
let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await;
// Send message to not connected client
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
let request = WebSocketMessageToDevice {
device_id: receiver.device_id.clone(),
@@ -77,7 +77,7 @@
// Wait a specified duration to ensure that message had time to persist
sleep(Duration::from_millis(100)).await;
- let mut receiver_socket = create_socket(&receiver).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
let response = receive_message(&mut receiver_socket).await.unwrap();
assert_eq!(request.payload, response);
}
diff --git a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs
--- a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs
+++ b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs
@@ -19,7 +19,7 @@
let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await;
// send message to not connected client
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
let request = WebSocketMessageToDevice {
device_id: receiver.device_id.clone(),
@@ -33,7 +33,7 @@
// wait a specified duration to ensure that message had time to persist
sleep(Duration::from_millis(100)).await;
- let mut receiver_socket = create_socket(&receiver).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
// receive message for the first time (without confirmation)
let Some(Ok(response)) = receiver_socket.next().await else {
@@ -50,7 +50,7 @@
.send(Close(None))
.await
.expect("Failed to send message");
- receiver_socket = create_socket(&receiver).await;
+ receiver_socket = create_socket(&receiver).await.unwrap();
// receive message for the second time
let response = receive_message(&mut receiver_socket).await.unwrap();
@@ -63,8 +63,8 @@
let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await;
// send message to connected client
- let mut receiver_socket = create_socket(&receiver).await;
- let mut sender_socket = create_socket(&sender).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
+ let mut sender_socket = create_socket(&sender).await.unwrap();
let request = WebSocketMessageToDevice {
device_id: receiver.device_id.clone(),
@@ -89,7 +89,7 @@
.send(Close(None))
.await
.expect("Failed to send message");
- receiver_socket = create_socket(&receiver).await;
+ receiver_socket = create_socket(&receiver).await.unwrap();
// receive message for the second time
let response = receive_message(&mut receiver_socket).await.unwrap();
@@ -102,8 +102,8 @@
let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await;
// send message to connected client
- let mut receiver_socket = create_socket(&receiver).await;
- let mut sender_socket = create_socket(&sender).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
+ let mut sender_socket = create_socket(&sender).await.unwrap();
let request = WebSocketMessageToDevice {
device_id: receiver.device_id.clone(),
@@ -157,8 +157,8 @@
let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await;
// send message to connected client
- let mut receiver_socket = create_socket(&receiver).await;
- let mut sender_socket = create_socket(&sender).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
+ let mut sender_socket = create_socket(&sender).await.unwrap();
// send first message
let first_request = WebSocketMessageToDevice {
@@ -180,7 +180,7 @@
.expect("Failed to send message");
tokio::time::sleep(Duration::from_millis(200)).await;
- receiver_socket = create_socket(&receiver).await;
+ receiver_socket = create_socket(&receiver).await.unwrap();
// send second message
let second_request = WebSocketMessageToDevice {
diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
--- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
+++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs
@@ -30,7 +30,7 @@
let serialized_request = serde_json::to_string(&request)
.expect("Failed to serialize message to device");
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
sender_socket
.send(Message::Text(serialized_request))
.await
@@ -46,7 +46,7 @@
};
// Connect receiver to flush DDB and avoid polluting other tests
- let mut receiver_socket = create_socket(&receiver).await;
+ let mut receiver_socket = create_socket(&receiver).await.unwrap();
let receiver_response = receive_message(&mut receiver_socket).await.unwrap();
assert_eq!(payload, receiver_response);
}
@@ -56,7 +56,7 @@
let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await;
let message = "some bad json".to_string();
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
sender_socket
.send(Message::Text(message.clone()))
.await
@@ -76,7 +76,7 @@
async fn get_invalid_request_error() {
let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await;
- let mut sender_socket = create_socket(&sender).await;
+ let mut sender_socket = create_socket(&sender).await.unwrap();
sender_socket
.send(Message::Binary(vec![]))
.await
diff --git a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs
--- a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs
+++ b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs
@@ -12,7 +12,7 @@
let ping_message = vec![1, 2, 3, 4, 5];
- let mut socket = create_socket(&device).await;
+ let mut socket = create_socket(&device).await.unwrap();
socket
.send(Message::Ping(ping_message.clone()))
.await
@@ -34,7 +34,7 @@
async fn test_close_message() {
let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await;
- let mut socket = create_socket(&device).await;
+ let mut socket = create_socket(&device).await.unwrap();
socket
.send(Close(None))
.await
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -1,10 +1,11 @@
pub mod session;
use crate::database::DatabaseClient;
-use crate::websockets::session::initialize_amqp;
+use crate::websockets::session::{initialize_amqp, SessionError};
use crate::CONFIG;
use futures_util::stream::SplitSink;
-use futures_util::StreamExt;
+use futures_util::{SinkExt, StreamExt};
+use hyper::upgrade::Upgraded;
use hyper::{Body, Request, Response, StatusCode};
use hyper_tungstenite::tungstenite::Message;
use hyper_tungstenite::HyperWebsocket;
@@ -16,7 +17,10 @@
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tracing::{debug, error, info};
-use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus};
+use tunnelbroker_messages::{
+ ConnectionInitializationStatus, MessageSentStatus,
+ MessageToDeviceRequestStatus,
+};
type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
@@ -130,6 +134,29 @@
Ok(())
}
+async fn send_error_init_response(
+ error: SessionError,
+ mut outgoing: SplitSink<WebSocketStream<Upgraded>, Message>,
+) {
+ let error_response =
+ tunnelbroker_messages::ConnectionInitializationResponse {
+ status: ConnectionInitializationStatus::Error(error.to_string()),
+ };
+
+ match serde_json::to_string(&error_response) {
+ Ok(serialized_response) => {
+ if let Err(send_error) =
+ outgoing.send(Message::Text(serialized_response)).await
+ {
+ error!("Failed to send init error response: {:?}", send_error);
+ }
+ }
+ Err(ser_error) => {
+ error!("Failed to serialize the error response: {:?}", ser_error);
+ }
+ }
+}
+
/// Handler for any incoming websocket connections
async fn accept_connection(
hyper_ws: HyperWebsocket,
@@ -156,14 +183,27 @@
// request over the websocket connection
let mut session = if let Some(Ok(first_msg)) = incoming.next().await {
match initiate_session(outgoing, first_msg, db_client, amqp_channel).await {
- Ok(session) => session,
- Err(_) => {
+ Ok(mut session) => {
+ let response =
+ tunnelbroker_messages::ConnectionInitializationResponse {
+ status: ConnectionInitializationStatus::Success,
+ };
+ let serialized_response = serde_json::to_string(&response).unwrap();
+
+ session
+ .send_message_to_device(Message::Text(serialized_response))
+ .await;
+ session
+ }
+ Err((err, outgoing)) => {
error!("Failed to create session with device");
+ send_error_init_response(err, outgoing).await;
return;
}
}
} else {
error!("Failed to create session with device");
+ send_error_init_response(SessionError::InvalidMessage, outgoing).await;
return;
};

File Metadata

Mime Type
text/plain
Expires
Sun, Nov 17, 1:46 AM (20 h, 53 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2510349
Default Alt Text
D9595.id32731.diff (13 KB)

Event Timeline