diff --git a/lib/hooks/peer-list-hooks.js b/lib/hooks/peer-list-hooks.js --- a/lib/hooks/peer-list-hooks.js +++ b/lib/hooks/peer-list-hooks.js @@ -5,7 +5,6 @@ import { setPeerDeviceListsActionType } from '../actions/aux-user-actions.js'; import { - getAllPeerDevices, getAllPeerUserIDAndDeviceIDs, getPeersPrimaryDeviceIDs, } from '../selectors/user-selectors.js'; @@ -81,7 +80,7 @@ const dispatch = useDispatch(); const broadcastDeviceListUpdates = useBroadcastDeviceListUpdates(); - const allPeerDevices = useSelector(getAllPeerDevices); + const allPeerDevices = useSelector(getAllPeerUserIDAndDeviceIDs); const peerPrimaryDevices = useSelector(getPeersPrimaryDeviceIDs); return React.useCallback( @@ -103,21 +102,34 @@ return {}; } - const primaryDeviceChanges = userIDs - .map(userID => { - const prevPrimaryDeviceID = peerPrimaryDevices[userID]; - const newPrimaryDeviceID = result.deviceLists[userID]?.devices[0]; - if ( - !prevPrimaryDeviceID || - !newPrimaryDeviceID || - newPrimaryDeviceID === prevPrimaryDeviceID - ) { - return null; - } - - return { userID, prevPrimaryDeviceID, newPrimaryDeviceID }; - }) - .filter(Boolean); + const primaryDeviceChanges: Array = []; + const allRemovedDevices: Array = []; + for (const userID of userIDs) { + const peerDeviceList = result.deviceLists[userID]?.devices ?? []; + + // detect primary device changes + const prevPrimaryDeviceID = peerPrimaryDevices[userID]; + const newPrimaryDeviceID = peerDeviceList[0]; + if ( + !!prevPrimaryDeviceID && + !!newPrimaryDeviceID && + newPrimaryDeviceID !== prevPrimaryDeviceID + ) { + primaryDeviceChanges.push({ + userID, + prevPrimaryDeviceID, + newPrimaryDeviceID, + }); + } + + // detect removed devices + const currentPeerDevicesSet = new Set(peerDeviceList); + const peerRemovedDevices = allPeerDevices.filter( + peer => + peer.userID === userID && !currentPeerDevicesSet.has(peer.deviceID), + ); + allRemovedDevices.concat(...peerRemovedDevices); + } if (primaryDeviceChanges.length > 0) { try { @@ -127,6 +139,21 @@ } } + if (allRemovedDevices.length > 0) { + try { + const { sqliteAPI } = getConfig(); + const removalPromises = allRemovedDevices.map(deviceID => + sqliteAPI.removeAllOutboundP2PMessages(deviceID), + ); + await Promise.all(removalPromises); + } catch (err) { + console.warn( + 'Failed to clear outbound P2P messages for removed devices:', + err, + ); + } + } + dispatch({ type: setPeerDeviceListsActionType, payload: { deviceLists, usersPlatformDetails }, @@ -138,12 +165,14 @@ const thisDeviceID = await getContentSigningKey(); + const allPeerDeviceIDs = new Set( + allPeerDevices.map(peer => peer.deviceID), + ); const newDevices = values(deviceLists) - .map((deviceList: RawDeviceList) => deviceList.devices) - .flat() + .flatMap((deviceList: RawDeviceList) => deviceList.devices) .filter( deviceID => - !allPeerDevices.includes(deviceID) && deviceID !== thisDeviceID, + !allPeerDeviceIDs.has(deviceID) && deviceID !== thisDeviceID, ); await broadcastDeviceListUpdates(newDevices); diff --git a/lib/types/sqlite-types.js b/lib/types/sqlite-types.js --- a/lib/types/sqlite-types.js +++ b/lib/types/sqlite-types.js @@ -96,6 +96,7 @@ messageID: string, deviceID: string, ) => Promise, + +removeAllOutboundP2PMessages: (deviceID: string) => Promise, +removeLocalMessageInfos: ( includeNonLocalMessages: boolean, dbID: DatabaseIdentifier, diff --git a/lib/utils/__mocks__/config.js b/lib/utils/__mocks__/config.js --- a/lib/utils/__mocks__/config.js +++ b/lib/utils/__mocks__/config.js @@ -41,6 +41,7 @@ getUnsentOutboundP2PMessages: jest.fn(), markOutboundP2PMessageAsSent: jest.fn(), removeOutboundP2PMessage: jest.fn(), + removeAllOutboundP2PMessages: jest.fn(), resetOutboundP2PMessagesForDevice: jest.fn(), getRelatedMessages: jest.fn(), getOutboundP2PMessagesByID: jest.fn(), diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -254,6 +254,8 @@ jsi::Runtime &rt, jsi::String messageID, jsi::String deviceID) override; + virtual jsi::Value + removeAllOutboundP2PMessages(jsi::Runtime &rt, jsi::String deviceID) override; virtual jsi::Value resetOutboundP2PMessagesForDevice( jsi::Runtime &rt, jsi::String deviceID, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -3006,6 +3006,34 @@ }); } +jsi::Value CommCoreModule::removeAllOutboundP2PMessages( + jsi::Runtime &rt, + jsi::String deviceID) { + auto deviceIDCpp{deviceID.utf8(rt)}; + + return createPromiseAsJSIValue( + rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { + taskType job = [=]() { + std::string error; + try { + DatabaseManager::getQueryExecutor().removeAllOutboundP2PMessages( + deviceIDCpp); + } catch (std::system_error &e) { + error = e.what(); + } + this->jsInvoker_->invokeAsync([error, promise]() { + if (error.size()) { + promise->reject(error); + } else { + promise->resolve(jsi::Value::undefined()); + } + }); + }; + GlobalDBSingleton::instance.scheduleOrRunCancellable( + job, promise, this->jsInvoker_); + }); +} + jsi::Value CommCoreModule::resetOutboundP2PMessagesForDevice( jsi::Runtime &rt, jsi::String deviceID, diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -205,6 +205,9 @@ static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeOutboundP2PMessage(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->removeOutboundP2PMessage(rt, args[0].asString(rt), args[1].asString(rt)); } +static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeAllOutboundP2PMessages(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->removeAllOutboundP2PMessages(rt, args[0].asString(rt)); +} static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->resetOutboundP2PMessagesForDevice(rt, args[0].asString(rt), args[1].isNull() || args[1].isUndefined() ? std::nullopt : std::make_optional(args[1].asString(rt))); } @@ -312,6 +315,7 @@ methodMap_["getUnsentOutboundP2PMessages"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getUnsentOutboundP2PMessages}; methodMap_["markOutboundP2PMessageAsSent"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_markOutboundP2PMessageAsSent}; methodMap_["removeOutboundP2PMessage"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeOutboundP2PMessage}; + methodMap_["removeAllOutboundP2PMessages"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_removeAllOutboundP2PMessages}; methodMap_["resetOutboundP2PMessagesForDevice"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_resetOutboundP2PMessagesForDevice}; methodMap_["getDatabaseVersion"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getDatabaseVersion}; methodMap_["getSyncedMetadata"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getSyncedMetadata}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -82,6 +82,7 @@ virtual jsi::Value getUnsentOutboundP2PMessages(jsi::Runtime &rt) = 0; virtual jsi::Value markOutboundP2PMessageAsSent(jsi::Runtime &rt, jsi::String messageID, jsi::String deviceID) = 0; virtual jsi::Value removeOutboundP2PMessage(jsi::Runtime &rt, jsi::String messageID, jsi::String deviceID) = 0; + virtual jsi::Value removeAllOutboundP2PMessages(jsi::Runtime &rt, jsi::String deviceID) = 0; virtual jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID, std::optional newDeviceID) = 0; virtual jsi::Value getDatabaseVersion(jsi::Runtime &rt, jsi::String dbID) = 0; virtual jsi::Value getSyncedMetadata(jsi::Runtime &rt, jsi::String entryName, jsi::String dbID) = 0; @@ -613,6 +614,14 @@ return bridging::callFromJs( rt, &T::removeOutboundP2PMessage, jsInvoker_, instance_, std::move(messageID), std::move(deviceID)); } + jsi::Value removeAllOutboundP2PMessages(jsi::Runtime &rt, jsi::String deviceID) override { + static_assert( + bridging::getParameterCount(&T::removeAllOutboundP2PMessages) == 2, + "Expected removeAllOutboundP2PMessages(...) to have 2 parameters"); + + return bridging::callFromJs( + rt, &T::removeAllOutboundP2PMessages, jsInvoker_, instance_, std::move(deviceID)); + } jsi::Value resetOutboundP2PMessagesForDevice(jsi::Runtime &rt, jsi::String deviceID, std::optional newDeviceID) override { static_assert( bridging::getParameterCount(&T::resetOutboundP2PMessagesForDevice) == 3, diff --git a/native/database/sqlite-api.js b/native/database/sqlite-api.js --- a/native/database/sqlite-api.js +++ b/native/database/sqlite-api.js @@ -44,6 +44,7 @@ resetOutboundP2PMessagesForDevice: commCoreModule.resetOutboundP2PMessagesForDevice, removeOutboundP2PMessage: commCoreModule.removeOutboundP2PMessage, + removeAllOutboundP2PMessages: commCoreModule.removeAllOutboundP2PMessages, async processDBStoreOperations( storeOperations: StoreOperations, diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -201,6 +201,7 @@ messageID: string, deviceID: string, ) => Promise; + +removeAllOutboundP2PMessages: (deviceID: string) => Promise; +resetOutboundP2PMessagesForDevice: ( deviceID: string, newDeviceID?: ?string, diff --git a/web/database/sqlite-api.js b/web/database/sqlite-api.js --- a/web/database/sqlite-api.js +++ b/web/database/sqlite-api.js @@ -232,6 +232,15 @@ }); }, + async removeAllOutboundP2PMessages(deviceID: string) { + const sharedWorker = await getCommSharedWorker(); + + await sharedWorker.schedule({ + type: workerRequestMessageTypes.REMOVE_ALL_OUTBOUND_P2P_MESSAGES, + deviceID, + }); + }, + async processDBStoreOperations( storeOperations: StoreOperations, dbID: DatabaseIdentifier, diff --git a/web/shared-worker/worker/shared-worker.js b/web/shared-worker/worker/shared-worker.js --- a/web/shared-worker/worker/shared-worker.js +++ b/web/shared-worker/worker/shared-worker.js @@ -588,6 +588,10 @@ message.messageID, message.deviceID, ); + } else if ( + message.type === workerRequestMessageTypes.REMOVE_ALL_OUTBOUND_P2P_MESSAGES + ) { + sqliteQueryExecutor.removeAllOutboundP2PMessages(message.deviceID); } else if ( message.type === workerRequestMessageTypes.RESET_OUTBOUND_P2P_MESSAGES ) { diff --git a/web/types/worker-types.js b/web/types/worker-types.js --- a/web/types/worker-types.js +++ b/web/types/worker-types.js @@ -62,6 +62,7 @@ GET_HOLDERS: 32, REMOVE_LOCAL_MESSAGE_INFOS: 33, GET_AUX_USER_INFOS: 34, + REMOVE_ALL_OUTBOUND_P2P_MESSAGES: 35, }); export const workerWriteRequests: $ReadOnlyArray = [ @@ -78,6 +79,7 @@ workerRequestMessageTypes.MIGRATE_BACKUP_SCHEMA, workerRequestMessageTypes.COPY_CONTENT_FROM_BACKUP_DB, workerRequestMessageTypes.REMOVE_LOCAL_MESSAGE_INFOS, + workerRequestMessageTypes.REMOVE_ALL_OUTBOUND_P2P_MESSAGES, ]; export const workerOlmAPIRequests: $ReadOnlyArray = [ @@ -289,6 +291,11 @@ +dbID: DatabaseIdentifier, }; +export type RemoveAllOutboundP2PMessagesRequestMessage = { + +type: 35, + +deviceID: string, +}; + export type WorkerRequestMessage = | PingWorkerRequestMessage | InitWorkerRequestMessage @@ -324,7 +331,8 @@ | GetSyncedMetadataRequestMessage | GetHoldersRequestMessage | RemoveLocalMessageInfosRequestMessage - | GetAuxUserInfosRequestMessage; + | GetAuxUserInfosRequestMessage + | RemoveAllOutboundP2PMessagesRequestMessage; export type WorkerRequestProxyMessage = { +id: number,