diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -130,6 +130,13 @@ jsi::String deviceID, double sessionVersion, bool overwrite) override; + virtual jsi::Value initializeNotificationsOutboundSession( + jsi::Runtime &rt, + jsi::String identityKeys, + jsi::String prekey, + jsi::String prekeySignature, + std::optional oneTimeKey, + jsi::String deviceID) override; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override; virtual jsi::Value encryptAndPersist( diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1439,6 +1439,77 @@ }); } +jsi::Value CommCoreModule::initializeNotificationsOutboundSession( + jsi::Runtime &rt, + jsi::String identityKeys, + jsi::String prekey, + jsi::String prekeySignature, + std::optional oneTimeKey, + jsi::String deviceID) { + auto identityKeysCpp{identityKeys.utf8(rt)}; + auto prekeyCpp{prekey.utf8(rt)}; + auto prekeySignatureCpp{prekeySignature.utf8(rt)}; + auto deviceIDCpp{deviceID.utf8(rt)}; + + std::optional oneTimeKeyCpp; + if (oneTimeKey) { + oneTimeKeyCpp = oneTimeKey->utf8(rt); + } + + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=, &innerRt]() { + std::string error; + crypto::EncryptedData result; + try { + std::optional oneTimeKeyBuffer; + if (oneTimeKeyCpp) { + oneTimeKeyBuffer = crypto::OlmBuffer( + oneTimeKeyCpp->begin(), oneTimeKeyCpp->end()); + } + this->notifsCryptoModule->initializeOutboundForSendingSession( + deviceIDCpp, + std::vector( + identityKeysCpp.begin(), identityKeysCpp.end()), + std::vector(prekeyCpp.begin(), prekeyCpp.end()), + std::vector( + prekeySignatureCpp.begin(), prekeySignatureCpp.end()), + oneTimeKeyBuffer); + + result = this->notifsCryptoModule->encrypt( + deviceIDCpp, + NotificationsCryptoModule::initialEncryptedMessageContent); + + std::shared_ptr peerNotificationsSession = + this->notifsCryptoModule->getSessionByDeviceId(deviceIDCpp); + + NotificationsCryptoModule::persistPeerNotificationsSession( + deviceIDCpp, peerNotificationsSession); + + this->notifsCryptoModule->removeSessionByDeviceId(deviceIDCpp); + this->persistCryptoModules(false, true); + } catch (const std::exception &e) { + error = e.what(); + } + this->jsInvoker_->invokeAsync([=, &innerRt]() { + if (error.size()) { + promise->reject(error); + return; + } + auto initialEncryptedDataJSI = jsi::Object(innerRt); + auto message = + std::string{result.message.begin(), result.message.end()}; + auto messageJSI = jsi::String::createFromUtf8(innerRt, message); + initialEncryptedDataJSI.setProperty(innerRt, "message", messageJSI); + initialEncryptedDataJSI.setProperty( + innerRt, "messageType", static_cast(result.messageType)); + promise->resolve(std::move(initialEncryptedDataJSI)); + }); + }; + this->cryptoThread->scheduleTask(job); + }); +} + jsi::Value CommCoreModule::encrypt( jsi::Runtime &rt, jsi::String message, diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -23,17 +23,22 @@ static std::string getKeyserverNotificationsSessionKey(const std::string &keyserverID); + static std::string + getPeerNotificationsSessionKey(const std::string &deviceID); static std::string serializeNotificationsSession( std::shared_ptr session, std::string picklingKey); static std::pair, std::string> deserializeNotificationsSession(const std::string &serializedSession); static void persistNotificationsSessionInternal( - const std::string &keyserverID, + bool isKeyserverSession, + const std::string &senderID, const std::string &picklingKey, std::shared_ptr session); static std::optional, std::string>> - fetchNotificationsSession(const std::string &keyserverID); + fetchNotificationsSession( + bool isKeyserverSession, + const std::string &senderID); public: const static std::string initialEncryptedMessageContent; @@ -43,7 +48,12 @@ static void persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession); + static void persistPeerNotificationsSession( + const std::string &deviceID, + std::shared_ptr peerNotificationsSession); static bool isNotificationsSessionInitialized(const std::string &keyserverID); + static bool + isPeerNotificationsSessionInitialized(const std::string &deviceID); class BaseStatefulDecryptResult { BaseStatefulDecryptResult( diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -151,6 +151,11 @@ return "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; } +std::string NotificationsCryptoModule::getPeerNotificationsSessionKey( + const std::string &deviceID) { + return "DEVICE." + deviceID + ".NOTIFS_SESSION"; +} + std::string NotificationsCryptoModule::serializeNotificationsSession( std::shared_ptr session, std::string picklingKey) { @@ -194,45 +199,65 @@ } void NotificationsCryptoModule::persistNotificationsSessionInternal( - const std::string &keyserverID, + bool isKeyserverSession, + const std::string &senderID, const std::string &picklingKey, std::shared_ptr session) { std::string serializedSession = NotificationsCryptoModule::serializeNotificationsSession( session, picklingKey); - std::string keyserverNotificationsSessionKey = - NotificationsCryptoModule::getKeyserverNotificationsSessionKey( - keyserverID); + + std::string notificationsSessionKey; + std::string persistenceErrorMessage; + + if (isKeyserverSession) { + notificationsSessionKey = + NotificationsCryptoModule::getKeyserverNotificationsSessionKey( + senderID); + persistenceErrorMessage = + "Failed to persist to MMKV notifications session for keyserver: " + + senderID; + } else { + notificationsSessionKey = + NotificationsCryptoModule::getPeerNotificationsSessionKey(senderID); + persistenceErrorMessage = + "Failed to persist to MMKV notifications session for device: " + + senderID; + } bool sessionStored = - CommMMKV::setString(keyserverNotificationsSessionKey, serializedSession); + CommMMKV::setString(notificationsSessionKey, serializedSession); if (!sessionStored) { - throw std::runtime_error( - "Failed to persist to MMKV notifications session for keyserver: " + - keyserverID); + throw std::runtime_error(persistenceErrorMessage); } } std::optional, std::string>> NotificationsCryptoModule::fetchNotificationsSession( - const std::string &keyserverID) { - std::string keyserverNotificationsSessionKey = - NotificationsCryptoModule::getKeyserverNotificationsSessionKey( - keyserverID); + bool isKeyserverSession, + const std::string &senderID) { + std::string notificationsSessionKey; + if (isKeyserverSession) { + notificationsSessionKey = + NotificationsCryptoModule::getKeyserverNotificationsSessionKey( + senderID); + } else { + notificationsSessionKey = + NotificationsCryptoModule::getPeerNotificationsSessionKey(senderID); + } std::optional serializedSession; try { - serializedSession = CommMMKV::getString(keyserverNotificationsSessionKey); + serializedSession = CommMMKV::getString(notificationsSessionKey); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { serializedSession = std::nullopt; } - if (!serializedSession.has_value() && - keyserverID != - ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { + if (!serializedSession.has_value() && isKeyserverSession && + senderID != ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { throw std::runtime_error( - "Missing notifications session for keyserver: " + keyserverID); + "Missing notifications session for keyserver: " + senderID); } else if (!serializedSession.has_value()) { return std::nullopt; } @@ -246,16 +271,31 @@ std::shared_ptr keyserverNotificationsSession) { std::string picklingKey = crypto::Tools::generateRandomString(64); NotificationsCryptoModule::persistNotificationsSessionInternal( - keyserverID, picklingKey, keyserverNotificationsSession); + true, keyserverID, picklingKey, keyserverNotificationsSession); +} + +void NotificationsCryptoModule::persistPeerNotificationsSession( + const std::string &deviceID, + std::shared_ptr peerNotificationsSession) { + std::string picklingKey = crypto::Tools::generateRandomString(64); + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, deviceID, picklingKey, peerNotificationsSession); } bool NotificationsCryptoModule::isNotificationsSessionInitialized( const std::string &keyserverID) { std::string keyserverNotificationsSessionKey = - "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; + getKeyserverNotificationsSessionKey(keyserverID); return CommMMKV::getString(keyserverNotificationsSessionKey).has_value(); } +bool NotificationsCryptoModule::isPeerNotificationsSessionInitialized( + const std::string &deviceID) { + std::string peerNotificationsSessionKey = + getPeerNotificationsSessionKey(deviceID); + return CommMMKV::getString(peerNotificationsSessionKey).has_value(); +} + NotificationsCryptoModule::BaseStatefulDecryptResult::BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData) @@ -280,7 +320,10 @@ void NotificationsCryptoModule::StatefulDecryptResult::flushState() { NotificationsCryptoModule::persistNotificationsSessionInternal( - this->keyserverID, this->picklingKey, std::move(this->sessionState)); + true, + this->keyserverID, + this->picklingKey, + std::move(this->sessionState)); } NotificationsCryptoModule::LegacyStatefulDecryptResult:: @@ -349,7 +392,7 @@ const size_t messageType) { auto sessionWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsSession(keyserverID); + NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { auto statefulDecryptResult = NotificationsCryptoModule::prepareLegacyDecryptedState( @@ -366,7 +409,7 @@ std::string decryptedData = session->decrypt(encryptedData); NotificationsCryptoModule::persistNotificationsSessionInternal( - keyserverID, picklingKey, std::move(session)); + true, keyserverID, picklingKey, std::move(session)); return decryptedData; } @@ -377,7 +420,7 @@ const size_t messageType) { auto sessionWithPicklingKey = - NotificationsCryptoModule::fetchNotificationsSession(keyserverID); + NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { return NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -84,6 +84,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber(), args[4].asBool()); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->initializeNotificationsOutboundSession(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].isNull() || args[3].isUndefined() ? std::nullopt : std::make_optional(args[3].asString(rt)), args[4].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encrypt(rt, args[0].asString(rt), args[1].asString(rt)); } @@ -234,6 +237,7 @@ methodMap_["getKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getKeyserverDataFromNotifStorage}; methodMap_["initializeContentOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentOutboundSession}; methodMap_["initializeContentInboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; + methodMap_["initializeNotificationsOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; methodMap_["encryptAndPersist"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptAndPersist}; methodMap_["decrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decrypt}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -43,6 +43,7 @@ virtual jsi::Value getKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDs) = 0; virtual jsi::Value initializeContentOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) = 0; + virtual jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; virtual jsi::Value encryptAndPersist(jsi::Runtime &rt, jsi::String message, jsi::String deviceID, jsi::String messageID) = 0; virtual jsi::Value decrypt(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID) = 0; @@ -288,6 +289,14 @@ return bridging::callFromJs( rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion), std::move(overwrite)); } + jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::initializeNotificationsOutboundSession) == 6, + "Expected initializeNotificationsOutboundSession(...) to have 6 parameters"); + + return bridging::callFromJs( + rt, &T::initializeNotificationsOutboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(prekey), std::move(prekeySignature), std::move(oneTimeKey), std::move(deviceID)); + } jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override { static_assert( bridging::getParameterCount(&T::encrypt) == 3, diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -91,6 +91,13 @@ sessionVersion: number, overwrite: boolean, ) => Promise; + +initializeNotificationsOutboundSession: ( + identityKeys: string, + prekey: string, + prekeySignature: string, + oneTimeKey: ?string, + deviceID: string, + ) => Promise; +encrypt: (message: string, deviceID: string) => Promise; +encryptAndPersist: ( message: string,