diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -383,26 +383,7 @@ if (!this->hasSessionFor(targetDeviceId)) { throw std::runtime_error{"error encrypt => uninitialized session"}; } - OlmSession *session = this->sessions.at(targetDeviceId)->getOlmSession(); - OlmBuffer encryptedMessage( - ::olm_encrypt_message_length(session, content.size())); - OlmBuffer messageRandom; - PlatformSpecificTools::generateSecureRandomBytes( - messageRandom, ::olm_encrypt_random_length(session)); - size_t messageType = ::olm_encrypt_message_type(session); - if (-1 == - ::olm_encrypt( - session, - (uint8_t *)content.data(), - content.size(), - messageRandom.data(), - messageRandom.size(), - encryptedMessage.data(), - encryptedMessage.size())) { - throw std::runtime_error{ - "error encrypt => " + std::string{::olm_session_last_error(session)}}; - } - return {encryptedMessage, messageType}; + return this->sessions.at(targetDeviceId)->encrypt(content); } std::string CryptoModule::decrypt( diff --git a/native/cpp/CommonCpp/CryptoTools/Session.h b/native/cpp/CommonCpp/CryptoTools/Session.h --- a/native/cpp/CommonCpp/CryptoTools/Session.h +++ b/native/cpp/CommonCpp/CryptoTools/Session.h @@ -33,6 +33,7 @@ restoreFromB64(const std::string &secretKey, OlmBuffer &b64); OlmSession *getOlmSession(); std::string decrypt(EncryptedData &encryptedData); + EncryptedData encrypt(const std::string &content); int getVersion(); void setVersion(int newVersion); }; diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -179,6 +179,29 @@ return std::string{(char *)decryptedMessage.data(), decryptedSize}; } +EncryptedData Session::encrypt(const std::string &content) { + OlmSession *session = this->getOlmSession(); + OlmBuffer encryptedMessage( + ::olm_encrypt_message_length(session, content.size())); + OlmBuffer messageRandom; + PlatformSpecificTools::generateSecureRandomBytes( + messageRandom, ::olm_encrypt_random_length(session)); + size_t messageType = ::olm_encrypt_message_type(session); + if (-1 == + ::olm_encrypt( + session, + (uint8_t *)content.data(), + content.size(), + messageRandom.data(), + messageRandom.size(), + encryptedMessage.data(), + encryptedMessage.size())) { + throw std::runtime_error{ + "error encrypt => " + std::string{::olm_session_last_error(session)}}; + } + return {encryptedMessage, messageType}; +} + int Session::getVersion() { return this->version; } diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -142,6 +142,10 @@ jsi::String deviceID) override; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override; + virtual jsi::Value encryptNotification( + jsi::Runtime &rt, + jsi::String payload, + jsi::String deviceID) override; virtual jsi::Value encryptAndPersist( jsi::Runtime &rt, jsi::String message, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1543,16 +1543,40 @@ promise->reject(error); return; } - auto encryptedDataJSI = jsi::Object(innerRt); - auto message = std::string{ - encryptedMessage.message.begin(), - encryptedMessage.message.end()}; - auto messageJSI = jsi::String::createFromUtf8(innerRt, message); - encryptedDataJSI.setProperty(innerRt, "message", messageJSI); - encryptedDataJSI.setProperty( - innerRt, - "messageType", - static_cast(encryptedMessage.messageType)); + auto encryptedDataJSI = + parseEncryptedData(innerRt, encryptedMessage); + promise->resolve(std::move(encryptedDataJSI)); + }); + }; + this->cryptoThread->scheduleTask(job); + }); +} + +jsi::Value CommCoreModule::encryptNotification( + jsi::Runtime &rt, + jsi::String payload, + jsi::String deviceID) { + auto payloadCpp{payload.utf8(rt)}; + auto deviceIDCpp{deviceID.utf8(rt)}; + + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=, &innerRt]() { + std::string error; + crypto::EncryptedData result; + try { + result = + NotificationsCryptoModule::encrypt(deviceIDCpp, payloadCpp); + } catch (const std::exception &e) { + error = e.what(); + } + + this->jsInvoker_->invokeAsync([=, &innerRt]() { + if (error.size()) { + promise->reject(error); + return; + } + auto encryptedDataJSI = parseEncryptedData(innerRt, result); promise->resolve(std::move(encryptedDataJSI)); }); }; @@ -1623,16 +1647,8 @@ promise->reject(error); return; } - auto encryptedDataJSI = jsi::Object(innerRt); - auto message = std::string{ - encryptedMessage.message.begin(), - encryptedMessage.message.end()}; - auto messageJSI = jsi::String::createFromUtf8(innerRt, message); - encryptedDataJSI.setProperty(innerRt, "message", messageJSI); - encryptedDataJSI.setProperty( - innerRt, - "messageType", - static_cast(encryptedMessage.messageType)); + auto encryptedDataJSI = + parseEncryptedData(innerRt, encryptedMessage); promise->resolve(std::move(encryptedDataJSI)); }); }; diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -116,6 +116,9 @@ const std::string &data, const size_t messageType); + static crypto::EncryptedData + encrypt(const std::string &deviceID, const std::string &payload); + static void flushState(std::unique_ptr statefulDecryptResult); }; diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -413,6 +413,24 @@ return decryptedData; } +crypto::EncryptedData NotificationsCryptoModule::encrypt( + const std::string &deviceID, + const std::string &payload) { + auto sessionWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); + if (!sessionWithPicklingKey.has_value()) { + throw std::runtime_error( + "Session with deviceID: " + deviceID + " not initialized."); + } + std::unique_ptr session = + std::move(sessionWithPicklingKey.value().first); + std::string picklingKey = sessionWithPicklingKey.value().second; + crypto::EncryptedData encryptedData = session->encrypt(payload); + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, deviceID, picklingKey, std::move(session)); + return encryptedData; +} + std::unique_ptr NotificationsCryptoModule::statefulDecrypt( const std::string &keyserverID, diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -90,6 +90,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encrypt(rt, args[0].asString(rt), args[1].asString(rt)); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptNotification(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->encryptNotification(rt, args[0].asString(rt), args[1].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptAndPersist(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encryptAndPersist(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt)); } @@ -248,6 +251,7 @@ methodMap_["initializeContentInboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; methodMap_["initializeNotificationsOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; + methodMap_["encryptNotification"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptNotification}; methodMap_["encryptAndPersist"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptAndPersist}; methodMap_["decrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decrypt}; methodMap_["decryptAndPersist"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decryptAndPersist}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -45,6 +45,7 @@ virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) = 0; virtual jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; + virtual jsi::Value encryptNotification(jsi::Runtime &rt, jsi::String payload, jsi::String deviceID) = 0; virtual jsi::Value encryptAndPersist(jsi::Runtime &rt, jsi::String message, jsi::String deviceID, jsi::String messageID) = 0; virtual jsi::Value decrypt(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID) = 0; virtual jsi::Value decryptAndPersist(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID, jsi::String messageID) = 0; @@ -308,6 +309,14 @@ return bridging::callFromJs( rt, &T::encrypt, jsInvoker_, instance_, std::move(message), std::move(deviceID)); } + jsi::Value encryptNotification(jsi::Runtime &rt, jsi::String payload, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::encryptNotification) == 3, + "Expected encryptNotification(...) to have 3 parameters"); + + return bridging::callFromJs( + rt, &T::encryptNotification, jsInvoker_, instance_, std::move(payload), std::move(deviceID)); + } jsi::Value encryptAndPersist(jsi::Runtime &rt, jsi::String message, jsi::String deviceID, jsi::String messageID) override { static_assert( bridging::getParameterCount(&T::encryptAndPersist) == 4, diff --git a/native/push/encrypted-notif-utils-api.js b/native/push/encrypted-notif-utils-api.js --- a/native/push/encrypted-notif-utils-api.js +++ b/native/push/encrypted-notif-utils-api.js @@ -2,7 +2,7 @@ import type { EncryptedNotifUtilsAPI } from 'lib/types/notif-types.js'; -import { commUtilsModule } from '../native-modules.js'; +import { commUtilsModule, commCoreModule } from '../native-modules.js'; const encryptedNotifUtilsAPI: EncryptedNotifUtilsAPI = { encryptSerializedNotifPayload: async ( @@ -13,16 +13,12 @@ type: '1' | '0', ) => boolean, ) => { - // The "mock" implementation below will be replaced with proper - // implementation after olm notif sessions initialization is - // implemented. for now it is actually beneficial to return - // original string as encrypted string since it allows for - // better testing as we can verify which data are encrypted - // and which aren't. + const { message: body, messageType: type } = + await commCoreModule.encryptNotification(unencryptedPayload, cryptoID); return { - encryptedData: { body: unencryptedPayload, type: 1 }, + encryptedData: { body, type }, sizeLimitViolated: encryptedPayloadSizeValidator - ? !encryptedPayloadSizeValidator(unencryptedPayload, '1') + ? !encryptedPayloadSizeValidator(body, type ? '1' : '0') : false, }; }, diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -99,6 +99,10 @@ deviceID: string, ) => Promise; +encrypt: (message: string, deviceID: string) => Promise; + +encryptNotification: ( + payload: string, + deviceID: string, + ) => Promise; +encryptAndPersist: ( message: string, deviceID: string,