diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -588,7 +588,8 @@ if (maybeUpdatedNotifsCryptoModule.has_value()) { NotificationsCryptoModule::persistNotificationsAccount( maybeUpdatedNotifsCryptoModule.value().first, - maybeUpdatedNotifsCryptoModule.value().second); + maybeUpdatedNotifsCryptoModule.value().second, + true); } DatabaseManager::getQueryExecutor().commitTransaction(); persistencePromise.set_value(); diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -25,6 +25,7 @@ getKeyserverNotificationsSessionKey(const std::string &keyserverID); static std::string getDeviceNotificationsSessionKey(const std::string &deviceID); + static void setNewSynchronizationValue(); static std::string serializeNotificationsSession( std::shared_ptr session, std::string picklingKey); @@ -63,7 +64,8 @@ // notifications account static void persistNotificationsAccount( const std::shared_ptr cryptoModule, - const std::string &picklingKey); + const std::string &picklingKey, + bool setNewSynchronizationValue); static std::optional< std::pair, std::string>> fetchNotificationsAccount(); @@ -121,11 +123,13 @@ std::string sessionPicklingKey, std::string accountPicklingKey, std::string deviceID, + std::optional expectedSynchronizationValue, std::string decryptedData); std::shared_ptr sessionState; std::shared_ptr accountState; std::string accountPicklingKey; std::string deviceID; + std::optional expectedSynchronizationValue; friend NotificationsCryptoModule; public: @@ -137,10 +141,12 @@ std::unique_ptr session, std::string deviceID, std::string picklingKey, + std::optional expectedSynchronizationValue, std::string decryptedData); std::unique_ptr sessionState; std::string deviceID; + std::optional expectedSynchronizationValue; friend NotificationsCryptoModule; public: diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -40,6 +40,7 @@ "256"; const int temporaryFilePathRandomSuffixLength = 32; const std::string notificationsAccountKey = "NOTIFS.ACCOUNT"; +const std::string notifsSyncKey = "NOTIFS.SYNC_KEY"; std::unique_ptr NotificationsCryptoModule::deserializeCryptoModule( @@ -160,6 +161,14 @@ return "DEVICE." + deviceID + ".NOTIFS_SESSION"; } +void NotificationsCryptoModule::setNewSynchronizationValue() { + const std::string newSynchronizationValue = + crypto::Tools::generateRandomString(32); + if (!CommMMKV::setString(notifsSyncKey, newSynchronizationValue)) { + throw std::runtime_error("Failed to persist notifs synchronization value."); + } +} + std::string NotificationsCryptoModule::serializeNotificationsSession( std::shared_ptr session, std::string picklingKey) { @@ -272,7 +281,8 @@ void NotificationsCryptoModule::persistNotificationsAccount( const std::shared_ptr cryptoModule, - const std::string &picklingKey) { + const std::string &picklingKey, + bool setNewSynchronizationValue) { crypto::Persist serializedCryptoModule = cryptoModule->storeAsB64(picklingKey); crypto::OlmBuffer serializedAccount = serializedCryptoModule.account; @@ -283,8 +293,16 @@ "account", serializedAccountString)("picklingKey", picklingKey); std::string serializedAccountJson = folly::toJson(serializedAccountObject); - bool accountPersisted = - CommMMKV::setString(notificationsAccountKey, serializedAccountJson); + bool accountPersisted; + if (setNewSynchronizationValue) { + CommMMKV::ScopedCommMMKVLock{}; + NotificationsCryptoModule::setNewSynchronizationValue(); + accountPersisted = + CommMMKV::setString(notificationsAccountKey, serializedAccountJson); + } else { + accountPersisted = + CommMMKV::setString(notificationsAccountKey, serializedAccountJson); + } if (!accountPersisted) { throw std::runtime_error("Failed to persist notifications crypto account."); @@ -453,33 +471,57 @@ std::string sessionPicklingKey, std::string accountPicklingKey, std::string deviceID, + std::optional expectedSynchronizationValue, std::string decryptedData) : BaseStatefulDecryptResult(sessionPicklingKey, decryptedData), sessionState(session), accountState(account), accountPicklingKey(accountPicklingKey), - deviceID(deviceID) { + deviceID(deviceID), + expectedSynchronizationValue(expectedSynchronizationValue) { } void NotificationsCryptoModule::StatefulPeerInitDecryptResult::flushState() { + CommMMKV::ScopedCommMMKVLock{}; + std::optional synchronizationValue = + CommMMKV::getString(notifsSyncKey); + if (this->expectedSynchronizationValue.has_value() != + synchronizationValue.has_value() || + this->expectedSynchronizationValue.value() != + synchronizationValue.value()) { + return; + } + NotificationsCryptoModule::setNewSynchronizationValue(); NotificationsCryptoModule::persistNotificationsSessionInternal( false, this->deviceID, this->picklingKey, std::move(this->sessionState)); NotificationsCryptoModule::persistNotificationsAccount( - std::move(this->accountState), this->accountPicklingKey); + std::move(this->accountState), this->accountPicklingKey, false); } NotificationsCryptoModule::StatefulPeerDecryptResult::StatefulPeerDecryptResult( std::unique_ptr session, std::string deviceID, std::string picklingKey, + std::optional expectedSynchronizationValue, std::string decryptedData) : NotificationsCryptoModule::BaseStatefulDecryptResult:: BaseStatefulDecryptResult(picklingKey, decryptedData), sessionState(std::move(session)), + expectedSynchronizationValue(expectedSynchronizationValue), deviceID(deviceID) { } void NotificationsCryptoModule::StatefulPeerDecryptResult::flushState() { + CommMMKV::ScopedCommMMKVLock{}; + std::optional synchronizationValue = + CommMMKV::getString(notifsSyncKey); + if (this->expectedSynchronizationValue.has_value() != + synchronizationValue.has_value() || + this->expectedSynchronizationValue.value() != + synchronizationValue.value()) { + return; + } + NotificationsCryptoModule::setNewSynchronizationValue(); NotificationsCryptoModule::persistNotificationsSessionInternal( false, this->deviceID, this->picklingKey, std::move(this->sessionState)); } @@ -571,6 +613,7 @@ std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData = session->encrypt(payload); + NotificationsCryptoModule::setNewSynchronizationValue(); NotificationsCryptoModule::persistNotificationsSessionInternal( false, deviceID, picklingKey, std::move(session)); return encryptedData; @@ -607,15 +650,26 @@ const std::string &deviceID, const std::string &data, const size_t messageType) { + + std::optional expectedSynchronizationValue; + std::optional, std::string>> + maybeSessionWithPicklingKey; + std::optional, std::string>> + maybeAccountWithPicklingKey; + if (messageType != OLM_MESSAGE_TYPE_MESSAGE && messageType != OLM_MESSAGE_TYPE_PRE_KEY) { throw std::runtime_error( "Received message of invalid type from device: " + deviceID); + } else { + CommMMKV::ScopedCommMMKVLock scopedLock{}; + expectedSynchronizationValue = CommMMKV::getString(notifsSyncKey); + maybeSessionWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); + maybeAccountWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); } - auto maybeSessionWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); - if (!maybeSessionWithPicklingKey.has_value() && messageType == OLM_MESSAGE_TYPE_MESSAGE) { throw std::runtime_error( @@ -652,6 +706,7 @@ std::move(maybeSessionWithPicklingKey.value().first), deviceID, maybeSessionWithPicklingKey.value().second, + expectedSynchronizationValue, decryptedData); return std::make_unique( std::move(decryptResult)); @@ -663,8 +718,6 @@ std::string notifInboundKeys = NotificationsInboundKeysProvider::getNotifsInboundKeysForDeviceID( deviceID); - auto maybeAccountWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsAccount(); if (!maybeAccountWithPicklingKey.has_value()) { throw std::runtime_error("Notifications account not initialized."); @@ -712,6 +765,7 @@ sessionPicklingKey, accountWithPicklingKey.second, deviceID, + expectedSynchronizationValue, decryptedData); return std::make_unique( std::move(decryptResult));