diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index de2656c00..846b3766f 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,94 +1,107 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ - ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, + ConnectionInitializationMessage, ConnectionInitializationResponse, + ConnectionInitializationStatus, DeviceTypes, MessageSentStatus, MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct WebSocketMessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } pub async fn create_socket( device_info: &DeviceInfo, -) -> WebSocketStream> { +) -> Result< + WebSocketStream>, + Box, +> { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); - socket + if let Some(Ok(response)) = socket.next().await { + let response: ConnectionInitializationResponse = + serde_json::from_str(response.to_text().unwrap())?; + return match response.status { + ConnectionInitializationStatus::Success => Ok(socket), + ConnectionInitializationStatus::Error(err) => Err(err.into()), + }; + } + + Err("Failed to get response from Tunnelbroker".into()) } pub async fn send_message( socket: &mut WebSocketStream>, message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { let confirmation: MessageToDeviceRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } pub async fn receive_message( socket: &mut WebSocketStream>, ) -> Result> { let Some(Ok(response)) = socket.next().await else { return Err("Failed to receive message".into()); }; let message = response.to_text().expect("Failed to get response content"); let message_to_device = serde_json::from_str::(message) .expect("Failed to parse MessageToDevice from response"); let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { message_ids: vec![message_to_device.message_id], }; let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); socket.send(Message::Text(serialized_confirmation)).await?; Ok(message_to_device.payload) } diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs index 0fe132a59..183c7289d 100644 --- a/services/commtest/tests/identity_tunnelbroker_tests.rs +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -1,107 +1,101 @@ use commtest::identity::device::{ create_device, DEVICE_TYPE, PLACEHOLDER_CODE_VERSION, }; use commtest::service_addr; use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::StreamExt; use grpc_clients::identity::protos::authenticated::OutboundKeysForUserRequest; use grpc_clients::identity::protos::client::UploadOneTimeKeysRequest; use grpc_clients::identity::{get_auth_client, get_unauthenticated_client}; use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] -#[should_panic] async fn test_tunnelbroker_invalid_auth() { let mut device_info = create_device(None).await; device_info.access_token = "".to_string(); - let mut socket = create_socket(&device_info).await; - - socket - .next() - .await - .expect("Failed to receive response") - .expect("Failed to read the response"); + let socket = create_socket(&device_info).await; + assert!(matches!(socket, Result::Err(_))) } #[tokio::test] async fn test_tunnelbroker_valid_auth() { let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); socket .next() .await .expect("Failed to receive response") .expect("Failed to read the response"); } #[tokio::test] async fn test_refresh_keys_request_upon_depletion() { let identity_grpc_endpoint = service_addr::IDENTITY_GRPC.to_string(); let device_info = create_device(None).await; let mut identity_client = get_unauthenticated_client( &identity_grpc_endpoint, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await .expect("Couldn't connect to identity service"); let upload_request = UploadOneTimeKeysRequest { user_id: device_info.user_id.clone(), device_id: device_info.device_id.clone(), access_token: device_info.access_token.clone(), content_one_time_pre_keys: vec!["content1".to_string()], notif_one_time_pre_keys: vec!["notif1".to_string()], }; identity_client .upload_one_time_keys(upload_request) .await .unwrap(); // Request outbound keys, which should trigger identity service to ask for more keys let mut client = get_auth_client( &identity_grpc_endpoint, device_info.user_id.clone(), device_info.device_id, device_info.access_token, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await .expect("Couldn't connect to identity service"); let keyserver_request = OutboundKeysForUserRequest { user_id: device_info.user_id.clone(), }; println!("Getting keyserver info for user, {}", device_info.user_id); let _first_reponse = client .get_keyserver_keys(keyserver_request.clone()) .await .expect("Second keyserver keys request failed") .into_inner() .keyserver_info .unwrap(); // The current threshold is 5, but we only upload two. Should receive request // from Tunnelbroker to refresh keys // Create session as a keyserver let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); for _ in 0..2 { let response = receive_message(&mut socket).await.unwrap(); let serialized_response: RefreshKeyRequest = serde_json::from_str(&response).unwrap(); let expected_response = RefreshKeyRequest { device_id: device_info.device_id.to_string(), number_of_keys: 5, }; assert_eq!(serialized_response, expected_response); } } diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs index c8c2cdb20..c688628fe 100644 --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -1,95 +1,95 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ create_socket, receive_message, send_message, WebSocketMessageToDevice, }; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.clone(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.clone(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Have keyserver receive any websocket messages let response = receive_message(&mut socket).await.unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(&response).unwrap(); assert_eq!(serialized_response, refresh_request); } #[tokio::test] async fn test_messages_order() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let messages = vec![ WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "first message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "third message".to_string(), }, ]; - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); for msg in messages.clone() { send_message(&mut sender_socket, msg).await.unwrap(); } // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); for msg in messages { let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(msg.payload, response); } } diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs index 8b5821be6..f01508178 100644 --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -1,83 +1,83 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ create_socket, receive_message, send_message, WebSocketMessageToDevice, }; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; use tunnelbroker_messages::RefreshKeyRequest; /// Tests that a message to an offline device gets pushed to dynamodb /// then recalled once a device connects #[tokio::test] async fn persist_grpc_messages() { let device_info = create_device(None).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.to_string(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.to_string(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut socket = create_socket(&device_info).await; + let mut socket = create_socket(&device_info).await.unwrap(); // Have keyserver receive any websocket messages let response = receive_message(&mut socket).await.unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(&response).unwrap(); assert_eq!(serialized_response, refresh_request); } #[tokio::test] async fn persist_websocket_messages() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "persisted message".to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(request.payload, response); } diff --git a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs index 3dea3b3bd..6055588d1 100644 --- a/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_recipient_confirmation_tests.rs @@ -1,197 +1,197 @@ use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::tunnelbroker::socket::{ create_socket, receive_message, send_message, WebSocketMessageToDevice, }; use futures_util::{SinkExt, StreamExt}; use std::time::Duration; use tokio::time::sleep; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message::Close; use tunnelbroker_messages::MessageToDevice; #[tokio::test] async fn deliver_until_confirmation_not_connected() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // send message to not connected client - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "message from deliver_until_confirmation_not_connected" .to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the first time (without confirmation) let Some(Ok(response)) = receiver_socket.next().await else { panic!("Receiving first message failed") }; let message = response.to_text().unwrap(); let message_to_device = serde_json::from_str::(message).unwrap(); assert_eq!(request.payload, message_to_device.payload); // restart connection receiver_socket .send(Close(None)) .await .expect("Failed to send message"); - receiver_socket = create_socket(&receiver).await; + receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the second time let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(request.payload, response); } #[tokio::test] async fn deliver_until_confirmation_connected() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // send message to connected client - let mut receiver_socket = create_socket(&receiver).await; - let mut sender_socket = create_socket(&sender).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "message from deliver_until_confirmation_connected".to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // receive message for the first time (without confirmation) let Some(Ok(response)) = receiver_socket.next().await else { panic!("Receiving first message failed") }; let message = response.to_text().unwrap(); let message_to_device = serde_json::from_str::(message).unwrap(); assert_eq!(request.payload, message_to_device.payload); // restart connection receiver_socket .send(Close(None)) .await .expect("Failed to send message"); - receiver_socket = create_socket(&receiver).await; + receiver_socket = create_socket(&receiver).await.unwrap(); // receive message for the second time let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(request.payload, response); } #[tokio::test] async fn test_confirming_deleted_message() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // send message to connected client - let mut receiver_socket = create_socket(&receiver).await; - let mut sender_socket = create_socket(&sender).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); + let mut sender_socket = create_socket(&sender).await.unwrap(); let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "message to bo confirmed twice".to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // receive a message let Some(Ok(response)) = receiver_socket.next().await else { panic!("Receiving first message failed") }; let message = response.to_text().unwrap(); let message_to_device = serde_json::from_str::(message).unwrap(); assert_eq!(request.payload, message_to_device.payload); let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { message_ids: vec![message_to_device.message_id], }; let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); // send confirmation twice receiver_socket .send(Message::Text(serialized_confirmation.clone())) .await .expect("Error while sending confirmation"); receiver_socket .send(Message::Text(serialized_confirmation)) .await .expect("Error while sending confirmation"); // test if socket is still alive by sending and receiving a message let second_request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second request".to_string(), }; send_message(&mut sender_socket, second_request.clone()) .await .unwrap(); let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(second_request.payload, response); } #[tokio::test] async fn test_confirming() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // send message to connected client - let mut receiver_socket = create_socket(&receiver).await; - let mut sender_socket = create_socket(&sender).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); + let mut sender_socket = create_socket(&sender).await.unwrap(); // send first message let first_request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "first request".to_string(), }; send_message(&mut sender_socket, first_request.clone()) .await .unwrap(); // receive a first message let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(first_request.payload, response); // restart connection receiver_socket .send(Close(None)) .await .expect("Failed to send message"); tokio::time::sleep(Duration::from_millis(200)).await; - receiver_socket = create_socket(&receiver).await; + receiver_socket = create_socket(&receiver).await.unwrap(); // send second message let second_request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second request".to_string(), }; send_message(&mut sender_socket, second_request.clone()) .await .unwrap(); // make sure only second message is received let response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(second_request.payload, response) } diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs index fc1016e7a..b83565e06 100644 --- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -1,93 +1,93 @@ use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ DEFAULT_CLIENT_KEYS, MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite::Message; use tunnelbroker_messages::{ MessageSentStatus, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; /// Tests of responses sent from Tunnelberoker to client /// trying to send message to other device #[tokio::test] async fn get_confirmation() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let client_message_id = "mockID".to_string(); // Send message to not connected client let payload = "persisted message"; let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: receiver.device_id.clone(), payload: payload.to_string(), }; let serialized_request = serde_json::to_string(&request) .expect("Failed to serialize message to device"); - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::Success(client_message_id)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; // Connect receiver to flush DDB and avoid polluting other tests - let mut receiver_socket = create_socket(&receiver).await; + let mut receiver_socket = create_socket(&receiver).await.unwrap(); let receiver_response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(payload, receiver_response); } #[tokio::test] async fn get_serialization_error() { let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; let message = "some bad json".to_string(); - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(message.clone())) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::SerializationError(message)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } #[tokio::test] async fn get_invalid_request_error() { let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; - let mut sender_socket = create_socket(&sender).await; + let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Binary(vec![])) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::InvalidRequest], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } diff --git a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs index 53cb52a66..e906accf9 100644 --- a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs +++ b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs @@ -1,49 +1,49 @@ use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::MOCK_CLIENT_KEYS_1; use commtest::tunnelbroker::socket::create_socket; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite::{Error, Message, Message::Close}; /// Tests for message types defined in tungstenite crate #[tokio::test] async fn test_ping_pong() { let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let ping_message = vec![1, 2, 3, 4, 5]; - let mut socket = create_socket(&device).await; + let mut socket = create_socket(&device).await.unwrap(); socket .send(Message::Ping(ping_message.clone())) .await .expect("Failed to send message"); if let Some(Ok(response)) = socket.next().await { let received_payload = match response { Message::Pong(received_payload) => received_payload, unexpected => panic!( "Unexpected message type or result. Expected Pong, got: {:?}. ", unexpected ), }; assert_eq!(ping_message.clone(), received_payload); }; } #[tokio::test] async fn test_close_message() { let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; - let mut socket = create_socket(&device).await; + let mut socket = create_socket(&device).await.unwrap(); socket .send(Close(None)) .await .expect("Failed to send message"); if let Some(response) = socket.next().await { assert!(matches!( response, Err(Error::AlreadyClosed | Error::ConnectionClosed) | Ok(Close(None)) )); }; } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index e880f6c71..67046fc68 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,256 +1,296 @@ pub mod session; use crate::database::DatabaseClient; -use crate::websockets::session::initialize_amqp; +use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; -use futures_util::StreamExt; +use futures_util::{SinkExt, StreamExt}; +use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; -use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; +use tunnelbroker_messages::{ + ConnectionInitializationStatus, MessageSentStatus, + MessageToDeviceRequestStatus, +}; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel).await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } +async fn send_error_init_response( + error: SessionError, + mut outgoing: SplitSink, Message>, +) { + let error_response = + tunnelbroker_messages::ConnectionInitializationResponse { + status: ConnectionInitializationStatus::Error(error.to_string()), + }; + + match serde_json::to_string(&error_response) { + Ok(serialized_response) => { + if let Err(send_error) = + outgoing.send(Message::Text(serialized_response)).await + { + error!("Failed to send init error response: {:?}", send_error); + } + } + Err(ser_error) => { + error!("Failed to serialize the error response: {:?}", ser_error); + } + } +} + /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { - Ok(session) => session, - Err(_) => { + Ok(mut session) => { + let response = + tunnelbroker_messages::ConnectionInitializationResponse { + status: ConnectionInitializationStatus::Success, + }; + let serialized_response = serde_json::to_string(&response).unwrap(); + + session + .send_message_to_device(Message::Text(serialized_response)) + .await; + session + } + Err((err, outgoing)) => { error!("Failed to create session with device"); + send_error_init_response(err, outgoing).await; return; } } } else { error!("Failed to create session with device"); + send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, )), Err(e) => Err((e, outgoing)), } }