diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -83,5 +83,12 @@ let message = response.to_text().expect("Failed to get response content"); let message_to_device = serde_json::from_str::(message) .expect("Failed to parse MessageToDevice from response"); + + let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { + message_ids: vec![message_to_device.message_id], + }; + let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); + socket.send(Message::Text(serialized_confirmation)).await?; + Ok(message_to_device.payload) } diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs --- a/services/commtest/tests/identity_tunnelbroker_tests.rs +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -92,13 +92,16 @@ let device_info = create_device(None).await; let mut socket = create_socket(&device_info).await; - let response = receive_message(&mut socket).await.unwrap(); - let serialized_response: RefreshKeyRequest = - serde_json::from_str(&response).unwrap(); + for _ in 0..2 { + let response = receive_message(&mut socket).await.unwrap(); + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response).unwrap(); - let expected_response = RefreshKeyRequest { - device_id: device_info.device_id.to_string(), - number_of_keys: 5, - }; - assert_eq!(serialized_response, expected_response); + let expected_response = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, + }; + + assert_eq!(serialized_response, expected_response); + } } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -195,7 +195,9 @@ session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { - let message_status = session.handle_websocket_frame_from_device(msg).await; + let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { + continue; + }; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -136,13 +136,6 @@ BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; - - if let Err(e) = db_client - .delete_message(&device_info.device_id, &message_to_device.message_id) - .await - { - error!("Failed to delete message: {}:", e); - } } debug!("Flushed messages for device: {}", &device_info.device_id); @@ -242,24 +235,37 @@ pub async fn handle_websocket_frame_from_device( &mut self, msg: String, - ) -> MessageSentStatus { + ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { - return MessageSentStatus::SerializationError(msg); + return Option::from(MessageSentStatus::SerializationError(msg)); }; match serialized_message { + Messages::MessageReceiveConfirmation(confirmation) => { + for message_id in confirmation.message_ids { + if let Err(e) = self + .db_client + .delete_message(&self.device_info.device_id, &message_id) + .await + { + error!("Failed to delete message: {}:", e); + } + } + + None + } Messages::MessageToDeviceRequest(message_request) => { debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; - self.get_message_to_device_status( + Option::from(self.get_message_to_device_status( &message_request.client_message_id, result, - ) + )) } _ => { error!("Client sent invalid message type"); - MessageSentStatus::InvalidRequest + Option::from(MessageSentStatus::InvalidRequest) } } }