diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index 1414b1d10..2b9ef2c35 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,31 +1,36 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); +pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; +pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; +pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; +pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; + pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } } diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs index 51ec821ec..e395bc932 100644 --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -1,88 +1,89 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use lapin::{options::BasicPublishOptions, BasicProperties}; use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; use proto::Empty; use tonic::transport::Server; use tracing::debug; +use crate::constants::CLIENT_RMQ_MSG_PRIORITY; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, amqp_channel: lapin::Channel, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { match error { lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => { tonic::Status::invalid_argument("Invalid argument") } _ => tonic::Status::internal("Internal Error"), } } #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Received message for {}", &message.device_id); let client_message_id = uuid::Uuid::new_v4().to_string(); self .client .persist_message(&message.device_id, &message.payload, &client_message_id) .await .map_err(handle_ddb_error)?; self .amqp_channel .basic_publish( "", &message.device_id, BasicPublishOptions::default(), message.payload.as_bytes(), - BasicProperties::default(), + BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await .map_err(handle_amqp_error)?; let response = tonic::Response::new(Empty {}); Ok(response) } } pub async fn run_server( client: DatabaseClient, ampq_connection: &lapin::Connection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); let amqp_channel = ampq_connection .create_channel() .await .expect("Unable to create amqp channel"); tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, amqp_channel, })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index bc2f12554..af92e925f 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,251 +1,249 @@ pub mod session; use crate::database::DatabaseClient; use crate::websockets::session::SessionError; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::StreamExt; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; use tunnelbroker_messages::{MessageSentStatus, MessageToDeviceRequestStatus}; type BoxedError = Box; use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel).await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { Ok(session) => session, Err(_) => { error!("Failed to create session with device"); return; } } } else { error!("Failed to create session with device"); return; }; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { let message_status = session.handle_websocket_frame_from_device(msg).await; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) -> Result, session::SessionError> { - let mut session = session::WebsocketSession::from_frame( + let session = session::WebsocketSession::from_frame( outgoing, db_client.clone(), frame, &amqp_channel, ) .await .map_err(|_| { error!("Device failed to send valid connection request."); SessionError::InvalidMessage })?; - session::consume_error(session.deliver_persisted_messages().await); - Ok(session) } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 3258a419c..b3325014c 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,309 +1,311 @@ +use crate::constants::{ + CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, + RMQ_CONSUMER_TAG, +}; use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, MessageToDeviceRequest, Messages, }; use crate::database::{self, DatabaseClient, DeviceMessage}; -use crate::error::Error; use crate::identity; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), } -pub fn consume_error(result: Result) { - if let Err(e) = result { - error!("{}", e) - } -} - // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, -) -> Result { +) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); - return Err(SessionError::InternalError.into()); + return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); - return Err(SessionError::UnauthorizedDevice.into()); + return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } _ => { debug!("Received invalid request"); - Err(SessionError::InvalidMessage.into()) + Err(SessionError::InvalidMessage) } } } +async fn publish_persisted_messages( + db_client: &DatabaseClient, + amqp_channel: &lapin::Channel, + device_info: &DeviceInfo, +) -> Result<(), SessionError> { + let messages = db_client + .retrieve_messages(&device_info.device_id) + .await + .unwrap_or_else(|e| { + error!("Error while retrieving messages: {}", e); + Vec::new() + }); + + for message in messages { + let device_message = DeviceMessage::from_hashmap(message)?; + + amqp_channel + .basic_publish( + "", + &device_message.device_id, + BasicPublishOptions::default(), + device_message.payload.as_bytes(), + BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), + ) + .await?; + + if let Err(e) = db_client + .delete_message(&device_info.device_id, &device_message.message_id) + .await + { + error!("Failed to delete message: {}:", e); + } + } + + debug!("Flushed messages for device: {}", &device_info.device_id); + Ok(()) +} + impl WebsocketSession { pub async fn from_frame( tx: SplitSink, Message>, db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, - ) -> Result, Error> { + ) -> Result, SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); - return Err(SessionError::InvalidMessage.into()); + return Err(SessionError::InvalidMessage); } }; - // We don't currently have a use case to interact directly with the queue, - // however, we need to declare a queue for a given device + let mut args = FieldTable::default(); + args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare( &device_info.device_id, QueueDeclareOptions::default(), - FieldTable::default(), + args, ) .await?; + publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; + let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, - "tunnelbroker", + RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok(WebsocketSession { tx, db_client, device_info, amqp_channel: amqp_channel.clone(), amqp_consumer, }) } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), message_request.payload.as_bytes(), - BasicProperties::default(), + BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> MessageSentStatus { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return MessageSentStatus::SerializationError(msg); }; match serialized_message { Messages::MessageToDeviceRequest(message_request) => { debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; self.get_message_to_device_status( &message_request.client_message_id, result, ) } _ => { error!("Client sent invalid message type"); MessageSentStatus::InvalidRequest } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } - pub async fn deliver_persisted_messages( - &mut self, - ) -> Result<(), SessionError> { - // Check for persisted messages - let messages = self - .db_client - .retrieve_messages(&self.device_info.device_id) - .await - .unwrap_or_else(|e| { - error!("Error while retrieving messages: {}", e); - Vec::new() - }); - - for message in messages { - let device_message = DeviceMessage::from_hashmap(message)?; - self - .send_message_to_device(Message::Text(device_message.payload)) - .await; - if let Err(e) = self - .db_client - .delete_message(&self.device_info.device_id, &device_message.message_id) - .await - { - error!("Failed to delete message: {}:", e); - } - } - - debug!( - "Flushed messages for device: {}", - &self.device_info.device_id - ); - - Ok(()) - } - pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), SessionError>, ) -> MessageSentStatus { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } }