diff --git a/services/commtest/tests/tunnelbroker_integration_tests.rs b/services/commtest/tests/tunnelbroker_integration_tests.rs index 45a0f2446..ff70b5b73 100644 --- a/services/commtest/tests/tunnelbroker_integration_tests.rs +++ b/services/commtest/tests/tunnelbroker_integration_tests.rs @@ -1,92 +1,48 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use commtest::identity::device::create_device; use commtest::tunnelbroker::socket::create_socket; use futures_util::StreamExt; use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; use proto::MessageToDevice; use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] async fn send_refresh_request() { // Create session as a keyserver let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; // Send request for keyserver to refresh keys (identity service) let mut tunnelbroker_client = TunnelbrokerServiceClient::connect("http://localhost:50051") .await .unwrap(); let refresh_request = RefreshKeyRequest { device_id: device_info.device_id.clone(), number_of_keys: 5, }; let payload = serde_json::to_string(&refresh_request).unwrap(); let request = MessageToDevice { device_id: device_info.device_id.clone(), payload, }; let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); // Have keyserver receive any websocket messages let response = socket.next().await.unwrap().unwrap(); // Check that message received by keyserver matches what identity server // issued let serialized_response: RefreshKeyRequest = serde_json::from_str(&response.to_text().unwrap()).unwrap(); assert_eq!(serialized_response, refresh_request); } - -/// Test that a message to an offline device gets pushed to dynamodb -/// then recalled once a device connects -#[tokio::test] -async fn persist_messages() { - let device_info = create_device(None).await; - - // Send request for keyserver to refresh keys (identity service) - let mut tunnelbroker_client = - TunnelbrokerServiceClient::connect("http://localhost:50051") - .await - .unwrap(); - - let refresh_request = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - - let payload = serde_json::to_string(&refresh_request).unwrap(); - let request = MessageToDevice { - device_id: device_info.device_id.to_string(), - payload, - }; - let grpc_message = tonic::Request::new(request); - tunnelbroker_client - .send_message_to_device(grpc_message) - .await - .unwrap(); - - // Wait one second to ensure that message had time to persist - use std::{thread, time}; - let ten_millis = time::Duration::from_millis(50); - thread::sleep(ten_millis); - - let mut socket = create_socket(&device_info).await; - // Have keyserver receive any websocket messages - if let Some(Ok(response)) = socket.next().await { - // Check that message received by keyserver matches what identity server - // issued - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response.to_text().unwrap()).unwrap(); - assert_eq!(serialized_response, refresh_request); - }; -} diff --git a/services/commtest/tests/tunnelbroker_persist_tests.rs b/services/commtest/tests/tunnelbroker_persist_tests.rs new file mode 100644 index 000000000..5461b5f44 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_persist_tests.rs @@ -0,0 +1,95 @@ +mod proto { + tonic::include_proto!("tunnelbroker"); +} + +use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::{ + MOCK_CLIENT_KEYS_1, MOCK_CLIENT_KEYS_2, +}; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::{SinkExt, StreamExt}; +use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; +use proto::MessageToDevice; +use std::time::Duration; +use tokio::time::sleep; +use tokio_tungstenite::tungstenite::Message; +use tunnelbroker_messages::{ + MessageToDevice as WebsocketMessageToDevice, RefreshKeyRequest, +}; + +/// Tests that a message to an offline device gets pushed to dynamodb +/// then recalled once a device connects + +#[tokio::test] +async fn persist_grpc_messages() { + let device_info = create_device(None).await; + + // Send request for keyserver to refresh keys (identity service) + let mut tunnelbroker_client = + TunnelbrokerServiceClient::connect("http://localhost:50051") + .await + .unwrap(); + + let refresh_request = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, + }; + + let payload = serde_json::to_string(&refresh_request).unwrap(); + let request = MessageToDevice { + device_id: device_info.device_id.to_string(), + payload, + }; + let grpc_message = tonic::Request::new(request); + tunnelbroker_client + .send_message_to_device(grpc_message) + .await + .unwrap(); + + // Wait a specified duration to ensure that message had time to persist + sleep(Duration::from_millis(100)).await; + + let mut socket = create_socket(&device_info).await; + // Have keyserver receive any websocket messages + if let Some(Ok(response)) = socket.next().await { + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + assert_eq!(serialized_response, refresh_request); + }; +} + +#[tokio::test] +async fn persist_websocket_messages() { + let sender = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + let receiver = create_device(Some(&MOCK_CLIENT_KEYS_2)).await; + + // Send message to not connected client + let payload = "persisted message"; + let request = WebsocketMessageToDevice { + device_id: receiver.device_id.clone(), + payload: payload.to_string(), + }; + + let serialized_request = serde_json::to_string(&request) + .expect("Failed to serialize message to device"); + + let mut sender_socket = create_socket(&sender).await; + sender_socket + .send(Message::Text(serialized_request)) + .await + .expect("Failed to send message"); + + // Wait a specified duration to ensure that message had time to persist + sleep(Duration::from_millis(100)).await; + + // Connect receiver + let mut receiver_socket = create_socket(&receiver).await; + + // Receive message + if let Some(Ok(response)) = receiver_socket.next().await { + let received_payload = response.to_text().unwrap(); + assert_eq!(payload, received_payload); + }; +} diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs index 3c4e7a32b..4df46a6ca 100644 --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -1,86 +1,86 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use lapin::{options::BasicPublishOptions, BasicProperties}; use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; use proto::Empty; use tonic::transport::Server; use tracing::debug; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, amqp_channel: lapin::Channel, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { match error { lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => { tonic::Status::invalid_argument("Invalid argument") } _ => tonic::Status::internal("Internal Error"), } } #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Received message for {}", &message.device_id); self .client .persist_message(&message.device_id, &message.payload) .await .map_err(handle_ddb_error)?; self .amqp_channel .basic_publish( "", &message.device_id, BasicPublishOptions::default(), - &message.payload.as_bytes(), + message.payload.as_bytes(), BasicProperties::default(), ) .await .map_err(handle_amqp_error)?; let response = tonic::Response::new(Empty {}); Ok(response) } } pub async fn run_server( client: DatabaseClient, ampq_connection: &lapin::Connection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); let amqp_channel = ampq_connection .create_channel() .await .expect("Unable to create amqp channel"); tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, amqp_channel, })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 21e9f84be..bc8ca8b2a 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,207 +1,251 @@ +use aws_sdk_dynamodb::error::SdkError; +use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; -use lapin::options::{BasicConsumeOptions, QueueDeclareOptions}; +use lapin::options::{ + BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, +}; use lapin::types::FieldTable; +use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::database::{self, DatabaseClient, DeviceMessage}; use crate::error::Error; use crate::identity; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, + amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, + PersistenceError(SdkError), } pub fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError.into()); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice.into()); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage.into()) } } } impl WebsocketSession { pub async fn from_frame( tx: SplitSink, Message>, db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result, Error> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage.into()); } }; // We don't currently have a use case to interact directly with the queue, // however, we need to declare a queue for a given device amqp_channel .queue_declare( &device_info.device_id, QueueDeclareOptions::default(), FieldTable::default(), ) .await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, "tunnelbroker", BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok(WebsocketSession { tx, db_client, device_info, + amqp_channel: amqp_channel.clone(), amqp_consumer, }) } pub async fn handle_websocket_frame_from_device( &self, msg: Message, ) -> Result<(), SessionError> { - debug!("Received frame: {:?}", msg); + let text_msg = match msg { + Message::Text(payload) => payload, + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + }; + + let serialized_message = serde_json::from_str::(&text_msg)?; + + match serialized_message { + Messages::MessageToDevice(message_to_device) => { + debug!("Received message for {}", message_to_device.device_id); + self + .db_client + .persist_message( + message_to_device.device_id.as_str(), + message_to_device.payload.as_str(), + ) + .await?; + + self + .amqp_channel + .basic_publish( + "", + &message_to_device.device_id, + BasicPublishOptions::default(), + message_to_device.payload.as_bytes(), + BasicProperties::default(), + ) + .await?; + } + _ => { + error!("Client sent invalid message type"); + return Err(SessionError::InvalidMessage); + } + } Ok(()) } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn deliver_persisted_messages( &mut self, ) -> Result<(), SessionError> { // Check for persisted messages let messages = self .db_client .retrieve_messages(&self.device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let device_message = DeviceMessage::from_hashmap(message)?; self.send_message_to_device(device_message.payload).await; if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &device_message.created_at) .await { error!("Failed to delete message: {}:", e); } } debug!( "Flushed messages for device: {}", &self.device_info.device_id ); Ok(()) } pub async fn send_message_to_device(&mut self, incoming_payload: String) { if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { error!("Failed to send message to device: {}", e); } } // Release websocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close session: {}", e); } } }