diff --git a/services/tunnelbroker/src/cxx_bridge.rs b/services/tunnelbroker/src/cxx_bridge.rs --- a/services/tunnelbroker/src/cxx_bridge.rs +++ b/services/tunnelbroker/src/cxx_bridge.rs @@ -40,6 +40,13 @@ deviceOS: String, isOnline: bool, } + struct MessageItem { + messageID: String, + fromDeviceID: String, + toDeviceID: String, + payload: String, + blobHashes: String, + } unsafe extern "C++" { include!("tunnelbroker/src/libcpp/Tunnelbroker.h"); @@ -65,5 +72,7 @@ sessionID: &str, newNotifToken: &str, ) -> Result<()>; + pub fn getMessagesFromDatabase(deviceID: &str) -> Result>; + pub fn eraseMessagesFromAMQP(deviceID: &str) -> Result<()>; } } diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.h b/services/tunnelbroker/src/libcpp/Tunnelbroker.h --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.h +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.h @@ -18,3 +18,5 @@ SessionItem getSessionItem(rust::Str sessionID); void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline); void updateSessionItemDeviceToken(rust::Str sessionID, rust::Str newNotifToken); +rust::Vec getMessagesFromDatabase(rust::Str deviceID); +void eraseMessagesFromAMQP(rust::Str deviceID); diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp @@ -4,6 +4,7 @@ #include "ConfigManager.h" #include "CryptoTools.h" #include "DatabaseManager.h" +#include "DeliveryBroker.h" #include "GlobalTools.h" #include "Tools.h" @@ -183,3 +184,24 @@ .updateSessionItemDeviceToken( std::string{sessionID}, std::string{newNotifToken}); } + +rust::Vec getMessagesFromDatabase(rust::Str deviceID) { + std::vector> + messagesFromDatabase = + comm::network::database::DatabaseManager::getInstance() + .findMessageItemsByReceiver(std::string{deviceID}); + rust::Vec result; + for (auto &messageFromDatabase : messagesFromDatabase) { + result.push_back(MessageItem{ + .messageID = messageFromDatabase->getMessageID(), + .fromDeviceID = messageFromDatabase->getFromDeviceID(), + .payload = messageFromDatabase->getPayload(), + .blobHashes = messageFromDatabase->getBlobHashes(), + }); + } + return result; +} + +void eraseMessagesFromAMQP(rust::Str deviceID) { + comm::network::DeliveryBroker::getInstance().erase(std::string{deviceID}); +} diff --git a/services/tunnelbroker/src/server/mod.rs b/services/tunnelbroker/src/server/mod.rs --- a/services/tunnelbroker/src/server/mod.rs +++ b/services/tunnelbroker/src/server/mod.rs @@ -1,7 +1,8 @@ use super::constants; use super::cxx_bridge::ffi::{ - getSessionItem, newSessionHandler, sessionSignatureHandler, - updateSessionItemDeviceToken, updateSessionItemIsOnline, GRPCStatusCodes, + eraseMessagesFromAMQP, getMessagesFromDatabase, getSessionItem, + newSessionHandler, sessionSignatureHandler, updateSessionItemDeviceToken, + updateSessionItemIsOnline, GRPCStatusCodes, }; use anyhow::Result; use futures::Stream; @@ -151,6 +152,46 @@ }; } + // When a client connects to the bidirectional messages stream, first we check + // if there are undelivered messages in the database + let messages_from_database = + match getMessagesFromDatabase(&session_item.deviceID) { + Ok(messages) => messages, + Err(err) => return Err(Status::internal(err.what())), + }; + if messages_from_database.len() > 0 { + if let Err(err) = eraseMessagesFromAMQP(&session_item.deviceID) { + return Err(Status::internal(err.what())); + }; + let mut messages_to_response = vec![]; + for message in &messages_from_database { + messages_to_response.push(tunnelbroker::MessageToClientStruct { + message_id: message.messageID.clone(), + from_device_id: message.fromDeviceID.clone(), + payload: message.payload.clone(), + blob_hashes: vec![message.blobHashes.clone()], + }); + } + let result_from_writer = tx_writer( + &session_id, + &tx, + Ok(tunnelbroker::MessageToClient { + data: Some(tunnelbroker::message_to_client::Data::MessagesToDeliver( + tunnelbroker::MessagesToDeliver { + messages: messages_to_response, + }, + )), + }), + ); + if let Err(err) = result_from_writer.await { + debug!( + "Error while sending undelivered messages from database to the client: {}", + err + ); + return Err(Status::aborted(err)); + }; + } + // Spawning asynchronous Tokio task with the client pinging loop inside to // make sure that the client is online tokio::spawn({