diff --git a/lib/types/crypto-types.js b/lib/types/crypto-types.js --- a/lib/types/crypto-types.js +++ b/lib/types/crypto-types.js @@ -167,6 +167,7 @@ contentIdentityKeys: OLMIdentityKeys, contentInitializationInfo: OlmSessionInitializationInfo, ) => Promise, + +isContentSessionInitialized: (deviceID: string) => Promise, +notificationsSessionCreator: ( cookie: ?string, notificationsIdentityKeys: OLMIdentityKeys, @@ -178,6 +179,9 @@ contentIdentityKeys: OLMIdentityKeys, notificationsInitializationInfo: OlmSessionInitializationInfo, ) => Promise, + +isPeerNotificationsSessionInitialized: ( + deviceID: string, + ) => Promise, +reassignNotificationsSession?: ( prevCookie: ?string, newCookie: ?string, diff --git a/lib/utils/__mocks__/config.js b/lib/utils/__mocks__/config.js --- a/lib/utils/__mocks__/config.js +++ b/lib/utils/__mocks__/config.js @@ -23,6 +23,8 @@ contentOutboundSessionCreator: jest.fn(), notificationsSessionCreator: jest.fn(), notificationsOutboundSessionCreator: jest.fn(), + isContentSessionInitialized: jest.fn(), + isPeerNotificationsSessionInitialized: jest.fn(), getOneTimeKeys: jest.fn(), validateAndUploadPrekeys: jest.fn(), signMessage: jest.fn(), diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -107,6 +107,9 @@ jsi::String keyserverID) override; virtual jsi::Value isNotificationsSessionInitialized(jsi::Runtime &rt) override; + virtual jsi::Value isPeerNotificationsSessionInitialized( + jsi::Runtime &rt, + jsi::String deviceID) override; virtual jsi::Value updateKeyserverDataInNotifStorage( jsi::Runtime &rt, jsi::Array keyserversData) override; @@ -130,6 +133,8 @@ jsi::String deviceID, double sessionVersion, bool overwrite) override; + virtual jsi::Value + isContentSessionInitialized(jsi::Runtime &rt, jsi::String deviceID) override; virtual jsi::Value initializeNotificationsOutboundSession( jsi::Runtime &rt, jsi::String identityKeys, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1184,6 +1184,33 @@ }); } +jsi::Value CommCoreModule::isPeerNotificationsSessionInitialized( + jsi::Runtime &rt, + jsi::String deviceID) { + auto deviceIDCpp{deviceID.utf8(rt)}; + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=, &innerRt]() { + std::string error; + bool result; + try { + result = NotificationsCryptoModule:: + isPeerNotificationsSessionInitialized(deviceIDCpp); + } catch (const std::exception &e) { + error = e.what(); + } + this->jsInvoker_->invokeAsync([=, &innerRt]() { + if (error.size()) { + promise->reject(error); + return; + } + promise->resolve(result); + }); + }; + this->cryptoThread->scheduleTask(job); + }); +} + jsi::Value CommCoreModule::updateKeyserverDataInNotifStorage( jsi::Runtime &rt, jsi::Array keyserversData) { @@ -1439,6 +1466,42 @@ }); } +jsi::Value CommCoreModule::isContentSessionInitialized( + jsi::Runtime &rt, + jsi::String deviceID) { + auto deviceIDCpp{deviceID.utf8(rt)}; + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=, &innerRt]() { + std::string error; + bool result; + + if (this->contentCryptoModule == nullptr || + this->notifsCryptoModule == nullptr) { + this->jsInvoker_->invokeAsync([=, &innerRt]() { + promise->reject("user has not been initialized"); + }); + return; + } + + try { + result = this->contentCryptoModule->hasSessionFor(deviceIDCpp); + } catch (const std::exception &e) { + error = e.what(); + } + + this->jsInvoker_->invokeAsync([=, &innerRt]() { + if (error.size()) { + promise->reject(error); + return; + } + promise->resolve(result); + }); + }; + this->cryptoThread->scheduleTask(job); + }); +} + jsi::Value CommCoreModule::initializeNotificationsOutboundSession( jsi::Runtime &rt, jsi::String identityKeys, diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -69,6 +69,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isNotificationsSessionInitialized(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->isNotificationsSessionInitialized(rt); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isPeerNotificationsSessionInitialized(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->isPeerNotificationsSessionInitialized(rt, args[0].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_updateKeyserverDataInNotifStorage(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->updateKeyserverDataInNotifStorage(rt, args[0].asObject(rt).asArray(rt)); } @@ -84,6 +87,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber(), args[4].asBool()); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isContentSessionInitialized(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->isContentSessionInitialized(rt, args[0].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->initializeNotificationsOutboundSession(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].isNull() || args[3].isUndefined() ? std::nullopt : std::make_optional(args[3].asString(rt)), args[4].asString(rt)); } @@ -235,11 +241,13 @@ methodMap_["validateAndUploadPrekeys"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_validateAndUploadPrekeys}; methodMap_["initializeNotificationsSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsSession}; methodMap_["isNotificationsSessionInitialized"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isNotificationsSessionInitialized}; + methodMap_["isPeerNotificationsSessionInitialized"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isPeerNotificationsSessionInitialized}; methodMap_["updateKeyserverDataInNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_updateKeyserverDataInNotifStorage}; methodMap_["removeKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeKeyserverDataFromNotifStorage}; methodMap_["getKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getKeyserverDataFromNotifStorage}; methodMap_["initializeContentOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentOutboundSession}; methodMap_["initializeContentInboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; + methodMap_["isContentSessionInitialized"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_isContentSessionInitialized}; methodMap_["initializeNotificationsOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeNotificationsOutboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; methodMap_["encryptNotification"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encryptNotification}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -38,11 +38,13 @@ virtual jsi::Value validateAndUploadPrekeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken) = 0; virtual jsi::Value initializeNotificationsSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String keyserverID) = 0; virtual jsi::Value isNotificationsSessionInitialized(jsi::Runtime &rt) = 0; + virtual jsi::Value isPeerNotificationsSessionInitialized(jsi::Runtime &rt, jsi::String deviceID) = 0; virtual jsi::Value updateKeyserverDataInNotifStorage(jsi::Runtime &rt, jsi::Array keyserversData) = 0; virtual jsi::Value removeKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDsToDelete) = 0; virtual jsi::Value getKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDs) = 0; virtual jsi::Value initializeContentOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) = 0; + virtual jsi::Value isContentSessionInitialized(jsi::Runtime &rt, jsi::String deviceID) = 0; virtual jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; virtual jsi::Value encryptNotification(jsi::Runtime &rt, jsi::String payload, jsi::String deviceID) = 0; @@ -250,6 +252,14 @@ return bridging::callFromJs( rt, &T::isNotificationsSessionInitialized, jsInvoker_, instance_); } + jsi::Value isPeerNotificationsSessionInitialized(jsi::Runtime &rt, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::isPeerNotificationsSessionInitialized) == 2, + "Expected isPeerNotificationsSessionInitialized(...) to have 2 parameters"); + + return bridging::callFromJs( + rt, &T::isPeerNotificationsSessionInitialized, jsInvoker_, instance_, std::move(deviceID)); + } jsi::Value updateKeyserverDataInNotifStorage(jsi::Runtime &rt, jsi::Array keyserversData) override { static_assert( bridging::getParameterCount(&T::updateKeyserverDataInNotifStorage) == 2, @@ -290,6 +300,14 @@ return bridging::callFromJs( rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion), std::move(overwrite)); } + jsi::Value isContentSessionInitialized(jsi::Runtime &rt, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::isContentSessionInitialized) == 2, + "Expected isContentSessionInitialized(...) to have 2 parameters"); + + return bridging::callFromJs( + rt, &T::isContentSessionInitialized, jsInvoker_, instance_, std::move(deviceID)); + } jsi::Value initializeNotificationsOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, std::optional oneTimeKey, jsi::String deviceID) override { static_assert( bridging::getParameterCount(&T::initializeNotificationsOutboundSession) == 6, diff --git a/native/crypto/olm-api.js b/native/crypto/olm-api.js --- a/native/crypto/olm-api.js +++ b/native/crypto/olm-api.js @@ -41,6 +41,7 @@ overwrite, ); }, + isContentSessionInitialized: commCoreModule.isContentSessionInitialized, async contentOutboundSessionCreator( contentIdentityKeys: OLMIdentityKeys, contentInitializationInfo: OlmSessionInitializationInfo, @@ -94,6 +95,8 @@ contentIdentityKeys.ed25519, ); }, + isPeerNotificationsSessionInitialized: + commCoreModule.isPeerNotificationsSessionInitialized, async getOneTimeKeys(numberOfKeys: number): Promise { const { contentOneTimeKeys, notificationsOneTimeKeys } = await commCoreModule.getOneTimeKeys(numberOfKeys); diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -68,6 +68,9 @@ keyserverID: string, ) => Promise; +isNotificationsSessionInitialized: () => Promise; + +isPeerNotificationsSessionInitialized: ( + deviceID: string, + ) => Promise; +updateKeyserverDataInNotifStorage: ( keyserversData: $ReadOnlyArray<{ +id: string, +unreadCount: number }>, ) => Promise; @@ -91,6 +94,7 @@ sessionVersion: number, overwrite: boolean, ) => Promise; + +isContentSessionInitialized: (deviceID: string) => Promise; +initializeNotificationsOutboundSession: ( identityKeys: string, prekey: string, diff --git a/web/crypto/olm-api.js b/web/crypto/olm-api.js --- a/web/crypto/olm-api.js +++ b/web/crypto/olm-api.js @@ -50,6 +50,10 @@ decryptSequentialAndPersist: proxyToWorker('decryptSequentialAndPersist'), contentInboundSessionCreator: proxyToWorker('contentInboundSessionCreator'), contentOutboundSessionCreator: proxyToWorker('contentOutboundSessionCreator'), + isContentSessionInitialized: proxyToWorker('isContentSessionInitialized'), + isPeerNotificationsSessionInitialized: proxyToWorker( + 'isPeerNotificationsSessionInitialized', + ), notificationsSessionCreator: proxyToWorker('notificationsSessionCreator'), notificationsOutboundSessionCreator: proxyToWorker( 'notificationsOutboundSessionCreator', diff --git a/web/shared-worker/worker/worker-crypto.js b/web/shared-worker/worker/worker-crypto.js --- a/web/shared-worker/worker/worker-crypto.js +++ b/web/shared-worker/worker/worker-crypto.js @@ -720,6 +720,12 @@ return { encryptedData, sessionVersion: newSessionVersion }; }, + async isContentSessionInitialized(deviceID: string) { + if (!cryptoStore) { + throw new Error('Crypto account not initialized'); + } + return !!cryptoStore.contentSessions[deviceID]; + }, async notificationsOutboundSessionCreator( notificationsIdentityKeys: OLMIdentityKeys, contentIdentityKeys: OLMIdentityKeys, @@ -731,6 +737,7 @@ const dataEncryptionKeyDBLabel = getOlmEncryptionKeyDBLabelForDeviceID( contentIdentityKeys.ed25519, ); + return createAndPersistNotificationsOutboundSession( notificationsIdentityKeys, notificationsInitializationInfo, @@ -738,6 +745,18 @@ dataEncryptionKeyDBLabel, ); }, + async isPeerNotificationsSessionInitialized(deviceID: string) { + const dataPersistenceKey = getOlmDataKeyForDeviceID(deviceID); + const dataEncryptionKeyDBLabel = + getOlmEncryptionKeyDBLabelForDeviceID(deviceID); + + const allKeys = await localforage.keys(); + const allKeysSet = new Set(allKeys); + return ( + allKeysSet.has(dataPersistenceKey) && + allKeysSet.has(dataEncryptionKeyDBLabel) + ); + }, async notificationsSessionCreator( cookie: ?string, notificationsIdentityKeys: OLMIdentityKeys,