diff --git a/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs new file mode 100644 --- /dev/null +++ b/services/commtest/tests/tunnelbroker_websocket_messages_tests.rs @@ -0,0 +1,45 @@ +use commtest::identity::device::create_device; +use commtest::identity::olm_account_infos::MOCK_CLIENT_KEYS_1; +use commtest::tunnelbroker::socket::create_socket; +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::tungstenite::{Error, Message, Message::Close}; + +/// Tests for message types defined in tungstenite crate + +#[tokio::test] +async fn test_ping_pong() { + let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + + let ping_message = vec![1, 2, 3, 4, 5]; + + let mut socket = create_socket(&device).await; + socket + .send(Message::Ping(ping_message.clone())) + .await + .expect("Failed to send message"); + + if let Some(Ok(response)) = socket.next().await { + let Message::Pong(received_payload) = response else { + panic!("Unexpected message type or result. Expected Pong. ") + }; + assert_eq!(ping_message.clone(), received_payload); + }; +} + +#[tokio::test] +async fn test_close_message() { + let device = create_device(Some(&MOCK_CLIENT_KEYS_1)).await; + + let mut socket = create_socket(&device).await; + socket + .send(Close(None)) + .await + .expect("Failed to send message"); + + if let Some(response) = socket.next().await { + assert!(matches!( + response, + Err(Error::AlreadyClosed | Error::ConnectionClosed) | Ok(Close(None)) + )); + }; +} diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -168,20 +168,37 @@ tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { - session.send_message_to_device(message.to_string()).await; + session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { - match device_message { - Some(Ok(msg)) => { - session::consume_error(session.handle_websocket_frame_from_device(msg).await); - } + let message: Message = match device_message { + Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } + }; + match message { + Message::Close(_) => { + debug!("Connection to {} closed.", addr); + break; + } + Message::Pong(_) => { + debug!("Received Pong message from {}", addr); + } + Message::Ping(msg) => { + debug!("Received Ping message from {}", addr); + session.send_message_to_device(Message::Pong(msg)).await; + } + Message::Text(msg) => { + session::consume_error(session.handle_websocket_frame_from_device(msg).await); + } + _ => { + error!("Client sent invalid message type"); + } } }, else => { diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -155,17 +155,9 @@ pub async fn handle_websocket_frame_from_device( &self, - msg: Message, + msg: String, ) -> Result<(), SessionError> { - let text_msg = match msg { - Message::Text(payload) => payload, - _ => { - error!("Client sent invalid message type"); - return Err(SessionError::InvalidMessage); - } - }; - - let serialized_message = serde_json::from_str::(&text_msg)?; + let serialized_message = serde_json::from_str::(&msg)?; match serialized_message { Messages::MessageToDevice(message_to_device) => { @@ -220,7 +212,9 @@ for message in messages { let device_message = DeviceMessage::from_hashmap(message)?; - self.send_message_to_device(device_message.payload).await; + self + .send_message_to_device(Message::Text(device_message.payload)) + .await; if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &device_message.message_id) @@ -238,8 +232,8 @@ Ok(()) } - pub async fn send_message_to_device(&mut self, incoming_payload: String) { - if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { + pub async fn send_message_to_device(&mut self, message: Message) { + if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } }