diff --git a/lib/types/sqlite-types.js b/lib/types/sqlite-types.js --- a/lib/types/sqlite-types.js +++ b/lib/types/sqlite-types.js @@ -90,6 +90,7 @@ ) => Promise, +resetOutboundP2PMessagesForDevice: ( deviceID: string, + newDeviceID?: ?string, ) => Promise>, +removeOutboundP2PMessage: ( messageID: string, diff --git a/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h b/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h --- a/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h +++ b/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h @@ -176,8 +176,9 @@ virtual void markOutboundP2PMessageAsSent( std::string messageID, std::string deviceID) const = 0; - virtual std::vector - resetOutboundP2PMessagesForDevice(std::string deviceID) const = 0; + virtual std::vector resetOutboundP2PMessagesForDevice( + std::string deviceID, + std::optional newDeviceID) const = 0; virtual void addInboundP2PMessage(InboundP2PMessage message) const = 0; virtual std::vector getAllInboundP2PMessage() const = 0; virtual void diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h @@ -168,8 +168,9 @@ std::string ciphertext) const override; void markOutboundP2PMessageAsSent(std::string messageID, std::string deviceID) const override; - std::vector - resetOutboundP2PMessagesForDevice(std::string deviceID) const override; + std::vector resetOutboundP2PMessagesForDevice( + std::string deviceID, + std::optional newDeviceID) const override; void addInboundP2PMessage(InboundP2PMessage message) const override; std::vector getAllInboundP2PMessage() const override; void diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp @@ -20,6 +20,7 @@ #include "entities/UserInfo.h" #include #include +#include #include const int CONTENT_ACCOUNT_ID = 1; @@ -1361,7 +1362,9 @@ } std::vector SQLiteQueryExecutor::resetOutboundP2PMessagesForDevice( - std::string deviceID) const { + std::string deviceID, + std::optional newDeviceID) const { + // Query all messages that need to be resent - all message that supports // auto retry or already sent messages. std::string queryMessageIDsToResend = @@ -1394,17 +1397,33 @@ std::stringstream resetMessagesSQLStream; resetMessagesSQLStream << "UPDATE outbound_p2p_messages " - << "SET supports_auto_retry = 1, status = 'persisted', ciphertext = '' " - << "WHERE message_id IN " << getSQLStatementArray(messageIDs.size()) - << ";"; + << "SET supports_auto_retry = 1, status = 'persisted', ciphertext = ''"; + + if (newDeviceID.has_value()) { + resetMessagesSQLStream << ", device_id = :new_device_id "; + } else { + resetMessagesSQLStream << " "; + } + + resetMessagesSQLStream << "WHERE message_id IN " + << getSQLStatementArray(messageIDs.size()) << ";"; SQLiteStatementWrapper preparedUpdateSQL( this->getConnection(), resetMessagesSQLStream.str(), "Failed to reset messages."); + int paramIdx = 1; + if (newDeviceID.has_value()) { + paramIdx = + sqlite3_bind_parameter_index(preparedUpdateSQL, ":new_device_id"); + bindStringToSQL(newDeviceID.value().c_str(), preparedUpdateSQL, paramIdx++); + } + for (int i = 0; i < messageIDs.size(); i++) { - int bindResult = bindStringToSQL(messageIDs[i], preparedUpdateSQL, i + 1); + int bindResult = + bindStringToSQL(messageIDs[i], preparedUpdateSQL, i + paramIdx); + if (bindResult != SQLITE_OK) { std::stringstream error_message; error_message << "Failed to bind key to SQL statement. Details: " @@ -1420,17 +1439,35 @@ // session) but not yet queued on Tunnelbroker. In this case, this message // is not considered to be sent (from the UI perspective), // and supports_auto_retry is not updated. - std::string updateCiphertextQuery = - "UPDATE outbound_p2p_messages " - "SET ciphertext = '', status = 'persisted'" - "WHERE device_id = ? " - " AND supports_auto_retry = 0 " - " AND status = 'encrypted';"; + std::stringstream updateCiphertextQuery; + updateCiphertextQuery << "UPDATE outbound_p2p_messages " + << "SET ciphertext = '', status = 'persisted'"; + + if (newDeviceID.has_value()) { + updateCiphertextQuery << ", device_id = :new_device_id "; + } else { + updateCiphertextQuery << " "; + } + + updateCiphertextQuery << "WHERE device_id = :device_id " + << " AND supports_auto_retry = 0 " + << " AND status = 'encrypted';"; SQLiteStatementWrapper preparedUpdateCiphertextSQL( - this->getConnection(), updateCiphertextQuery, "Failed to set ciphertext"); + this->getConnection(), + updateCiphertextQuery.str(), + "Failed to set ciphertext"); + + if (newDeviceID.has_value()) { + paramIdx = sqlite3_bind_parameter_index( + preparedUpdateCiphertextSQL, ":new_device_id"); + bindStringToSQL( + newDeviceID.value().c_str(), preparedUpdateCiphertextSQL, paramIdx); + } - bindStringToSQL(deviceID.c_str(), preparedUpdateCiphertextSQL, 1); + paramIdx = + sqlite3_bind_parameter_index(preparedUpdateCiphertextSQL, ":device_id"); + bindStringToSQL(deviceID.c_str(), preparedUpdateCiphertextSQL, paramIdx); sqlite3_step(preparedUpdateCiphertextSQL); return messageIDs; diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -256,7 +256,8 @@ jsi::String deviceID) override; virtual jsi::Value resetOutboundP2PMessagesForDevice( jsi::Runtime &rt, - jsi::String deviceID) override; + jsi::String deviceID, + std::optional newDeviceID) override; virtual jsi::Value getDatabaseVersion(jsi::Runtime &rt, jsi::String dbID) override; virtual jsi::Value getSyncedMetadata( diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -18,6 +18,7 @@ #include "JSIRust.h" #include "lib.rs.h" #include +#include #include namespace comm { @@ -3007,8 +3008,12 @@ jsi::Value CommCoreModule::resetOutboundP2PMessagesForDevice( jsi::Runtime &rt, - jsi::String deviceID) { + jsi::String deviceID, + std::optional newDeviceID) { std::string deviceIDCpp{deviceID.utf8(rt)}; + std::optional newDeviceIDCpp = newDeviceID.has_value() + ? std::optional{newDeviceID.value().utf8(rt)} + : std::nullopt; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { @@ -3019,7 +3024,8 @@ try { DatabaseManager::getQueryExecutor().beginTransaction(); messageIDs = DatabaseManager::getQueryExecutor() - .resetOutboundP2PMessagesForDevice(deviceIDCpp); + .resetOutboundP2PMessagesForDevice( + deviceIDCpp, newDeviceIDCpp); DatabaseManager::getQueryExecutor().commitTransaction(); } catch (std::system_error &e) { error = e.what(); diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -206,7 +206,7 @@ return static_cast(&turboModule)->removeOutboundP2PMessage(rt, args[0].asString(rt), args[1].asString(rt)); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->resetOutboundP2PMessagesForDevice(rt, args[0].asString(rt)); + return static_cast(&turboModule)->resetOutboundP2PMessagesForDevice(rt, args[0].asString(rt), args[1].isNull() || args[1].isUndefined() ? std::nullopt : std::make_optional(args[1].asString(rt))); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getDatabaseVersion(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->getDatabaseVersion(rt, args[0].asString(rt)); @@ -312,7 +312,7 @@ methodMap_["getUnsentOutboundP2PMessages"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getUnsentOutboundP2PMessages}; methodMap_["markOutboundP2PMessageAsSent"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_markOutboundP2PMessageAsSent}; methodMap_["removeOutboundP2PMessage"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeOutboundP2PMessage}; - methodMap_["resetOutboundP2PMessagesForDevice"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_resetOutboundP2PMessagesForDevice}; + methodMap_["resetOutboundP2PMessagesForDevice"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_resetOutboundP2PMessagesForDevice}; methodMap_["getDatabaseVersion"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getDatabaseVersion}; methodMap_["getSyncedMetadata"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getSyncedMetadata}; methodMap_["markPrekeysAsPublished"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_markPrekeysAsPublished}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -82,7 +82,7 @@ virtual jsi::Value getUnsentOutboundP2PMessages(jsi::Runtime &rt) = 0; virtual jsi::Value markOutboundP2PMessageAsSent(jsi::Runtime &rt, jsi::String messageID, jsi::String deviceID) = 0; virtual jsi::Value removeOutboundP2PMessage(jsi::Runtime &rt, jsi::String messageID, jsi::String deviceID) = 0; - virtual jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID) = 0; + virtual jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID, std::optional newDeviceID) = 0; virtual jsi::Value getDatabaseVersion(jsi::Runtime &rt, jsi::String dbID) = 0; virtual jsi::Value getSyncedMetadata(jsi::Runtime &rt, jsi::String entryName, jsi::String dbID) = 0; virtual jsi::Value markPrekeysAsPublished(jsi::Runtime &rt) = 0; @@ -613,13 +613,13 @@ return bridging::callFromJs( rt, &T::removeOutboundP2PMessage, jsInvoker_, instance_, std::move(messageID), std::move(deviceID)); } - jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID) override { + jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID, std::optional newDeviceID) override { static_assert( - bridging::getParameterCount(&T::resetOutboundP2PMessagesForDevice) == 2, - "Expected resetOutboundP2PMessagesForDevice(...) to have 2 parameters"); + bridging::getParameterCount(&T::resetOutboundP2PMessagesForDevice) == 3, + "Expected resetOutboundP2PMessagesForDevice(...) to have 3 parameters"); return bridging::callFromJs( - rt, &T::resetOutboundP2PMessagesForDevice, jsInvoker_, instance_, std::move(deviceID)); + rt, &T::resetOutboundP2PMessagesForDevice, jsInvoker_, instance_, std::move(deviceID), std::move(newDeviceID)); } jsi::Value getDatabaseVersion(jsi::Runtime &rt, jsi::String dbID) override { static_assert( diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -203,6 +203,7 @@ ) => Promise; +resetOutboundP2PMessagesForDevice: ( deviceID: string, + newDeviceID?: ?string, ) => Promise>; // This type should be DatabaseIdentifier +getDatabaseVersion: (dbID: string) => Promise; diff --git a/web/database/sqlite-api.js b/web/database/sqlite-api.js --- a/web/database/sqlite-api.js +++ b/web/database/sqlite-api.js @@ -206,12 +206,14 @@ async resetOutboundP2PMessagesForDevice( deviceID: string, + newDeviceID?: ?string, ): Promise> { const sharedWorker = await getCommSharedWorker(); const data = await sharedWorker.schedule({ type: workerRequestMessageTypes.RESET_OUTBOUND_P2P_MESSAGES, deviceID, + newDeviceID, }); const messageIDs: ?$ReadOnlyArray = data?.messageIDs; return messageIDs ? [...messageIDs] : []; diff --git a/web/shared-worker/_generated/comm_query_executor.wasm b/web/shared-worker/_generated/comm_query_executor.wasm index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@ { + const deviceID = 'deviceID'; + const newDeviceID = 'newDeviceID'; + const MSG_TO_RESET_1: OutboundP2PMessage = { + messageID: 'reset-1a', + deviceID, + userID: 'user-1', + timestamp: '1', + plaintext: 'decrypted-1', + ciphertext: 'encrypted-1', + status: 'encrypted', + supportsAutoRetry: true, + }; + + const MSG_TO_RESET_2: OutboundP2PMessage = { + messageID: 'reset-2a', + deviceID, + userID: 'user-1', + timestamp: '1', + plaintext: 'decrypted-1', + ciphertext: 'encrypted-1', + status: 'sent', + supportsAutoRetry: false, + }; + + const MSG_NOT_RESET: OutboundP2PMessage = { + messageID: 'reset-3a', + deviceID, + userID: 'user-1', + timestamp: '3', + plaintext: 'decrypted-1', + ciphertext: 'encrypted-1', + status: 'encrypted', + supportsAutoRetry: false, + }; + + queryExecutor?.addOutboundP2PMessages([ + MSG_TO_RESET_1, + MSG_TO_RESET_2, + MSG_NOT_RESET, + ]); + + const messageIDs = queryExecutor?.resetOutboundP2PMessagesForDevice( + deviceID, + newDeviceID, + ); + + expect(messageIDs).toEqual([ + MSG_TO_RESET_1.messageID, + MSG_TO_RESET_2.messageID, + ]); + + const expectedMessagesAfterReset = [ + { + ...MSG_TO_RESET_1, + deviceID: newDeviceID, + status: outboundP2PMessageStatuses.persisted, + ciphertext: '', + supportsAutoRetry: true, + }, + { + ...MSG_TO_RESET_2, + deviceID: newDeviceID, + status: outboundP2PMessageStatuses.persisted, + ciphertext: '', + supportsAutoRetry: true, + }, + ]; + expect(queryExecutor?.getOutboundP2PMessagesByID(messageIDs ?? [])).toEqual( + expectedMessagesAfterReset, + ); + expect( + queryExecutor?.getOutboundP2PMessagesByID([MSG_NOT_RESET.messageID]), + ).toEqual([ + { + ...MSG_NOT_RESET, + deviceID: newDeviceID, + ciphertext: '', + status: outboundP2PMessageStatuses.persisted, + }, + ]); + }); }); diff --git a/web/shared-worker/types/sqlite-query-executor.js b/web/shared-worker/types/sqlite-query-executor.js --- a/web/shared-worker/types/sqlite-query-executor.js +++ b/web/shared-worker/types/sqlite-query-executor.js @@ -198,7 +198,10 @@ ciphertext: string, ): void; markOutboundP2PMessageAsSent(messageID: string, deviceID: string): void; - resetOutboundP2PMessagesForDevice(deviceID: string): $ReadOnlyArray; + resetOutboundP2PMessagesForDevice( + deviceID: string, + newDeviceID?: ?string, + ): $ReadOnlyArray; addInboundP2PMessage(message: InboundP2PMessage): void; getAllInboundP2PMessage(): $ReadOnlyArray; diff --git a/web/types/worker-types.js b/web/types/worker-types.js --- a/web/types/worker-types.js +++ b/web/types/worker-types.js @@ -234,6 +234,7 @@ export type ResetOutboundP2PMessagesRequestMessage = { +type: 24, +deviceID: string, + +newDeviceID?: ?string, }; export type FetchMessagesRequestMessage = {