diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index 7bfffa493..f6243a41c 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,75 +1,87 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, DeviceTypes, MessageSentStatus, - MessageToDeviceRequest, MessageToDeviceRequestStatus, + MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct WebSocketMessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } pub async fn create_socket( device_info: &DeviceInfo, ) -> WebSocketStream> { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); socket } pub async fn send_message( socket: &mut WebSocketStream>, message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { let confirmation: MessageToDeviceRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } + +pub async fn receive_message( + socket: &mut WebSocketStream>, +) -> Result> { + let Some(Ok(response)) = socket.next().await else { + return Err("Failed to receive message".into()); + }; + let message = response.to_text().expect("Failed to get response content"); + let message_to_device = serde_json::from_str::(message) + .expect("Failed to parse MessageToDevice from response"); + Ok(message_to_device.payload) +} diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs index e316e5cf2..eff29b6e8 100644 --- a/services/commtest/tests/identity_tunnelbroker_tests.rs +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -1,109 +1,104 @@ use commtest::identity::device::{ create_device, DEVICE_TYPE, PLACEHOLDER_CODE_VERSION, }; use commtest::service_addr; -use commtest::tunnelbroker::socket::create_socket; +use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::StreamExt; use grpc_clients::identity::protos::authenticated::OutboundKeysForUserRequest; use grpc_clients::identity::protos::client::UploadOneTimeKeysRequest; use grpc_clients::identity::{get_auth_client, get_unauthenticated_client}; use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] #[should_panic] async fn test_tunnelbroker_invalid_auth() { let mut device_info = create_device(None).await; device_info.access_token = "".to_string(); let mut socket = create_socket(&device_info).await; socket .next() .await .expect("Failed to receive response") .expect("Failed to read the response"); } #[tokio::test] async fn test_tunnelbroker_valid_auth() { let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; socket .next() .await .expect("Failed to receive response") .expect("Failed to read the response"); } #[tokio::test] async fn test_refresh_keys_request_upon_depletion() { let identity_grpc_endpoint = service_addr::IDENTITY_GRPC.to_string(); let device_info = create_device(None).await; let mut identity_client = get_unauthenticated_client( &identity_grpc_endpoint, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await .expect("Couldn't connect to identity service"); let upload_request = UploadOneTimeKeysRequest { user_id: device_info.user_id.clone(), device_id: device_info.device_id.clone(), access_token: device_info.access_token.clone(), content_one_time_pre_keys: vec!["content1".to_string()], notif_one_time_pre_keys: vec!["notif1".to_string()], }; identity_client .upload_one_time_keys(upload_request) .await .unwrap(); // Request outbound keys, which should trigger identity service to ask for more keys let mut client = get_auth_client( &identity_grpc_endpoint, device_info.user_id.clone(), device_info.device_id, device_info.access_token, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await .expect("Couldn't connect to identity service"); let keyserver_request = OutboundKeysForUserRequest { user_id: device_info.user_id.clone(), }; println!("Getting keyserver info for user, {}", device_info.user_id); let _first_reponse = client .get_keyserver_keys(keyserver_request.clone()) .await .expect("Second keyserver keys request failed") .into_inner() .keyserver_info .unwrap(); // The current threshold is 5, but we only upload two. Should receive request // from Tunnelbroker to refresh keys // Create session as a keyserver let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; + let response = receive_message(&mut socket).await.unwrap(); + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response).unwrap(); - // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); - - let expected_response = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - assert_eq!(serialized_response, expected_response); + let expected_response = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, }; + assert_eq!(serialized_response, expected_response); } diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs index 4ef700fc4..c8c2cdb20 100644 --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -1,99 +1,95 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ - create_socket, send_message, WebSocketMessageToDevice, + create_socket, receive_message, send_message, WebSocketMessageToDevice, }; -use futures_util::StreamExt; + use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.clone(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.clone(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Have keyserver receive any websocket messages - let response = socket.next().await.unwrap().unwrap(); + let response = receive_message(&mut socket).await.unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); + serde_json::from_str(&response).unwrap(); assert_eq!(serialized_response, refresh_request); } #[tokio::test] async fn test_messages_order() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let messages = vec![ WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "first message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "second message".to_string(), }, WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "third message".to_string(), }, ]; let mut sender_socket = create_socket(&sender).await; for msg in messages.clone() { send_message(&mut sender_socket, msg).await.unwrap(); } // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; let mut receiver_socket = create_socket(&receiver).await; for msg in messages { - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(msg.payload, received_payload); - } else { - panic!("Unable to receive message"); - } + let response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(msg.payload, response); } } diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs index d13af10b1..8b5821be6 100644 --- a/services/commtest/tests/tunnelbroker_persist_tests.rs +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -1,89 +1,83 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; use commtest::service_addr; use commtest::tunnelbroker::socket::{ - create_socket, send_message, WebSocketMessageToDevice, + create_socket, receive_message, send_message, WebSocketMessageToDevice, }; -use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use std::time::Duration; use tokio::time::sleep; use tunnelbroker_messages::RefreshKeyRequest; /// Tests that a message to an offline device gets pushed to dynamodb /// then recalled once a device connects #[tokio::test] async fn persist_grpc_messages() { let device_info = create_device(None).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect(service_addr::TUNNELBROKER_GRPC) .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.to_string(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.to_string(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; let mut socket = create_socket(&device_info).await; // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); - }; + let response = receive_message(&mut socket).await.unwrap(); + + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response).unwrap(); + assert_eq!(serialized_response, refresh_request); } #[tokio::test] async fn persist_websocket_messages() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; // Send message to not connected client let mut sender_socket = create_socket(&sender).await; let request = WebSocketMessageToDevice { device_id: receiver.device_id.clone(), payload: "persisted message".to_string(), }; send_message(&mut sender_socket, request.clone()) .await .unwrap(); // Wait a specified duration to ensure that message had time to persist sleep(Duration::from_millis(100)).await; - // Connect receiver let mut receiver_socket = create_socket(&receiver).await; - - // Receive message - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(request.payload, received_payload); - }; + let response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(request.payload, response); } diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs index 4b1d0e452..fc1016e7a 100644 --- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -1,95 +1,93 @@ use commtest::identity::device::create_device; use commtest::identity::olm_account_infos::{ DEFAULT_CLIENT_KEYS, MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, }; -use commtest::tunnelbroker::socket::create_socket; +use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite::Message; use tunnelbroker_messages::{ MessageSentStatus, MessageToDeviceRequest, MessageToDeviceRequestStatus, }; /// Tests of responses sent from Tunnelberoker to client /// trying to send message to other device #[tokio::test] async fn get_confirmation() { let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; let client_message_id = "mockID".to_string(); // Send message to not connected client let payload = "persisted message"; let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: receiver.device_id.clone(), payload: payload.to_string(), }; let serialized_request = serde_json::to_string(&request) .expect("Failed to serialize message to device"); let mut sender_socket = create_socket(&sender).await; sender_socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::Success(client_message_id)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; // Connect receiver to flush DDB and avoid polluting other tests let mut receiver_socket = create_socket(&receiver).await; - if let Some(Ok(response)) = receiver_socket.next().await { - let received_payload = response.to_text().unwrap(); - assert_eq!(payload, received_payload); - }; + let receiver_response = receive_message(&mut receiver_socket).await.unwrap(); + assert_eq!(payload, receiver_response); } #[tokio::test] async fn get_serialization_error() { let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; let message = "some bad json".to_string(); let mut sender_socket = create_socket(&sender).await; sender_socket .send(Message::Text(message.clone())) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::SerializationError(message)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } #[tokio::test] async fn get_invalid_request_error() { let sender = create_device(Some(&DEFAULT_CLIENT_KEYS)).await; let mut sender_socket = create_socket(&sender).await; sender_socket .send(Message::Binary(vec![])) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { let expected_response = MessageToDeviceRequestStatus { client_message_ids: vec![MessageSentStatus::InvalidRequest], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 2cc3a5117..af92e925f 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,252 +1,249 @@ pub mod session; use crate::database::DatabaseClient; use crate::websockets::session::SessionError; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::StreamExt; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; -use tunnelbroker_messages::{ - MessageSentStatus, MessageToDevice, MessageToDeviceRequestStatus, -}; +use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; type BoxedError = Box; use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel).await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { Ok(session) => session, Err(_) => { error!("Failed to create session with device"); return; } } } else { error!("Failed to create session with device"); return; }; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { - let message_to_device = serde_json::from_str::(message).unwrap(); - session.send_message_to_device(Message::Text(message_to_device.payload)).await; + session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { let message_status = session.handle_websocket_frame_from_device(msg).await; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) -> Result, session::SessionError> { let session = session::WebsocketSession::from_frame( outgoing, db_client.clone(), frame, &amqp_channel, ) .await .map_err(|_| { error!("Device failed to send valid connection request."); SessionError::InvalidMessage })?; Ok(session) }