diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h @@ -58,8 +58,9 @@ const std::string &targetDeviceId, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys, + int sessionVersion, const bool overwrite = true); - void initializeOutboundForSendingSession( + int initializeOutboundForSendingSession( const std::string &targetDeviceId, const OlmBuffer &idKeys, const OlmBuffer &preKeys, diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -261,8 +261,17 @@ const std::string &targetDeviceId, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys, + int sessionVersion, const bool overwrite) { if (this->hasSessionFor(targetDeviceId)) { + std::shared_ptr existingSession = + getSessionByDeviceId(targetDeviceId); + if (existingSession->version > sessionVersion) { + throw std::runtime_error{"OLM_SESSION_ALREADY_CREATED"}; + } else if (existingSession->version == sessionVersion) { + throw std::runtime_error{"OLM_SESSION_CREATION_RACE_CONDITION"}; + } + if (overwrite) { this->sessions.erase(this->sessions.find(targetDeviceId)); } else { @@ -276,16 +285,21 @@ this->keys.identityKeys.data(), encryptedMessage, idKeys); + newSession->version = sessionVersion; this->sessions.insert(make_pair(targetDeviceId, std::move(newSession))); } -void CryptoModule::initializeOutboundForSendingSession( +int CryptoModule::initializeOutboundForSendingSession( const std::string &targetDeviceId, const OlmBuffer &idKeys, const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const OlmBuffer &oneTimeKey) { + int newSessionVersion = 1; if (this->hasSessionFor(targetDeviceId)) { + std::shared_ptr existingSession = + getSessionByDeviceId(targetDeviceId); + newSessionVersion = existingSession->version + 1; Logger::log( "olm session overwritten for the device with id: " + targetDeviceId); this->sessions.erase(this->sessions.find(targetDeviceId)); @@ -297,7 +311,9 @@ preKeys, preKeySignature, oneTimeKey); + newSession->version = newSessionVersion; this->sessions.insert(make_pair(targetDeviceId, std::move(newSession))); + return newSessionVersion; } bool CryptoModule::hasSessionFor(const std::string &targetDeviceId) { diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -132,7 +132,8 @@ jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedDataJSI, - jsi::String deviceID) override; + jsi::String deviceID, + double sessionVersion) override; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override; virtual jsi::Value decrypt( diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1163,16 +1163,18 @@ taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData initialEncryptedData; + int sessionVersion; try { - this->contentCryptoModule->initializeOutboundForSendingSession( - deviceIDCpp, - std::vector( - identityKeysCpp.begin(), identityKeysCpp.end()), - std::vector(prekeyCpp.begin(), prekeyCpp.end()), - std::vector( - prekeySignatureCpp.begin(), prekeySignatureCpp.end()), - std::vector( - oneTimeKeyCpp.begin(), oneTimeKeyCpp.end())); + sessionVersion = + this->contentCryptoModule->initializeOutboundForSendingSession( + deviceIDCpp, + std::vector( + identityKeysCpp.begin(), identityKeysCpp.end()), + std::vector(prekeyCpp.begin(), prekeyCpp.end()), + std::vector( + prekeySignatureCpp.begin(), prekeySignatureCpp.end()), + std::vector( + oneTimeKeyCpp.begin(), oneTimeKeyCpp.end())); const std::string initMessage = "{\"type\": \"init\"}"; initialEncryptedData = @@ -1198,7 +1200,13 @@ "messageType", static_cast(initialEncryptedData.messageType)); - promise->resolve(std::move(initialEncryptedDataJSI)); + auto outboundSessionCreationResultJSI = jsi::Object(innerRt); + outboundSessionCreationResultJSI.setProperty( + innerRt, "encryptedData", initialEncryptedDataJSI); + outboundSessionCreationResultJSI.setProperty( + innerRt, "sessionVersion", sessionVersion); + + promise->resolve(std::move(outboundSessionCreationResultJSI)); }); }; this->cryptoThread->scheduleTask(job); @@ -1209,7 +1217,8 @@ jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedDataJSI, - jsi::String deviceID) { + jsi::String deviceID, + double sessionVersion) { auto identityKeysCpp{identityKeys.utf8(rt)}; size_t messageType = std::lround(encryptedDataJSI.getProperty(rt, "messageType").asNumber()); @@ -1227,7 +1236,8 @@ std::vector( encryptedMessageCpp.begin(), encryptedMessageCpp.end()), std::vector( - identityKeysCpp.begin(), identityKeysCpp.end())); + identityKeysCpp.begin(), identityKeysCpp.end()), + static_cast(sessionVersion)); crypto::EncryptedData encryptedData{ std::vector( encryptedMessageCpp.begin(), encryptedMessageCpp.end()), diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -109,7 +109,7 @@ return static_cast(&turboModule)->initializeContentOutboundSession(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt), args[4].asString(rt)); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt)); + return static_cast(&turboModule)->initializeContentInboundSession(rt, args[0].asString(rt), args[1].asObject(rt), args[2].asString(rt), args[3].asNumber()); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->encrypt(rt, args[0].asString(rt), args[1].asString(rt)); @@ -221,7 +221,7 @@ methodMap_["removeKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeKeyserverDataFromNotifStorage}; methodMap_["getKeyserverDataFromNotifStorage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getKeyserverDataFromNotifStorage}; methodMap_["initializeContentOutboundSession"] = MethodMetadata {5, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentOutboundSession}; - methodMap_["initializeContentInboundSession"] = MethodMetadata {3, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; + methodMap_["initializeContentInboundSession"] = MethodMetadata {4, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_initializeContentInboundSession}; methodMap_["encrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_encrypt}; methodMap_["decrypt"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_decrypt}; methodMap_["signMessage"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_signMessage}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -51,7 +51,7 @@ virtual jsi::Value removeKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDsToDelete) = 0; virtual jsi::Value getKeyserverDataFromNotifStorage(jsi::Runtime &rt, jsi::Array keyserverIDs) = 0; virtual jsi::Value initializeContentOutboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, jsi::String oneTimeKey, jsi::String deviceID) = 0; - virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID) = 0; + virtual jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion) = 0; virtual jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) = 0; virtual jsi::Value decrypt(jsi::Runtime &rt, jsi::Object encryptedData, jsi::String deviceID) = 0; virtual jsi::Value signMessage(jsi::Runtime &rt, jsi::String message) = 0; @@ -345,13 +345,13 @@ return bridging::callFromJs( rt, &T::initializeContentOutboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(prekey), std::move(prekeySignature), std::move(oneTimeKey), std::move(deviceID)); } - jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID) override { + jsi::Value initializeContentInboundSession(jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedContent, jsi::String deviceID, double sessionVersion) override { static_assert( - bridging::getParameterCount(&T::initializeContentInboundSession) == 4, - "Expected initializeContentInboundSession(...) to have 4 parameters"); + bridging::getParameterCount(&T::initializeContentInboundSession) == 5, + "Expected initializeContentInboundSession(...) to have 5 parameters"); return bridging::callFromJs( - rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID)); + rt, &T::initializeContentInboundSession, jsInvoker_, instance_, std::move(identityKeys), std::move(encryptedContent), std::move(deviceID), std::move(sessionVersion)); } jsi::Value encrypt(jsi::Runtime &rt, jsi::String message, jsi::String deviceID) override { static_assert( diff --git a/native/crypto/olm-api.js b/native/crypto/olm-api.js --- a/native/crypto/olm-api.js +++ b/native/crypto/olm-api.js @@ -7,6 +7,7 @@ type OlmAPI, type OLMIdentityKeys, type EncryptedData, + type OutboundSessionCreationResult, } from 'lib/types/crypto-types.js'; import type { OlmSessionInitializationInfo } from 'lib/types/request-types.js'; @@ -19,10 +20,10 @@ getUserPublicKey: commCoreModule.getUserPublicKey, encrypt: commCoreModule.encrypt, decrypt: commCoreModule.decrypt, - // $FlowFixMe async contentInboundSessionCreator( contentIdentityKeys: OLMIdentityKeys, initialEncryptedData: EncryptedData, + sessionVersion: number, ): Promise { const identityKeys = JSON.stringify({ curve25519: contentIdentityKeys.curve25519, @@ -32,13 +33,13 @@ identityKeys, initialEncryptedData, contentIdentityKeys.ed25519, + sessionVersion, ); }, async contentOutboundSessionCreator( contentIdentityKeys: OLMIdentityKeys, contentInitializationInfo: OlmSessionInitializationInfo, - // $FlowFixMe - ): Promise { + ): Promise { const { prekey, prekeySignature, oneTimeKey } = contentInitializationInfo; const identityKeys = JSON.stringify({ curve25519: contentIdentityKeys.curve25519, diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -19,6 +19,7 @@ SignedPrekeys, ClientPublicKeys, EncryptedData, + OutboundSessionCreationResult, } from 'lib/types/crypto-types.js'; import type { ClientDBDraftStoreOperation } from 'lib/types/draft-types.js'; import type { ClientDBMessageInfo } from 'lib/types/message-types.js'; @@ -110,11 +111,12 @@ prekeySignature: string, oneTimeKey: string, deviceID: string, - ) => Promise; + ) => Promise; +initializeContentInboundSession: ( identityKeys: string, encryptedContent: Object, deviceID: string, + sessionVersion: number, ) => Promise; +encrypt: (message: string, deviceID: string) => Promise; +decrypt: (encryptedData: Object, deviceID: string) => Promise; @@ -160,6 +162,7 @@ identityKeys: string, encryptedContent: EncryptedData, deviceID: string, + sessionVersion: number, ) => Promise; }