diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -6,7 +6,8 @@ use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ - ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, + ConnectionInitializationMessage, ConnectionInitializationResponse, + ConnectionInitializationStatus, DeviceTypes, MessageSentStatus, MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; @@ -20,7 +21,10 @@ pub async fn create_socket( device_info: &DeviceInfo, -) -> WebSocketStream> { +) -> Result< + WebSocketStream>, + Box, +> { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); @@ -43,7 +47,16 @@ .await .expect("Failed to send message"); - socket + if let Some(Ok(response)) = socket.next().await { + let response: ConnectionInitializationResponse = + serde_json::from_str(response.to_text().unwrap())?; + return match response.status { + ConnectionInitializationStatus::Success => Ok(socket), + ConnectionInitializationStatus::Error(err) => Err(err.into()), + }; + } + + Err("Failed to get response from Tunnelbroker".into()) } pub async fn send_message( diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs --- a/services/commtest/tests/identity_tunnelbroker_tests.rs +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -10,23 +10,17 @@ use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] -#[should_panic] async fn test_tunnelbroker_invalid_auth() { let mut device_info = create_device(None).await; device_info.access_token = "".to_string(); - let mut socket = create_socket(&device_info).await; - - socket - .next() - .await - .expect("Failed to receive response") - .expect("Failed to read the response"); + let socket = create_socket(&device_info).await; + assert!(matches!(socket, Result::Err(_))) } #[tokio::test] async fn test_tunnelbroker_valid_auth() { let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); socket .next() @@ -91,7 +85,7 @@ // Create session as a keyserver let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); for _ in 0..2 { let response = receive_message(&mut socket).await.unwrap(); let serialized_response: RefreshKeyRequest = diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -22,7 +22,7 @@ async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = @@ -77,7 +77,7 @@ }, ]; - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); for msg in messages.clone() { send_message(&mut sender_socket, msg).await.unwrap(); @@ -86,7 +86,7 @@ // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); for msg in messages { let response = receive_message(&mut receiver_socket).await.unwrap(); diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -47,7 +47,7 @@ // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); // Have keyserver receive any websocket messages let response = receive_message(&mut socket).await.unwrap(); @@ -64,7 +64,7 @@ let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), @@ -77,7 +77,7 @@ // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(request.payload, response); } diff --git a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs --- a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs @@ -18,7 +18,7 @@ let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), @@ -32,7 +32,7 @@ // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the first time (without confirmation) if let Some(Ok(response)) = receiver_socket.next().await { @@ -49,7 +49,7 @@ .send(Close(None)) .await .expect("Failed to send message"); - receiver_socket = create_socket(&receiver).await; + receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the second time let response = receive_message(&mut receiver_socket).await.unwrap(); @@ -62,8 +62,8 @@ let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to connected client - let mut receiver_socket = create_socket(&receiver).await; - let mut sender_socket = create_socket(&sender).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), @@ -88,7 +88,7 @@ .send(Close(None)) .await .expect("Failed to send message"); - receiver_socket = create_socket(&receiver).await; + receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the second time let response = receive_message(&mut receiver_socket).await.unwrap(); diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs --- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -30,7 +30,7 @@ let serialized_request = serde_json::to_string(&request) .expect("Failed to serialize message to device"); - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(serialized_request)) .await @@ -46,7 +46,7 @@ }; // Connect receiver to flush DDB and avoid polluting other tests - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); let receiver_response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(payload, receiver_response); } @@ -56,7 +56,7 @@ let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; let message = "some bad json".to_string(); - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(message.clone())) .await @@ -76,7 +76,7 @@ async fn get_invalid_request_error() { let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Binary(vec![])) .await diff --git a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs --- a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs +++ b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs @@ -12,7 +12,7 @@ let ping_message = vec![1, 2, 3, 4, 5]; - let mut socket = create_socket(&device).await; + let mut socket = create_socket(&device).await.unwrap(); socket .send(Message::Ping(ping_message.clone())) .await @@ -34,7 +34,7 @@ async fn test_close_message() { let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; - let mut socket = create_socket(&device).await; + let mut socket = create_socket(&device).await.unwrap(); socket .send(Close(None)) .await diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -4,7 +4,8 @@ use crate::websockets::session::initialize_amqp; use crate::CONFIG; use futures_util::stream::SplitSink; -use futures_util::StreamExt; +use futures_util::{SinkExt, StreamExt}; +use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; @@ -16,7 +17,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; -use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; +use tunnelbroker_messages::{ + ConnectionInitializationStatus, MessageSentStatus, + MessageToDeviceRequestStatus, +}; type BoxedError = Box; @@ -125,6 +129,29 @@ Ok(()) } +async fn send_error_init_response( + error: SessionError, + mut outgoing: SplitSink, Message>, +) { + let error_response = + tunnelbroker_messages::ConnectionInitializationResponse { + status: ConnectionInitializationStatus::Error(error.to_string()), + }; + + match serde_json::to_string(&error_response) { + Ok(serialized_response) => { + if let Err(send_error) = + outgoing.send(Message::Text(serialized_response)).await + { + error!("Failed to send init error response: {:?}", send_error); + } + } + Err(ser_error) => { + error!("Failed to serialize the error response: {:?}", ser_error); + } + } +} + /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, @@ -151,14 +178,27 @@ // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { - Ok(session) => session, - Err(_) => { + Ok(mut session) => { + let response = + tunnelbroker_messages::ConnectionInitializationResponse { + status: ConnectionInitializationStatus::Success, + }; + let serialized_response = serde_json::to_string(&response).unwrap(); + + session + .send_message_to_device(Message::Text(serialized_response)) + .await; + session + } + Err((err, tx)) => { error!("Failed to create session with device"); + send_error_init_response(err, tx).await; return; } } } else { error!("Failed to create session with device"); + send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; };