diff --git a/lib/handlers/peer-to-peer-message-handler.js b/lib/handlers/peer-to-peer-message-handler.js --- a/lib/handlers/peer-to-peer-message-handler.js +++ b/lib/handlers/peer-to-peer-message-handler.js @@ -35,6 +35,7 @@ deviceKeys.identityKeysBlob.primaryIdentityPublicKeys, encryptedData, sessionVersion, + false, ); console.log( 'Created inbound session with device ' + diff --git a/lib/types/crypto-types.js b/lib/types/crypto-types.js --- a/lib/types/crypto-types.js +++ b/lib/types/crypto-types.js @@ -151,6 +151,7 @@ contentIdentityKeys: OLMIdentityKeys, initialEncryptedData: EncryptedData, sessionVersion: number, + overwrite: boolean, ) => Promise, +contentOutboundSessionCreator: ( contentIdentityKeys: OLMIdentityKeys, diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h @@ -59,7 +59,7 @@ const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys, int sessionVersion, - const bool overwrite = true); + const bool overwrite); int initializeOutboundForSendingSession( const std::string &targetDeviceId, const OlmBuffer &idKeys, diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -272,13 +272,7 @@ throw std::runtime_error{"OLM_SESSION_CREATION_RACE_CONDITION"}; } - if (overwrite) { - this->sessions.erase(this->sessions.find(targetDeviceId)); - } else { - throw std::runtime_error{ - "error initializeInboundForReceivingSession => session already " - "initialized"}; - } + this->sessions.erase(this->sessions.find(targetDeviceId)); } std::unique_ptr newSession = Session::createSessionAsResponder( this->getOlmAccount(), diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -133,7 +133,8 @@ jsi::String identityKeys, jsi::Object encryptedDataJSI, jsi::String deviceID, - double sessionVersion) override; + double sessionVersion, + bool overwrite) override; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override; virtual jsi::Value decrypt( diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1218,7 +1218,8 @@ jsi::String identityKeys, jsi::Object encryptedDataJSI, jsi::String deviceID, - double sessionVersion) { + double sessionVersion, + bool overwrite) { auto identityKeysCpp{identityKeys.utf8(rt)}; size_t messageType = std::lround(encryptedDataJSI.getProperty(rt, "messageType").asNumber()); @@ -1237,7 +1238,8 @@ encryptedMessageCpp.begin(), encryptedMessageCpp.end()), std::vector( identityKeysCpp.begin(), identityKeysCpp.end()), - static_cast(sessionVersion)); + static_cast(sessionVersion), + overwrite); crypto::EncryptedData encryptedData{ std::vector( encryptedMessageCpp.begin(), encryptedMessageCpp.end()), diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -109,7 +109,7 @@ return static_cast(&turboModule)->initializeContentOutboundSession(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt), args[4].asString(rt)); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber()); + return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber(), args[4].asBool()); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encrypt(rt, args[0].asString(rt), args[1].asString(rt)); @@ -221,7 +221,7 @@ methodMap_["removeKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeKeyserverDataFromNotifStorage}; methodMap_["getKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getKeyserverDataFromNotifStorage}; methodMap_["initializeContentOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentOutboundSession}; - methodMap_["initializeContentInboundSession"] = MethodMetadata {4, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; + methodMap_["initializeContentInboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; methodMap_["decrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decrypt}; methodMap_["signMessage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_signMessage}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -51,7 +51,7 @@ virtual jsi::Value removeKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDsToDelete) = 0; virtual jsi::Value getKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDs) = 0; virtual jsi::Value initializeContentOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, jsi::String oneTimeKey, jsi::String deviceID) = 0; - virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion) = 0; + virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; virtual jsi::Value decrypt(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID) = 0; virtual jsi::Value signMessage(jsi::Runtime &rt, jsi::String message) = 0; @@ -345,13 +345,13 @@ return bridging::callFromJs( rt, &T::initializeContentOutboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(prekey), std::move(prekeySignature), std::move(oneTimeKey), std::move(deviceID)); } - jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion) override { + jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion, bool overwrite) override { static_assert( - bridging::getParameterCount(&T::initializeContentInboundSession) == 5, - "Expected initializeContentInboundSession(...) to have 5 parameters"); + bridging::getParameterCount(&T::initializeContentInboundSession) == 6, + "Expected initializeContentInboundSession(...) to have 6 parameters"); return bridging::callFromJs( - rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion)); + rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion), std::move(overwrite)); } jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override { static_assert( diff --git a/native/crypto/olm-api.js b/native/crypto/olm-api.js --- a/native/crypto/olm-api.js +++ b/native/crypto/olm-api.js @@ -24,6 +24,7 @@ contentIdentityKeys: OLMIdentityKeys, initialEncryptedData: EncryptedData, sessionVersion: number, + overwrite: boolean, ): Promise { const identityKeys = JSON.stringify({ curve25519: contentIdentityKeys.curve25519, @@ -34,6 +35,7 @@ initialEncryptedData, contentIdentityKeys.ed25519, sessionVersion, + overwrite, ); }, async contentOutboundSessionCreator( diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -117,6 +117,7 @@ encryptedContent: Object, deviceID: string, sessionVersion: number, + overwrite: boolean, ) => Promise; +encrypt: (message: string, deviceID: string) => Promise; +decrypt: (encryptedData: Object, deviceID: string) => Promise; @@ -163,6 +164,7 @@ encryptedContent: EncryptedData, deviceID: string, sessionVersion: number, + overwrite: boolean, ) => Promise; } diff --git a/web/shared-worker/worker/worker-crypto.js b/web/shared-worker/worker/worker-crypto.js --- a/web/shared-worker/worker/worker-crypto.js +++ b/web/shared-worker/worker/worker-crypto.js @@ -31,6 +31,7 @@ shouldForgetPrekey, shouldRotatePrekey, retrieveIdentityKeysAndPrekeys, + olmSessionErrors, } from 'lib/utils/olm-utils.js'; import { getIdentityClient } from './identity-client.js'; @@ -417,6 +418,7 @@ contentIdentityKeys: OLMIdentityKeys, initialEncryptedData: EncryptedData, sessionVersion: number, + overwrite: boolean, ): Promise { if (!cryptoStore) { throw new Error('Crypto account not initialized'); @@ -424,10 +426,12 @@ const { contentAccount, contentSessions } = cryptoStore; const existingSession = contentSessions[contentIdentityKeys.ed25519]; - if (existingSession && existingSession.version > sessionVersion) { - throw new Error('OLM_SESSION_ALREADY_CREATED'); - } else if (existingSession && existingSession.version === sessionVersion) { - throw new Error('OLM_SESSION_CREATION_RACE_CONDITION'); + if (existingSession) { + if (!overwrite && existingSession.version > sessionVersion) { + throw new Error(olmSessionErrors.alreadyCreated); + } else if (!overwrite && existingSession.version === sessionVersion) { + throw new Error(olmSessionErrors.raceCondition); + } } const session = new olm.Session();