diff --git a/services/tunnelbroker/src/cxx_bridge.rs b/services/tunnelbroker/src/cxx_bridge.rs --- a/services/tunnelbroker/src/cxx_bridge.rs +++ b/services/tunnelbroker/src/cxx_bridge.rs @@ -21,6 +21,13 @@ deviceOS: String, isOnline: bool, } + struct MessageItem { + messageID: String, + fromDeviceID: String, + toDeviceID: String, + payload: String, + blobHashes: String, + } unsafe extern "C++" { include!("tunnelbroker/src/libcpp/Tunnelbroker.h"); @@ -44,5 +51,11 @@ sessionID: &str, newNotifToken: &str, ) -> Result<()>; + pub fn getMessagesFromDatabase(deviceID: &str) -> Result>; + pub fn eraseMessagesFromAMQP(deviceID: &str) -> Result<()>; + pub fn removeMessagesFromDatabase( + deviceID: &str, + messageIDs: Vec, + ) -> Result<()>; } } diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.h b/services/tunnelbroker/src/libcpp/Tunnelbroker.h --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.h +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.h @@ -16,3 +16,8 @@ SessionItem getSessionItem(rust::Str sessionID); void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline); void updateSessionItemDeviceToken(rust::Str sessionID, rust::Str newNotifToken); +rust::Vec getMessagesFromDatabase(rust::Str deviceID); +void eraseMessagesFromAMQP(rust::Str deviceID); +void removeMessagesFromDatabase( + rust::Str deviceID, + rust::Vec messageIDs); diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp @@ -4,6 +4,7 @@ #include "ConfigManager.h" #include "CryptoTools.h" #include "DatabaseManager.h" +#include "DeliveryBroker.h" #include "GlobalTools.h" #include "Tools.h" @@ -169,3 +170,36 @@ .updateSessionItemDeviceToken( std::string{sessionID}, std::string{newNotifToken}); } + +rust::Vec getMessagesFromDatabase(rust::Str deviceID) { + std::vector> + messagesFromDatabase = + comm::network::database::DatabaseManager::getInstance() + .findMessageItemsByReceiver(std::string{deviceID}); + rust::Vec result; + for (auto &messageFromDatabase : messagesFromDatabase) { + result.push_back(MessageItem{ + .messageID = messageFromDatabase->getMessageID(), + .fromDeviceID = messageFromDatabase->getFromDeviceID(), + .payload = messageFromDatabase->getPayload(), + .blobHashes = messageFromDatabase->getBlobHashes(), + }); + } + return result; +} + +void eraseMessagesFromAMQP(rust::Str deviceID) { + comm::network::DeliveryBroker::getInstance().erase(std::string{deviceID}); +} + +void removeMessagesFromDatabase( + rust::Str deviceID, + rust::Vec messageIDs) { + std::vector stdMessagesIDs; + std::for_each(messageIDs.begin(), messageIDs.end(), [&](auto &messageID) { + stdMessagesIDs.push_back(std::string{messageID}); + }); + comm::network::database::DatabaseManager::getInstance() + .removeMessageItemsByIDsForDeviceID( + stdMessagesIDs, std::string{deviceID}); +} diff --git a/services/tunnelbroker/src/server/mod.rs b/services/tunnelbroker/src/server/mod.rs --- a/services/tunnelbroker/src/server/mod.rs +++ b/services/tunnelbroker/src/server/mod.rs @@ -1,6 +1,7 @@ use super::constants; use super::cxx_bridge::ffi::{ - getSessionItem, newSessionHandler, sessionSignatureHandler, + eraseMessagesFromAMQP, getMessagesFromDatabase, getSessionItem, + newSessionHandler, removeMessagesFromDatabase, sessionSignatureHandler, updateSessionItemDeviceToken, updateSessionItemIsOnline, }; use futures::Stream; @@ -10,6 +11,9 @@ use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Server, Request, Response, Status, Streaming}; use tracing::debug; +use tunnelbroker::message_to_client::Data::{ + MessagesToDeliver, NewNotifyTokenRequired, Ping, +}; use tunnelbroker::message_to_tunnelbroker::Data::{ MessagesToSend, NewNotifyToken, ProcessedMessages, }; @@ -136,9 +140,7 @@ &session_id, &tx, Ok(tunnelbroker::MessageToClient { - data: Some( - tunnelbroker::message_to_client::Data::NewNotifyTokenRequired(()), - ), + data: Some(NewNotifyTokenRequired(())), }), ); if let Err(err) = result.await { @@ -149,6 +151,54 @@ }; } + // When a client connects to the bidirectional messages stream, first we check + // if there are undelivered messages in the database + let messages_from_database; + match getMessagesFromDatabase(&session_item.deviceID) { + Ok(messages) => messages_from_database = messages, + Err(err) => return Err(Status::internal(err.what())), + } + if messages_from_database.len() > 0 { + if let Err(err) = eraseMessagesFromAMQP(&session_item.deviceID) { + return Err(Status::internal(err.what())); + }; + let mut messages_to_response: Vec = + Vec::new(); + for message in &messages_from_database { + messages_to_response.push(tunnelbroker::MessageToClientStruct { + message_id: message.messageID.clone(), + from_device_id: message.fromDeviceID.clone(), + payload: message.payload.clone(), + blob_hashes: vec![message.blobHashes.clone()], + }); + } + let result_from_writer = tx_writer( + &session_id, + &tx, + Ok(tunnelbroker::MessageToClient { + data: Some(MessagesToDeliver(tunnelbroker::MessagesToDeliver { + messages: messages_to_response, + })), + }), + ); + if let Err(err) = result_from_writer.await { + eprintln!( + "Error while sending undelivered messages from database to the client: {}", + err + ); + return Err(Status::aborted(err)); + }; + if let Err(err) = removeMessagesFromDatabase( + &session_item.deviceID, + vec![messages_from_database + .into_iter() + .map(|message_item| message_item.messageID) + .collect()], + ) { + return Err(Status::internal(err.what())); + }; + } + // Spawning asynchronous Tokio task with the client pinging loop inside to // make sure that the client is online tokio::spawn({ @@ -161,7 +211,7 @@ &session_id, &tx, Ok(tunnelbroker::MessageToClient { - data: Some(tunnelbroker::message_to_client::Data::Ping(())), + data: Some(Ping(())), }), ); if let Err(_) = result.await {