diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -34,7 +34,7 @@ const std::string publicCryptoAccountID = "publicCryptoAccountID"; std::unique_ptr contentCryptoModule; const std::string notifsCryptoAccountID = "notifsCryptoAccountID"; - std::unique_ptr notifsCryptoModule; + DraftStore draftStore; ThreadStore threadStore; MessageStore messageStore; @@ -48,8 +48,11 @@ ThreadActivityStore threadActivityStore; EntryStore entryStore; - void - persistCryptoModules(bool persistContentModule, bool persistNotifsModule); + void persistCryptoModules( + bool persistContentModule, + const std::optional< + std::pair, std::string>> + &maybeUpdatedNotifsCryptoModule); jsi::Value createNewBackupInternal( jsi::Runtime &rt, std::string backupSecret, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -553,10 +553,12 @@ void CommCoreModule::persistCryptoModules( bool persistContentModule, - bool persistNotifsModule) { + const std::optional< + std::pair, std::string>> + &maybeUpdatedNotifsCryptoModule) { std::string storedSecretKey = getAccountDataKey(secureStoreAccountDataKey); - if (!persistContentModule && !persistNotifsModule) { + if (!persistContentModule && !maybeUpdatedNotifsCryptoModule.has_value()) { return; } @@ -565,15 +567,10 @@ newContentPersist = this->contentCryptoModule->storeAsB64(storedSecretKey); } - crypto::Persist newNotifsPersist; - if (persistNotifsModule) { - newNotifsPersist = this->notifsCryptoModule->storeAsB64(storedSecretKey); - } - std::promise persistencePromise; std::future persistenceFuture = persistencePromise.get_future(); GlobalDBSingleton::instance.scheduleOrRunCancellable( - [=, &persistencePromise]() { + [=, &persistencePromise, &maybeUpdatedNotifsCryptoModule]() { try { DatabaseManager::getQueryExecutor().beginTransaction(); if (persistContentModule) { @@ -581,10 +578,10 @@ DatabaseManager::getQueryExecutor().getContentAccountID(), newContentPersist); } - if (persistNotifsModule) { - DatabaseManager::getQueryExecutor().storeOlmPersistData( - DatabaseManager::getQueryExecutor().getNotifsAccountID(), - newNotifsPersist); + if (maybeUpdatedNotifsCryptoModule.has_value()) { + NotificationsCryptoModule::persistNotificationsAccount( + maybeUpdatedNotifsCryptoModule.value().first, + maybeUpdatedNotifsCryptoModule.value().second); } DatabaseManager::getQueryExecutor().commitTransaction(); persistencePromise.set_value(); @@ -611,6 +608,7 @@ crypto::Persist contentPersist; crypto::Persist notifsPersist; std::string error; + bool notifsCryptoAccountPresentInMMKV = false; try { std::optional contentAccountData = DatabaseManager::getQueryExecutor().getOlmPersistAccountData( @@ -633,9 +631,17 @@ } } - std::optional notifsAccountData = - DatabaseManager::getQueryExecutor().getOlmPersistAccountData( - DatabaseManager::getQueryExecutor().getNotifsAccountID()); + std::optional notifsAccountData; + + if (NotificationsCryptoModule:: + isNotificationsAccountInitialized()) { + notifsCryptoAccountPresentInMMKV = true; + } else { + notifsAccountData = + DatabaseManager::getQueryExecutor().getOlmPersistAccountData( + DatabaseManager::getQueryExecutor().getNotifsAccountID()); + } + if (notifsAccountData.has_value()) { notifsPersist.account = crypto::OlmBuffer( notifsAccountData->begin(), notifsAccountData->end()); @@ -652,14 +658,22 @@ storedSecretKey.value(), contentPersist)); - this->notifsCryptoModule.reset(new crypto::CryptoModule( - this->notifsCryptoAccountID, - storedSecretKey.value(), - notifsPersist)); + std::optional< + std::pair, std::string>> + maybeNotifsCryptoAccountToPersist; + + if (!notifsCryptoAccountPresentInMMKV) { + maybeNotifsCryptoAccountToPersist = { + std::make_unique( + this->notifsCryptoAccountID, + storedSecretKey.value(), + notifsPersist), + storedSecretKey.value()}; + } try { this->persistCryptoModules( - contentPersist.isEmpty(), notifsPersist.isEmpty()); + contentPersist.isEmpty(), maybeNotifsCryptoAccountToPersist); } catch (const std::exception &e) { error = e.what(); } @@ -688,12 +702,12 @@ std::string primaryKeysResult; std::string notificationsKeysResult; if (this->contentCryptoModule == nullptr || - this->notifsCryptoModule == nullptr) { + !NotificationsCryptoModule::isNotificationsAccountInitialized()) { error = "user has not been initialized"; } else { primaryKeysResult = this->contentCryptoModule->getIdentityKeys(); notificationsKeysResult = - this->notifsCryptoModule->getIdentityKeys(); + NotificationsCryptoModule::getIdentityKeys(); } std::string notificationsCurve25519Cpp, notificationsEd25519Cpp, @@ -852,7 +866,7 @@ std::string contentResult; std::string notifResult; if (this->contentCryptoModule == nullptr || - this->notifsCryptoModule == nullptr) { + !NotificationsCryptoModule::isNotificationsAccountInitialized()) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); @@ -862,9 +876,14 @@ contentResult = this->contentCryptoModule->getOneTimeKeysForPublishing( oneTimeKeysAmount); - notifResult = this->notifsCryptoModule->getOneTimeKeysForPublishing( - oneTimeKeysAmount); - this->persistCryptoModules(true, true); + std::pair, std::string> + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount() + .value(); + notifResult = notifsCryptoModuleWithPicklingKey.first + ->getOneTimeKeysForPublishing(oneTimeKeysAmount); + this->persistCryptoModules( + true, std::move(notifsCryptoModuleWithPicklingKey)); } catch (const std::exception &e) { error = e.what(); } @@ -897,19 +916,25 @@ std::optional maybeNotifsPrekeyToUpload; if (this->contentCryptoModule == nullptr || - this->notifsCryptoModule == nullptr) { + !NotificationsCryptoModule::isNotificationsAccountInitialized()) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } + std::optional< + std::pair, std::string>> + notifsCryptoModuleWithPicklingKey; try { + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); maybeContentPrekeyToUpload = this->contentCryptoModule->validatePrekey(); maybeNotifsPrekeyToUpload = - this->notifsCryptoModule->validatePrekey(); - this->persistCryptoModules(true, true); + notifsCryptoModuleWithPicklingKey.value() + .first->validatePrekey(); + this->persistCryptoModules(true, notifsCryptoModuleWithPicklingKey); if (!maybeContentPrekeyToUpload.has_value()) { maybeContentPrekeyToUpload = @@ -917,7 +942,8 @@ } if (!maybeNotifsPrekeyToUpload.has_value()) { maybeNotifsPrekeyToUpload = - this->notifsCryptoModule->getUnpublishedPrekey(); + notifsCryptoModuleWithPicklingKey.value() + .first->getUnpublishedPrekey(); } } catch (const std::exception &e) { error = e.what(); @@ -947,7 +973,8 @@ if (maybeNotifsPrekeyToUpload.has_value()) { notifsPrekeyToUpload = maybeNotifsPrekeyToUpload.value(); } else { - notifsPrekeyToUpload = this->notifsCryptoModule->getPrekey(); + notifsPrekeyToUpload = + notifsCryptoModuleWithPicklingKey.value().first->getPrekey(); } std::string prekeyUploadError; @@ -956,7 +983,8 @@ std::string contentPrekeySignature = this->contentCryptoModule->getPrekeySignature(); std::string notifsPrekeySignature = - this->notifsCryptoModule->getPrekeySignature(); + notifsCryptoModuleWithPicklingKey.value() + .first->getPrekeySignature(); try { std::promise prekeyPromise; @@ -989,8 +1017,10 @@ if (!prekeyUploadError.size()) { this->contentCryptoModule->markPrekeyAsPublished(); - this->notifsCryptoModule->markPrekeyAsPublished(); - this->persistCryptoModules(true, true); + notifsCryptoModuleWithPicklingKey.value() + .first->markPrekeyAsPublished(); + this->persistCryptoModules( + true, notifsCryptoModuleWithPicklingKey); } } catch (std::exception &e) { error = e.what(); @@ -1023,13 +1053,19 @@ std::optional notifPrekeyBlob; if (this->contentCryptoModule == nullptr || - this->notifsCryptoModule == nullptr) { + !NotificationsCryptoModule::isNotificationsAccountInitialized()) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } + + std::optional< + std::pair, std::string>> + notifsCryptoModuleWithPicklingKey; try { + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); contentPrekeyBlob = this->contentCryptoModule->validatePrekey(); if (!contentPrekeyBlob) { contentPrekeyBlob = @@ -1039,20 +1075,22 @@ contentPrekeyBlob = this->contentCryptoModule->getPrekey(); } - notifPrekeyBlob = this->notifsCryptoModule->validatePrekey(); + notifPrekeyBlob = notifsCryptoModuleWithPicklingKey.value() + .first->validatePrekey(); if (!notifPrekeyBlob) { - notifPrekeyBlob = - this->notifsCryptoModule->getUnpublishedPrekey(); + notifPrekeyBlob = notifsCryptoModuleWithPicklingKey.value() + .first->getUnpublishedPrekey(); } if (!notifPrekeyBlob) { - notifPrekeyBlob = this->notifsCryptoModule->getPrekey(); + notifPrekeyBlob = + notifsCryptoModuleWithPicklingKey.value().first->getPrekey(); } - this->persistCryptoModules(true, true); + this->persistCryptoModules(true, notifsCryptoModuleWithPicklingKey); contentPrekeySignature = this->contentCryptoModule->getPrekeySignature(); - notifPrekeySignature = - this->notifsCryptoModule->getPrekeySignature(); + notifPrekeySignature = notifsCryptoModuleWithPicklingKey.value() + .first->getPrekeySignature(); contentPrekey = parseOLMPrekey(contentPrekeyBlob.value()); notifPrekey = parseOLMPrekey(notifPrekeyBlob.value()); @@ -1114,34 +1152,43 @@ taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData result; + std::optional< + std::pair, std::string>> + notifsCryptoModuleWithPicklingKey; try { + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); std::optional oneTimeKeyBuffer; if (oneTimeKeyCpp) { oneTimeKeyBuffer = crypto::OlmBuffer( oneTimeKeyCpp->begin(), oneTimeKeyCpp->end()); } - this->notifsCryptoModule->initializeOutboundForSendingSession( - keyserverIDCpp, - std::vector( - identityKeysCpp.begin(), identityKeysCpp.end()), - std::vector(prekeyCpp.begin(), prekeyCpp.end()), - std::vector( - prekeySignatureCpp.begin(), prekeySignatureCpp.end()), - oneTimeKeyBuffer); + notifsCryptoModuleWithPicklingKey.value() + .first->initializeOutboundForSendingSession( + keyserverIDCpp, + std::vector( + identityKeysCpp.begin(), identityKeysCpp.end()), + std::vector(prekeyCpp.begin(), prekeyCpp.end()), + std::vector( + prekeySignatureCpp.begin(), prekeySignatureCpp.end()), + oneTimeKeyBuffer); - result = this->notifsCryptoModule->encrypt( + result = notifsCryptoModuleWithPicklingKey.value().first->encrypt( keyserverIDCpp, NotificationsCryptoModule::initialEncryptedMessageContent); std::shared_ptr keyserverNotificationsSession = - this->notifsCryptoModule->getSessionByDeviceId(keyserverIDCpp); + notifsCryptoModuleWithPicklingKey.value() + .first->getSessionByDeviceId(keyserverIDCpp); NotificationsCryptoModule::persistNotificationsSession( keyserverIDCpp, keyserverNotificationsSession); - this->notifsCryptoModule->removeSessionByDeviceId(keyserverIDCpp); - this->persistCryptoModules(false, true); + notifsCryptoModuleWithPicklingKey.value() + .first->removeSessionByDeviceId(keyserverIDCpp); + this->persistCryptoModules( + false, notifsCryptoModuleWithPicklingKey); } catch (const std::exception &e) { error = e.what(); } @@ -1383,7 +1430,7 @@ initialEncryptedData = contentCryptoModule->encrypt(deviceIDCpp, initMessage); - this->persistCryptoModules(true, false); + this->persistCryptoModules(true, std::nullopt); } catch (const std::exception &e) { error = e.what(); } @@ -1449,7 +1496,7 @@ messageType}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); - this->persistCryptoModules(true, false); + this->persistCryptoModules(true, std::nullopt); } catch (const std::exception &e) { error = e.what(); } @@ -1504,33 +1551,42 @@ taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData result; + std::optional< + std::pair, std::string>> + notifsCryptoModuleWithPicklingKey; try { + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); std::optional oneTimeKeyBuffer; if (oneTimeKeyCpp) { oneTimeKeyBuffer = crypto::OlmBuffer( oneTimeKeyCpp->begin(), oneTimeKeyCpp->end()); } - this->notifsCryptoModule->initializeOutboundForSendingSession( - deviceIDCpp, - std::vector( - identityKeysCpp.begin(), identityKeysCpp.end()), - std::vector(prekeyCpp.begin(), prekeyCpp.end()), - std::vector( - prekeySignatureCpp.begin(), prekeySignatureCpp.end()), - oneTimeKeyBuffer); + notifsCryptoModuleWithPicklingKey.value() + .first->initializeOutboundForSendingSession( + deviceIDCpp, + std::vector( + identityKeysCpp.begin(), identityKeysCpp.end()), + std::vector(prekeyCpp.begin(), prekeyCpp.end()), + std::vector( + prekeySignatureCpp.begin(), prekeySignatureCpp.end()), + oneTimeKeyBuffer); - result = this->notifsCryptoModule->encrypt( + result = notifsCryptoModuleWithPicklingKey.value().first->encrypt( deviceIDCpp, NotificationsCryptoModule::initialEncryptedMessageContent); std::shared_ptr peerNotificationsSession = - this->notifsCryptoModule->getSessionByDeviceId(deviceIDCpp); + notifsCryptoModuleWithPicklingKey.value() + .first->getSessionByDeviceId(deviceIDCpp); NotificationsCryptoModule::persistPeerNotificationsSession( deviceIDCpp, peerNotificationsSession); - this->notifsCryptoModule->removeSessionByDeviceId(deviceIDCpp); - this->persistCryptoModules(false, true); + notifsCryptoModuleWithPicklingKey.value() + .first->removeSessionByDeviceId(deviceIDCpp); + this->persistCryptoModules( + false, notifsCryptoModuleWithPicklingKey); } catch (const std::exception &e) { error = e.what(); } @@ -1567,7 +1623,7 @@ try { encryptedMessage = contentCryptoModule->encrypt(deviceIDCpp, messageCpp); - this->persistCryptoModules(true, false); + this->persistCryptoModules(true, std::nullopt); } catch (const std::exception &e) { error = e.what(); } @@ -1731,7 +1787,7 @@ messageType}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); - this->persistCryptoModules(true, false); + this->persistCryptoModules(true, std::nullopt); } catch (const std::exception &e) { error = e.what(); } @@ -2718,17 +2774,23 @@ std::string error; if (this->contentCryptoModule == nullptr || - this->notifsCryptoModule == nullptr) { + !NotificationsCryptoModule::isNotificationsAccountInitialized()) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } + std::optional< + std::pair, std::string>> + notifsCryptoModuleWithPicklingKey; try { + notifsCryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); this->contentCryptoModule->markPrekeyAsPublished(); - this->notifsCryptoModule->markPrekeyAsPublished(); - this->persistCryptoModules(true, true); + notifsCryptoModuleWithPicklingKey.value() + .first->markPrekeyAsPublished(); + this->persistCryptoModules(true, notifsCryptoModuleWithPicklingKey); } catch (std::exception &e) { error = e.what(); } diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -45,6 +45,8 @@ const static int olmEncryptedTypeMessage; static void clearSensitiveData(); + + // notifications sessions static void persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession); @@ -55,6 +57,16 @@ static bool isPeerNotificationsSessionInitialized(const std::string &deviceID); + // notifications account + static void persistNotificationsAccount( + const std::unique_ptr &cryptoModule, + const std::string &picklingKey); + static std::optional< + std::pair, std::string>> + fetchNotificationsAccount(); + static bool isNotificationsAccountInitialized(); + static std::string getIdentityKeys(); + class BaseStatefulDecryptResult { BaseStatefulDecryptResult( std::string picklingKey, diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -36,6 +36,7 @@ const std::string ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage = "256"; const int temporaryFilePathRandomSuffixLength = 32; +const std::string notificationsAccountKey = "NOTIFS.ACCOUNT"; std::unique_ptr NotificationsCryptoModule::deserializeCryptoModule( @@ -266,6 +267,62 @@ serializedSession.value()); } +void NotificationsCryptoModule::persistNotificationsAccount( + const std::unique_ptr &cryptoModule, + const std::string &picklingKey) { + crypto::Persist serializedCryptoModule = + cryptoModule->storeAsB64(picklingKey); + crypto::OlmBuffer serializedAccount = serializedCryptoModule.account; + std::string serializedAccountString{ + serializedAccount.begin(), serializedAccount.end()}; + + folly::dynamic serializedAccountObject = folly::dynamic::object( + "account", serializedAccountString)("picklingKey", picklingKey); + std::string serializedAccountJson = folly::toJson(serializedAccountObject); + + bool accountPersisted = + CommMMKV::setString(notificationsAccountKey, serializedAccountJson); + + if (!accountPersisted) { + throw std::runtime_error("Failed to persist notifications crypto account."); + } +} + +std::optional, std::string>> +NotificationsCryptoModule::fetchNotificationsAccount() { + std::optional serializedAccountJson; + try { + serializedAccountJson = CommMMKV::getString(notificationsAccountKey); + } catch (const CommMMKV::InitFromNSEForbiddenError &e) { + serializedAccountJson = std::nullopt; + } + + if (!serializedAccountJson.has_value()) { + return std::nullopt; + } + + folly::dynamic serializedAccountObject; + try { + serializedAccountObject = folly::parseJson(serializedAccountJson.value()); + } catch (const folly::json::parse_error &e) { + throw std::runtime_error( + "Notifications account deserialization failed with reason: " + + std::string(e.what())); + } + + std::string picklingKey = serializedAccountObject["picklingKey"].asString(); + std::string accountString = serializedAccountObject["account"].asString(); + crypto::OlmBuffer account = + crypto::OlmBuffer{accountString.begin(), accountString.end()}; + crypto::Persist serializedCryptoModule{account, {}}; + + std::unique_ptr cryptoModule = + std::make_unique( + notificationsCryptoAccountID, picklingKey, serializedCryptoModule); + + return {{std::move(cryptoModule), picklingKey}}; +} + void NotificationsCryptoModule::persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession) { @@ -296,6 +353,21 @@ return CommMMKV::getString(peerNotificationsSessionKey).has_value(); } +// notifications account + +bool NotificationsCryptoModule::isNotificationsAccountInitialized() { + return fetchNotificationsAccount().has_value(); +} + +std::string NotificationsCryptoModule::getIdentityKeys() { + auto cryptoModuleWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); + if (!cryptoModuleWithPicklingKey.has_value()) { + throw std::runtime_error("Notifications crypto account not initialized."); + } + return cryptoModuleWithPicklingKey.value().first->getIdentityKeys(); +} + NotificationsCryptoModule::BaseStatefulDecryptResult::BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData)