diff --git a/services/tunnelbroker/src/Database/DatabaseManager.h b/services/tunnelbroker/src/Database/DatabaseManager.h --- a/services/tunnelbroker/src/Database/DatabaseManager.h +++ b/services/tunnelbroker/src/Database/DatabaseManager.h @@ -48,10 +48,13 @@ void putMessageItem(const MessageItem &item); void putMessageItemsByBatch(std::vector &messageItems); - std::shared_ptr findMessageItem(const std::string &messageID); + std::shared_ptr + findMessageItem(const std::string &toDeviceID, const std::string &messageID); std::vector> findMessageItemsByReceiver(const std::string &toDeviceID); - void removeMessageItem(const std::string &messageID); + void removeMessageItem( + const std::string &toDeviceID, + const std::string &messageID); void removeMessageItemsByIDsForDeviceID( std::vector &messageIDs, const std::string &toDeviceID); diff --git a/services/tunnelbroker/src/Database/DatabaseManager.cpp b/services/tunnelbroker/src/Database/DatabaseManager.cpp --- a/services/tunnelbroker/src/Database/DatabaseManager.cpp +++ b/services/tunnelbroker/src/Database/DatabaseManager.cpp @@ -191,9 +191,13 @@ writeRequests); } -std::shared_ptr -DatabaseManager::findMessageItem(const std::string &messageID) { +std::shared_ptr DatabaseManager::findMessageItem( + const std::string &toDeviceID, + const std::string &messageID) { Aws::DynamoDB::Model::GetItemRequest request; + request.AddKey( + MessageItem::FIELD_TO_DEVICE_ID, + Aws::DynamoDB::Model::AttributeValue(toDeviceID)); request.AddKey( MessageItem::FIELD_MESSAGE_ID, Aws::DynamoDB::Model::AttributeValue(messageID)); @@ -228,8 +232,11 @@ return result; } -void DatabaseManager::removeMessageItem(const std::string &messageID) { - std::shared_ptr item = this->findMessageItem(messageID); +void DatabaseManager::removeMessageItem( + const std::string &toDeviceID, + const std::string &messageID) { + std::shared_ptr item = + this->findMessageItem(toDeviceID, messageID); if (item == nullptr) { return; } diff --git a/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp b/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp --- a/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp +++ b/services/tunnelbroker/src/Service/TunnelbrokerServiceImpl.cpp @@ -228,7 +228,7 @@ messageFromDatabase->getFromDeviceID(), messageFromDatabase->getPayload()); database::DatabaseManager::getInstance().removeMessageItem( - messageFromDatabase->getMessageID()); + clientDeviceID, messageFromDatabase->getMessageID()); } while (1) { messageToDeliver = DeliveryBroker::getInstance().pop(clientDeviceID); @@ -236,7 +236,7 @@ comm::network::AmqpManager::getInstance().ack( messageToDeliver.deliveryTag); database::DatabaseManager::getInstance().removeMessageItem( - messageToDeliver.messageID); + clientDeviceID, messageToDeliver.messageID); // If messages queue for `clientDeviceID` is empty we don't need to store // `folly::MPMCQueue` for it and need to free memory to fix possible // 'ghost' queues in DeliveryBroker. diff --git a/services/tunnelbroker/test/DatabaseManagerTest.cpp b/services/tunnelbroker/test/DatabaseManagerTest.cpp --- a/services/tunnelbroker/test/DatabaseManagerTest.cpp +++ b/services/tunnelbroker/test/DatabaseManagerTest.cpp @@ -46,7 +46,7 @@ database::DatabaseManager::getInstance().putMessageItem(item); std::shared_ptr foundItem = database::DatabaseManager::getInstance().findMessageItem( - item.getMessageID()); + item.getToDeviceID(), item.getMessageID()); EXPECT_NE(foundItem, nullptr); EXPECT_EQ(item.getFromDeviceID(), foundItem->getFromDeviceID()); EXPECT_EQ(item.getToDeviceID(), foundItem->getToDeviceID()); @@ -62,7 +62,7 @@ foundItem->getCreatedAt() <= tools::getCurrentTimestamp(), true); database::DatabaseManager::getInstance().removeMessageItem( - item.getMessageID()); + item.getToDeviceID(), item.getMessageID()); } TEST_F(DatabaseManagerTest, PutAndFoundMessageItemsGeneratedDataIsSame) { @@ -79,7 +79,7 @@ database::DatabaseManager::getInstance().putMessageItem(item); std::shared_ptr foundItem = database::DatabaseManager::getInstance().findMessageItem( - item.getMessageID()); + item.getToDeviceID(), item.getMessageID()); EXPECT_NE(foundItem, nullptr); EXPECT_EQ(item.getFromDeviceID(), foundItem->getFromDeviceID()) << "Generated FromDeviceID \"" << item.getFromDeviceID() @@ -98,7 +98,7 @@ << "\" differs from what is found in the database " << foundItem->getBlobHashes(); database::DatabaseManager::getInstance().removeMessageItem( - item.getMessageID()); + item.getToDeviceID(), item.getMessageID()); } TEST_F(DatabaseManagerTest, BatchPutAndFoundMessagesItemsCountIsSame) { @@ -127,7 +127,7 @@ EXPECT_EQ(foundItems.size(), itemsSize); for (std::shared_ptr messageItem : foundItems) { database::DatabaseManager::getInstance().removeMessageItem( - messageItem->getMessageID()); + messageItem->getToDeviceID(), messageItem->getMessageID()); } } @@ -330,5 +330,5 @@ static_cast(std::time(0) + MESSAGE_RECORD_TTL)), true); database::DatabaseManager::getInstance().removeMessageItem( - item.getMessageID()); + item.getToDeviceID(), item.getMessageID()); }