diff --git a/lib/types/crypto-types.js b/lib/types/crypto-types.js --- a/lib/types/crypto-types.js +++ b/lib/types/crypto-types.js @@ -117,11 +117,13 @@ export type EncryptedData = { +message: string, +messageType: OlmEncryptedMessageTypes, + +sessionVersion?: number, }; export const encryptedDataValidator: TInterface = tShape({ message: t.String, messageType: t.Number, + sessionVersion: t.maybe(t.Number), }); export type ClientPublicKeys = { diff --git a/lib/utils/olm-utils.js b/lib/utils/olm-utils.js --- a/lib/utils/olm-utils.js +++ b/lib/utils/olm-utils.js @@ -133,6 +133,12 @@ // the corresponding .cpp file // at `native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp`. sessionNotExists: 'SESSION_NOT_EXISTS', + // Error thrown when attempting to decrypt a message encrypted + // with an already replaced old session. + // This definition should remain in sync with the value defined in + // the corresponding .cpp file + // at `native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp`. + invalidSessionVersion: 'INVALID_SESSION_VERSION', }); function hasHigherDeviceID( diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -15,6 +15,7 @@ // This definition should remain in sync with the value defined in // the corresponding JavaScript file at `lib/utils/olm-utils.js`. const std::string SESSION_NOT_EXISTS_ERROR{"SESSION_NOT_EXISTS"}; +const std::string INVALID_SESSION_VERSION_ERROR{"INVALID_SESSION_VERSION"}; CryptoModule::CryptoModule(std::string id) : id{id} { this->createAccount(); @@ -396,7 +397,12 @@ if (!this->hasSessionFor(targetDeviceId)) { throw std::runtime_error{SESSION_NOT_EXISTS_ERROR}; } - return this->sessions.at(targetDeviceId)->decrypt(encryptedData); + auto session = this->sessions.at(targetDeviceId); + if (encryptedData.sessionVersion.has_value() && + encryptedData.sessionVersion.value() < session->getVersion()) { + throw std::runtime_error{INVALID_SESSION_VERSION_ERROR}; + } + return session->decrypt(encryptedData); } std::string CryptoModule::signMessage(const std::string &message) { diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -201,7 +201,7 @@ throw std::runtime_error{ "error encrypt => " + std::string{::olm_session_last_error(session)}}; } - return {encryptedMessage, messageType}; + return {encryptedMessage, messageType, this->getVersion()}; } int Session::getVersion() { diff --git a/native/cpp/CommonCpp/CryptoTools/Tools.h b/native/cpp/CommonCpp/CryptoTools/Tools.h --- a/native/cpp/CommonCpp/CryptoTools/Tools.h +++ b/native/cpp/CommonCpp/CryptoTools/Tools.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -24,6 +25,7 @@ struct EncryptedData { OlmBuffer message; size_t messageType; + std::optional sessionVersion; }; class Tools { diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -867,6 +867,12 @@ encryptedDataJSI.setProperty(rt, "message", messageJSI); encryptedDataJSI.setProperty( rt, "messageType", static_cast(encryptedData.messageType)); + if (encryptedData.sessionVersion.has_value()) { + encryptedDataJSI.setProperty( + rt, + "sessionVersion", + static_cast(encryptedData.sessionVersion.value())); + } return encryptedDataJSI; } @@ -1838,6 +1844,13 @@ std::string message = encryptedDataJSI.getProperty(rt, "message").asString(rt).utf8(rt); auto deviceIDCpp{deviceID.utf8(rt)}; + + std::optional sessionVersion; + if (encryptedDataJSI.hasProperty(rt, "sessionVersion")) { + sessionVersion = std::lround( + encryptedDataJSI.getProperty(rt, "sessionVersion").asNumber()); + } + return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { @@ -1846,7 +1859,8 @@ try { crypto::EncryptedData encryptedData{ std::vector(message.begin(), message.end()), - messageType}; + messageType, + sessionVersion}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); this->persistCryptoModules(true, std::nullopt); @@ -1875,6 +1889,13 @@ std::lround(encryptedDataJSI.getProperty(rt, "messageType").asNumber()); std::string message = encryptedDataJSI.getProperty(rt, "message").asString(rt).utf8(rt); + + std::optional sessionVersion; + if (encryptedDataJSI.hasProperty(rt, "sessionVersion")) { + sessionVersion = std::lround( + encryptedDataJSI.getProperty(rt, "sessionVersion").asNumber()); + } + auto deviceIDCpp{deviceID.utf8(rt)}; auto messageIDCpp{messageID.utf8(rt)}; return createPromiseAsJSIValue( @@ -1885,7 +1906,8 @@ try { crypto::EncryptedData encryptedData{ std::vector(message.begin(), message.end()), - messageType}; + messageType, + sessionVersion}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); diff --git a/web/shared-worker/worker/worker-crypto.js b/web/shared-worker/worker/worker-crypto.js --- a/web/shared-worker/worker/worker-crypto.js +++ b/web/shared-worker/worker/worker-crypto.js @@ -565,6 +565,7 @@ return { message: encryptedContent.body, messageType: encryptedContent.type, + sessionVersion: olmSession.version, }; }, async encryptAndPersist( @@ -594,6 +595,7 @@ const result: EncryptedData = { message: encryptedContent.body, messageType: encryptedContent.type, + sessionVersion: olmSession.version, }; sqliteQueryExecutor.beginTransaction(); @@ -636,6 +638,13 @@ throw new Error(olmSessionErrors.sessionNotExists); } + if ( + encryptedData.sessionVersion && + encryptedData.sessionVersion < olmSession.version + ) { + throw new Error(olmSessionErrors.invalidSessionVersion); + } + const result = olmSession.session.decrypt( encryptedData.messageType, encryptedData.message, @@ -660,6 +669,13 @@ throw new Error(olmSessionErrors.sessionNotExists); } + if ( + encryptedData.sessionVersion && + encryptedData.sessionVersion < olmSession.version + ) { + throw new Error(olmSessionErrors.invalidSessionVersion); + } + const result = olmSession.session.decrypt( encryptedData.messageType, encryptedData.message, @@ -776,6 +792,7 @@ const encryptedData: EncryptedData = { message: initialEncryptedData.body, messageType: initialEncryptedData.type, + sessionVersion: newSessionVersion, }; return { encryptedData, sessionVersion: newSessionVersion };