diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -133,6 +133,13 @@ jsi::String deviceID, double sessionVersion, bool overwrite) override; + virtual jsi::Value initializeNotificationsOutboundSession( + jsi::Runtime &rt, + jsi::String identityKeys, + jsi::String prekey, + jsi::String prekeySignature, + std::optional oneTimeKey, + jsi::String deviceID) override; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override; virtual jsi::Value encryptAndPersist( diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -851,6 +851,19 @@ return jsiOneTimeKeysResult; } +jsi::Object parseEncryptedData( + jsi::Runtime &rt, + const crypto::EncryptedData &encryptedData) { + auto encryptedDataJSI = jsi::Object(rt); + auto message = + std::string{encryptedData.message.begin(), encryptedData.message.end()}; + auto messageJSI = jsi::String::createFromUtf8(rt, message); + encryptedDataJSI.setProperty(rt, "message", messageJSI); + encryptedDataJSI.setProperty( + rt, "messageType", static_cast(encryptedData.messageType)); + return encryptedDataJSI; +} + jsi::Value CommCoreModule::getOneTimeKeys(jsi::Runtime &rt, double oneTimeKeysAmount) { return createPromiseAsJSIValue( @@ -1148,6 +1161,8 @@ NotificationsCryptoModule::persistNotificationsSession( keyserverIDCpp, keyserverNotificationsSession); + // Session is removed from the account since it is persisted + // at different location that the account after serialization this->notifsCryptoModule->removeSessionByDeviceId(keyserverIDCpp); this->persistCryptoModules(false, true); } catch (const std::exception &e) { @@ -1373,17 +1388,8 @@ promise->reject(error); return; } - auto initialEncryptedDataJSI = jsi::Object(innerRt); - auto message = std::string{ - initialEncryptedData.message.begin(), - initialEncryptedData.message.end()}; - auto messageJSI = jsi::String::createFromUtf8(innerRt, message); - initialEncryptedDataJSI.setProperty(innerRt, "message", messageJSI); - initialEncryptedDataJSI.setProperty( - innerRt, - "messageType", - static_cast(initialEncryptedData.messageType)); - + auto initialEncryptedDataJSI = + parseEncryptedData(innerRt, initialEncryptedData); auto outboundSessionCreationResultJSI = jsi::Object(innerRt); outboundSessionCreationResultJSI.setProperty( innerRt, "encryptedData", initialEncryptedDataJSI); @@ -1447,6 +1453,73 @@ }); } +jsi::Value CommCoreModule::initializeNotificationsOutboundSession( + jsi::Runtime &rt, + jsi::String identityKeys, + jsi::String prekey, + jsi::String prekeySignature, + std::optional oneTimeKey, + jsi::String deviceID) { + auto identityKeysCpp{identityKeys.utf8(rt)}; + auto prekeyCpp{prekey.utf8(rt)}; + auto prekeySignatureCpp{prekeySignature.utf8(rt)}; + auto deviceIDCpp{deviceID.utf8(rt)}; + + std::optional oneTimeKeyCpp; + if (oneTimeKey) { + oneTimeKeyCpp = oneTimeKey->utf8(rt); + } + + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=, &innerRt]() { + std::string error; + crypto::EncryptedData result; + try { + std::optional oneTimeKeyBuffer; + if (oneTimeKeyCpp) { + oneTimeKeyBuffer = crypto::OlmBuffer( + oneTimeKeyCpp->begin(), oneTimeKeyCpp->end()); + } + this->notifsCryptoModule->initializeOutboundForSendingSession( + deviceIDCpp, + std::vector( + identityKeysCpp.begin(), identityKeysCpp.end()), + std::vector(prekeyCpp.begin(), prekeyCpp.end()), + std::vector( + prekeySignatureCpp.begin(), prekeySignatureCpp.end()), + oneTimeKeyBuffer); + + result = this->notifsCryptoModule->encrypt( + deviceIDCpp, + NotificationsCryptoModule::initialEncryptedMessageContent); + + std::shared_ptr peerNotificationsSession = + this->notifsCryptoModule->getSessionByDeviceId(deviceIDCpp); + + NotificationsCryptoModule::persistDeviceNotificationsSession( + deviceIDCpp, peerNotificationsSession); + + // Session is removed from the account since it is persisted + // at different location that the account after serialization + this->notifsCryptoModule->removeSessionByDeviceId(deviceIDCpp); + this->persistCryptoModules(false, true); + } catch (const std::exception &e) { + error = e.what(); + } + this->jsInvoker_->invokeAsync([=, &innerRt]() { + if (error.size()) { + promise->reject(error); + return; + } + auto initialEncryptedDataJSI = parseEncryptedData(innerRt, result); + promise->resolve(std::move(initialEncryptedDataJSI)); + }); + }; + this->cryptoThread->scheduleTask(job); + }); +} + jsi::Value CommCoreModule::encrypt( jsi::Runtime &rt, jsi::String message, diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -23,17 +23,22 @@ static std::string getKeyserverNotificationsSessionKey(const std::string &keyserverID); + static std::string + getDeviceNotificationsSessionKey(const std::string &deviceID); static std::string serializeNotificationsSession( std::shared_ptr session, std::string picklingKey); static std::pair, std::string> deserializeNotificationsSession(const std::string &serializedSession); static void persistNotificationsSessionInternal( - const std::string &keyserverID, + bool isKeyserverSession, + const std::string &senderID, const std::string &picklingKey, std::shared_ptr session); static std::optional, std::string>> - fetchNotificationsSession(const std::string &keyserverID); + fetchNotificationsSession( + bool isKeyserverSession, + const std::string &senderID); public: const static std::string initialEncryptedMessageContent; @@ -43,7 +48,12 @@ static void persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession); + static void persistDeviceNotificationsSession( + const std::string &deviceID, + std::shared_ptr peerNotificationsSession); static bool isNotificationsSessionInitialized(const std::string &keyserverID); + static bool + isDeviceNotificationsSessionInitialized(const std::string &deviceID); class BaseStatefulDecryptResult { BaseStatefulDecryptResult( diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -151,6 +151,11 @@ return "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; } +std::string NotificationsCryptoModule::getDeviceNotificationsSessionKey( + const std::string &deviceID) { + return "DEVICE." + deviceID + ".NOTIFS_SESSION"; +} + std::string NotificationsCryptoModule::serializeNotificationsSession( std::shared_ptr session, std::string picklingKey) { @@ -194,45 +199,65 @@ } void NotificationsCryptoModule::persistNotificationsSessionInternal( - const std::string &keyserverID, + bool isKeyserverSession, + const std::string &senderID, const std::string &picklingKey, std::shared_ptr session) { std::string serializedSession = NotificationsCryptoModule::serializeNotificationsSession( session, picklingKey); - std::string keyserverNotificationsSessionKey = - NotificationsCryptoModule::getKeyserverNotificationsSessionKey( - keyserverID); + + std::string notificationsSessionKey; + std::string persistenceErrorMessage; + + if (isKeyserverSession) { + notificationsSessionKey = + NotificationsCryptoModule::getKeyserverNotificationsSessionKey( + senderID); + persistenceErrorMessage = + "Failed to persist to MMKV notifications session for keyserver: " + + senderID; + } else { + notificationsSessionKey = + NotificationsCryptoModule::getDeviceNotificationsSessionKey(senderID); + persistenceErrorMessage = + "Failed to persist to MMKV notifications session for device: " + + senderID; + } bool sessionStored = - CommMMKV::setString(keyserverNotificationsSessionKey, serializedSession); + CommMMKV::setString(notificationsSessionKey, serializedSession); if (!sessionStored) { - throw std::runtime_error( - "Failed to persist to MMKV notifications session for keyserver: " + - keyserverID); + throw std::runtime_error(persistenceErrorMessage); } } std::optional, std::string>> NotificationsCryptoModule::fetchNotificationsSession( - const std::string &keyserverID) { - std::string keyserverNotificationsSessionKey = - NotificationsCryptoModule::getKeyserverNotificationsSessionKey( - keyserverID); + bool isKeyserverSession, + const std::string &senderID) { + std::string notificationsSessionKey; + if (isKeyserverSession) { + notificationsSessionKey = + NotificationsCryptoModule::getKeyserverNotificationsSessionKey( + senderID); + } else { + notificationsSessionKey = + NotificationsCryptoModule::getDeviceNotificationsSessionKey(senderID); + } std::optional serializedSession; try { - serializedSession = CommMMKV::getString(keyserverNotificationsSessionKey); + serializedSession = CommMMKV::getString(notificationsSessionKey); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { serializedSession = std::nullopt; } - if (!serializedSession.has_value() && - keyserverID != - ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { + if (!serializedSession.has_value() && isKeyserverSession && + senderID != ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { throw std::runtime_error( - "Missing notifications session for keyserver: " + keyserverID); + "Missing notifications session for keyserver: " + senderID); } else if (!serializedSession.has_value()) { return std::nullopt; } @@ -246,16 +271,31 @@ std::shared_ptr keyserverNotificationsSession) { std::string picklingKey = crypto::Tools::generateRandomString(64); NotificationsCryptoModule::persistNotificationsSessionInternal( - keyserverID, picklingKey, keyserverNotificationsSession); + true, keyserverID, picklingKey, keyserverNotificationsSession); +} + +void NotificationsCryptoModule::persistDeviceNotificationsSession( + const std::string &deviceID, + std::shared_ptr peerNotificationsSession) { + std::string picklingKey = crypto::Tools::generateRandomString(64); + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, deviceID, picklingKey, peerNotificationsSession); } bool NotificationsCryptoModule::isNotificationsSessionInitialized( const std::string &keyserverID) { std::string keyserverNotificationsSessionKey = - "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; + getKeyserverNotificationsSessionKey(keyserverID); return CommMMKV::getString(keyserverNotificationsSessionKey).has_value(); } +bool NotificationsCryptoModule::isDeviceNotificationsSessionInitialized( + const std::string &deviceID) { + std::string peerNotificationsSessionKey = + getDeviceNotificationsSessionKey(deviceID); + return CommMMKV::getString(peerNotificationsSessionKey).has_value(); +} + NotificationsCryptoModule::BaseStatefulDecryptResult::BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData) @@ -280,7 +320,10 @@ void NotificationsCryptoModule::StatefulDecryptResult::flushState() { NotificationsCryptoModule::persistNotificationsSessionInternal( - this->keyserverID, this->picklingKey, std::move(this->sessionState)); + true, + this->keyserverID, + this->picklingKey, + std::move(this->sessionState)); } NotificationsCryptoModule::LegacyStatefulDecryptResult:: @@ -349,7 +392,7 @@ const size_t messageType) { auto sessionWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsSession(keyserverID); + NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { auto statefulDecryptResult = NotificationsCryptoModule::prepareLegacyDecryptedState( @@ -366,7 +409,7 @@ std::string decryptedData = session->decrypt(encryptedData); NotificationsCryptoModule::persistNotificationsSessionInternal( - keyserverID, picklingKey, std::move(session)); + true, keyserverID, picklingKey, std::move(session)); return decryptedData; } @@ -377,7 +420,7 @@ const size_t messageType) { auto sessionWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsSession(keyserverID); + NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { return NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -84,6 +84,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber(), args[4].asBool()); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->initializeNotificationsOutboundSession(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].isNull() || args[3].isUndefined() ? std::nullopt : std::make_optional(args[3].asString(rt)), args[4].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encrypt(rt, args[0].asString(rt), args[1].asString(rt)); } @@ -243,6 +246,7 @@ methodMap_["getKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getKeyserverDataFromNotifStorage}; methodMap_["initializeContentOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentOutboundSession}; methodMap_["initializeContentInboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; + methodMap_["initializeNotificationsOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; methodMap_["encryptAndPersist"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptAndPersist}; methodMap_["decrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decrypt}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -43,6 +43,7 @@ virtual jsi::Value getKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDs) = 0; virtual jsi::Value initializeContentOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) = 0; + virtual jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; virtual jsi::Value encryptAndPersist(jsi::Runtime &rt, jsi::String message, jsi::String deviceID, jsi::String messageID) = 0; virtual jsi::Value decrypt(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID) = 0; @@ -291,6 +292,14 @@ return bridging::callFromJs( rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion), std::move(overwrite)); } + jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::initializeNotificationsOutboundSession) == 6, + "Expected initializeNotificationsOutboundSession(...) to have 6 parameters"); + + return bridging::callFromJs( + rt, &T::initializeNotificationsOutboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(prekey), std::move(prekeySignature), std::move(oneTimeKey), std::move(deviceID)); + } jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override { static_assert( bridging::getParameterCount(&T::encrypt) == 3, diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -91,6 +91,13 @@ sessionVersion: number, overwrite: boolean, ) => Promise; + +initializeNotificationsOutboundSession: ( + identityKeys: string, + prekey: string, + prekeySignature: string, + oneTimeKey: ?string, + deviceID: string, + ) => Promise; +encrypt: (message: string, deviceID: string) => Promise; +encryptAndPersist: ( message: string,