diff --git a/keyserver/src/deleters/thread-deleters.js b/keyserver/src/deleters/thread-deleters.js --- a/keyserver/src/deleters/thread-deleters.js +++ b/keyserver/src/deleters/thread-deleters.js @@ -15,6 +15,7 @@ import { fetchThreadInfos, fetchServerThreadInfos, + fetchContainedThreadIDs, } from '../fetchers/thread-fetchers.js'; import { fetchThreadPermissionsBlob } from '../fetchers/thread-permission-fetchers.js'; import { fetchUpdateInfoForThreadDeletion } from '../fetchers/update-fetchers.js'; @@ -30,11 +31,7 @@ } const { threadID } = threadDeletionRequest; - const [permissionsBlob, { threadInfos: serverThreadInfos }] = - await Promise.all([ - fetchThreadPermissionsBlob(viewer, threadID), - fetchServerThreadInfos(SQL`t.id = ${threadID}`), - ]); + const permissionsBlob = await fetchThreadPermissionsBlob(viewer, threadID); if (!permissionsBlob) { // This should only occur if the first request goes through but the client @@ -60,14 +57,18 @@ throw new ServerError('invalid_credentials'); } - await rescindPushNotifs( - SQL`n.thread = ${threadID}`, - SQL`IF(m.thread = ${threadID}, NULL, m.thread)`, - ); - - // TODO: if org, delete all descendant threads as well. make sure to warn user // TODO: handle descendant thread permission update correctly. // thread-permission-updaters should be used for descendant threads. + const threadIDs = await fetchContainedThreadIDs(threadID); + + const [{ threadInfos: serverThreadInfos }] = await Promise.all([ + fetchServerThreadInfos(SQL`t.id IN (${threadIDs})`), + rescindPushNotifs( + SQL`n.thread IN (${threadIDs})`, + SQL`IF(m.thread IN (${threadIDs}), NULL, m.thread)`, + ), + ]); + const query = SQL` DELETE t, ic, d, id, e, ie, re, ire, mm, r, ir, ms, im, up, iu, f, n, ino FROM threads t @@ -88,19 +89,20 @@ LEFT JOIN focused f ON f.thread = t.id LEFT JOIN notifications n ON n.thread = t.id LEFT JOIN ids ino ON ino.id = n.id - WHERE t.id = ${threadID} + WHERE t.id IN (${threadIDs}) `; - const serverThreadInfo = serverThreadInfos[threadID]; const time = Date.now(); const updateDatas = []; - for (const memberInfo of serverThreadInfo.members) { - updateDatas.push({ - type: updateTypes.DELETE_THREAD, - userID: memberInfo.id, - time, - threadID, - }); + for (const containedThreadID of threadIDs) { + for (const memberInfo of serverThreadInfos[containedThreadID].members) { + updateDatas.push({ + type: updateTypes.DELETE_THREAD, + userID: memberInfo.id, + time, + threadID: containedThreadID, + }); + } } const [{ viewerUpdates }] = await Promise.all([ diff --git a/keyserver/src/fetchers/thread-fetchers.js b/keyserver/src/fetchers/thread-fetchers.js --- a/keyserver/src/fetchers/thread-fetchers.js +++ b/keyserver/src/fetchers/thread-fetchers.js @@ -308,6 +308,25 @@ return threads.threadInfos[threadID]; } +async function fetchContainedThreadIDs( + parentThreadID: string, +): Promise> { + const query = SQL` + WITH RECURSIVE thread_tree AS ( + SELECT id, containing_thread_id + FROM threads + WHERE id = ${parentThreadID} + UNION ALL + SELECT t.id, t.containing_thread_id + FROM threads t + JOIN thread_tree tt ON t.containing_thread_id = tt.id + ) + SELECT id FROM thread_tree + `; + const [result] = await dbQuery(query); + return result.map(row => row.id.toString()); +} + export { fetchServerThreadInfos, fetchThreadInfos, @@ -318,4 +337,5 @@ personalThreadQuery, fetchPersonalThreadID, serverThreadInfoFromMessageInfo, + fetchContainedThreadIDs, };