diff --git a/keyserver/src/fetchers/thread-fetchers.js b/keyserver/src/fetchers/thread-fetchers.js --- a/keyserver/src/fetchers/thread-fetchers.js +++ b/keyserver/src/fetchers/thread-fetchers.js @@ -24,6 +24,7 @@ import type { Viewer } from '../session/viewer.js'; type FetchThreadInfosFilter = $Shape<{ + +accessibleToUserID: string, +threadID: string, +threadIDs: $ReadOnlySet, +parentThreadID: string, @@ -32,13 +33,25 @@ function constructWhereClause( filter: FetchThreadInfosFilter, ): SQLStatementType { + const fromTable = filter.accessibleToUserID ? 'memberships' : 'threads'; + const conditions = []; - if (filter.threadID) { + if (filter.accessibleToUserID) { + conditions.push( + SQL`mm.user = ${filter.accessibleToUserID} AND mm.role > -1`, + ); + } + + if (filter.threadID && fromTable === 'memberships') { + conditions.push(SQL`mm.thread = ${filter.threadID}`); + } else if (filter.threadID) { conditions.push(SQL`t.id = ${filter.threadID}`); } - if (filter.threadIDs) { + if (filter.threadIDs && fromTable === 'memberships') { + conditions.push(SQL`mm.thread IN (${[...filter.threadIDs]})`); + } else if (filter.threadIDs) { conditions.push(SQL`t.id IN (${[...filter.threadIDs]})`); } @@ -65,25 +78,46 @@ async function fetchServerThreadInfos( filter?: FetchThreadInfosFilter, ): Promise { + let primaryFetchClause; + if (filter?.accessibleToUserID) { + primaryFetchClause = SQL` + FROM memberships mm + LEFT JOIN threads t ON t.id = mm.thread + `; + } else { + primaryFetchClause = SQL` + FROM threads t + `; + } + const whereClause = filter ? constructWhereClause(filter) : ''; const rolesQuery = SQL` SELECT t.id, t.default_role, r.id AS role, r.name, r.permissions - FROM threads t - LEFT JOIN roles r ON r.thread = t.id - `.append(whereClause); + ` + .append(primaryFetchClause) + .append( + SQL` + LEFT JOIN roles r ON r.thread = t.id + `, + ) + .append(whereClause); const threadsQuery = SQL` - SELECT t.id, t.name, t.parent_thread_id, t.containing_thread_id, - t.community, t.depth, t.color, t.description, t.type, t.creation_time, - t.source_message, t.replies_count, t.avatar, t.pinned_count, m.user, - m.role, m.permissions, m.subscription, - m.last_read_message < m.last_message AS unread, m.sender, - up.id AS upload_id, up.secret AS upload_secret - FROM threads t - LEFT JOIN memberships m ON m.thread = t.id AND m.role >= 0 - LEFT JOIN uploads up ON up.container = t.id + SELECT t.id, t.name, t.parent_thread_id, t.containing_thread_id, + t.community, t.depth, t.color, t.description, t.type, t.creation_time, + t.source_message, t.replies_count, t.avatar, t.pinned_count, m.user, + m.role, m.permissions, m.subscription, + m.last_read_message < m.last_message AS unread, m.sender, + up.id AS upload_id, up.secret AS upload_secret ` + .append(primaryFetchClause) + .append( + SQL` + LEFT JOIN memberships m ON m.thread = t.id AND m.role >= 0 + LEFT JOIN uploads up ON up.container = t.id + `, + ) .append(whereClause) .append(SQL` ORDER BY m.user ASC`); const [[threadsResult], [rolesResult]] = await Promise.all([ @@ -191,8 +225,12 @@ async function fetchThreadInfos( viewer: Viewer, - filter?: FetchThreadInfosFilter, + inputFilter?: FetchThreadInfosFilter, ): Promise { + const filter = { + accessibleToUserID: viewer.id, + ...inputFilter, + }; const serverResult = await fetchServerThreadInfos(filter); return rawThreadInfosFromServerThreadInfos(viewer, serverResult); }