diff --git a/lib/types/tunnelbroker/heartbeat-types.js b/lib/types/tunnelbroker/heartbeat-types.js new file mode 100644 --- /dev/null +++ b/lib/types/tunnelbroker/heartbeat-types.js @@ -0,0 +1,13 @@ +// @flow + +import type { TInterface } from 'tcomb'; + +import { tShape, tString } from '../../utils/validation-utils.js'; + +export type Heartbeat = { + +type: 'Heartbeat', +}; + +export const heartbeatValidator: TInterface = tShape({ + type: tString('Heartbeat'), +}); diff --git a/lib/types/tunnelbroker/messages.js b/lib/types/tunnelbroker/messages.js --- a/lib/types/tunnelbroker/messages.js +++ b/lib/types/tunnelbroker/messages.js @@ -7,6 +7,7 @@ type ConnectionInitializationResponse, connectionInitializationResponseValidator, } from './connection-initialization-response-types.js'; +import { type Heartbeat, heartbeatValidator } from './heartbeat-types.js'; import { type RefreshKeyRequest, refreshKeysRequestValidator, @@ -51,6 +52,7 @@ MESSAGE_TO_DEVICE_REQUEST: 'MessageToDeviceRequest', MESSAGE_TO_DEVICE: 'MessageToDevice', MESSAGE_RECEIVE_CONFIRMATION: 'MessageReceiveConfirmation', + HEARTBEAT: 'Heartbeat', }); export const tunnelbrokerMessageValidator: TUnion = @@ -62,6 +64,7 @@ messageToDeviceRequestValidator, messageToDeviceValidator, messageReceiveConfirmationValidator, + heartbeatValidator, ]); export type TunnelbrokerMessage = @@ -71,4 +74,5 @@ | MessageToDeviceRequestStatus | MessageToDeviceRequest | MessageToDevice - | MessageReceiveConfirmation; + | MessageReceiveConfirmation + | Heartbeat; diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -7,8 +7,8 @@ use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, ConnectionInitializationResponse, - ConnectionInitializationStatus, DeviceTypes, MessageSentStatus, - MessageToDevice, MessageToDeviceRequest, MessageToDeviceRequestStatus, + ConnectionInitializationStatus, DeviceTypes, Heartbeat, MessageSentStatus, + MessageToDeviceRequest, MessageToDeviceRequestStatus, Messages, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] @@ -90,18 +90,28 @@ pub async fn receive_message( socket: &mut WebSocketStream>, ) -> Result> { - let Some(Ok(response)) = socket.next().await else { - return Err("Failed to receive message".into()); - }; - let message = response.to_text().expect("Failed to get response content"); - let message_to_device = serde_json::from_str::(message) - .expect("Failed to parse MessageToDevice from response"); - - let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { - message_ids: vec![message_to_device.message_id], - }; - let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); - socket.send(Message::Text(serialized_confirmation)).await?; + while let Some(Ok(response)) = socket.next().await { + let message_str = + response.to_text().expect("Failed to get response content"); + let message = serde_json::from_str::(message_str).unwrap(); + match message { + Messages::MessageToDevice(msg) => { + let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { + message_ids: vec![msg.message_id], + }; + let serialized_confirmation = + serde_json::to_string(&confirmation).unwrap(); + socket.send(Message::Text(serialized_confirmation)).await?; + return Ok(msg.payload); + } + Messages::Heartbeat(Heartbeat {}) => { + let msg = Heartbeat {}; + let serialized = serde_json::to_string(&msg).unwrap(); + socket.send(Message::Text(serialized)).await?; + } + _ => return Err(format!("Unexpected message type {message:?}").into()), + } + } - Ok(message_to_device.payload) + Err("Failed to receive message".into()) } diff --git a/services/commtest/tests/tunnelbroker_heartbeat_tests.rs b/services/commtest/tests/tunnelbroker_heartbeat_tests.rs new file mode 100644 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_heartbeat_tests.rs @@ -0,0 +1,85 @@ +use commtest::identity::device::create_device; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; +use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::Message::Close; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tunnelbroker_messages::Heartbeat; + +async fn receive_and_parse_message( + socket: &mut WebSocketStream>, +) -> Heartbeat { + if let Some(Ok(response)) = socket.next().await { + let message = response + .to_text() + .expect("Unable to retrieve response content"); + serde_json::from_str::(message) + .expect("Unable to parse Heartbeat from response") + } else { + panic!("Received incorrect message type.") + } +} + +#[tokio::test] +async fn test_receiving() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + socket + .send(Close(None)) + .await + .expect("Failed to close socket"); +} + +#[tokio::test] +async fn test_responding() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + let heartbeat = Heartbeat {}; + let serialized = serde_json::to_string(&heartbeat).unwrap(); + socket + .send(Message::Text(serialized)) + .await + .expect("Error while sending heartbeat"); + + // Receive and parse another heartbeat message + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + socket + .send(Close(None)) + .await + .expect("Failed to close the socket"); +} + +#[tokio::test] +async fn test_closing() { + let client = create_device(None).await; + let mut socket = create_socket(&client).await.unwrap(); + + let message_to_device = receive_and_parse_message(&mut socket).await; + + assert_eq!(message_to_device, Heartbeat {}); + + // The next message should be a Close message because we did not respond + // to the Heartbeat. + // This suggests that the Tunnelbroker might consider the connection + // as unhealthy and decide to close it. + if let Some(Ok(response)) = socket.next().await { + assert_eq!(response, Message::Close(None)) + } else { + panic!("Received incorrect message type. Expected Close.") + } +} diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -5,6 +5,8 @@ pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); +pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); + pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,5 +1,6 @@ pub mod session; +use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; use crate::database::DatabaseClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; @@ -18,7 +19,7 @@ use tokio::net::TcpListener; use tracing::{debug, error, info}; use tunnelbroker_messages::{ - ConnectionInitializationStatus, MessageSentStatus, + ConnectionInitializationStatus, Heartbeat, MessageSentStatus, MessageToDeviceRequestStatus, }; @@ -207,6 +208,9 @@ return; }; + let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + let mut got_heartbeat_response = true; + // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { @@ -240,6 +244,9 @@ session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { + got_heartbeat_response = true; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; @@ -263,6 +270,17 @@ } } }, + _ = &mut ping_timeout => { + if !got_heartbeat_response { + error!("Connection to {} died", addr); + break; + } + let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); + session.send_message_to_device(Message::text(serialized)).await; + + got_heartbeat_response = false; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + } else => { debug!("Unhealthy connection for: {}", addr); break; diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -22,7 +22,7 @@ use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - MessageToDevice, MessageToDeviceRequest, Messages, + Heartbeat, MessageToDevice, MessageToDeviceRequest, Messages, }; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; @@ -244,6 +244,10 @@ }; match serialized_message { + Messages::Heartbeat(Heartbeat {}) => { + debug!("Received heartbeat from: {}", self.device_info.device_id); + None + } Messages::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self diff --git a/shared/tunnelbroker_messages/src/messages/heartbeat.rs b/shared/tunnelbroker_messages/src/messages/heartbeat.rs new file mode 100644 --- /dev/null +++ b/shared/tunnelbroker_messages/src/messages/heartbeat.rs @@ -0,0 +1,23 @@ +//! Messages sent between Tunnelbroker and devices. + +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +#[serde(tag = "type")] +pub struct Heartbeat {} + +#[cfg(test)] +mod heartbeat_tests { + use super::*; + + #[test] + fn test_heartbeat_deserialization() { + let example_payload = r#"{ + "type": "Heartbeat" + }"#; + + let request = serde_json::from_str::(example_payload).unwrap(); + let expected = Heartbeat {}; + assert_eq!(request, expected); + } +} diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,6 +1,7 @@ //! Messages sent between Tunnelbroker and a device. pub mod connection_initialization_response; +pub mod heartbeat; pub mod keys; pub mod message_receive_confirmation; pub mod message_to_device; @@ -9,6 +10,7 @@ pub mod session; pub use connection_initialization_response::*; +pub use heartbeat::*; pub use keys::*; pub use message_receive_confirmation::*; pub use message_to_device::*; @@ -39,4 +41,5 @@ MessageToDeviceRequest(MessageToDeviceRequest), MessageToDevice(MessageToDevice), MessageReceiveConfirmation(MessageReceiveConfirmation), + Heartbeat(Heartbeat), }