diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -1,4 +1,5 @@ #include "CryptoModule.h" +#include "Logger.h" #include "PlatformSpecificTools.h" #include "olm/account.hh" #include "olm/session.hh" @@ -213,9 +214,9 @@ const OlmBuffer &oneTimeKeys, size_t keyIndex) { if (this->hasSessionFor(targetUserId)) { - throw std::runtime_error{ - "error initializeOutboundForSendingSession => session already " - "initialized"}; + Logger::log( + "olm session overwritten for the user with id: " + targetUserId); + this->sessions.erase(this->sessions.find(targetUserId)); } std::unique_ptr newSession = Session::createSessionAsInitializer( this->account, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1005,7 +1005,8 @@ try { if (!error.size()) { notificationsKeysResult = - NotificationsCryptoModule::getNotificationsIdentityKeys(); + NotificationsCryptoModule::getNotificationsIdentityKeys( + "Comm"); } } catch (const std::exception &e) { error = e.what(); diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -8,6 +8,8 @@ class NotificationsCryptoModule { const static std::string secureStoreNotificationsAccountDataKey; const static std::string notificationsCryptoAccountID; + const static std::string keyserverHostedNotificationsID; + const static std::string initialEncryptedMessageContent; static void serializeAndFlushCryptoModule( crypto::CryptoModule &cryptoModule, @@ -17,11 +19,23 @@ static crypto::CryptoModule deserializeCryptoModule( const std::string &path, const std::string &picklingKey); + static void callCryptoModule( + std::function caller, + const std::string &callingProcessName); public: static void initializeNotificationsCryptoAccount(const std::string &callingProcessName); static void clearSensitiveData(); - static std::string getNotificationsIdentityKeys(); + static std::string + getNotificationsIdentityKeys(const std::string &callingProcessName); + static crypto::EncryptedData initializeNotificationsSession( + const std::string &identityKeys, + const std::string &prekey, + const std::string &prekeySignature, + const std::string &oneTimeKeys, + const std::string &callingProcessName); + static bool + isNotificationsSessionInitialized(const std::string &callingProcessName); }; } // namespace comm diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -19,6 +19,10 @@ "notificationsCryptoAccountDataKey"; const std::string NotificationsCryptoModule::notificationsCryptoAccountID = "notificationsCryptoAccountDataID"; +const std::string NotificationsCryptoModule::keyserverHostedNotificationsID = + "keyserverHostedNotificationsID"; +const std::string NotificationsCryptoModule::initialEncryptedMessageContent = + "{\"type\": \"init\"}"; crypto::CryptoModule NotificationsCryptoModule::deserializeCryptoModule( const std::string &path, @@ -126,6 +130,28 @@ } } +void NotificationsCryptoModule::callCryptoModule( + std::function caller, + const std::string &callingProcessName) { + CommSecureStore secureStore{}; + folly::Optional picklingKey = secureStore.get( + NotificationsCryptoModule::secureStoreNotificationsAccountDataKey); + if (!picklingKey.hasValue()) { + throw std::runtime_error( + "Attempt to retrieve notifications crypto account before it was " + "correctly initialized."); + } + + const std::string path = + PlatformSpecificTools::getNotificationsCryptoAccountPath(); + crypto::CryptoModule cryptoModule = + NotificationsCryptoModule::deserializeCryptoModule( + path, picklingKey.value()); + caller(cryptoModule); + NotificationsCryptoModule::serializeAndFlushCryptoModule( + cryptoModule, path, picklingKey.value(), callingProcessName); +} + void NotificationsCryptoModule::initializeNotificationsCryptoAccount( const std::string &callingProcessName) { const std::string notificationsCryptoAccountPath = @@ -155,22 +181,47 @@ callingProcessName); } -std::string NotificationsCryptoModule::getNotificationsIdentityKeys() { - CommSecureStore secureStore{}; - folly::Optional picklingKey = secureStore.get( - NotificationsCryptoModule::secureStoreNotificationsAccountDataKey); - if (!picklingKey.hasValue()) { - throw std::runtime_error( - "Attempt to retrieve notifications crypto account before it was " - "correctly initialized."); - } +std::string NotificationsCryptoModule::getNotificationsIdentityKeys( + const std::string &callingProcessName) { + std::string identityKeys; + auto caller = [&identityKeys](crypto::CryptoModule cryptoModule) { + identityKeys = cryptoModule.getIdentityKeys(); + }; + NotificationsCryptoModule::callCryptoModule(caller, callingProcessName); + return identityKeys; +} - const std::string path = - PlatformSpecificTools::getNotificationsCryptoAccountPath(); - crypto::CryptoModule cryptoModule = - NotificationsCryptoModule::deserializeCryptoModule( - path, picklingKey.value()); - return cryptoModule.getIdentityKeys(); +crypto::EncryptedData NotificationsCryptoModule::initializeNotificationsSession( + const std::string &identityKeys, + const std::string &prekey, + const std::string &prekeySignature, + const std::string &oneTimeKeys, + const std::string &callingProcessName) { + crypto::EncryptedData initialEncryptedMessage; + auto caller = [&](crypto::CryptoModule &cryptoModule) { + cryptoModule.initializeOutboundForSendingSession( + NotificationsCryptoModule::keyserverHostedNotificationsID, + std::vector(identityKeys.begin(), identityKeys.end()), + std::vector(prekey.begin(), prekey.end()), + std::vector(prekeySignature.begin(), prekeySignature.end()), + std::vector(oneTimeKeys.begin(), oneTimeKeys.end())); + initialEncryptedMessage = cryptoModule.encrypt( + NotificationsCryptoModule::keyserverHostedNotificationsID, + NotificationsCryptoModule::initialEncryptedMessageContent); + }; + NotificationsCryptoModule::callCryptoModule(caller, callingProcessName); + return initialEncryptedMessage; +} + +bool NotificationsCryptoModule::isNotificationsSessionInitialized( + const std::string &callingProcessName) { + bool sessionInitialized; + auto caller = [&sessionInitialized](crypto::CryptoModule &cryptoModule) { + sessionInitialized = cryptoModule.hasSessionFor( + NotificationsCryptoModule::keyserverHostedNotificationsID); + }; + NotificationsCryptoModule::callCryptoModule(caller, callingProcessName); + return sessionInitialized; } void NotificationsCryptoModule::clearSensitiveData() {