diff --git a/services/tunnelbroker/src/cxx_bridge.rs b/services/tunnelbroker/src/cxx_bridge.rs index 894506fa7..9b9db01b6 100644 --- a/services/tunnelbroker/src/cxx_bridge.rs +++ b/services/tunnelbroker/src/cxx_bridge.rs @@ -1,86 +1,87 @@ #[cxx::bridge] pub mod ffi { enum GRPCStatusCodes { Ok, Cancelled, Unknown, InvalidArgument, DeadlineExceeded, NotFound, AlreadyExists, PermissionDenied, ResourceExhausted, FailedPrecondition, Aborted, OutOfRange, Unimplemented, Internal, Unavailable, DataLoss, Unauthenticated, } struct GrpcResult { statusCode: GRPCStatusCodes, errorText: String, } struct SessionSignatureResult { toSign: String, grpcStatus: GrpcResult, } struct NewSessionResult { sessionID: String, grpcStatus: GrpcResult, } struct SessionItem { deviceID: String, publicKey: String, notifyToken: String, deviceType: i32, appVersion: String, deviceOS: String, isOnline: bool, } struct MessageItem { messageID: String, fromDeviceID: String, toDeviceID: String, payload: String, blobHashes: String, deliveryTag: u64, } unsafe extern "C++" { include!("tunnelbroker/src/libcpp/Tunnelbroker.h"); pub fn initialize(); pub fn getConfigParameter(parameter: &str) -> Result; pub fn isSandbox() -> Result; pub fn sessionSignatureHandler(deviceID: &str) -> SessionSignatureResult; pub fn newSessionHandler( deviceID: &str, publicKey: &str, signature: &str, deviceType: i32, deviceAppVersion: &str, deviceOS: &str, notifyToken: &str, ) -> NewSessionResult; pub fn getSessionItem(sessionID: &str) -> Result; pub fn updateSessionItemIsOnline( sessionID: &str, isOnline: bool, ) -> Result<()>; pub fn updateSessionItemDeviceToken( sessionID: &str, newNotifToken: &str, ) -> Result<()>; pub fn getMessagesFromDatabase(deviceID: &str) -> Result>; + pub fn sendMessages(messages: &Vec) -> Result>; pub fn eraseMessagesFromAMQP(deviceID: &str) -> Result<()>; pub fn ackMessageFromAMQP(deliveryTag: u64) -> Result<()>; pub fn waitMessageFromDeliveryBroker(deviceID: &str) -> Result; pub fn removeMessages( deviceID: &str, messagesIDs: &Vec, ) -> Result<()>; } } diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp index 4132b5317..e299e3abe 100644 --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp @@ -1,241 +1,264 @@ #include "Tunnelbroker.h" #include "AmqpManager.h" #include "AwsTools.h" #include "ConfigManager.h" #include "CryptoTools.h" #include "DatabaseManager.h" #include "DeliveryBroker.h" #include "GlobalTools.h" #include "Tools.h" #include "rust/cxx.h" #include "tunnelbroker/src/cxx_bridge.rs.h" #include void initialize() { comm::network::tools::InitLogging("tunnelbroker"); comm::network::config::ConfigManager::getInstance().load(); Aws::InitAPI({}); // List of AWS DynamoDB tables to check if they are created and can be // accessed before any AWS API methods const std::list tablesList = { comm::network::config::ConfigManager::getInstance().getParameter( comm::network::config::ConfigManager::OPTION_DYNAMODB_SESSIONS_TABLE), comm::network::config::ConfigManager::getInstance().getParameter( comm::network::config::ConfigManager:: OPTION_DYNAMODB_SESSIONS_VERIFICATION_TABLE), comm::network::config::ConfigManager::getInstance().getParameter( comm::network::config::ConfigManager:: OPTION_DYNAMODB_SESSIONS_PUBLIC_KEY_TABLE), comm::network::config::ConfigManager::getInstance().getParameter( comm::network::config::ConfigManager:: OPTION_DYNAMODB_MESSAGES_TABLE)}; for (const std::string &table : tablesList) { if (!comm::network::database::DatabaseManager::getInstance() .isTableAvailable(table)) { throw std::runtime_error( "Error: AWS DynamoDB table '" + table + "' is not available"); } }; comm::network::AmqpManager::getInstance().init(); } rust::String getConfigParameter(rust::Str parameter) { return rust::String{ comm::network::config::ConfigManager::getInstance().getParameter( std::string{parameter})}; } bool isSandbox() { return comm::network::tools::isSandbox(); } SessionSignatureResult sessionSignatureHandler(rust::Str deviceID) { const std::string requestedDeviceID(deviceID); if (!comm::network::tools::validateDeviceID(requestedDeviceID)) { return SessionSignatureResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::InvalidArgument, .errorText = "Format validation failed for deviceID: " + requestedDeviceID}}; } const std::string toSign = comm::network::tools::generateRandomString( comm::network::SIGNATURE_REQUEST_LENGTH); std::shared_ptr SessionSignItem = std::make_shared( toSign, requestedDeviceID); comm::network::database::DatabaseManager::getInstance().putSessionSignItem( *SessionSignItem); return SessionSignatureResult{ .toSign = toSign, .grpcStatus = {.statusCode = GRPCStatusCodes::Ok}}; } NewSessionResult newSessionHandler( rust::Str deviceID, rust::Str publicKey, rust::Str signature, int32_t deviceType, rust::Str deviceAppVersion, rust::Str deviceOS, rust::Str notifyToken) { std::shared_ptr deviceSessionItem; std::shared_ptr sessionSignItem; std::shared_ptr publicKeyItem; const std::string stringDeviceID{deviceID}; if (!comm::network::tools::validateDeviceID(stringDeviceID)) { return NewSessionResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::InvalidArgument, .errorText = "Format validation failed for deviceID"}}; } const std::string stringPublicKey{publicKey}; const std::string newSessionID = comm::network::tools::generateUUID(); try { sessionSignItem = comm::network::database::DatabaseManager::getInstance() .findSessionSignItem(stringDeviceID); if (sessionSignItem == nullptr) { return NewSessionResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::NotFound, .errorText = "Session signature request not found for deviceID"}}; } publicKeyItem = comm::network::database::DatabaseManager::getInstance() .findPublicKeyItem(stringDeviceID); if (publicKeyItem == nullptr) { std::shared_ptr newPublicKeyItem = std::make_shared( stringDeviceID, stringPublicKey); comm::network::database::DatabaseManager::getInstance().putPublicKeyItem( *newPublicKeyItem); } else if (stringPublicKey != publicKeyItem->getPublicKey()) { return NewSessionResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::PermissionDenied, .errorText = "The public key doesn't match for deviceID"}}; } const std::string verificationMessage = sessionSignItem->getSign(); if (!comm::network::crypto::rsaVerifyString( stringPublicKey, verificationMessage, std::string{signature})) { return NewSessionResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::PermissionDenied, .errorText = "Signature for the verification message is not valid"}}; } comm::network::database::DatabaseManager::getInstance() .removeSessionSignItem(stringDeviceID); deviceSessionItem = std::make_shared( newSessionID, stringDeviceID, stringPublicKey, std::string{notifyToken}, deviceType, std::string{deviceAppVersion}, std::string{deviceOS}); comm::network::database::DatabaseManager::getInstance().putSessionItem( *deviceSessionItem); } catch (std::runtime_error &e) { LOG(ERROR) << "gRPC: " << "Error while processing 'NewSession' request: " << e.what(); return NewSessionResult{ .grpcStatus = { .statusCode = GRPCStatusCodes::Internal, .errorText = e.what()}}; } return NewSessionResult{ .sessionID = newSessionID, .grpcStatus = {.statusCode = GRPCStatusCodes::Ok}}; } SessionItem getSessionItem(rust::Str sessionID) { const std::string stringSessionID = std::string{sessionID}; if (!comm::network::tools::validateSessionID(stringSessionID)) { throw std::invalid_argument("Invalid format for 'sessionID'"); } std::shared_ptr sessionItem = comm::network::database::DatabaseManager::getInstance().findSessionItem( stringSessionID); if (sessionItem == nullptr) { throw std::invalid_argument( "No sessions found for 'sessionID': " + stringSessionID); } return SessionItem{ .deviceID = sessionItem->getDeviceID(), .publicKey = sessionItem->getPubKey(), .notifyToken = sessionItem->getNotifyToken(), .deviceType = static_cast(sessionItem->getDeviceType()), .appVersion = sessionItem->getAppVersion(), .deviceOS = sessionItem->getDeviceOs(), .isOnline = sessionItem->getIsOnline()}; } void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline) { comm::network::database::DatabaseManager::getInstance() .updateSessionItemIsOnline(std::string{sessionID}, isOnline); } void updateSessionItemDeviceToken( rust::Str sessionID, rust::Str newNotifToken) { comm::network::database::DatabaseManager::getInstance() .updateSessionItemDeviceToken( std::string{sessionID}, std::string{newNotifToken}); } rust::Vec getMessagesFromDatabase(rust::Str deviceID) { std::vector> messagesFromDatabase = comm::network::database::DatabaseManager::getInstance() .findMessageItemsByReceiver(std::string{deviceID}); rust::Vec result; for (auto &messageFromDatabase : messagesFromDatabase) { result.push_back(MessageItem{ .messageID = messageFromDatabase->getMessageID(), .fromDeviceID = messageFromDatabase->getFromDeviceID(), .payload = messageFromDatabase->getPayload(), .blobHashes = messageFromDatabase->getBlobHashes(), }); } return result; } void eraseMessagesFromAMQP(rust::Str deviceID) { comm::network::DeliveryBroker::getInstance().erase(std::string{deviceID}); } void ackMessageFromAMQP(uint64_t deliveryTag) { comm::network::AmqpManager::getInstance().ack(deliveryTag); } MessageItem waitMessageFromDeliveryBroker(rust::Str deviceID) { const auto message = comm::network::DeliveryBroker::getInstance().pop(std::string{deviceID}); return MessageItem{ .messageID = message.messageID, .fromDeviceID = message.fromDeviceID, .payload = message.payload, .deliveryTag = message.deliveryTag}; } void removeMessages( rust::Str deviceID, const rust::Vec &messagesIDs) { std::vector vectorOfmessagesIDs; std::string stringDeviceID = std::string{deviceID}; for (auto id : messagesIDs) { vectorOfmessagesIDs.push_back(std::string{id}); }; comm::network::database::DatabaseManager::getInstance() .removeMessageItemsByIDsForDeviceID(vectorOfmessagesIDs, stringDeviceID); // If messages queue for `deviceID` is empty we don't need to store // `folly::MPMCQueue` for it and need to free memory to fix possible // 'ghost' queues in DeliveryBroker. // We call `deleteQueueIfEmpty()` for this purpose here after removing // messages. comm::network::DeliveryBroker::DeliveryBroker::getInstance() .deleteQueueIfEmpty(stringDeviceID); } + +rust::Vec sendMessages(const rust::Vec &messages) { + std::vector vectorOfMessages; + rust::Vec messagesIDs; + for (auto &message : messages) { + std::string messageID = comm::network::tools::generateUUID(); + vectorOfMessages.push_back(comm::network::database::MessageItem{ + comm::network::database::MessageItem{ + messageID, + std::string{message.fromDeviceID}, + std::string{message.toDeviceID}, + std::string{message.payload}, + std::string{message.blobHashes}, + }}); + messagesIDs.push_back(rust::String{messageID}); + }; + comm::network::database::DatabaseManager::getInstance() + .putMessageItemsByBatch(vectorOfMessages); + for (auto message : vectorOfMessages) { + comm::network::AmqpManager::getInstance().send(&message); + } + return messagesIDs; +} diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.h b/services/tunnelbroker/src/libcpp/Tunnelbroker.h index 345842d6c..951c234f1 100644 --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.h +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.h @@ -1,27 +1,28 @@ #pragma once #include "rust/cxx.h" #include "tunnelbroker/src/cxx_bridge.rs.h" void initialize(); rust::String getConfigParameter(rust::Str parameter); bool isSandbox(); SessionSignatureResult sessionSignatureHandler(rust::Str deviceID); NewSessionResult newSessionHandler( rust::Str deviceID, rust::Str publicKey, rust::Str signature, int32_t deviceType, rust::Str deviceAppVersion, rust::Str deviceOS, rust::Str notifyToken); SessionItem getSessionItem(rust::Str sessionID); void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline); void updateSessionItemDeviceToken(rust::Str sessionID, rust::Str newNotifToken); rust::Vec getMessagesFromDatabase(rust::Str deviceID); +rust::Vec sendMessages(const rust::Vec &messages); void eraseMessagesFromAMQP(rust::Str deviceID); void ackMessageFromAMQP(uint64_t deliveryTag); MessageItem waitMessageFromDeliveryBroker(rust::Str deviceID); void removeMessages( rust::Str deviceID, const rust::Vec &messagesIDs); diff --git a/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.cpp b/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.cpp index fd00352e0..87c8f84ee 100644 --- a/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.cpp +++ b/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.cpp @@ -1,352 +1,352 @@ #include "DatabaseManager.h" #include "DynamoDBTools.h" #include "GlobalTools.h" #include namespace comm { namespace network { namespace database { DatabaseManager &DatabaseManager::getInstance() { static DatabaseManager instance; return instance; } bool DatabaseManager::isTableAvailable(const std::string &tableName) { Aws::DynamoDB::Model::DescribeTableRequest request; request.SetTableName(tableName); // Check table availability by invoking DescribeTable const Aws::DynamoDB::Model::DescribeTableOutcome &result = getDynamoDBClient()->DescribeTable(request); return result.IsSuccess(); } void DatabaseManager::putSessionItem(const DeviceSessionItem &item) { Aws::DynamoDB::Model::PutItemRequest request; request.SetTableName(item.getTableName()); request.AddItem( DeviceSessionItem::FIELD_SESSION_ID, Aws::DynamoDB::Model::AttributeValue(item.getSessionID())); request.AddItem( DeviceSessionItem::FIELD_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(item.getDeviceID())); request.AddItem( DeviceSessionItem::FIELD_PUBKEY, Aws::DynamoDB::Model::AttributeValue(item.getPubKey())); request.AddItem( DeviceSessionItem::FIELD_NOTIFY_TOKEN, Aws::DynamoDB::Model::AttributeValue(item.getNotifyToken())); request.AddItem( DeviceSessionItem::FIELD_DEVICE_TYPE, Aws::DynamoDB::Model::AttributeValue( std::to_string(item.getDeviceType()))); request.AddItem( DeviceSessionItem::FIELD_APP_VERSION, Aws::DynamoDB::Model::AttributeValue(item.getAppVersion())); request.AddItem( DeviceSessionItem::FIELD_DEVICE_OS, Aws::DynamoDB::Model::AttributeValue(item.getDeviceOs())); request.AddItem( DeviceSessionItem::FIELD_CHECKPOINT_TIME, Aws::DynamoDB::Model::AttributeValue( std::to_string(item.getCheckpointTime()))); request.AddItem( DeviceSessionItem::FIELD_EXPIRE, Aws::DynamoDB::Model::AttributeValue(std::to_string( static_cast(std::time(0)) + SESSION_RECORD_TTL))); request.AddItem( DeviceSessionItem::FIELD_IS_ONLINE, Aws::DynamoDB::Model::AttributeValue().SetBool(false)); this->innerPutItem(std::make_shared(item), request); } std::shared_ptr DatabaseManager::findSessionItem(const std::string &sessionID) { Aws::DynamoDB::Model::GetItemRequest request; request.AddKey( DeviceSessionItem::FIELD_SESSION_ID, Aws::DynamoDB::Model::AttributeValue(sessionID)); return this->innerFindItem(request); } void DatabaseManager::removeSessionItem(const std::string &sessionID) { std::shared_ptr item = this->findSessionItem(sessionID); if (item == nullptr) { return; } this->innerRemoveItem(*item); } void DatabaseManager::updateSessionItemIsOnline( const std::string &sessionID, bool isOnline) { std::shared_ptr item = this->findSessionItem(sessionID); if (item == nullptr) { LOG(ERROR) << "Can't find for update sessionItem for sessionID: " << sessionID; return; } Aws::DynamoDB::Model::UpdateItemRequest request; request.SetTableName(item->getTableName()); Aws::DynamoDB::Model::AttributeValue attributeKeyValue; attributeKeyValue.SetS(sessionID); request.AddKey(DeviceSessionItem::FIELD_SESSION_ID, attributeKeyValue); Aws::String update_expression("SET #a = :valueA"); request.SetUpdateExpression(update_expression); Aws::Map expressionAttributeNames; expressionAttributeNames["#a"] = DeviceSessionItem::FIELD_IS_ONLINE; request.SetExpressionAttributeNames(expressionAttributeNames); Aws::DynamoDB::Model::AttributeValue attributeUpdatedValue; attributeUpdatedValue.SetBool(isOnline); Aws::Map expressionAttributeValue; expressionAttributeValue[":valueA"] = attributeUpdatedValue; request.SetExpressionAttributeValues(expressionAttributeValue); const Aws::DynamoDB::Model::UpdateItemOutcome &result = getDynamoDBClient()->UpdateItem(request); if (!result.IsSuccess()) { LOG(ERROR) << "Error updating device online status at " "`updateSessionItemIsOnline`: " << result.GetError().GetMessage(); } } bool DatabaseManager::updateSessionItemDeviceToken( const std::string &sessionID, const std::string &newDeviceToken) { std::shared_ptr item = this->findSessionItem(sessionID); if (item == nullptr) { LOG(ERROR) << "Can't find for update sessionItem for sessionID: " << sessionID; return false; } Aws::DynamoDB::Model::UpdateItemRequest request; request.SetTableName(item->getTableName()); Aws::DynamoDB::Model::AttributeValue attributeKeyValue; attributeKeyValue.SetS(sessionID); request.AddKey(DeviceSessionItem::FIELD_SESSION_ID, attributeKeyValue); Aws::String update_expression("SET #a = :valueA"); request.SetUpdateExpression(update_expression); Aws::Map expressionAttributeNames; expressionAttributeNames["#a"] = DeviceSessionItem::FIELD_NOTIFY_TOKEN; request.SetExpressionAttributeNames(expressionAttributeNames); Aws::DynamoDB::Model::AttributeValue attributeUpdatedValue; attributeUpdatedValue.SetS(newDeviceToken); Aws::Map expressionAttributeValue; expressionAttributeValue[":valueA"] = attributeUpdatedValue; request.SetExpressionAttributeValues(expressionAttributeValue); const Aws::DynamoDB::Model::UpdateItemOutcome &result = getDynamoDBClient()->UpdateItem(request); if (!result.IsSuccess()) { LOG(ERROR) << "Error updating device token at updateSessionItemDeviceToken: " << result.GetError().GetMessage(); return false; } return true; } void DatabaseManager::putSessionSignItem(const SessionSignItem &item) { Aws::DynamoDB::Model::PutItemRequest request; request.SetTableName(item.getTableName()); request.AddItem( SessionSignItem::FIELD_SESSION_VERIFICATION, Aws::DynamoDB::Model::AttributeValue(item.getSign())); request.AddItem( SessionSignItem::FIELD_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(item.getDeviceID())); request.AddItem( SessionSignItem::FIELD_EXPIRE, Aws::DynamoDB::Model::AttributeValue(std::to_string( static_cast(std::time(0)) + SESSION_SIGN_RECORD_TTL))); this->innerPutItem(std::make_shared(item), request); } std::shared_ptr DatabaseManager::findSessionSignItem(const std::string &deviceID) { Aws::DynamoDB::Model::GetItemRequest request; request.AddKey( SessionSignItem::FIELD_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(deviceID)); return this->innerFindItem(request); } void DatabaseManager::removeSessionSignItem(const std::string &deviceID) { std::shared_ptr item = this->findSessionSignItem(deviceID); if (item == nullptr) { return; } this->innerRemoveItem(*item); } void DatabaseManager::putPublicKeyItem(const PublicKeyItem &item) { Aws::DynamoDB::Model::PutItemRequest request; request.SetTableName(item.getTableName()); request.AddItem( PublicKeyItem::FIELD_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(item.getDeviceID())); request.AddItem( PublicKeyItem::FIELD_PUBLIC_KEY, Aws::DynamoDB::Model::AttributeValue(item.getPublicKey())); this->innerPutItem(std::make_shared(item), request); } std::shared_ptr DatabaseManager::findPublicKeyItem(const std::string &deviceID) { Aws::DynamoDB::Model::GetItemRequest request; request.AddKey( PublicKeyItem::FIELD_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(deviceID)); return this->innerFindItem(request); } void DatabaseManager::removePublicKeyItem(const std::string &deviceID) { std::shared_ptr item = this->findPublicKeyItem(deviceID); if (item == nullptr) { return; } this->innerRemoveItem(*item); } template T DatabaseManager::populatePutRequestFromMessageItem( T &putRequest, const MessageItem &item) { putRequest.AddItem( MessageItem::FIELD_MESSAGE_ID, Aws::DynamoDB::Model::AttributeValue(item.getMessageID())); putRequest.AddItem( MessageItem::FIELD_FROM_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(item.getFromDeviceID())); putRequest.AddItem( MessageItem::FIELD_TO_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(item.getToDeviceID())); putRequest.AddItem( MessageItem::FIELD_PAYLOAD, Aws::DynamoDB::Model::AttributeValue(item.getPayload())); putRequest.AddItem( MessageItem::FIELD_BLOB_HASHES, Aws::DynamoDB::Model::AttributeValue(item.getBlobHashes())); putRequest.AddItem( MessageItem::FIELD_EXPIRE, Aws::DynamoDB::Model::AttributeValue(std::to_string( static_cast(std::time(0) + MESSAGE_RECORD_TTL)))); putRequest.AddItem( MessageItem::FIELD_CREATED_AT, Aws::DynamoDB::Model::AttributeValue( std::to_string(tools::getCurrentTimestamp()))); return putRequest; } void DatabaseManager::putMessageItem(const MessageItem &item) { Aws::DynamoDB::Model::PutItemRequest request; request = this->populatePutRequestFromMessageItem(request, item); request.SetTableName(item.getTableName()); this->innerPutItem(std::make_shared(item), request); } void DatabaseManager::putMessageItemsByBatch( - std::vector &messageItems) { + const std::vector &messageItems) { std::vector writeRequests; - for (MessageItem &messageItem : messageItems) { + for (MessageItem messageItem : messageItems) { Aws::DynamoDB::Model::PutRequest putRequest; putRequest = this->populatePutRequestFromMessageItem(putRequest, messageItem); Aws::DynamoDB::Model::WriteRequest writeRequest; writeRequest.SetPutRequest(putRequest); writeRequests.push_back(writeRequest); } this->innerBatchWriteItem( messageItems[0].getTableName(), DYNAMODB_MAX_BATCH_ITEMS, DYNAMODB_BACKOFF_FIRST_RETRY_DELAY, DYNAMODB_MAX_BACKOFF_TIME, writeRequests); } std::shared_ptr DatabaseManager::findMessageItem( const std::string &toDeviceID, const std::string &messageID) { Aws::DynamoDB::Model::GetItemRequest request; request.AddKey( MessageItem::FIELD_TO_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(toDeviceID)); request.AddKey( MessageItem::FIELD_MESSAGE_ID, Aws::DynamoDB::Model::AttributeValue(messageID)); return this->innerFindItem(request); } std::vector> DatabaseManager::findMessageItemsByReceiver(const std::string &toDeviceID) { std::vector> result; Aws::DynamoDB::Model::QueryRequest req; req.SetTableName(MessageItem().getTableName()); req.SetKeyConditionExpression( MessageItem::FIELD_TO_DEVICE_ID + " = :valueToMatch"); AttributeValues attributeValues; attributeValues.emplace(":valueToMatch", toDeviceID); req.SetExpressionAttributeValues(attributeValues); const Aws::DynamoDB::Model::QueryOutcome &outcome = getDynamoDBClient()->Query(req); if (!outcome.IsSuccess()) { throw std::runtime_error(outcome.GetError().GetMessage()); } const Aws::Vector &items = outcome.GetResult().GetItems(); for (auto &item : items) { result.push_back(std::make_shared(item)); } return result; } void DatabaseManager::removeMessageItem( const std::string &toDeviceID, const std::string &messageID) { std::shared_ptr item = this->findMessageItem(toDeviceID, messageID); if (item == nullptr) { return; } this->innerRemoveItem(*item); } void DatabaseManager::removeMessageItemsByIDsForDeviceID( std::vector &messageIDs, const std::string &toDeviceID) { std::vector writeRequests; for (std::string &messageID : messageIDs) { Aws::DynamoDB::Model::DeleteRequest deleteRequest; deleteRequest.AddKey( MessageItem::FIELD_TO_DEVICE_ID, Aws::DynamoDB::Model::AttributeValue(toDeviceID)); deleteRequest.AddKey( MessageItem::FIELD_MESSAGE_ID, Aws::DynamoDB::Model::AttributeValue(messageID)); Aws::DynamoDB::Model::WriteRequest currentWriteRequest; currentWriteRequest.SetDeleteRequest(deleteRequest); writeRequests.push_back(currentWriteRequest); } this->innerBatchWriteItem( MessageItem().getTableName(), DYNAMODB_MAX_BATCH_ITEMS, DYNAMODB_BACKOFF_FIRST_RETRY_DELAY, DYNAMODB_MAX_BACKOFF_TIME, writeRequests); } } // namespace database } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.h b/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.h index b6bda4d58..c750ef124 100644 --- a/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.h +++ b/services/tunnelbroker/src/libcpp/src/Database/DatabaseManager.h @@ -1,75 +1,75 @@ #pragma once #include "AwsTools.h" #include "Constants.h" #include "DatabaseEntitiesTools.h" #include "DatabaseManagerBase.h" #include "DeviceSessionItem.h" #include "MessageItem.h" #include "PublicKeyItem.h" #include "SessionSignItem.h" #include "Tools.h" #include #include #include #include #include #include #include #include #include #include #include #include #include namespace comm { namespace network { namespace database { class DatabaseManager : public DatabaseManagerBase { private: template T populatePutRequestFromMessageItem(T &putRequest, const MessageItem &item); public: static DatabaseManager &getInstance(); bool isTableAvailable(const std::string &tableName); void putSessionItem(const DeviceSessionItem &item); std::shared_ptr findSessionItem(const std::string &deviceID); void removeSessionItem(const std::string &sessionID); void updateSessionItemIsOnline(const std::string &sessionID, bool isOnline); bool updateSessionItemDeviceToken( const std::string &sessionID, const std::string &newDeviceToken); void putSessionSignItem(const SessionSignItem &item); std::shared_ptr findSessionSignItem(const std::string &deviceID); void removeSessionSignItem(const std::string &deviceID); void putPublicKeyItem(const PublicKeyItem &item); std::shared_ptr findPublicKeyItem(const std::string &deviceID); void removePublicKeyItem(const std::string &deviceID); void putMessageItem(const MessageItem &item); - void putMessageItemsByBatch(std::vector &messageItems); + void putMessageItemsByBatch(const std::vector &messageItems); std::shared_ptr findMessageItem(const std::string &toDeviceID, const std::string &messageID); std::vector> findMessageItemsByReceiver(const std::string &toDeviceID); void removeMessageItem( const std::string &toDeviceID, const std::string &messageID); void removeMessageItemsByIDsForDeviceID( std::vector &messageIDs, const std::string &toDeviceID); }; } // namespace database } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/server/mod.rs b/services/tunnelbroker/src/server/mod.rs index e290d4619..ea6db3664 100644 --- a/services/tunnelbroker/src/server/mod.rs +++ b/services/tunnelbroker/src/server/mod.rs @@ -1,382 +1,422 @@ +use crate::cxx_bridge::ffi::MessageItem; + use super::constants; use super::cxx_bridge::ffi::{ ackMessageFromAMQP, eraseMessagesFromAMQP, getMessagesFromDatabase, - getSessionItem, newSessionHandler, removeMessages, sessionSignatureHandler, - updateSessionItemDeviceToken, updateSessionItemIsOnline, - waitMessageFromDeliveryBroker, GRPCStatusCodes, + getSessionItem, newSessionHandler, removeMessages, sendMessages, + sessionSignatureHandler, updateSessionItemDeviceToken, + updateSessionItemIsOnline, waitMessageFromDeliveryBroker, GRPCStatusCodes, }; use anyhow::Result; use futures::Stream; use std::pin::Pin; use tokio::sync::mpsc; use tokio::time::{sleep, Duration}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Server, Request, Response, Status, Streaming}; use tracing::{debug, error}; use tunnelbroker::message_to_tunnelbroker::Data::{ MessagesToSend, NewNotifyToken, ProcessedMessages, }; use tunnelbroker::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; mod tools; mod tunnelbroker { tonic::include_proto!("tunnelbroker"); } #[derive(Debug, Default)] struct TunnelbrokerServiceHandlers {} #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerServiceHandlers { async fn session_signature( &self, request: Request, ) -> Result, Status> { let result = sessionSignatureHandler(&request.into_inner().device_id); if result.grpcStatus.statusCode != GRPCStatusCodes::Ok { return Err(tools::create_tonic_status( result.grpcStatus.statusCode, &result.grpcStatus.errorText, )); } Ok(Response::new(tunnelbroker::SessionSignatureResponse { to_sign: result.toSign, })) } async fn new_session( &self, request: Request, ) -> Result, Status> { let inner_request = request.into_inner(); let notify_token = inner_request.notify_token.unwrap_or(String::new()); if !tunnelbroker::new_session_request::DeviceTypes::is_valid( inner_request.device_type, ) { return Err(tools::create_tonic_status( GRPCStatusCodes::InvalidArgument, "Unsupported device type", )); }; let result = newSessionHandler( &inner_request.device_id, &inner_request.public_key, &inner_request.signature, inner_request.device_type, &inner_request.device_app_version, &inner_request.device_os, ¬ify_token, ); if result.grpcStatus.statusCode != GRPCStatusCodes::Ok { return Err(tools::create_tonic_status( result.grpcStatus.statusCode, &result.grpcStatus.errorText, )); } Ok(Response::new(tunnelbroker::NewSessionResponse { session_id: result.sessionID, })) } type MessagesStreamStream = Pin< Box< dyn Stream> + Send, >, >; async fn messages_stream( &self, request: Request>, ) -> Result, Status> { let session_id = match request.metadata().get("sessionID") { Some(metadata_session_id) => metadata_session_id .to_str() .expect("metadata session id was not valid UTF8") .to_string(), None => { return Err(Status::invalid_argument( "No 'sessionID' in metadata was provided", )) } }; let session_item = match getSessionItem(&session_id) { Ok(database_item) => database_item, Err(err) => return Err(Status::unauthenticated(err.what())), }; let (tx, rx) = mpsc::channel(constants::GRPC_TX_QUEUE_SIZE); // Through this function, we will write to the output stream from different Tokio // tasks and update the device's online status if the write was unsuccessful async fn tx_writer( session_id: &str, channel: &tokio::sync::mpsc::Sender, payload: T, ) -> Result<(), String> { let result = channel.send(payload).await; match result { Ok(result) => Ok(result), Err(err) => { if let Err(err) = updateSessionItemIsOnline(&session_id, false) { return Err(err.what().to_string()); } return Err(err.to_string()); } } } if let Err(err) = updateSessionItemIsOnline(&session_id, true) { return Err(Status::internal(err.what())); } // Checking for an empty notif token and requesting the new one from the client if session_item.notifyToken.is_empty() && session_item.deviceType == tunnelbroker::new_session_request::DeviceTypes::Mobile as i32 { let result = tx_writer( &session_id, &tx, Ok(tunnelbroker::MessageToClient { data: Some( tunnelbroker::message_to_client::Data::NewNotifyTokenRequired(()), ), }), ); if let Err(err) = result.await { debug!( "Error while sending notification token request to the client: {}", err ); }; } // When a client connects to the bidirectional messages stream, first we check // if there are undelivered messages in the database let messages_from_database = match getMessagesFromDatabase(&session_item.deviceID) { Ok(messages) => messages, Err(err) => return Err(Status::internal(err.what())), }; if messages_from_database.len() > 0 { if let Err(err) = eraseMessagesFromAMQP(&session_item.deviceID) { return Err(Status::internal(err.what())); }; let mut messages_to_response = vec![]; for message in &messages_from_database { messages_to_response.push(tunnelbroker::MessageToClientStruct { message_id: message.messageID.clone(), from_device_id: message.fromDeviceID.clone(), payload: message.payload.clone(), blob_hashes: vec![message.blobHashes.clone()], }); } let result_from_writer = tx_writer( &session_id, &tx, Ok(tunnelbroker::MessageToClient { data: Some(tunnelbroker::message_to_client::Data::MessagesToDeliver( tunnelbroker::MessagesToDeliver { messages: messages_to_response, }, )), }), ); if let Err(err) = result_from_writer.await { debug!( "Error while sending undelivered messages from database to the client: {}", err ); return Err(Status::aborted(err)); }; } // Spawning asynchronous Tokio task with the client pinging loop inside to // make sure that the client is online tokio::spawn({ let session_id = session_id.clone(); let tx = tx.clone(); async move { loop { sleep(Duration::from_millis(constants::GRPC_PING_INTERVAL_MS)).await; let result = tx_writer( &session_id, &tx, Ok(tunnelbroker::MessageToClient { data: Some(tunnelbroker::message_to_client::Data::Ping(())), }), ); if let Err(err) = result.await { debug!("Failed to write ping to a channel: {}", err); break; }; } } }); // Spawning asynchronous Tokio task to deliver new messages // to the client from delivery broker tokio::spawn({ let device_id = session_item.deviceID.clone(); let session_id = session_id.clone(); let tx = tx.clone(); async move { loop { let message_to_deliver = match waitMessageFromDeliveryBroker(&device_id) { Ok(message_item) => message_item, Err(err) => { error!( "Error on waiting messages from DeliveryBroker: {}", err.what() ); return; } }; let writer_result = tx_writer( &session_id, &tx, Ok(tunnelbroker::MessageToClient { data: Some( tunnelbroker::message_to_client::Data::MessagesToDeliver( tunnelbroker::MessagesToDeliver { messages: vec![tunnelbroker::MessageToClientStruct { message_id: message_to_deliver.messageID, from_device_id: message_to_deliver.fromDeviceID, payload: message_to_deliver.payload, blob_hashes: vec![message_to_deliver.blobHashes], }], }, ), ), }), ); if let Err(err) = writer_result.await { debug!("Error on writing to the stream: {}", err); return; }; if let Err(err) = ackMessageFromAMQP(message_to_deliver.deliveryTag) { debug!("Error on message acknowledgement in AMQP queue: {}", err); return; }; } } }); let mut input_stream = request.into_inner(); // Spawning asynchronous Tokio task for handling incoming messages from the client tokio::spawn(async move { while let Some(result) = input_stream.next().await { if let Err(err) = result { debug!("Error in input stream: {}", err); break; } if let Some(message_data) = result.unwrap().data { match message_data { NewNotifyToken(new_token) => { if let Err(err) = updateSessionItemDeviceToken(&session_id, &new_token) { error!( "Error in updating the device notification token in the database: {}", err.what() ); let writer_result = tx_writer( &session_id, &tx, Err( Status::internal( "Error in updating the device notification token in the database" ) ), ); if let Err(err) = writer_result.await { debug!( "Failed to write internal error to a channel: {}", err ); }; } } - MessagesToSend(_) => (), + MessagesToSend(messages_to_send) => { + let mut messages_vec = vec![]; + for message in messages_to_send.messages { + messages_vec.push(MessageItem { + messageID: String::new(), + fromDeviceID: session_item.deviceID.clone(), + toDeviceID: message.to_device_id, + payload: message.payload, + blobHashes: String::new(), + deliveryTag: 0, + }); + } + let messages_ids = match sendMessages(&messages_vec) { + Err(err) => { + error!("Error on sending messages: {}", err.what()); + return; + } + Ok(ids) => ids, + }; + if let Err(err) = tx_writer( + &session_id, + &tx, + Ok(tunnelbroker::MessageToClient { + data: Some( + tunnelbroker::message_to_client::Data::ProcessedMessages( + tunnelbroker::ProcessedMessages { + message_id: messages_ids, + }, + ), + ), + }), + ) + .await + { + debug!( + "Error on sending back processed messages IDs to the stream: {}", + err); + }; + } ProcessedMessages(processed_messages) => { if let Err(err) = removeMessages( &session_item.deviceID, &processed_messages.message_id, ) { error!( "Error removing messages from the database: {}", err.what() ); }; } } } } if let Err(err) = updateSessionItemIsOnline(&session_id, false) { error!( "Error in updating the session online state in the database: {}", err.what() ); } }); let output_stream = ReceiverStream::new(rx); Ok(Response::new( Box::pin(output_stream) as Self::MessagesStreamStream )) } // These empty old API handlers are deprecated and should be removed. // They are implemented only to fix the building process. async fn check_if_primary_device_online( &self, _request: Request, ) -> Result, Status> { Err(Status::cancelled("Deprecated")) } async fn become_new_primary_device( &self, _request: Request, ) -> Result, Status> { Err(Status::cancelled("Deprecated")) } async fn send_pong( &self, _request: Request, ) -> Result, Status> { Err(Status::cancelled("Deprecated")) } async fn send( &self, _request: Request, ) -> Result, Status> { Err(Status::cancelled("Deprecated")) } type GetStream = Pin< Box> + Send>, >; async fn get( &self, _request: Request, ) -> Result, Status> { Err(Status::cancelled("Deprecated")) } } pub async fn run_grpc_server() -> Result<()> { let addr = format!("[::1]:{}", constants::GRPC_SERVER_PORT).parse()?; Server::builder() .add_service(TunnelbrokerServiceServer::new( TunnelbrokerServiceHandlers::default(), )) .serve(addr) .await?; Ok(()) }