diff --git a/services/tunnelbroker/src/Amqp/AmqpManager.cpp b/services/tunnelbroker/src/Amqp/AmqpManager.cpp index db78d4c18..ebfb2c6ee 100644 --- a/services/tunnelbroker/src/Amqp/AmqpManager.cpp +++ b/services/tunnelbroker/src/Amqp/AmqpManager.cpp @@ -1,144 +1,147 @@ #include "AmqpManager.h" #include "ConfigManager.h" #include "Constants.h" #include "DeliveryBroker.h" #include "Tools.h" #include namespace comm { namespace network { AmqpManager &AmqpManager::getInstance() { static AmqpManager instance; return instance; } void AmqpManager::connectInternal() { const std::string amqpUri = config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_AMQP_URI); const std::string tunnelbrokerID = config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_TUNNELBROKER_ID); const std::string fanoutExchangeName = config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_AMQP_FANOUT_EXCHANGE); std::cout << "AMQP: Connecting to " << amqpUri << std::endl; auto *loop = uv_default_loop(); AMQP::LibUvHandler handler(loop); AMQP::TcpConnection connection(&handler, AMQP::Address(amqpUri)); this->amqpChannel = std::make_unique(&connection); this->amqpChannel->onError([this](const char *message) { std::cout << "AMQP: channel error: " << message << ", will try to reconnect" << std::endl; this->amqpReady = false; }); AMQP::Table arguments; arguments["x-message-ttl"] = AMQP_MESSAGE_TTL; arguments["x-expires"] = AMQP_QUEUE_TTL; this->amqpChannel->declareExchange(fanoutExchangeName, AMQP::fanout); this->amqpChannel->declareQueue(tunnelbrokerID, AMQP::durable, arguments) .onSuccess([this, tunnelbrokerID, fanoutExchangeName]( const std::string &name, uint32_t messagecount, uint32_t consumercount) { std::cout << "AMQP: Queue " << name << " created" << std::endl; this->amqpChannel->bindQueue(fanoutExchangeName, tunnelbrokerID, "") .onError([this, tunnelbrokerID, fanoutExchangeName]( const char *message) { std::cout << "AMQP: Failed to bind queue: " << tunnelbrokerID << " to exchange: " << fanoutExchangeName << std::endl; this->amqpReady = false; }); this->amqpReady = true; this->amqpChannel->consume(tunnelbrokerID) .onReceived([](const AMQP::Message &message, uint64_t deliveryTag, bool redelivered) { try { AMQP::Table headers = message.headers(); const std::string payload(message.body()); + const std::string messageID(headers[AMQP_HEADER_MESSAGEID]); const std::string toDeviceID(headers[AMQP_HEADER_TO_DEVICEID]); const std::string fromDeviceID( headers[AMQP_HEADER_FROM_DEVICEID]); std::cout << "AMQP: Message consumed for deviceID: " << toDeviceID << std::endl; DeliveryBroker::getInstance().push( - deliveryTag, toDeviceID, fromDeviceID, payload); + messageID, deliveryTag, toDeviceID, fromDeviceID, payload); } catch (const std::exception &e) { std::cout << "AMQP: Message parsing exception: " << e.what() << std::endl; } }) .onError([](const char *message) { std::cout << "AMQP: Error on message consume: " << message << std::endl; }); }) .onError([](const char *message) { throw std::runtime_error( "AMQP: Queue creation error: " + std::string(message)); }); uv_run(loop, UV_RUN_DEFAULT); }; void AmqpManager::connect() { while (true) { int64_t currentTimestamp = tools::getCurrentTimestamp(); if (this->lastConnectionTimestamp && currentTimestamp - this->lastConnectionTimestamp < AMQP_SHORTEST_RECONNECTION_ATTEMPT_INTERVAL) { throw std::runtime_error( "AMQP reconnection attempt interval too short, tried to reconnect " "after " + std::to_string(currentTimestamp - this->lastConnectionTimestamp) + "ms, the shortest allowed interval is " + std::to_string(AMQP_SHORTEST_RECONNECTION_ATTEMPT_INTERVAL) + "ms"); } this->lastConnectionTimestamp = currentTimestamp; this->connectInternal(); } } bool AmqpManager::send( - std::string toDeviceID, + std::string messageID, std::string fromDeviceID, + std::string toDeviceID, std::string payload) { if (!this->amqpReady) { std::cout << "AMQP: Message send error: channel not ready" << std::endl; return false; } try { AMQP::Envelope env(payload.c_str(), payload.size()); AMQP::Table headers; + headers[AMQP_HEADER_MESSAGEID] = messageID; headers[AMQP_HEADER_FROM_DEVICEID] = fromDeviceID; headers[AMQP_HEADER_TO_DEVICEID] = toDeviceID; // Set delivery mode to: Durable (2) env.setDeliveryMode(2); env.setHeaders(std::move(headers)); this->amqpChannel->publish( config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_AMQP_FANOUT_EXCHANGE), "", env); } catch (std::runtime_error &e) { std::cout << "AMQP: Error while publishing message: " << e.what() << std::endl; return false; } return true; }; void AmqpManager::ack(uint64_t deliveryTag) { if (!this->amqpReady) { std::cout << "AMQP: Message ACK error: channel not ready" << std::endl; return; } this->amqpChannel->ack(deliveryTag); } } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/Amqp/AmqpManager.h b/services/tunnelbroker/src/Amqp/AmqpManager.h index 875a0f73d..8885ea3f3 100644 --- a/services/tunnelbroker/src/Amqp/AmqpManager.h +++ b/services/tunnelbroker/src/Amqp/AmqpManager.h @@ -1,33 +1,36 @@ #pragma once #include #include #include #include #include namespace comm { namespace network { class AmqpManager { AmqpManager(){}; std::unique_ptr amqpChannel; std::atomic amqpReady; std::atomic lastConnectionTimestamp; void connectInternal(); public: static AmqpManager &getInstance(); void connect(); - bool - send(std::string toDeviceID, std::string fromDeviceID, std::string payload); + bool send( + std::string messageID, + std::string fromDeviceID, + std::string toDeviceID, + std::string payload); void ack(uint64_t deliveryTag); AmqpManager(AmqpManager const &) = delete; void operator=(AmqpManager const &) = delete; }; } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/Constants.h b/services/tunnelbroker/src/Constants.h index f39ce1168..92e14ba04 100644 --- a/services/tunnelbroker/src/Constants.h +++ b/services/tunnelbroker/src/Constants.h @@ -1,52 +1,56 @@ #pragma once #include #include #include namespace comm { namespace network { // AWS DynamoDB const std::string DEVICE_SESSIONS_TABLE_NAME = "tunnelbroker-device-session"; const std::string DEVICE_SESSIONS_VERIFICATION_MESSAGES_TABLE_NAME = "tunnelbroker-verification-message"; const std::string DEVICE_PUBLIC_KEY_TABLE_NAME = "tunnelbroker-public-key"; const std::string MESSAGES_TABLE_NAME = "tunnelbroker-message"; // Sessions const size_t SIGNATURE_REQUEST_LENGTH = 64; const size_t SESSION_ID_LENGTH = 64; const size_t SESSION_RECORD_TTL = 30 * 24 * 3600; // 30 days const size_t SESSION_SIGN_RECORD_TTL = 24 * 3600; // 24 hours const std::regex SESSION_ID_FORMAT_REGEX( "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"); // gRPC Server const std::string SERVER_LISTEN_ADDRESS = "0.0.0.0:50051"; // AMQP (RabbitMQ) const std::string AMQP_FANOUT_EXCHANGE_NAME = "allBrokers"; // message TTL const size_t AMQP_MESSAGE_TTL = 300 * 1000; // 5 min // queue TTL in case of no consumers (tunnelbroker is down) const size_t AMQP_QUEUE_TTL = 24 * 3600 * 1000; // 24 hours // routing message headers name -const std::string AMQP_HEADER_FROM_DEVICEID = "fromDeviceid"; -const std::string AMQP_HEADER_TO_DEVICEID = "toDeviceid"; +const std::string AMQP_HEADER_FROM_DEVICEID = "fromDeviceID"; +const std::string AMQP_HEADER_TO_DEVICEID = "toDeviceID"; +const std::string AMQP_HEADER_MESSAGEID = "messageID"; const int64_t AMQP_SHORTEST_RECONNECTION_ATTEMPT_INTERVAL = 1000 * 60; // 1 min // DeviceID const size_t DEVICEID_CHAR_LENGTH = 64; const std::regex DEVICEID_FORMAT_REGEX( "^(ks|mobile|web):[a-zA-Z0-9]{" + std::to_string(DEVICEID_CHAR_LENGTH) + "}$"); // Config const std::string CONFIG_FILE_PATH = std::string(std::getenv("HOME")) + "/tunnelbroker/tunnelbroker.ini"; +// DeliveryBroker +const size_t DELIVERY_BROKER_MAX_QUEUE_SIZE = 100; + } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.cpp b/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.cpp index d690b04d7..46e217db4 100644 --- a/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.cpp +++ b/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.cpp @@ -1,67 +1,68 @@ #include "DeliveryBroker.h" namespace comm { namespace network { DeliveryBroker &DeliveryBroker::getInstance() { static DeliveryBroker instance; return instance; }; void DeliveryBroker::push( + const std::string messageID, const uint64_t deliveryTag, const std::string toDeviceID, const std::string fromDeviceID, const std::string payload) { try { - std::unique_lock localLock(this->localMutex); - std::vector messagesList; - const DeliveryBrokerMessage newMessage = { - .deliveryTag = deliveryTag, - .fromDeviceID = fromDeviceID, - .payload = payload}; - if (this->messagesMap.find(toDeviceID) == this->messagesMap.end()) { - messagesList.push_back(newMessage); - this->messagesMap.insert({toDeviceID, messagesList}); - this->localCv.notify_all(); - return; + this->messagesMap.insert( + toDeviceID, + std::make_unique( + DELIVERY_BROKER_MAX_QUEUE_SIZE)); } - - messagesList = this->messagesMap[toDeviceID]; - messagesList.push_back(newMessage); - this->messagesMap.assign(toDeviceID, messagesList); - this->localCv.notify_all(); + this->messagesMap.find(toDeviceID) + ->second->blockingWrite(DeliveryBrokerMessage{ + .messageID = messageID, + .deliveryTag = deliveryTag, + .fromDeviceID = fromDeviceID, + .payload = payload}); } catch (const std::exception &e) { - std::cout << "DeliveryBroker: " + std::cout << "DeliveryBroker push: " << "Got an exception " << e.what() << std::endl; - this->localCv.notify_all(); } }; -std::vector -DeliveryBroker::get(const std::string deviceID) { +bool DeliveryBroker::isEmpty(const std::string deviceID) { if (this->messagesMap.find(deviceID) == this->messagesMap.end()) { - return {}; - } - return this->messagesMap[deviceID]; -}; - -bool DeliveryBroker::isEmpty(const std::string key) { - if (this->messagesMap.empty()) { return true; - } - return (this->messagesMap.find(key) == this->messagesMap.end()); + }; + return this->messagesMap.find(deviceID)->second->isEmpty(); }; -void DeliveryBroker::remove(const std::string key) { - this->messagesMap.erase(key); +DeliveryBrokerMessage DeliveryBroker::pop(const std::string deviceID) { + try { + // If we don't already have a queue, insert it for the blocking read purpose + // in case we listen first before the insert happens. + if (this->messagesMap.find(deviceID) == this->messagesMap.end()) { + this->messagesMap.insert( + deviceID, + std::make_unique( + DELIVERY_BROKER_MAX_QUEUE_SIZE)); + } + DeliveryBrokerMessage receievedMessage; + this->messagesMap.find(deviceID)->second->blockingRead(receievedMessage); + return receievedMessage; + } catch (const std::exception &e) { + std::cout << "DeliveryBroker pop: " + << "Got an exception " << e.what() << std::endl; + } + return {}; }; -void DeliveryBroker::wait(const std::string key) { - std::unique_lock localLock(this->localMutex); - this->localCv.wait(localLock, [this, &key] { return !this->isEmpty(key); }); +void DeliveryBroker::erase(const std::string deviceID) { + this->messagesMap.erase(deviceID); }; } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.h b/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.h index 0bcde4868..64ba59756 100644 --- a/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.h +++ b/services/tunnelbroker/src/DeliveryBroker/DeliveryBroker.h @@ -1,37 +1,33 @@ #pragma once #include "Constants.h" #include "DeliveryBrokerEntites.h" #include -#include #include #include -#include namespace comm { namespace network { class DeliveryBroker { - folly::ConcurrentHashMap> + folly::ConcurrentHashMap> messagesMap; - std::mutex localMutex; - std::condition_variable localCv; public: static DeliveryBroker &getInstance(); void push( + const std::string messageID, const uint64_t deliveryTag, const std::string toDeviceID, const std::string fromDeviceID, const std::string payload); - std::vector get(const std::string deviceID); - bool isEmpty(const std::string key); - void remove(const std::string key); - void wait(const std::string key); + bool isEmpty(const std::string deviceID); + DeliveryBrokerMessage pop(const std::string deviceID); + void erase(const std::string deviceID); }; } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/DeliveryBroker/DeliveryBrokerEntites.h b/services/tunnelbroker/src/DeliveryBroker/DeliveryBrokerEntites.h index 94f223032..bbdd5efeb 100644 --- a/services/tunnelbroker/src/DeliveryBroker/DeliveryBrokerEntites.h +++ b/services/tunnelbroker/src/DeliveryBroker/DeliveryBrokerEntites.h @@ -1,17 +1,22 @@ #pragma once +#include + #include #include namespace comm { namespace network { struct DeliveryBrokerMessage { + std::string messageID; uint64_t deliveryTag; std::string fromDeviceID; std::string payload; std::vector blobHashes; }; +typedef folly::MPMCQueue DeliveryBrokerQueue; + } // namespace network } // namespace comm diff --git a/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp b/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp index 849e3607b..57ec7273a 100644 --- a/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp +++ b/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp @@ -1,226 +1,227 @@ #include "TunnelbrokerServiceImpl.h" #include "AmqpManager.h" #include "AwsTools.h" #include "ConfigManager.h" #include "CryptoTools.h" #include "DatabaseManager.h" #include "DeliveryBroker.h" #include "Tools.h" + namespace comm { namespace network { TunnelBrokerServiceImpl::TunnelBrokerServiceImpl() { Aws::InitAPI({}); // List of AWS DynamoDB tables to check if they are created and can be // accessed before any AWS API methods const std::list tablesList = { config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_DYNAMODB_SESSIONS_TABLE), config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_DYNAMODB_SESSIONS_VERIFICATION_TABLE), config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_DYNAMODB_SESSIONS_PUBLIC_KEY_TABLE), config::ConfigManager::getInstance().getParameter( config::ConfigManager::OPTION_DYNAMODB_MESSAGES_TABLE)}; for (const std::string &table : tablesList) { if (!database::DatabaseManager::getInstance().isTableAvailable(table)) { throw std::runtime_error( "Error: AWS DynamoDB table '" + table + "' is not available"); } }; }; TunnelBrokerServiceImpl::~TunnelBrokerServiceImpl() { Aws::ShutdownAPI({}); }; grpc::Status TunnelBrokerServiceImpl::SessionSignature( grpc::ServerContext *context, const tunnelbroker::SessionSignatureRequest *request, tunnelbroker::SessionSignatureResponse *reply) { const std::string deviceID = request->deviceid(); if (!tools::validateDeviceID(deviceID)) { std::cout << "gRPC: " << "Format validation failed for " << deviceID << std::endl; return grpc::Status( grpc::StatusCode::INVALID_ARGUMENT, "Format validation failed for deviceID"); } const std::string toSign = tools::generateRandomString(SIGNATURE_REQUEST_LENGTH); std::shared_ptr SessionSignItem = std::make_shared(toSign, deviceID); database::DatabaseManager::getInstance().putSessionSignItem(*SessionSignItem); reply->set_tosign(toSign); return grpc::Status::OK; }; grpc::Status TunnelBrokerServiceImpl::NewSession( grpc::ServerContext *context, const tunnelbroker::NewSessionRequest *request, tunnelbroker::NewSessionResponse *reply) { std::shared_ptr deviceSessionItem; std::shared_ptr sessionSignItem; std::shared_ptr publicKeyItem; const std::string deviceID = request->deviceid(); if (!tools::validateDeviceID(deviceID)) { std::cout << "gRPC: " << "Format validation failed for " << deviceID << std::endl; return grpc::Status( grpc::StatusCode::INVALID_ARGUMENT, "Format validation failed for deviceID"); } const std::string signature = request->signature(); const std::string publicKey = request->publickey(); const std::string newSessionID = tools::generateUUID(); try { sessionSignItem = database::DatabaseManager::getInstance().findSessionSignItem(deviceID); if (sessionSignItem == nullptr) { std::cout << "gRPC: " << "Session sign request not found for deviceID: " << deviceID << std::endl; return grpc::Status( grpc::StatusCode::NOT_FOUND, "Session sign request not found"); } publicKeyItem = database::DatabaseManager::getInstance().findPublicKeyItem(deviceID); if (publicKeyItem == nullptr) { std::shared_ptr newPublicKeyItem = std::make_shared(deviceID, publicKey); database::DatabaseManager::getInstance().putPublicKeyItem( *newPublicKeyItem); } else if (publicKey != publicKeyItem->getPublicKey()) { std::cout << "gRPC: " << "The public key doesn't match for deviceID" << std::endl; return grpc::Status( grpc::StatusCode::PERMISSION_DENIED, "The public key doesn't match for deviceID"); } const std::string verificationMessage = sessionSignItem->getSign(); if (!comm::network::crypto::rsaVerifyString( publicKey, verificationMessage, signature)) { std::cout << "gRPC: " << "Signature for the verification message is not valid" << std::endl; return grpc::Status( grpc::StatusCode::PERMISSION_DENIED, "Signature for the verification message is not valid"); } database::DatabaseManager::getInstance().removeSessionSignItem(deviceID); deviceSessionItem = std::make_shared( newSessionID, deviceID, request->publickey(), request->notifytoken(), tunnelbroker::NewSessionRequest_DeviceTypes_Name(request->devicetype()), request->deviceappversion(), request->deviceos()); database::DatabaseManager::getInstance().putSessionItem(*deviceSessionItem); } catch (std::runtime_error &e) { std::cout << "gRPC: " << "Error while processing 'NewSession' request: " << e.what() << std::endl; return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); } reply->set_sessionid(newSessionID); return grpc::Status::OK; }; grpc::Status TunnelBrokerServiceImpl::Send( grpc::ServerContext *context, const tunnelbroker::SendRequest *request, google::protobuf::Empty *reply) { try { const std::string sessionID = request->sessionid(); if (!tools::validateSessionID(sessionID)) { std::cout << "gRPC: " << "Format validation failed for " << sessionID << std::endl; return grpc::Status( grpc::StatusCode::INVALID_ARGUMENT, "Format validation failed for sessionID"); } std::shared_ptr sessionItem = database::DatabaseManager::getInstance().findSessionItem(sessionID); if (sessionItem == nullptr) { std::cout << "gRPC: " << "Session " << sessionID << " not found" << std::endl; return grpc::Status( grpc::StatusCode::PERMISSION_DENIED, "No such session found. SessionID: " + sessionID); } const std::string clientDeviceID = sessionItem->getDeviceID(); + const std::string messageID = tools::generateUUID(); if (!AmqpManager::getInstance().send( - request->todeviceid(), + messageID, clientDeviceID, + request->todeviceid(), std::string(request->payload()))) { std::cout << "gRPC: " << "Error while publish the message to AMQP" << std::endl; return grpc::Status( grpc::StatusCode::INTERNAL, "Error while publish the message to AMQP"); } } catch (std::runtime_error &e) { std::cout << "gRPC: " << "Error while processing 'Send' request: " << e.what() << std::endl; return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); } return grpc::Status::OK; }; grpc::Status TunnelBrokerServiceImpl::Get( grpc::ServerContext *context, const tunnelbroker::GetRequest *request, grpc::ServerWriter *writer) { try { const std::string sessionID = request->sessionid(); if (!tools::validateSessionID(sessionID)) { std::cout << "gRPC: " << "Format validation failed for " << sessionID << std::endl; return grpc::Status( grpc::StatusCode::INVALID_ARGUMENT, "Format validation failed for sessionID"); } std::shared_ptr sessionItem = database::DatabaseManager::getInstance().findSessionItem(sessionID); if (sessionItem == nullptr) { std::cout << "gRPC: " << "Session " << sessionID << " not found" << std::endl; return grpc::Status( grpc::StatusCode::PERMISSION_DENIED, "No such session found. SessionID: " + sessionID); } const std::string clientDeviceID = sessionItem->getDeviceID(); - std::vector messagesToDeliver; + DeliveryBrokerMessage messageToDeliver; while (1) { - messagesToDeliver = DeliveryBroker::getInstance().get(clientDeviceID); - for (auto const &message : messagesToDeliver) { - tunnelbroker::GetResponse response; - response.set_fromdeviceid(message.fromDeviceID); - response.set_payload(message.payload); - if (!writer->Write(response)) { - throw std::runtime_error( - "gRPC: 'Get' writer error on sending data to the client"); - } - comm::network::AmqpManager::getInstance().ack(message.deliveryTag); + messageToDeliver = DeliveryBroker::getInstance().pop(clientDeviceID); + tunnelbroker::GetResponse response; + response.set_fromdeviceid(messageToDeliver.fromDeviceID); + response.set_payload(messageToDeliver.payload); + if (!writer->Write(response)) { + throw std::runtime_error( + "gRPC: 'Get' writer error on sending data to the client"); } - if (!DeliveryBroker::getInstance().isEmpty(clientDeviceID)) { - DeliveryBroker::getInstance().remove(clientDeviceID); + comm::network::AmqpManager::getInstance().ack( + messageToDeliver.deliveryTag); + if (DeliveryBroker::getInstance().isEmpty(clientDeviceID)) { + DeliveryBroker::getInstance().erase(clientDeviceID); } - DeliveryBroker::getInstance().wait(clientDeviceID); } } catch (std::runtime_error &e) { std::cout << "gRPC: " << "Error while processing 'Get' request: " << e.what() << std::endl; return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); } return grpc::Status::OK; }; } // namespace network } // namespace comm