diff --git a/lib/types/tunnelbroker/heartbeat-types.js b/lib/types/tunnelbroker/heartbeat-types.js new file mode 100644 index 000000000..66e0a3523 --- /dev/null +++ b/lib/types/tunnelbroker/heartbeat-types.js @@ -0,0 +1,13 @@ +// @flow + +import type { TInterface } from 'tcomb'; + +import { tShape, tString } from '../../utils/validation-utils.js'; + +export type Heartbeat = { + +type: 'Heartbeat', +}; + +export const heartbeatValidator: TInterface = tShape({ + type: tString('Heartbeat'), +}); diff --git a/lib/types/tunnelbroker/messages.js b/lib/types/tunnelbroker/messages.js index 904c8e8e1..d8a0ae883 100644 --- a/lib/types/tunnelbroker/messages.js +++ b/lib/types/tunnelbroker/messages.js @@ -1,74 +1,78 @@ // @flow import type { TUnion } from 'tcomb'; import t from 'tcomb'; import { type ConnectionInitializationResponse, connectionInitializationResponseValidator, } from './connection-initialization-response-types.js'; +import { type Heartbeat, heartbeatValidator } from './heartbeat-types.js'; import { type RefreshKeyRequest, refreshKeysRequestValidator, } from './keys-types.js'; import { type MessageReceiveConfirmation, messageReceiveConfirmationValidator, } from './message-receive-confirmation-types.js'; import { type MessageToDeviceRequestStatus, messageToDeviceRequestStatusValidator, } from './message-to-device-request-status-types.js'; import { type MessageToDeviceRequest, messageToDeviceRequestValidator, } from './message-to-device-request-types.js'; import { type MessageToDevice, messageToDeviceValidator, } from './message-to-device-types.js'; import { type ConnectionInitializationMessage, connectionInitializationMessageValidator, } from './session-types.js'; /* * This file defines types and validation for messages exchanged * with the Tunnelbroker. The definitions in this file should remain in sync * with the structures defined in the corresponding * Rust file at `shared/tunnelbroker_messages/src/messages/mod.rs`. * * If you edit the definitions in one file, * please make sure to update the corresponding definitions in the other. * */ export const tunnelbrokerMessageTypes = Object.freeze({ REFRESH_KEYS_REQUEST: 'RefreshKeyRequest', CONNECTION_INITIALIZATION_MESSAGE: 'ConnectionInitializationMessage', CONNECTION_INITIALIZATION_RESPONSE: 'ConnectionInitializationResponse', MESSAGE_TO_DEVICE_REQUEST_STATUS: 'MessageToDeviceRequestStatus', MESSAGE_TO_DEVICE_REQUEST: 'MessageToDeviceRequest', MESSAGE_TO_DEVICE: 'MessageToDevice', MESSAGE_RECEIVE_CONFIRMATION: 'MessageReceiveConfirmation', + HEARTBEAT: 'Heartbeat', }); export const tunnelbrokerMessageValidator: TUnion = t.union([ refreshKeysRequestValidator, connectionInitializationMessageValidator, connectionInitializationResponseValidator, messageToDeviceRequestStatusValidator, messageToDeviceRequestValidator, messageToDeviceValidator, messageReceiveConfirmationValidator, + heartbeatValidator, ]); export type TunnelbrokerMessage = | RefreshKeyRequest | ConnectionInitializationMessage | ConnectionInitializationResponse | MessageToDeviceRequestStatus | MessageToDeviceRequest | MessageToDevice - | MessageReceiveConfirmation; + | MessageReceiveConfirmation + | Heartbeat; diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index 846b3766f..fca00f1d1 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,107 +1,117 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, ConnectionInitializationResponse, - ConnectionInitializationStatus, DeviceTypes, MessageSentStatus, - MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, + ConnectionInitializationStatus, DeviceTypes, Heartbeat, MessageSentStatus, + MessageToDeviceRequest, MessageToDeviceRequestStatus, Messages, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct WebSocketMessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } pub async fn create_socket( device_info: &DeviceInfo, ) -> Result< WebSocketStream>, Box, > { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = socket.next().await { let response: ConnectionInitializationResponse = serde_json::from_str(response.to_text().unwrap())?; return match response.status { ConnectionInitializationStatus::Success => Ok(socket), ConnectionInitializationStatus::Error(err) => Err(err.into()), }; } Err("Failed to get response from Tunnelbroker".into()) } pub async fn send_message( socket: &mut WebSocketStream>, message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { let confirmation: MessageToDeviceRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } pub async fn receive_message( socket: &mut WebSocketStream>, ) -> Result> { - let Some(Ok(response)) = socket.next().await else { - return Err("Failed to receive message".into()); - }; - let message = response.to_text().expect("Failed to get response content"); - let message_to_device = serde_json::from_str::(message) - .expect("Failed to parse MessageToDevice from response"); - - let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { - message_ids: vec![message_to_device.message_id], - }; - let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); - socket.send(Message::Text(serialized_confirmation)).await?; + while let Some(Ok(response)) = socket.next().await { + let message_str = + response.to_text().expect("Failed to get response content"); + let message = serde_json::from_str::(message_str).unwrap(); + match message { + Messages::MessageToDevice(msg) => { + let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { + message_ids: vec![msg.message_id], + }; + let serialized_confirmation = + serde_json::to_string(&confirmation).unwrap(); + socket.send(Message::Text(serialized_confirmation)).await?; + return Ok(msg.payload); + } + Messages::Heartbeat(Heartbeat {}) => { + let msg = Heartbeat {}; + let serialized = serde_json::to_string(&msg).unwrap(); + socket.send(Message::Text(serialized)).await?; + } + _ => return Err(format!("Unexpected message type {message:?}").into()), + } + } - Ok(message_to_device.payload) + Err("Failed to receive message".into()) } diff --git a/services/commtest/tests/tunnelbroker_heartbeat_tests.rs b/services/commtest/tests/tunnelbroker_heartbeat_tests.rs new file mode 100644 index 000000000..3c9b6e628 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_heartbeat_tests.rs @@ -0,0 +1,85 @@ +use commtest::identity::device::create_device; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; +use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::Message::Close; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tunnelbroker_messages::Heartbeat; + +async fn receive_and_parse_message( + socket: &mut WebSocketStream>, +) -> Heartbeat { + if let Some(Ok(response)) = socket.next().await { + let message = response + .to_text() + .expect("Unable to retrieve response content"); + serde_json::from_str::(message) + .expect("Unable to parse Heartbeat from response") + } else { + panic!("Received incorrect message type.") + } +} + +#[tokio::test] +async fn test_receiving() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + socket + .send(Close(None)) + .await + .expect("Failed to close socket"); +} + +#[tokio::test] +async fn test_responding() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + let heartbeat = Heartbeat {}; + let serialized = serde_json::to_string(&heartbeat).unwrap(); + socket + .send(Message::Text(serialized)) + .await + .expect("Error while sending heartbeat"); + + // Receive and parse another heartbeat message + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + socket + .send(Close(None)) + .await + .expect("Failed to close the socket"); +} + +#[tokio::test] +async fn test_closing() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + // The next message should be a Close message because we did not respond + // to the Heartbeat. + // This suggests that the Tunnelbroker might consider the connection + // as unhealthy and decide to close it. + if let Some(Ok(response)) = socket.next().await { + assert_eq!(response, Message::Close(None)) + } else { + panic!("Received incorrect message type. Expected Close.") + } +} diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index 2b9ef2c35..557728d5e 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,36 +1,38 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); +pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); + pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 67046fc68..3a66439c5 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,296 +1,314 @@ pub mod session; +use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; use crate::database::DatabaseClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; use tunnelbroker_messages::{ - ConnectionInitializationStatus, MessageSentStatus, + ConnectionInitializationStatus, Heartbeat, MessageSentStatus, MessageToDeviceRequestStatus, }; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel).await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_error_init_response( error: SessionError, mut outgoing: SplitSink, Message>, ) { let error_response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error(error.to_string()), }; match serde_json::to_string(&error_response) { Ok(serialized_response) => { if let Err(send_error) = outgoing.send(Message::Text(serialized_response)).await { error!("Failed to send init error response: {:?}", send_error); } } Err(ser_error) => { error!("Failed to serialize the error response: {:?}", ser_error); } } } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&response).unwrap(); session .send_message_to_device(Message::Text(serialized_response)) .await; session } Err((err, outgoing)) => { error!("Failed to create session with device"); send_error_init_response(err, outgoing).await; return; } } } else { error!("Failed to create session with device"); send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; + let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + let mut got_heartbeat_response = true; + // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { + got_heartbeat_response = true; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, + _ = &mut ping_timeout => { + if !got_heartbeat_response { + error!("Connection to {} died", addr); + break; + } + let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); + session.send_message_to_device(Message::text(serialized)).await; + + got_heartbeat_response = false; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + } else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, )), Err(e) => Err((e, outgoing)), } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 6265811c2..cffe478c0 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,330 +1,334 @@ use crate::constants::{ CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, }; use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - MessageToDevice, MessageToDeviceRequest, Messages, + Heartbeat, MessageToDevice, MessageToDeviceRequest, Messages, }; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage) } } } async fn publish_persisted_messages( db_client: &DatabaseClient, amqp_channel: &lapin::Channel, device_info: &DeviceInfo, ) -> Result<(), SessionError> { let messages = db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let message_to_device = MessageToDevice::from_hashmap(message)?; let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; } debug!("Flushed messages for device: {}", &device_info.device_id); Ok(()) } pub async fn initialize_amqp( db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; let mut args = FieldTable::default(); args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) .await?; publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok((device_info, amqp_consumer)) } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, ) -> Self { Self { tx, db_client, device_info, amqp_channel, amqp_consumer, } } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let message_to_device = MessageToDevice { device_id: message_request.device_id.clone(), payload: message_request.payload.clone(), message_id: message_id.clone(), }; let serialized_message = serde_json::to_string(&message_to_device)?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return Option::from(MessageSentStatus::SerializationError(msg)); }; match serialized_message { + Messages::Heartbeat(Heartbeat {}) => { + debug!("Received heartbeat from: {}", self.device_info.device_id); + None + } Messages::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &message_id) .await { error!("Failed to delete message: {}:", e); } } None } Messages::MessageToDeviceRequest(message_request) => { debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; Option::from(self.get_message_to_device_status( &message_request.client_message_id, result, )) } _ => { error!("Client sent invalid message type"); Option::from(MessageSentStatus::InvalidRequest) } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), SessionError>, ) -> MessageSentStatus { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } } diff --git a/shared/tunnelbroker_messages/src/messages/heartbeat.rs b/shared/tunnelbroker_messages/src/messages/heartbeat.rs new file mode 100644 index 000000000..b2c05fbc7 --- /dev/null +++ b/shared/tunnelbroker_messages/src/messages/heartbeat.rs @@ -0,0 +1,23 @@ +//! Messages sent between Tunnelbroker and devices. + +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +#[serde(tag = "type")] +pub struct Heartbeat {} + +#[cfg(test)] +mod heartbeat_tests { + use super::*; + + #[test] + fn test_heartbeat_deserialization() { + let example_payload = r#"{ + "type": "Heartbeat" + }"#; + + let request = serde_json::from_str::(example_payload).unwrap(); + let expected = Heartbeat {}; + assert_eq!(request, expected); + } +} diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index 3282c9b33..2f0cfc084 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,42 +1,45 @@ //! Messages sent between Tunnelbroker and a device. pub mod connection_initialization_response; +pub mod heartbeat; pub mod keys; pub mod message_receive_confirmation; pub mod message_to_device; pub mod message_to_device_request; pub mod message_to_device_request_status; pub mod session; pub use connection_initialization_response::*; +pub use heartbeat::*; pub use keys::*; pub use message_receive_confirmation::*; pub use message_to_device::*; pub use message_to_device_request::*; pub use message_to_device_request_status::*; pub use session::*; use serde::{Deserialize, Serialize}; // This file defines types and validation for messages exchanged // with the Tunnelbroker. The definitions in this file should remain in sync // with the structures defined in the corresponding // JavaScript file at `lib/types/tunnelbroker/messages.js`. // If you edit the definitions in one file, // please make sure to update the corresponding definitions in the other. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum Messages { RefreshKeysRequest(RefreshKeyRequest), ConnectionInitializationMessage(ConnectionInitializationMessage), ConnectionInitializationResponse(ConnectionInitializationResponse), // MessageToDeviceRequestStatus must be placed before MessageToDeviceRequest. // This is due to serde's pattern matching behavior where it prioritizes // the first matching pattern it encounters. MessageToDeviceRequestStatus(MessageToDeviceRequestStatus), MessageToDeviceRequest(MessageToDeviceRequest), MessageToDevice(MessageToDevice), MessageReceiveConfirmation(MessageReceiveConfirmation), + Heartbeat(Heartbeat), }