diff --git a/keyserver/src/creators/account-creator.js b/keyserver/src/creators/account-creator.js index 881924a25..a5fb30d25 100644 --- a/keyserver/src/creators/account-creator.js +++ b/keyserver/src/creators/account-creator.js @@ -1,315 +1,319 @@ // @flow import { getRustAPI } from 'rust-node-addon'; import bcrypt from 'twin-bcrypt'; import ashoat from 'lib/facts/ashoat.js'; import bots from 'lib/facts/bots.js'; import genesis from 'lib/facts/genesis.js'; import { policyTypes } from 'lib/facts/policies.js'; import { validUsernameRegex } from 'lib/shared/account-utils.js'; import type { RegisterResponse, RegisterRequest, } from 'lib/types/account-types.js'; import type { ReservedUsernameMessage, SignedIdentityKeysBlob, } from 'lib/types/crypto-types.js'; import type { PlatformDetails, DeviceTokenUpdateRequest, } from 'lib/types/device-types.js'; import type { CalendarQuery } from 'lib/types/entry-types.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import type { SIWESocialProof } from 'lib/types/siwe-types.js'; import { threadTypes } from 'lib/types/thread-types-enum.js'; import { ServerError } from 'lib/utils/errors.js'; import { values } from 'lib/utils/objects.js'; +import { ignorePromiseRejections } from 'lib/utils/promises.js'; import { reservedUsernamesSet } from 'lib/utils/reserved-users.js'; import { isValidEthereumAddress } from 'lib/utils/siwe-utils.js'; import createIDs from './id-creator.js'; import createMessages from './message-creator.js'; import { createOlmSession } from './olm-session-creator.js'; import { createThread, createPrivateThread, privateThreadDescription, } from './thread-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import { fetchThreadInfos } from '../fetchers/thread-fetchers.js'; import { fetchLoggedInUserInfo, fetchKnownUserInfos, } from '../fetchers/user-fetchers.js'; import { verifyCalendarQueryThreadIDs } from '../responders/entry-responders.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { searchForUser } from '../search/users.js'; import { createNewUserCookie, setNewSession } from '../session/cookies.js'; import { createScriptViewer } from '../session/scripts.js'; import type { Viewer } from '../session/viewer.js'; import { fetchOlmAccount } from '../updaters/olm-account-updater.js'; import { updateThread } from '../updaters/thread-updaters.js'; import { viewerAcknowledgmentUpdater } from '../updaters/viewer-acknowledgment-updater.js'; const { commbot } = bots; const ashoatMessages = [ 'welcome to Comm!', 'as you inevitably discover bugs, have feature requests, or design ' + 'suggestions, feel free to message them to me in the app.', ]; const privateMessages = [privateThreadDescription]; async function createAccount( viewer: Viewer, request: RegisterRequest, ): Promise { if (request.password.trim() === '') { throw new ServerError('empty_password'); } if (request.username.search(validUsernameRegex) === -1) { throw new ServerError('invalid_username'); } const promises = [searchForUser(request.username)]; const { calendarQuery, signedIdentityKeysBlob, initialNotificationsEncryptedMessage, } = request; if (calendarQuery) { promises.push(verifyCalendarQueryThreadIDs(calendarQuery)); } const [existingUser] = await Promise.all(promises); if ( reservedUsernamesSet.has(request.username.toLowerCase()) || isValidEthereumAddress(request.username.toLowerCase()) ) { throw new ServerError('username_reserved'); } if (existingUser) { throw new ServerError('username_taken'); } const hash = bcrypt.hashSync(request.password); const time = Date.now(); const deviceToken = request.deviceTokenUpdateRequest ? request.deviceTokenUpdateRequest.deviceToken : viewer.deviceToken; const [id] = await createIDs('users', 1); const newUserRow = [id, request.username, hash, time]; const newUserQuery = SQL` INSERT INTO users(id, username, hash, creation_time) VALUES ${[newUserRow]} `; const [userViewerData] = await Promise.all([ createNewUserCookie(id, { platformDetails: request.platformDetails, deviceToken, signedIdentityKeysBlob, }), deleteCookie(viewer.cookieID), dbQuery(newUserQuery), ]); viewer.setNewCookie(userViewerData); if (calendarQuery) { await setNewSession(viewer, calendarQuery, 0); } const olmSessionPromise = (async () => { if (userViewerData.cookieID && initialNotificationsEncryptedMessage) { await createOlmSession( initialNotificationsEncryptedMessage, 'notifications', userViewerData.cookieID, ); } })(); await Promise.all([ updateThread( createScriptViewer(ashoat.id), { threadID: genesis.id, changes: { newMemberIDs: [id] }, }, { forceAddMembers: true, silenceMessages: true, ignorePermissions: true }, ), viewerAcknowledgmentUpdater(viewer, policyTypes.tosAndPrivacyPolicy), olmSessionPromise, ]); const [privateThreadResult, ashoatThreadResult] = await Promise.all([ createPrivateThread(viewer), createThread( viewer, { type: threadTypes.PERSONAL, initialMemberIDs: [ashoat.id], }, { forceAddMembers: true }, ), ]); const ashoatThreadID = ashoatThreadResult.newThreadID; const privateThreadID = privateThreadResult.newThreadID; let messageTime = Date.now(); const ashoatMessageDatas = ashoatMessages.map(message => ({ type: messageTypes.TEXT, threadID: ashoatThreadID, creatorID: ashoat.id, time: messageTime++, text: message, })); const privateMessageDatas = privateMessages.map(message => ({ type: messageTypes.TEXT, threadID: privateThreadID, creatorID: commbot.userID, time: messageTime++, text: message, })); const messageDatas = [...ashoatMessageDatas, ...privateMessageDatas]; const [messageInfos, threadsResult, userInfos, currentUserInfo] = await Promise.all([ createMessages(viewer, messageDatas), fetchThreadInfos(viewer), fetchKnownUserInfos(viewer), fetchLoggedInUserInfo(viewer), ]); const rawMessageInfos = [ ...ashoatThreadResult.newMessageInfos, ...privateThreadResult.newMessageInfos, ...messageInfos, ]; - handleAsyncPromise(createAndSendReservedUsernameMessage([request.username])); + ignorePromiseRejections( + createAndSendReservedUsernameMessage([request.username]), + ); return { id, rawMessageInfos, currentUserInfo, cookieChange: { threadInfos: threadsResult.threadInfos, userInfos: values(userInfos), }, }; } export type ProcessSIWEAccountCreationRequest = { +address: string, +calendarQuery: CalendarQuery, +deviceTokenUpdateRequest?: ?DeviceTokenUpdateRequest, +platformDetails: PlatformDetails, +socialProof: SIWESocialProof, +signedIdentityKeysBlob?: ?SignedIdentityKeysBlob, }; // Note: `processSIWEAccountCreation(...)` assumes that the validity of // `ProcessSIWEAccountCreationRequest` was checked at call site. async function processSIWEAccountCreation( viewer: Viewer, request: ProcessSIWEAccountCreationRequest, ): Promise { const { calendarQuery, signedIdentityKeysBlob } = request; await verifyCalendarQueryThreadIDs(calendarQuery); const time = Date.now(); const deviceToken = request.deviceTokenUpdateRequest ? request.deviceTokenUpdateRequest.deviceToken : viewer.deviceToken; const [id] = await createIDs('users', 1); const newUserRow = [id, request.address, request.address, time]; const newUserQuery = SQL` INSERT INTO users(id, username, ethereum_address, creation_time) VALUES ${[newUserRow]} `; const [userViewerData] = await Promise.all([ createNewUserCookie(id, { platformDetails: request.platformDetails, deviceToken, socialProof: request.socialProof, signedIdentityKeysBlob, }), deleteCookie(viewer.cookieID), dbQuery(newUserQuery), ]); viewer.setNewCookie(userViewerData); await setNewSession(viewer, calendarQuery, 0); await Promise.all([ updateThread( createScriptViewer(ashoat.id), { threadID: genesis.id, changes: { newMemberIDs: [id] }, }, { forceAddMembers: true, silenceMessages: true, ignorePermissions: true }, ), viewerAcknowledgmentUpdater(viewer, policyTypes.tosAndPrivacyPolicy), ]); const [privateThreadResult, ashoatThreadResult] = await Promise.all([ createPrivateThread(viewer), createThread( viewer, { type: threadTypes.PERSONAL, initialMemberIDs: [ashoat.id], }, { forceAddMembers: true }, ), ]); const ashoatThreadID = ashoatThreadResult.newThreadID; const privateThreadID = privateThreadResult.newThreadID; let messageTime = Date.now(); const ashoatMessageDatas = ashoatMessages.map(message => ({ type: messageTypes.TEXT, threadID: ashoatThreadID, creatorID: ashoat.id, time: messageTime++, text: message, })); const privateMessageDatas = privateMessages.map(message => ({ type: messageTypes.TEXT, threadID: privateThreadID, creatorID: commbot.userID, time: messageTime++, text: message, })); const messageDatas = [...ashoatMessageDatas, ...privateMessageDatas]; await Promise.all([createMessages(viewer, messageDatas)]); - handleAsyncPromise(createAndSendReservedUsernameMessage([request.address])); + ignorePromiseRejections( + createAndSendReservedUsernameMessage([request.address]), + ); return id; } async function createAndSendReservedUsernameMessage( payload: $ReadOnlyArray, ) { const issuedAt = new Date().toISOString(); const reservedUsernameMessage: ReservedUsernameMessage = { statement: 'Add the following usernames to reserved list', payload, issuedAt, }; const stringifiedMessage = JSON.stringify(reservedUsernameMessage); const [rustAPI, accountInfo] = await Promise.all([ getRustAPI(), fetchOlmAccount('content'), ]); const signature = accountInfo.account.sign(stringifiedMessage); await rustAPI.addReservedUsernames(stringifiedMessage, signature); } export { createAccount, processSIWEAccountCreation }; diff --git a/keyserver/src/creators/message-creator.js b/keyserver/src/creators/message-creator.js index 636108899..dbd49db7a 100644 --- a/keyserver/src/creators/message-creator.js +++ b/keyserver/src/creators/message-creator.js @@ -1,701 +1,700 @@ // @flow import invariant from 'invariant'; import _pickBy from 'lodash/fp/pickBy.js'; import { permissionLookup } from 'lib/permissions/thread-permissions.js'; import { rawMessageInfoFromMessageData, shimUnsupportedRawMessageInfos, stripLocalIDs, } from 'lib/shared/message-utils.js'; import { pushTypes } from 'lib/shared/messages/message-spec.js'; import type { PushType } from 'lib/shared/messages/message-spec.js'; import { messageSpecs } from 'lib/shared/messages/message-specs.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import { messageDataLocalID, type MessageData, type RawMessageInfo, } from 'lib/types/message-types.js'; import { redisMessageTypes } from 'lib/types/redis-types.js'; import { threadPermissions } from 'lib/types/thread-permission-types.js'; import { updateTypes } from 'lib/types/update-types-enum.js'; -import { promiseAll } from 'lib/utils/promises.js'; +import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import createIDs from './id-creator.js'; import type { UpdatesForCurrentSession } from './update-creator.js'; import { createUpdates } from './update-creator.js'; import { dbQuery, SQL, appendSQLArray, mergeOrConditions, } from '../database/database.js'; import { processMessagesForSearch } from '../database/search-utils.js'; import { fetchMessageInfoForLocalID, fetchMessageInfoByID, } from '../fetchers/message-fetchers.js'; import { fetchOtherSessionsForViewer } from '../fetchers/session-fetchers.js'; import { fetchServerThreadInfos } from '../fetchers/thread-fetchers.js'; import type { Device, PushUserInfo } from '../push/send.js'; import { sendPushNotifs, sendRescindNotifs } from '../push/send.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import type { Viewer } from '../session/viewer.js'; import { earliestFocusedTimeConsideredExpired } from '../shared/focused-times.js'; import { publisher } from '../socket/redis.js'; import { creationString } from '../utils/idempotent.js'; type UserThreadInfo = { +devices: Map, +threadIDs: Set, +notFocusedThreadIDs: Set, +userNotMemberOfSubthreads: Set, +subthreadsCanSetToUnread: Set, }; type LatestMessagesPerUser = Map< string, $ReadOnlyMap< string, { +latestMessage: string, +latestReadMessage?: string, }, >, >; type LatestMessages = $ReadOnlyArray<{ +userID: string, +threadID: string, +latestMessage: string, +latestReadMessage: ?string, }>; // Does not do permission checks! (checkThreadPermission) async function createMessages( viewer: Viewer, messageDatas: $ReadOnlyArray, updatesForCurrentSession?: UpdatesForCurrentSession = 'return', ): Promise { if (messageDatas.length === 0) { return []; } const existingMessages = await Promise.all( messageDatas.map(messageData => fetchMessageInfoForLocalID(viewer, messageDataLocalID(messageData)), ), ); const existingMessageInfos: RawMessageInfo[] = []; const newMessageDatas: MessageData[] = []; for (let i = 0; i < messageDatas.length; i++) { const existingMessage = existingMessages[i]; if (existingMessage) { existingMessageInfos.push(existingMessage); } else { newMessageDatas.push(messageDatas[i]); } } if (newMessageDatas.length === 0) { return shimUnsupportedRawMessageInfos( existingMessageInfos, viewer.platformDetails, ); } const ids = await createIDs('messages', newMessageDatas.length); const returnMessageInfos: RawMessageInfo[] = []; const subthreadPermissionsToCheck: Set = new Set(); const messageInsertRows = []; // Indices in threadsToMessageIndices point to newMessageInfos const newMessageInfos: RawMessageInfo[] = []; const threadsToMessageIndices: Map = new Map(); let nextNewMessageIndex = 0; for (let i = 0; i < messageDatas.length; i++) { const existingMessage = existingMessages[i]; if (existingMessage) { returnMessageInfos.push(existingMessage); continue; } const messageData = messageDatas[i]; const threadID = messageData.threadID; const creatorID = messageData.creatorID; let messageIndices = threadsToMessageIndices.get(threadID); if (!messageIndices) { messageIndices = []; threadsToMessageIndices.set(threadID, messageIndices); } const newMessageIndex = nextNewMessageIndex++; messageIndices.push(newMessageIndex); const serverID = ids[newMessageIndex]; if (messageData.type === messageTypes.CREATE_SUB_THREAD) { subthreadPermissionsToCheck.add(messageData.childThreadID); } const content = messageSpecs[messageData.type].messageContentForServerDB?.(messageData); const creation = messageData.localID && viewer.hasSessionInfo ? creationString(viewer, messageData.localID) : null; let targetMessageID = null; if (messageData.targetMessageID) { targetMessageID = messageData.targetMessageID; } else if (messageData.sourceMessage) { targetMessageID = messageData.sourceMessage.id; } messageInsertRows.push([ serverID, threadID, creatorID, messageData.type, content, messageData.time, creation, targetMessageID, ]); const rawMessageInfo = rawMessageInfoFromMessageData(messageData, serverID); newMessageInfos.push(rawMessageInfo); // at newMessageIndex returnMessageInfos.push(rawMessageInfo); // at i } const messageInsertQuery = SQL` INSERT INTO messages(id, thread, user, type, content, time, creation, target_message) VALUES ${messageInsertRows} `; await dbQuery(messageInsertQuery); const postMessageSendPromise = postMessageSend( viewer, threadsToMessageIndices, subthreadPermissionsToCheck, stripLocalIDs(newMessageInfos), newMessageDatas, updatesForCurrentSession, ); if (!viewer.isScriptViewer) { // If we're not being called from a script, then we avoid awaiting // postMessageSendPromise below so that we don't delay the response to the - // user on external services. In that case, we use handleAsyncPromise to - // make sure any exceptions are caught and logged. - handleAsyncPromise(postMessageSendPromise); + // user on external services. In that case, we use ignorePromiseRejections + // to make sure any exceptions are caught and logged. + ignorePromiseRejections(postMessageSendPromise); } await Promise.all([ updateRepliesCount(threadsToMessageIndices, newMessageDatas), viewer.isScriptViewer ? postMessageSendPromise : undefined, ]); if (updatesForCurrentSession !== 'return') { return []; } return shimUnsupportedRawMessageInfos( returnMessageInfos, viewer.platformDetails, ); } async function updateRepliesCount( threadsToMessageIndices: Map, newMessageDatas: MessageData[], ) { const updatedThreads = []; const updateThreads = SQL` UPDATE threads SET replies_count = replies_count + (CASE `; const membershipConditions = []; for (const [threadID, messages] of threadsToMessageIndices.entries()) { const newRepliesIncrease = messages .map(i => newMessageDatas[i].type) .filter(type => messageSpecs[type].includedInRepliesCount).length; if (newRepliesIncrease === 0) { continue; } updateThreads.append(SQL` WHEN id = ${threadID} THEN ${newRepliesIncrease} `); updatedThreads.push(threadID); const senders = messages.map(i => newMessageDatas[i].creatorID); membershipConditions.push( SQL`thread = ${threadID} AND user IN (${senders})`, ); } updateThreads.append(SQL` ELSE 0 END) WHERE id IN (${updatedThreads}) AND source_message IS NOT NULL `); const updateMemberships = SQL` UPDATE memberships SET sender = 1 WHERE sender = 0 AND ( `; updateMemberships.append(mergeOrConditions(membershipConditions)); updateMemberships.append(SQL` ) `); if (updatedThreads.length > 0) { const [{ threadInfos: serverThreadInfos }] = await Promise.all([ fetchServerThreadInfos({ threadIDs: new Set(updatedThreads) }), dbQuery(updateThreads), dbQuery(updateMemberships), ]); const time = Date.now(); const updates = []; for (const threadID in serverThreadInfos) { for (const member of serverThreadInfos[threadID].members) { updates.push({ userID: member.id, time, threadID, type: updateTypes.UPDATE_THREAD, }); } } await createUpdates(updates); } } // Handles: // (1) Sending push notifs // (2) Setting threads to unread and generating corresponding UpdateInfos // (3) Publishing to Redis so that active sockets pass on new messages // (4) Processing messages for search async function postMessageSend( viewer: Viewer, threadsToMessageIndices: Map, subthreadPermissionsToCheck: Set, messageInfos: RawMessageInfo[], messageDatas: MessageData[], updatesForCurrentSession: UpdatesForCurrentSession, ) { const processForSearch = processMessagesForSearch(messageInfos); let joinIndex = 0; let subthreadSelects = ''; const subthreadJoins = []; for (const subthread of subthreadPermissionsToCheck) { const index = joinIndex++; subthreadSelects += ` , stm${index}.permissions AS subthread${subthread}_permissions, stm${index}.role AS subthread${subthread}_role `; const join = SQL`LEFT JOIN memberships `; join.append(`stm${index} ON stm${index}.`); join.append(SQL`thread = ${subthread} AND `); join.append(`stm${index}.user = m.user`); subthreadJoins.push(join); } const time = earliestFocusedTimeConsideredExpired(); const visibleExtractString = `$.${threadPermissions.VISIBLE}.value`; const query = SQL` SELECT m.user, m.thread, c.platform, c.device_token, c.versions, c.id, f.user AS focused_user `; query.append(subthreadSelects); query.append(SQL` FROM memberships m LEFT JOIN cookies c ON c.user = m.user AND c.device_token IS NOT NULL LEFT JOIN focused f ON f.user = m.user AND f.thread = m.thread AND f.time > ${time} `); appendSQLArray(query, subthreadJoins, SQL` `); query.append(SQL` WHERE (m.role > 0 OR f.user IS NOT NULL) AND JSON_EXTRACT(m.permissions, ${visibleExtractString}) IS TRUE AND m.thread IN (${[...threadsToMessageIndices.keys()]}) `); const perUserInfo = new Map(); const [result] = await dbQuery(query); for (const row of result) { const userID = row.user.toString(); const threadID = row.thread.toString(); const deviceToken = row.device_token; const focusedUser = !!row.focused_user; const { platform } = row; const versions = JSON.parse(row.versions); const cookieID = row.id; let thisUserInfo = perUserInfo.get(userID); if (!thisUserInfo) { thisUserInfo = { devices: new Map(), threadIDs: new Set(), notFocusedThreadIDs: new Set(), userNotMemberOfSubthreads: new Set(), subthreadsCanSetToUnread: new Set(), }; perUserInfo.set(userID, thisUserInfo); // Subthread info will be the same for each subthread, so we only parse // it once for (const subthread of subthreadPermissionsToCheck) { const isSubthreadMember = row[`subthread${subthread}_role`] > 0; const rawSubthreadPermissions = row[`subthread${subthread}_permissions`]; const subthreadPermissions = JSON.parse(rawSubthreadPermissions); const canSeeSubthread = permissionLookup( subthreadPermissions, threadPermissions.KNOW_OF, ); if (!canSeeSubthread) { continue; } thisUserInfo.subthreadsCanSetToUnread.add(subthread); // Only include the notification from the superthread if there is no // notification from the subthread if ( !isSubthreadMember || !permissionLookup(subthreadPermissions, threadPermissions.VISIBLE) ) { thisUserInfo.userNotMemberOfSubthreads.add(subthread); } } } if (deviceToken && cookieID) { thisUserInfo.devices.set(deviceToken, { platform, deviceToken, cookieID: cookieID.toString(), codeVersion: versions ? versions.codeVersion : null, stateVersion: versions ? versions.stateVersion : null, majorDesktopVersion: versions ? versions.majorDesktopVersion : null, }); } thisUserInfo.threadIDs.add(threadID); if (!focusedUser) { thisUserInfo.notFocusedThreadIDs.add(threadID); } } const messageInfosPerUser: { [userID: string]: $ReadOnlyArray, } = {}; const latestMessagesPerUser: LatestMessagesPerUser = new Map(); const userPushInfoPromises: { [string]: Promise } = {}; const userRescindInfoPromises: { [string]: Promise } = {}; for (const pair of perUserInfo) { const [userID, preUserPushInfo] = pair; const userMessageInfos = []; for (const threadID of preUserPushInfo.threadIDs) { const messageIndices = threadsToMessageIndices.get(threadID); invariant(messageIndices, `indices should exist for thread ${threadID}`); for (const messageIndex of messageIndices) { const messageInfo = messageInfos[messageIndex]; userMessageInfos.push(messageInfo); } } if (userMessageInfos.length > 0) { messageInfosPerUser[userID] = userMessageInfos; } latestMessagesPerUser.set( userID, determineLatestMessagesPerThread( preUserPushInfo, userID, threadsToMessageIndices, messageInfos, ), ); const { userNotMemberOfSubthreads } = preUserPushInfo; const userDevices = [...preUserPushInfo.devices.values()]; if (userDevices.length === 0) { continue; } const generateNotifUserInfoPromise = async (pushType: PushType) => { const promises: Array< Promise, > = []; for (const threadID of preUserPushInfo.notFocusedThreadIDs) { const messageIndices = threadsToMessageIndices.get(threadID); invariant( messageIndices, `indices should exist for thread ${threadID}`, ); promises.push( ...messageIndices.map(async messageIndex => { const messageInfo = messageInfos[messageIndex]; const { type } = messageInfo; if (messageInfo.creatorID === userID) { // We never send a user notifs about their own activity return undefined; } const { generatesNotifs } = messageSpecs[type]; const messageData = messageDatas[messageIndex]; if (!generatesNotifs) { return undefined; } const doesGenerateNotif = await generatesNotifs( messageInfo, messageData, { notifTargetUserID: userID, userNotMemberOfSubthreads, fetchMessageInfoByID: (messageID: string) => fetchMessageInfoByID(viewer, messageID), }, ); return doesGenerateNotif === pushType ? { messageInfo, messageData } : undefined; }), ); } const messagesToNotify = await Promise.all(promises); const filteredMessagesToNotify = messagesToNotify.filter(Boolean); if (filteredMessagesToNotify.length === 0) { return undefined; } return { devices: userDevices, messageInfos: filteredMessagesToNotify.map( ({ messageInfo }) => messageInfo, ), messageDatas: filteredMessagesToNotify.map( ({ messageData }) => messageData, ), }; }; const userPushInfoPromise = generateNotifUserInfoPromise(pushTypes.NOTIF); const userRescindInfoPromise = generateNotifUserInfoPromise( pushTypes.RESCIND, ); userPushInfoPromises[userID] = userPushInfoPromise; userRescindInfoPromises[userID] = userRescindInfoPromise; } const latestMessages = flattenLatestMessagesPerUser(latestMessagesPerUser); const [pushInfo, rescindInfo] = await Promise.all([ promiseAll(userPushInfoPromises), promiseAll(userRescindInfoPromises), createReadStatusUpdates(latestMessages), redisPublish(viewer, messageInfosPerUser, updatesForCurrentSession), updateLatestMessages(latestMessages), processForSearch, ]); await Promise.all([ sendPushNotifs(_pickBy(Boolean)(pushInfo)), sendRescindNotifs(_pickBy(Boolean)(rescindInfo)), ]); } async function redisPublish( viewer: Viewer, messageInfosPerUser: { [userID: string]: $ReadOnlyArray }, updatesForCurrentSession: UpdatesForCurrentSession, ) { const avoidBroadcastingToCurrentSession = viewer.hasSessionInfo && updatesForCurrentSession !== 'broadcast'; for (const userID in messageInfosPerUser) { if (userID === viewer.userID && avoidBroadcastingToCurrentSession) { continue; } const messageInfos = messageInfosPerUser[userID]; publisher.sendMessage( { userID }, { type: redisMessageTypes.NEW_MESSAGES, messages: messageInfos, }, ); } const viewerMessageInfos = messageInfosPerUser[viewer.userID]; if (!viewerMessageInfos || !avoidBroadcastingToCurrentSession) { return; } const sessionIDs = await fetchOtherSessionsForViewer(viewer); for (const sessionID of sessionIDs) { publisher.sendMessage( { userID: viewer.userID, sessionID }, { type: redisMessageTypes.NEW_MESSAGES, messages: viewerMessageInfos, }, ); } } type LatestMessagePerThread = { +latestMessage: string, +latestReadMessage?: string, }; function determineLatestMessagesPerThread( preUserPushInfo: UserThreadInfo, userID: string, threadsToMessageIndices: $ReadOnlyMap>, messageInfos: $ReadOnlyArray, ): $ReadOnlyMap { const { threadIDs, notFocusedThreadIDs, subthreadsCanSetToUnread } = preUserPushInfo; const latestMessagesPerThread = new Map(); for (const threadID of threadIDs) { const messageIndices = threadsToMessageIndices.get(threadID); invariant(messageIndices, `indices should exist for thread ${threadID}`); for (const messageIndex of messageIndices) { const messageInfo = messageInfos[messageIndex]; if ( messageInfo.type === messageTypes.CREATE_SUB_THREAD && !subthreadsCanSetToUnread.has(messageInfo.childThreadID) ) { continue; } const messageID = messageInfo.id; invariant( messageID, 'message ID should exist in determineLatestMessagesPerThread', ); if ( notFocusedThreadIDs.has(threadID) && messageInfo.creatorID !== userID ) { latestMessagesPerThread.set(threadID, { latestMessage: messageID, }); } else { latestMessagesPerThread.set(threadID, { latestMessage: messageID, latestReadMessage: messageID, }); } } } return latestMessagesPerThread; } function flattenLatestMessagesPerUser( latestMessagesPerUser: LatestMessagesPerUser, ): LatestMessages { const result = []; for (const [userID, latestMessagesPerThread] of latestMessagesPerUser) { for (const [threadID, latestMessages] of latestMessagesPerThread) { result.push({ userID, threadID, latestMessage: latestMessages.latestMessage, latestReadMessage: latestMessages.latestReadMessage, }); } } return result; } async function createReadStatusUpdates(latestMessages: LatestMessages) { const now = Date.now(); const readStatusUpdates = latestMessages .filter(message => !message.latestReadMessage) .map(({ userID, threadID }) => ({ type: updateTypes.UPDATE_THREAD_READ_STATUS, userID, time: now, threadID, unread: true, })); if (readStatusUpdates.length === 0) { return; } await createUpdates(readStatusUpdates); } function updateLatestMessages(latestMessages: LatestMessages) { if (latestMessages.length === 0) { return; } const query = SQL` UPDATE memberships SET `; const lastMessageExpression = SQL` last_message = GREATEST(last_message, CASE `; const lastReadMessageExpression = SQL` , last_read_message = GREATEST(last_read_message, CASE `; let shouldUpdateLastReadMessage = false; for (const { userID, threadID, latestMessage, latestReadMessage, } of latestMessages) { lastMessageExpression.append(SQL` WHEN user = ${userID} AND thread = ${threadID} THEN ${latestMessage} `); if (latestReadMessage) { shouldUpdateLastReadMessage = true; lastReadMessageExpression.append(SQL` WHEN user = ${userID} AND thread = ${threadID} THEN ${latestReadMessage} `); } } lastMessageExpression.append(SQL` ELSE last_message END) `); lastReadMessageExpression.append(SQL` ELSE last_read_message END) `); const conditions = latestMessages.map( ({ userID, threadID }) => SQL`(user = ${userID} AND thread = ${threadID})`, ); query.append(lastMessageExpression); if (shouldUpdateLastReadMessage) { query.append(lastReadMessageExpression); } query.append(SQL`WHERE `); query.append(mergeOrConditions(conditions)); dbQuery(query); } export default createMessages; diff --git a/keyserver/src/creators/report-creator.js b/keyserver/src/creators/report-creator.js index 038bfc734..0c76d70fe 100644 --- a/keyserver/src/creators/report-creator.js +++ b/keyserver/src/creators/report-creator.js @@ -1,238 +1,238 @@ // @flow import _isEqual from 'lodash/fp/isEqual.js'; import bots from 'lib/facts/bots.js'; import { filterRawEntryInfosByCalendarQuery, serverEntryInfosObject, } from 'lib/shared/entry-utils.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import { type ReportCreationRequest, type ReportCreationResponse, type ThreadInconsistencyReportCreationRequest, type EntryInconsistencyReportCreationRequest, type UserInconsistencyReportCreationRequest, reportTypes, } from 'lib/types/report-types.js'; import { values } from 'lib/utils/objects.js'; +import { ignorePromiseRejections } from 'lib/utils/promises.js'; import { sanitizeReduxReport, type ReduxCrashReport, } from 'lib/utils/sanitization.js'; import createIDs from './id-creator.js'; import createMessages from './message-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { fetchUsername } from '../fetchers/user-fetchers.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { createBotViewer } from '../session/bots.js'; import type { Viewer } from '../session/viewer.js'; import { getAndAssertKeyserverURLFacts } from '../utils/urls.js'; const { commbot } = bots; async function createReport( viewer: Viewer, request: ReportCreationRequest, ): Promise { const shouldIgnore = await ignoreReport(viewer, request); if (shouldIgnore) { return null; } const [id] = await createIDs('reports', 1); let type, report, time; if (request.type === reportTypes.THREAD_INCONSISTENCY) { ({ type, time, ...report } = request); time = time ? time : Date.now(); } else if (request.type === reportTypes.ENTRY_INCONSISTENCY) { ({ type, time, ...report } = request); } else if (request.type === reportTypes.MEDIA_MISSION) { ({ type, time, ...report } = request); } else if (request.type === reportTypes.USER_INCONSISTENCY) { ({ type, time, ...report } = request); } else { ({ type, ...report } = request); time = Date.now(); const redactedReduxReport: ReduxCrashReport = sanitizeReduxReport({ preloadedState: report.preloadedState, currentState: report.currentState, actions: report.actions, }); report = { ...report, ...redactedReduxReport, }; } const row = [ id, viewer.id, type, request.platformDetails.platform, JSON.stringify(report), time, ]; const query = SQL` INSERT INTO reports (id, user, type, platform, report, creation_time) VALUES ${[row]} `; await dbQuery(query); - handleAsyncPromise(sendCommbotMessage(viewer, request, id)); + ignorePromiseRejections(sendCommbotMessage(viewer, request, id)); return { id }; } async function sendCommbotMessage( viewer: Viewer, request: ReportCreationRequest, reportID: string, ): Promise { const canGenerateMessage = getCommbotMessage(request, reportID, null); if (!canGenerateMessage) { return; } const username = await fetchUsername(viewer.id); const message = getCommbotMessage(request, reportID, username); if (!message) { return; } const time = Date.now(); await createMessages(createBotViewer(commbot.userID), [ { type: messageTypes.TEXT, threadID: commbot.staffThreadID, creatorID: commbot.userID, time, text: message, }, ]); } async function ignoreReport( viewer: Viewer, request: ReportCreationRequest, ): Promise { // The below logic is to avoid duplicate inconsistency reports if ( request.type !== reportTypes.THREAD_INCONSISTENCY && request.type !== reportTypes.ENTRY_INCONSISTENCY ) { return false; } const { type, platformDetails, time } = request; if (!time) { return false; } const { platform } = platformDetails; const query = SQL` SELECT id FROM reports WHERE user = ${viewer.id} AND type = ${type} AND platform = ${platform} AND creation_time = ${time} `; const [result] = await dbQuery(query); return result.length !== 0; } function getCommbotMessage( request: ReportCreationRequest, reportID: string, username: ?string, ): ?string { const name = username ? username : '[null]'; const { platformDetails } = request; const { platform, codeVersion } = platformDetails; const platformString = codeVersion ? `${platform} v${codeVersion}` : platform; if (request.type === reportTypes.ERROR) { const { baseDomain, basePath } = getAndAssertKeyserverURLFacts(); return ( `${name} got an error :(\n` + `using ${platformString}\n` + `${baseDomain}${basePath}download_error_report/${reportID}` ); } else if (request.type === reportTypes.THREAD_INCONSISTENCY) { const nonMatchingThreadIDs = getInconsistentThreadIDsFromReport(request); const nonMatchingString = [...nonMatchingThreadIDs].join(', '); return ( `system detected inconsistency for ${name}!\n` + `using ${platformString}\n` + `occurred during ${request.action.type}\n` + `thread IDs that are inconsistent: ${nonMatchingString}` ); } else if (request.type === reportTypes.ENTRY_INCONSISTENCY) { const nonMatchingEntryIDs = getInconsistentEntryIDsFromReport(request); const nonMatchingString = [...nonMatchingEntryIDs].join(', '); return ( `system detected inconsistency for ${name}!\n` + `using ${platformString}\n` + `occurred during ${request.action.type}\n` + `entry IDs that are inconsistent: ${nonMatchingString}` ); } else if (request.type === reportTypes.USER_INCONSISTENCY) { const nonMatchingUserIDs = getInconsistentUserIDsFromReport(request); const nonMatchingString = [...nonMatchingUserIDs].join(', '); return ( `system detected inconsistency for ${name}!\n` + `using ${platformString}\n` + `occurred during ${request.action.type}\n` + `user IDs that are inconsistent: ${nonMatchingString}` ); } else if (request.type === reportTypes.MEDIA_MISSION) { const mediaMissionJSON = JSON.stringify(request.mediaMission); const success = request.mediaMission.result.success ? 'media mission success!' : 'media mission failed :('; return `${name} ${success}\n` + mediaMissionJSON; } else { return null; } } function findInconsistentObjectKeys( first: { +[id: string]: O }, second: { +[id: string]: O }, ): Set { const nonMatchingIDs = new Set(); for (const id in first) { if (!_isEqual(first[id])(second[id])) { nonMatchingIDs.add(id); } } for (const id in second) { if (!first[id]) { nonMatchingIDs.add(id); } } return nonMatchingIDs; } function getInconsistentThreadIDsFromReport( request: ThreadInconsistencyReportCreationRequest, ): Set { const { pushResult, beforeAction } = request; return findInconsistentObjectKeys(beforeAction, pushResult); } function getInconsistentEntryIDsFromReport( request: EntryInconsistencyReportCreationRequest, ): Set { const { pushResult, beforeAction, calendarQuery } = request; const filteredBeforeAction = filterRawEntryInfosByCalendarQuery( serverEntryInfosObject(values(beforeAction)), calendarQuery, ); const filteredAfterAction = filterRawEntryInfosByCalendarQuery( serverEntryInfosObject(values(pushResult)), calendarQuery, ); return findInconsistentObjectKeys(filteredBeforeAction, filteredAfterAction); } function getInconsistentUserIDsFromReport( request: UserInconsistencyReportCreationRequest, ): Set { const { beforeStateCheck, afterStateCheck } = request; return findInconsistentObjectKeys(beforeStateCheck, afterStateCheck); } export default createReport; diff --git a/keyserver/src/cron/update-geoip-db.js b/keyserver/src/cron/update-geoip-db.js index f48be4d93..2018100a7 100644 --- a/keyserver/src/cron/update-geoip-db.js +++ b/keyserver/src/cron/update-geoip-db.js @@ -1,64 +1,63 @@ // @flow import childProcess from 'child_process'; import cluster from 'cluster'; import geoip from 'geoip-lite'; import { getCommConfig } from 'lib/utils/comm-config.js'; - -import { handleAsyncPromise } from '../responders/handlers.js'; +import { ignorePromiseRejections } from 'lib/utils/promises.js'; type GeoIPLicenseConfig = { +key: string }; async function updateGeoipDB(): Promise { const geoipLicense = await getCommConfig({ folder: 'secrets', name: 'geoip_license', }); if (!geoipLicense) { console.log('no keyserver/secrets/geoip_license.json so skipping update'); return; } await spawnUpdater(geoipLicense); } function spawnUpdater(geoipLicense: GeoIPLicenseConfig): Promise { const spawned = childProcess.spawn(process.execPath, [ '../node_modules/geoip-lite/scripts/updatedb.js', `license_key=${geoipLicense.key}`, ]); return new Promise((resolve, reject) => { spawned.on('error', reject); spawned.on('exit', () => resolve()); }); } function reloadGeoipDB(): Promise { return new Promise(resolve => geoip.reloadData(resolve)); } type IPCMessage = { type: 'geoip_reload', }; const reloadMessage: IPCMessage = { type: 'geoip_reload' }; async function updateAndReloadGeoipDB(): Promise { await updateGeoipDB(); await reloadGeoipDB(); if (!cluster.isMaster) { return; } for (const id in cluster.workers) { cluster.workers[Number(id)].send(reloadMessage); } } if (!cluster.isMaster) { process.on('message', (ipcMessage: IPCMessage) => { if (ipcMessage.type === 'geoip_reload') { - handleAsyncPromise(reloadGeoipDB()); + ignorePromiseRejections(reloadGeoipDB()); } }); } export { updateAndReloadGeoipDB }; diff --git a/keyserver/src/deleters/account-deleters.js b/keyserver/src/deleters/account-deleters.js index 836c60da2..d6c8a8c72 100644 --- a/keyserver/src/deleters/account-deleters.js +++ b/keyserver/src/deleters/account-deleters.js @@ -1,153 +1,152 @@ // @flow import { getRustAPI } from 'rust-node-addon'; import type { LogOutResponse } from 'lib/types/account-types.js'; import type { ReservedUsernameMessage } from 'lib/types/crypto-types.js'; import { updateTypes } from 'lib/types/update-types-enum.js'; import type { UserInfo } from 'lib/types/user-types.js'; import { ServerError } from 'lib/utils/errors.js'; import { values } from 'lib/utils/objects.js'; -import { promiseAll } from 'lib/utils/promises.js'; +import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import { createUpdates } from '../creators/update-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { fetchKnownUserInfos, fetchUsername, } from '../fetchers/user-fetchers.js'; import { rescindPushNotifs } from '../push/rescind.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { createNewAnonymousCookie } from '../session/cookies.js'; import type { Viewer, AnonymousViewerData } from '../session/viewer.js'; import { fetchOlmAccount } from '../updaters/olm-account-updater.js'; async function deleteAccount(viewer: Viewer): Promise { if (!viewer.loggedIn) { throw new ServerError('not_logged_in'); } const deletedUserID = viewer.userID; await rescindPushNotifs(SQL`n.user = ${deletedUserID}`, SQL`NULL`); const knownUserInfos = await fetchKnownUserInfos(viewer); const usersToUpdate: $ReadOnlyArray = values(knownUserInfos).filter( (user: UserInfo): boolean => user.id !== deletedUserID, ); // TODO: if this results in any orphaned orgs, convert them to chats const deletionQuery = SQL` START TRANSACTION; DELETE FROM users WHERE id = ${deletedUserID}; DELETE FROM ids WHERE id = ${deletedUserID}; DELETE c, i FROM cookies c LEFT JOIN ids i ON i.id = c.id WHERE c.user = ${deletedUserID}; DELETE FROM memberships WHERE user = ${deletedUserID}; DELETE FROM focused WHERE user = ${deletedUserID}; DELETE n, i FROM notifications n LEFT JOIN ids i ON i.id = n.id WHERE n.user = ${deletedUserID}; DELETE u, i FROM updates u LEFT JOIN ids i ON i.id = u.id WHERE u.user = ${deletedUserID}; DELETE s, i FROM sessions s LEFT JOIN ids i ON i.id = s.id WHERE s.user = ${deletedUserID}; DELETE r, i FROM reports r LEFT JOIN ids i ON i.id = r.id WHERE r.user = ${deletedUserID}; DELETE u, i FROM uploads u LEFT JOIN ids i on i.id = u.id WHERE u.container = ${deletedUserID}; DELETE FROM relationships_undirected WHERE user1 = ${deletedUserID}; DELETE FROM relationships_undirected WHERE user2 = ${deletedUserID}; DELETE FROM relationships_directed WHERE user1 = ${deletedUserID}; DELETE FROM relationships_directed WHERE user2 = ${deletedUserID}; COMMIT; `; const deletionPromise = dbQuery(deletionQuery, { multipleStatements: true }); const anonymousViewerDataPromise: Promise = (async () => { if (viewer.isScriptViewer) { return undefined; } return await createNewAnonymousCookie({ platformDetails: viewer.platformDetails, deviceToken: viewer.deviceToken, }); })(); const usernamePromise = fetchUsername(deletedUserID); const { anonymousViewerData, username } = await promiseAll({ anonymousViewerData: anonymousViewerDataPromise, username: usernamePromise, deletion: deletionPromise, }); if (username) { const issuedAt = new Date().toISOString(); const reservedUsernameMessage: ReservedUsernameMessage = { statement: 'Remove the following username from reserved list', payload: username, issuedAt, }; const message = JSON.stringify(reservedUsernameMessage); - handleAsyncPromise( + ignorePromiseRejections( (async () => { const rustAPI = await getRustAPI(); const accountInfo = await fetchOlmAccount('content'); const signature = accountInfo.account.sign(message); await rustAPI.removeReservedUsername(message, signature); })(), ); } if (anonymousViewerData) { viewer.setNewCookie(anonymousViewerData); } const deletionUpdatesPromise = createAccountDeletionUpdates( usersToUpdate, deletedUserID, ); if (viewer.isScriptViewer) { await deletionUpdatesPromise; } else { - handleAsyncPromise(deletionUpdatesPromise); + ignorePromiseRejections(deletionUpdatesPromise); } if (viewer.isScriptViewer) { return null; } return { currentUserInfo: { anonymous: true, }, }; } async function createAccountDeletionUpdates( knownUserInfos: $ReadOnlyArray, deletedUserID: string, ): Promise { const time = Date.now(); const updateDatas = []; for (const userInfo of knownUserInfos) { const { id: userID } = userInfo; updateDatas.push({ type: updateTypes.DELETE_ACCOUNT, userID, time, deletedUserID, }); } await createUpdates(updateDatas); } export { deleteAccount }; diff --git a/keyserver/src/keyserver.js b/keyserver/src/keyserver.js index 10572b348..efa6a8c05 100644 --- a/keyserver/src/keyserver.js +++ b/keyserver/src/keyserver.js @@ -1,286 +1,288 @@ // @flow import olm from '@commapp/olm'; import cluster from 'cluster'; import compression from 'compression'; import cookieParser from 'cookie-parser'; import cors from 'cors'; import crypto from 'crypto'; import express from 'express'; import type { $Request, $Response } from 'express'; import expressWs from 'express-ws'; import os from 'os'; import qrcode from 'qrcode'; import './cron/cron.js'; import { qrCodeLinkURL } from 'lib/facts/links.js'; import { isDev } from 'lib/utils/dev-utils.js'; +import { ignorePromiseRejections } from 'lib/utils/promises.js'; import { migrate } from './database/migrations.js'; import { jsonEndpoints } from './endpoints.js'; import { logEndpointMetrics } from './middleware/endpoint-profiling.js'; import { emailSubscriptionResponder } from './responders/comm-landing-responders.js'; import { jsonHandler, downloadHandler, - handleAsyncPromise, htmlHandler, uploadHandler, } from './responders/handlers.js'; import landingHandler from './responders/landing-handler.js'; import { errorReportDownloadResponder } from './responders/report-responders.js'; import { inviteResponder, websiteResponder, } from './responders/website-responders.js'; import { webWorkerResponder } from './responders/webworker-responders.js'; import { onConnection } from './socket/socket.js'; import { createAndMaintainTunnelbrokerWebsocket } from './socket/tunnelbroker.js'; import { multerProcessor, multimediaUploadResponder, uploadDownloadResponder, } from './uploads/uploads.js'; import { verifyUserLoggedIn } from './user/login.js'; import { initENSCache } from './utils/ens-cache.js'; import { getContentSigningKey } from './utils/olm-utils.js'; import { prefetchAllURLFacts, getKeyserverURLFacts, getLandingURLFacts, getWebAppURLFacts, getWebAppCorsConfig, } from './utils/urls.js'; const shouldDisplayQRCodeInTerminal = false; (async () => { const [webAppCorsConfig] = await Promise.all([ getWebAppCorsConfig(), olm.init(), prefetchAllURLFacts(), initENSCache(), ]); const keyserverURLFacts = getKeyserverURLFacts(); const keyserverBaseRoutePath = keyserverURLFacts?.baseRoutePath; const landingBaseRoutePath = getLandingURLFacts()?.baseRoutePath; const webAppURLFacts = getWebAppURLFacts(); const webAppBaseRoutePath = webAppURLFacts?.baseRoutePath; const compiledFolderOptions = process.env.NODE_ENV === 'development' ? undefined : { maxAge: '1y', immutable: true }; let keyserverCorsOptions = null; if (webAppCorsConfig) { keyserverCorsOptions = { origin: webAppCorsConfig.domain, methods: ['GET', 'POST'], }; } const isCPUProfilingEnabled = process.env.KEYSERVER_CPU_PROFILING_ENABLED; const areEndpointMetricsEnabled = process.env.KEYSERVER_ENDPOINT_METRICS_ENABLED; if (cluster.isMaster) { const didMigrationsSucceed: boolean = await migrate(); if (!didMigrationsSucceed) { // The following line uses exit code 2 to ensure nodemon exits // in a dev environment, instead of restarting. Context provided // in https://github.com/remy/nodemon/issues/751 process.exit(2); } // Allow login to be optional until staging environment is available try { // We await here to ensure that the keyserver has been provisioned a // commServicesAccessToken. In the future, this will be necessary for // many keyserver operations. const identityInfo = await verifyUserLoggedIn(); // We don't await here, as Tunnelbroker communication is not needed for // normal keyserver behavior yet. In addition, this doesn't return // information useful for other keyserver functions. - handleAsyncPromise(createAndMaintainTunnelbrokerWebsocket(identityInfo)); + ignorePromiseRejections( + createAndMaintainTunnelbrokerWebsocket(identityInfo), + ); } catch (e) { console.warn( 'Failed identity login. Login optional until staging environment is available', ); } if (shouldDisplayQRCodeInTerminal) { try { const aes256Key = crypto.randomBytes(32).toString('hex'); const ed25519Key = await getContentSigningKey(); console.log( '\nOpen the Comm app on your phone and scan the QR code below\n', ); console.log('How to find the scanner:\n'); console.log('Go to \x1b[1mProfile\x1b[0m'); console.log('Select \x1b[1mLinked devices\x1b[0m'); console.log('Click \x1b[1mAdd\x1b[0m on the top right'); const url = qrCodeLinkURL(aes256Key, ed25519Key); qrcode.toString(url, (error, encodedURL) => console.log(encodedURL)); } catch (e) { console.log('Error generating QR code', e); } } if (!isCPUProfilingEnabled) { const cpuCount = os.cpus().length; for (let i = 0; i < cpuCount; i++) { cluster.fork(); } cluster.on('exit', () => cluster.fork()); } } if (!cluster.isMaster || isCPUProfilingEnabled) { const server = express(); server.use(compression()); expressWs(server); server.use(express.json({ limit: '250mb' })); server.use(cookieParser()); // Note - the order of router declarations matters. On prod we have // keyserverBaseRoutePath configured to '/', which means it's a catch-all. // If we call server.use on keyserverRouter first, it will catch all // requests and prevent webAppRouter and landingRouter from working // correctly. So we make sure that keyserverRouter goes last server.get('/invite/:secret', inviteResponder); if (landingBaseRoutePath) { const landingRouter = express.Router<$Request, $Response>(); landingRouter.get('/invite/:secret', inviteResponder); landingRouter.use( '/.well-known', express.static( '.well-known', // Necessary for apple-app-site-association file { setHeaders: res => res.setHeader('Content-Type', 'application/json'), }, ), ); landingRouter.use('/images', express.static('images')); landingRouter.use('/fonts', express.static('fonts')); landingRouter.use( '/compiled', express.static('landing_compiled', compiledFolderOptions), ); landingRouter.use('/', express.static('landing_icons')); landingRouter.post('/subscribe_email', emailSubscriptionResponder); landingRouter.get('*', landingHandler); server.use(landingBaseRoutePath, landingRouter); } if (webAppBaseRoutePath) { const webAppRouter = express.Router<$Request, $Response>(); webAppRouter.use('/images', express.static('images')); webAppRouter.use('/fonts', express.static('fonts')); webAppRouter.use('/misc', express.static('misc')); webAppRouter.use( '/.well-known', express.static( '.well-known', // Necessary for apple-app-site-association file { setHeaders: res => res.setHeader('Content-Type', 'application/json'), }, ), ); webAppRouter.use( '/compiled', express.static('app_compiled', compiledFolderOptions), ); webAppRouter.use('/', express.static('icons')); webAppRouter.get('/invite/:secret', inviteResponder); webAppRouter.get('/worker/:worker', webWorkerResponder); if (keyserverURLFacts) { webAppRouter.get( '/upload/:uploadID/:secret', (req: $Request, res: $Response) => { const { uploadID, secret } = req.params; const url = `${keyserverURLFacts.baseDomain}${keyserverURLFacts.basePath}upload/${uploadID}/${secret}`; res.redirect(url); }, ); } webAppRouter.get('*', htmlHandler(websiteResponder)); server.use(webAppBaseRoutePath, webAppRouter); } if (keyserverBaseRoutePath) { const keyserverRouter = express.Router<$Request, $Response>(); if (areEndpointMetricsEnabled) { keyserverRouter.use(logEndpointMetrics); } if (keyserverCorsOptions) { keyserverRouter.use(cors(keyserverCorsOptions)); } for (const endpoint in jsonEndpoints) { // $FlowFixMe Flow thinks endpoint is string const responder = jsonEndpoints[endpoint]; const expectCookieInvalidation = endpoint === 'log_out'; keyserverRouter.post( `/${endpoint}`, jsonHandler(responder, expectCookieInvalidation), ); } keyserverRouter.get( '/download_error_report/:reportID', downloadHandler(errorReportDownloadResponder), ); keyserverRouter.get( '/upload/:uploadID/:secret', downloadHandler(uploadDownloadResponder), ); // $FlowFixMe express-ws has side effects that can't be typed keyserverRouter.ws('/ws', onConnection); keyserverRouter.post( '/upload_multimedia', multerProcessor, uploadHandler(multimediaUploadResponder), ); server.use(keyserverBaseRoutePath, keyserverRouter); } if (isDev && webAppURLFacts) { const oldPath = '/comm/'; server.all(`${oldPath}*`, (req: $Request, res: $Response) => { const endpoint = req.url.slice(oldPath.length); const newURL = `${webAppURLFacts.baseDomain}${webAppURLFacts.basePath}${endpoint}`; res.redirect(newURL); }); } const listenAddress = (() => { if (process.env.COMM_LISTEN_ADDR) { return process.env.COMM_LISTEN_ADDR; } else if (process.env.NODE_ENV === 'development') { return undefined; } else { return 'localhost'; } })(); server.listen(parseInt(process.env.PORT, 10) || 3000, listenAddress); } })(); diff --git a/keyserver/src/responders/handlers.js b/keyserver/src/responders/handlers.js index 929c9bbaf..2bbb67d27 100644 --- a/keyserver/src/responders/handlers.js +++ b/keyserver/src/responders/handlers.js @@ -1,246 +1,237 @@ // @flow import type { $Response, $Request } from 'express'; import type { TType } from 'tcomb'; import { ServerError } from 'lib/utils/errors.js'; import { assertWithValidator, tPlatformDetails, } from 'lib/utils/validation-utils.js'; import { getMessageForException } from './utils.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import type { PolicyType } from '../lib/facts/policies.js'; import { fetchViewerForJSONRequest, addCookieToJSONResponse, addCookieToHomeResponse, createNewAnonymousCookie, setCookiePlatformDetails, } from '../session/cookies.js'; import type { Viewer } from '../session/viewer.js'; import { getAppURLFactsFromRequestURL } from '../utils/urls.js'; import { policiesValidator, validateInput, validateOutput, } from '../utils/validation-utils.js'; type InnerJSONResponder = { responder: (viewer: Viewer, input: any) => Promise<*>, requiredPolicies: $ReadOnlyArray, }; export opaque type JSONResponder: InnerJSONResponder = InnerJSONResponder; function createJSONResponder( responder: (Viewer, input: I) => Promise, inputValidator: TType, outputValidator: TType, requiredPolicies: $ReadOnlyArray, ): JSONResponder { return { responder: async (viewer, input) => { const request = await validateInput(viewer, inputValidator, input); const result = await responder(viewer, request); return validateOutput(viewer.platformDetails, outputValidator, result); }, requiredPolicies, }; } export type DownloadResponder = ( viewer: Viewer, req: $Request, res: $Response, ) => Promise; export type HTTPGetResponder = DownloadResponder; export type HTMLResponder = (req: $Request, res: $Response) => Promise; function jsonHandler( responder: JSONResponder, expectCookieInvalidation: boolean, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { if (!req.body || typeof req.body !== 'object') { throw new ServerError('invalid_parameters'); } const { input, platformDetails } = req.body; viewer = await fetchViewerForJSONRequest(req); const promises = [policiesValidator(viewer, responder.requiredPolicies)]; if (platformDetails) { if (!tPlatformDetails.is(platformDetails)) { throw new ServerError('invalid_platform_details'); } promises.push( setCookiePlatformDetails( viewer, assertWithValidator(platformDetails, tPlatformDetails), ), ); } await Promise.all(promises); const responderResult = await responder.responder(viewer, input); if (res.headersSent) { return; } const result = { ...responderResult }; addCookieToJSONResponse(viewer, res, result, expectCookieInvalidation); res.json({ success: true, ...result }); } catch (e) { await handleException(e, res, viewer, expectCookieInvalidation); } }; } function httpGetHandler( responder: HTTPGetResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { viewer = await fetchViewerForJSONRequest(req); await responder(viewer, req, res); } catch (e) { await handleException(e, res, viewer); } }; } function downloadHandler( responder: DownloadResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { try { const viewer = await fetchViewerForJSONRequest(req); await responder(viewer, req, res); } catch (e) { // Passing viewer in only makes sense if we want to handle failures as // JSON. We don't, and presume all download handlers avoid ServerError. await handleException(e, res); } }; } async function handleException( error: Error, res: $Response, viewer?: ?Viewer, expectCookieInvalidation?: boolean, ) { console.warn(error); if (res.headersSent) { return; } if (!(error instanceof ServerError)) { res.status(500).send(getMessageForException(error)); return; } const result: Object = error.payload ? { error: error.message, payload: error.payload } : { error: error.message }; if (viewer) { if (error.message === 'client_version_unsupported' && viewer.loggedIn) { // If the client version is unsupported, log the user out const { platformDetails } = error; const [data] = await Promise.all([ createNewAnonymousCookie({ platformDetails, deviceToken: viewer.deviceToken, }), deleteCookie(viewer.cookieID), ]); viewer.setNewCookie(data); viewer.cookieInvalidated = true; } // This can mutate the result object addCookieToJSONResponse(viewer, res, result, !!expectCookieInvalidation); } res.json(result); } function htmlHandler( responder: HTMLResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { try { addCookieToHomeResponse( req, res, getAppURLFactsFromRequestURL(req.originalUrl), ); res.type('html'); await responder(req, res); } catch (e) { console.warn(e); if (!res.headersSent) { res.status(500).send(getMessageForException(e)); } } }; } type MulterFile = { fieldname: string, originalname: string, encoding: string, mimetype: string, buffer: Buffer, size: number, }; export type MulterRequest = $Request & { files?: $ReadOnlyArray, ... }; type UploadResponder = (viewer: Viewer, req: MulterRequest) => Promise; function uploadHandler( responder: UploadResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { if (!req.body || typeof req.body !== 'object') { throw new ServerError('invalid_parameters'); } viewer = await fetchViewerForJSONRequest(req); const responderResult = await responder( viewer, ((req: any): MulterRequest), ); if (res.headersSent) { return; } const result = { ...responderResult }; addCookieToJSONResponse(viewer, res, result, false); res.json({ success: true, ...result }); } catch (e) { await handleException(e, res, viewer); } }; } -async function handleAsyncPromise(promise: Promise) { - try { - await promise; - } catch (error) { - console.warn(error); - } -} - export { createJSONResponder, jsonHandler, httpGetHandler, downloadHandler, htmlHandler, uploadHandler, - handleAsyncPromise, }; diff --git a/keyserver/src/session/cookies.js b/keyserver/src/session/cookies.js index 1db6167ec..33d7df636 100644 --- a/keyserver/src/session/cookies.js +++ b/keyserver/src/session/cookies.js @@ -1,813 +1,813 @@ // @flow import crypto from 'crypto'; import type { $Response, $Request } from 'express'; import invariant from 'invariant'; import url from 'url'; import { isStaff } from 'lib/shared/staff-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { SignedIdentityKeysBlob } from 'lib/types/crypto-types.js'; import type { Platform, PlatformDetails } from 'lib/types/device-types.js'; import type { CalendarQuery } from 'lib/types/entry-types.js'; import { type ServerSessionChange, cookieLifetime, cookieTypes, sessionIdentifierTypes, type SessionIdentifierType, } from 'lib/types/session-types.js'; import type { SIWESocialProof } from 'lib/types/siwe-types.js'; import type { InitialClientSocketMessage } from 'lib/types/socket-types.js'; import type { UserInfo } from 'lib/types/user-types.js'; import { isDev } from 'lib/utils/dev-utils.js'; +import { ignorePromiseRejections } from 'lib/utils/promises.js'; import { isBcryptHash, getCookieHash, verifyCookieHash, } from './cookie-hash.js'; import { Viewer } from './viewer.js'; import type { AnonymousViewerData, UserViewerData } from './viewer.js'; import createIDs from '../creators/id-creator.js'; import { createSession } from '../creators/session-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { clearDeviceToken } from '../updaters/device-token-updaters.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { type AppURLFacts, getAppURLFactsFromRequestURL, } from '../utils/urls.js'; function cookieIsExpired(lastUsed: number) { return lastUsed + cookieLifetime <= Date.now(); } type SessionParameterInfo = { isSocket: boolean, sessionID: ?string, sessionIdentifierType: SessionIdentifierType, ipAddress: string, userAgent: ?string, }; type FetchViewerResult = | { +type: 'valid', +viewer: Viewer } | InvalidFetchViewerResult; type InvalidFetchViewerResult = | { +type: 'nonexistant', +cookieName: ?string, +sessionParameterInfo: SessionParameterInfo, } | { +type: 'invalidated', +cookieName: string, +cookieID: string, +sessionParameterInfo: SessionParameterInfo, +platformDetails: ?PlatformDetails, +deviceToken: ?string, }; async function fetchUserViewer( cookie: string, sessionParameterInfo: SessionParameterInfo, ): Promise { const [cookieID, cookiePassword] = cookie.split(':'); if (!cookieID || !cookiePassword) { return { type: 'nonexistant', cookieName: cookieTypes.USER, sessionParameterInfo, }; } const query = SQL` SELECT hash, user, last_used, platform, device_token, versions FROM cookies WHERE id = ${cookieID} AND user IS NOT NULL `; const [[result], allSessionInfo] = await Promise.all([ dbQuery(query), fetchSessionInfo(sessionParameterInfo, cookieID), ]); if (result.length === 0) { return { type: 'nonexistant', cookieName: cookieTypes.USER, sessionParameterInfo, }; } let sessionID = null, sessionInfo = null; if (allSessionInfo) { ({ sessionID, ...sessionInfo } = allSessionInfo); } const cookieRow = result[0]; let platformDetails; let versions = null; if (cookieRow.versions) { versions = JSON.parse(cookieRow.versions); platformDetails = { platform: cookieRow.platform, codeVersion: versions.codeVersion, stateVersion: versions.stateVersion, }; } else { platformDetails = { platform: cookieRow.platform }; } if (versions && versions.majorDesktopVersion) { platformDetails = { ...platformDetails, majorDesktopVersion: versions.majorDesktopVersion, }; } const deviceToken = cookieRow.device_token; const cookieHash = cookieRow.hash; if ( !verifyCookieHash(cookiePassword, cookieHash) || cookieIsExpired(cookieRow.last_used) ) { return { type: 'invalidated', cookieName: cookieTypes.USER, cookieID, sessionParameterInfo, platformDetails, deviceToken, }; } const userID = cookieRow.user.toString(); const viewer = new Viewer({ isSocket: sessionParameterInfo.isSocket, loggedIn: true, id: userID, platformDetails, deviceToken, userID, cookieID, cookiePassword, cookieHash, sessionIdentifierType: sessionParameterInfo.sessionIdentifierType, sessionID, sessionInfo, isScriptViewer: false, ipAddress: sessionParameterInfo.ipAddress, userAgent: sessionParameterInfo.userAgent, }); return { type: 'valid', viewer }; } async function fetchAnonymousViewer( cookie: string, sessionParameterInfo: SessionParameterInfo, ): Promise { const [cookieID, cookiePassword] = cookie.split(':'); if (!cookieID || !cookiePassword) { return { type: 'nonexistant', cookieName: cookieTypes.ANONYMOUS, sessionParameterInfo, }; } const query = SQL` SELECT last_used, hash, platform, device_token, versions FROM cookies WHERE id = ${cookieID} AND user IS NULL `; const [[result], allSessionInfo] = await Promise.all([ dbQuery(query), fetchSessionInfo(sessionParameterInfo, cookieID), ]); if (result.length === 0) { return { type: 'nonexistant', cookieName: cookieTypes.ANONYMOUS, sessionParameterInfo, }; } let sessionID = null, sessionInfo = null; if (allSessionInfo) { ({ sessionID, ...sessionInfo } = allSessionInfo); } const cookieRow = result[0]; let platformDetails = null; let versions = null; if (cookieRow.platform && cookieRow.versions) { versions = JSON.parse(cookieRow.versions); platformDetails = { platform: cookieRow.platform, codeVersion: versions.codeVersion, stateVersion: versions.stateVersion, }; } else if (cookieRow.platform) { platformDetails = { platform: cookieRow.platform }; } if (platformDetails && versions && versions.majorDesktopVersion) { platformDetails = { ...platformDetails, majorDesktopVersion: versions.majorDesktopVersion, }; } const deviceToken = cookieRow.device_token; const cookieHash = cookieRow.hash; if ( !verifyCookieHash(cookiePassword, cookieHash) || cookieIsExpired(cookieRow.last_used) ) { return { type: 'invalidated', cookieName: cookieTypes.ANONYMOUS, cookieID, sessionParameterInfo, platformDetails, deviceToken, }; } const viewer = new Viewer({ isSocket: sessionParameterInfo.isSocket, loggedIn: false, id: cookieID, platformDetails, deviceToken, cookieID, cookiePassword, cookieHash, sessionIdentifierType: sessionParameterInfo.sessionIdentifierType, sessionID, sessionInfo, isScriptViewer: false, ipAddress: sessionParameterInfo.ipAddress, userAgent: sessionParameterInfo.userAgent, }); return { type: 'valid', viewer }; } type SessionInfo = { +sessionID: ?string, +lastValidated: number, +lastUpdate: number, +calendarQuery: CalendarQuery, }; async function fetchSessionInfo( sessionParameterInfo: SessionParameterInfo, cookieID: string, ): Promise { const { sessionID } = sessionParameterInfo; const session = sessionID !== undefined ? sessionID : cookieID; if (!session) { return null; } const query = SQL` SELECT query, last_validated, last_update FROM sessions WHERE id = ${session} AND cookie = ${cookieID} `; const [result] = await dbQuery(query); if (result.length === 0) { return null; } return { sessionID, lastValidated: result[0].last_validated, lastUpdate: result[0].last_update, calendarQuery: JSON.parse(result[0].query), }; } async function fetchViewerFromRequestBody( body: mixed, sessionParameterInfo: SessionParameterInfo, ): Promise { if (!body || typeof body !== 'object') { return { type: 'nonexistant', cookieName: null, sessionParameterInfo, }; } const cookiePair = body.cookie; if (cookiePair === null || cookiePair === '') { return { type: 'nonexistant', cookieName: null, sessionParameterInfo, }; } if (!cookiePair || typeof cookiePair !== 'string') { return { type: 'nonexistant', cookieName: null, sessionParameterInfo, }; } const [type, cookie] = cookiePair.split('='); if (type === cookieTypes.USER && cookie) { return await fetchUserViewer(cookie, sessionParameterInfo); } else if (type === cookieTypes.ANONYMOUS && cookie) { return await fetchAnonymousViewer(cookie, sessionParameterInfo); } return { type: 'nonexistant', cookieName: null, sessionParameterInfo, }; } function getRequestIPAddress(req: $Request) { const { proxy } = getAppURLFactsFromRequestURL(req.originalUrl); let ipAddress; if (proxy === 'none') { ipAddress = req.socket.remoteAddress; } else if (proxy === 'apache') { ipAddress = req.get('X-Forwarded-For'); } invariant(ipAddress, 'could not determine requesting IP address'); return ipAddress; } function getSessionParameterInfoFromRequestBody( req: $Request, ): SessionParameterInfo { const body = (req.body: any); let sessionID = body.sessionID !== undefined || req.method !== 'GET' ? body.sessionID : null; if (sessionID === '') { sessionID = null; } const sessionIdentifierType = req.method === 'GET' || sessionID !== undefined ? sessionIdentifierTypes.BODY_SESSION_ID : sessionIdentifierTypes.COOKIE_ID; return { isSocket: false, sessionID, sessionIdentifierType, ipAddress: getRequestIPAddress(req), userAgent: req.get('User-Agent'), }; } async function fetchViewerForJSONRequest(req: $Request): Promise { assertSecureRequest(req); const sessionParameterInfo = getSessionParameterInfoFromRequestBody(req); const result = await fetchViewerFromRequestBody( req.body, sessionParameterInfo, ); return await handleFetchViewerResult(result); } async function fetchViewerForSocket( req: $Request, clientMessage: InitialClientSocketMessage, ): Promise { assertSecureRequest(req); const { sessionIdentification } = clientMessage.payload; const { sessionID } = sessionIdentification; const sessionParameterInfo = { isSocket: true, sessionID, sessionIdentifierType: sessionID !== undefined ? sessionIdentifierTypes.BODY_SESSION_ID : sessionIdentifierTypes.COOKIE_ID, ipAddress: getRequestIPAddress(req), userAgent: req.get('User-Agent'), }; const result = await fetchViewerFromRequestBody( clientMessage.payload.sessionIdentification, sessionParameterInfo, ); if (result.type === 'valid') { return result.viewer; } const anonymousViewerDataPromise: Promise = (async () => { const platformDetails = result.type === 'invalidated' ? result.platformDetails : null; const deviceToken = result.type === 'invalidated' ? result.deviceToken : null; return await createNewAnonymousCookie({ platformDetails, deviceToken, }); })(); const deleteCookiePromise = (async () => { if (result.type === 'invalidated') { await deleteCookie(result.cookieID); } })(); const [anonymousViewerData] = await Promise.all([ anonymousViewerDataPromise, deleteCookiePromise, ]); return createViewerForInvalidFetchViewerResult(result, anonymousViewerData); } async function handleFetchViewerResult( result: FetchViewerResult, inputPlatformDetails?: PlatformDetails, ) { if (result.type === 'valid') { return result.viewer; } let platformDetails: ?PlatformDetails = inputPlatformDetails; if (!platformDetails && result.type === 'invalidated') { platformDetails = result.platformDetails; } const deviceToken = result.type === 'invalidated' ? result.deviceToken : null; const deleteCookiePromise = (async () => { if (result.type === 'invalidated') { await deleteCookie(result.cookieID); } })(); const [anonymousViewerData] = await Promise.all([ createNewAnonymousCookie({ platformDetails, deviceToken }), deleteCookiePromise, ]); return createViewerForInvalidFetchViewerResult(result, anonymousViewerData); } function createViewerForInvalidFetchViewerResult( result: InvalidFetchViewerResult, anonymousViewerData: AnonymousViewerData, ): Viewer { const viewer = new Viewer({ ...anonymousViewerData, sessionIdentifierType: result.sessionParameterInfo.sessionIdentifierType, isSocket: result.sessionParameterInfo.isSocket, ipAddress: result.sessionParameterInfo.ipAddress, userAgent: result.sessionParameterInfo.userAgent, }); viewer.sessionChanged = true; // If cookieName is falsey, that tells us that there was no cookie specified // in the request, which means we can't be invalidating anything. if (result.cookieName) { viewer.cookieInvalidated = true; viewer.initialCookieName = result.cookieName; } return viewer; } function addSessionChangeInfoToResult( viewer: Viewer, res: $Response, result: Object, ) { let threadInfos = {}, userInfos: $ReadOnlyArray = []; if (result.cookieChange) { ({ threadInfos, userInfos } = result.cookieChange); } let sessionChange; if (viewer.cookieInvalidated) { sessionChange = ({ cookieInvalidated: true, threadInfos, userInfos, currentUserInfo: { anonymous: true, }, }: ServerSessionChange); } else { sessionChange = ({ cookieInvalidated: false, threadInfos, userInfos, }: ServerSessionChange); } sessionChange.cookie = viewer.cookiePairString; if (viewer.sessionIdentifierType === sessionIdentifierTypes.BODY_SESSION_ID) { sessionChange.sessionID = viewer.sessionID ? viewer.sessionID : null; } result.cookieChange = sessionChange; } type AnonymousCookieCreationParams = Partial<{ +platformDetails: ?PlatformDetails, +deviceToken: ?string, }>; const defaultPlatformDetails = {}; // The result of this function should not be passed directly to the Viewer // constructor. Instead, it should be passed to viewer.setNewCookie. There are // several fields on AnonymousViewerData that are not set by this function: // sessionIdentifierType, ipAddress, and userAgent. These parameters all depend // on the initial request. If the result of this function is passed to the // Viewer constructor directly, the resultant Viewer object will throw whenever // anybody attempts to access the relevant properties. async function createNewAnonymousCookie( params: AnonymousCookieCreationParams, ): Promise { const { platformDetails, deviceToken } = params; const { platform, ...versions } = platformDetails || defaultPlatformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const time = Date.now(); const cookiePassword = crypto.randomBytes(32).toString('hex'); const cookieHash = getCookieHash(cookiePassword); const [[id]] = await Promise.all([ createIDs('cookies', 1), deviceToken ? clearDeviceToken(deviceToken) : undefined, ]); const cookieRow = [ id, cookieHash, null, platform, time, time, deviceToken, versionsString, ]; const query = SQL` INSERT INTO cookies(id, hash, user, platform, creation_time, last_used, device_token, versions) VALUES ${[cookieRow]} `; await dbQuery(query); return { loggedIn: false, id, platformDetails, deviceToken, cookieID: id, cookiePassword, cookieHash, sessionID: undefined, sessionInfo: null, cookieInsertedThisRequest: true, isScriptViewer: false, }; } type UserCookieCreationParams = { +platformDetails: PlatformDetails, +deviceToken?: ?string, +socialProof?: ?SIWESocialProof, +signedIdentityKeysBlob?: ?SignedIdentityKeysBlob, }; // The result of this function should never be passed directly to the Viewer // constructor. Instead, it should be passed to viewer.setNewCookie. There are // several fields on UserViewerData that are not set by this function: // sessionID, sessionIdentifierType, and ipAddress. These parameters all depend // on the initial request. If the result of this function is passed to the // Viewer constructor directly, the resultant Viewer object will throw whenever // anybody attempts to access the relevant properties. async function createNewUserCookie( userID: string, params: UserCookieCreationParams, ): Promise { const { platformDetails, deviceToken, socialProof, signedIdentityKeysBlob } = params; const { platform, ...versions } = platformDetails || defaultPlatformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const time = Date.now(); const cookiePassword = crypto.randomBytes(32).toString('hex'); const cookieHash = getCookieHash(cookiePassword); const [[cookieID]] = await Promise.all([ createIDs('cookies', 1), deviceToken ? clearDeviceToken(deviceToken) : undefined, ]); const cookieRow = [ cookieID, cookieHash, userID, platform, time, time, deviceToken, versionsString, JSON.stringify(socialProof), signedIdentityKeysBlob ? JSON.stringify(signedIdentityKeysBlob) : null, ]; const query = SQL` INSERT INTO cookies(id, hash, user, platform, creation_time, last_used, device_token, versions, social_proof, signed_identity_keys) VALUES ${[cookieRow]} `; await dbQuery(query); return { loggedIn: true, id: userID, platformDetails, deviceToken, userID, cookieID, sessionID: undefined, sessionInfo: null, cookiePassword, cookieHash, cookieInsertedThisRequest: true, isScriptViewer: false, }; } // This gets called after createNewUserCookie and from websiteResponder. If the // Viewer's sessionIdentifierType is COOKIE_ID then the cookieID is used as the // session identifier; otherwise, a new ID is created for the session. async function setNewSession( viewer: Viewer, calendarQuery: CalendarQuery, initialLastUpdate: number, ): Promise { if (viewer.sessionIdentifierType !== sessionIdentifierTypes.COOKIE_ID) { const [sessionID] = await createIDs('sessions', 1); viewer.setSessionID(sessionID); } await createSession(viewer, calendarQuery, initialLastUpdate); } async function updateCookie(viewer: Viewer) { const time = Date.now(); const { cookieID, cookieHash, cookiePassword } = viewer; const updateObj: { [string]: string | number } = {}; updateObj.last_used = time; if (isBcryptHash(cookieHash)) { updateObj.hash = getCookieHash(cookiePassword); } const query = SQL` UPDATE cookies SET ${updateObj} WHERE id = ${cookieID} `; await dbQuery(query); } function addCookieToJSONResponse( viewer: Viewer, res: $Response, result: Object, expectCookieInvalidation: boolean, ) { if (expectCookieInvalidation) { viewer.cookieInvalidated = false; } if (!viewer.getData().cookieInsertedThisRequest) { - handleAsyncPromise(updateCookie(viewer)); + ignorePromiseRejections(updateCookie(viewer)); } if (viewer.sessionChanged) { addSessionChangeInfoToResult(viewer, res, result); } } function addCookieToHomeResponse( req: $Request, res: $Response, appURLFacts: AppURLFacts, ) { const { user, anonymous } = req.cookies; if (user) { res.cookie(cookieTypes.USER, user, getCookieOptions(appURLFacts)); } if (anonymous) { res.cookie(cookieTypes.ANONYMOUS, anonymous, getCookieOptions(appURLFacts)); } } function getCookieOptions(appURLFacts: AppURLFacts) { const { baseDomain, basePath, https } = appURLFacts; const domainAsURL = new url.URL(baseDomain); return { domain: domainAsURL.hostname, path: basePath, httpOnly: false, secure: https, maxAge: cookieLifetime, sameSite: 'Strict', }; } async function setCookieSignedIdentityKeysBlob( cookieID: string, signedIdentityKeysBlob: SignedIdentityKeysBlob, ) { const signedIdentityKeysStr = JSON.stringify(signedIdentityKeysBlob); const query = SQL` UPDATE cookies SET signed_identity_keys = ${signedIdentityKeysStr} WHERE id = ${cookieID} `; await dbQuery(query); } // Returns `true` if row with `id = cookieID` exists AND // `signed_identity_keys` is `NULL`. Otherwise, returns `false`. async function isCookieMissingSignedIdentityKeysBlob( cookieID: string, ): Promise { const query = SQL` SELECT signed_identity_keys FROM cookies WHERE id = ${cookieID} `; const [queryResult] = await dbQuery(query); return ( queryResult.length === 1 && queryResult[0].signed_identity_keys === null ); } async function isCookieMissingOlmNotificationsSession( viewer: Viewer, ): Promise { const isStaffOrDev = isStaff(viewer.userID) || isDev; if ( !viewer.platformDetails || (viewer.platformDetails.platform !== 'ios' && viewer.platformDetails.platform !== 'android' && !(viewer.platformDetails.platform === 'web' && isStaffOrDev) && !(viewer.platformDetails.platform === 'macos' && isStaffOrDev)) || !hasMinCodeVersion(viewer.platformDetails, { native: 222, web: 43, majorDesktop: 9, }) ) { return false; } const query = SQL` SELECT COUNT(*) AS count FROM olm_sessions WHERE cookie_id = ${viewer.cookieID} AND is_content = FALSE `; const [queryResult] = await dbQuery(query); return queryResult[0].count === 0; } async function setCookiePlatform( viewer: Viewer, platform: Platform, ): Promise { const newPlatformDetails = { ...viewer.platformDetails, platform }; viewer.setPlatformDetails(newPlatformDetails); const query = SQL` UPDATE cookies SET platform = ${platform} WHERE id = ${viewer.cookieID} `; await dbQuery(query); } async function setCookiePlatformDetails( viewer: Viewer, platformDetails: PlatformDetails, ): Promise { viewer.setPlatformDetails(platformDetails); const { platform, ...versions } = platformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const query = SQL` UPDATE cookies SET platform = ${platform}, versions = ${versionsString} WHERE id = ${viewer.cookieID} `; await dbQuery(query); } export { fetchViewerForJSONRequest, fetchViewerForSocket, createNewAnonymousCookie, createNewUserCookie, setNewSession, updateCookie, addCookieToJSONResponse, addCookieToHomeResponse, setCookieSignedIdentityKeysBlob, isCookieMissingSignedIdentityKeysBlob, setCookiePlatform, setCookiePlatformDetails, isCookieMissingOlmNotificationsSession, }; diff --git a/keyserver/src/socket/session-utils.js b/keyserver/src/socket/session-utils.js index 12fc3d42e..07887ff8b 100644 --- a/keyserver/src/socket/session-utils.js +++ b/keyserver/src/socket/session-utils.js @@ -1,559 +1,558 @@ // @flow import invariant from 'invariant'; import t from 'tcomb'; import type { TUnion } from 'tcomb'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { UpdateActivityResult, ActivityUpdate, } from 'lib/types/activity-types.js'; import type { IdentityKeysBlob } from 'lib/types/crypto-types.js'; import { isDeviceType } from 'lib/types/device-types.js'; import type { CalendarQuery, DeltaEntryInfosResponse, } from 'lib/types/entry-types.js'; import { reportTypes, type ThreadInconsistencyReportCreationRequest, type EntryInconsistencyReportCreationRequest, } from 'lib/types/report-types.js'; import { serverRequestTypes, type ThreadInconsistencyClientResponse, type EntryInconsistencyClientResponse, type ClientResponse, type ServerServerRequest, type ServerCheckStateServerRequest, } from 'lib/types/request-types.js'; import { sessionCheckFrequency } from 'lib/types/session-types.js'; import { signedIdentityKeysBlobValidator } from 'lib/utils/crypto-utils.js'; import { hash, values } from 'lib/utils/objects.js'; -import { promiseAll } from 'lib/utils/promises.js'; +import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import { tShape, tPlatform, tPlatformDetails, } from 'lib/utils/validation-utils.js'; import { createOlmSession } from '../creators/olm-session-creator.js'; import { saveOneTimeKeys } from '../creators/one-time-keys-creator.js'; import createReport from '../creators/report-creator.js'; import { fetchEntriesForSession } from '../fetchers/entry-fetchers.js'; import { checkIfSessionHasEnoughOneTimeKeys } from '../fetchers/key-fetchers.js'; import { activityUpdatesInputValidator } from '../responders/activity-responders.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { threadInconsistencyReportValidatorShape, entryInconsistencyReportValidatorShape, } from '../responders/report-responders.js'; import { setNewSession, setCookiePlatform, setCookiePlatformDetails, setCookieSignedIdentityKeysBlob, } from '../session/cookies.js'; import type { Viewer } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { activityUpdater } from '../updaters/activity-updaters.js'; import { compareNewCalendarQuery } from '../updaters/entry-updaters.js'; import type { SessionUpdate } from '../updaters/session-updaters.js'; import { getOlmUtility } from '../utils/olm-utils.js'; const clientResponseInputValidator: TUnion = t.union([ tShape({ type: t.irreducible( 'serverRequestTypes.PLATFORM', x => x === serverRequestTypes.PLATFORM, ), platform: tPlatform, }), tShape({ ...threadInconsistencyReportValidatorShape, type: t.irreducible( 'serverRequestTypes.THREAD_INCONSISTENCY', x => x === serverRequestTypes.THREAD_INCONSISTENCY, ), }), tShape({ ...entryInconsistencyReportValidatorShape, type: t.irreducible( 'serverRequestTypes.ENTRY_INCONSISTENCY', x => x === serverRequestTypes.ENTRY_INCONSISTENCY, ), }), tShape({ type: t.irreducible( 'serverRequestTypes.PLATFORM_DETAILS', x => x === serverRequestTypes.PLATFORM_DETAILS, ), platformDetails: tPlatformDetails, }), tShape({ type: t.irreducible( 'serverRequestTypes.CHECK_STATE', x => x === serverRequestTypes.CHECK_STATE, ), hashResults: t.dict(t.String, t.Boolean), }), tShape({ type: t.irreducible( 'serverRequestTypes.INITIAL_ACTIVITY_UPDATES', x => x === serverRequestTypes.INITIAL_ACTIVITY_UPDATES, ), activityUpdates: activityUpdatesInputValidator, }), tShape({ type: t.irreducible( 'serverRequestTypes.MORE_ONE_TIME_KEYS', x => x === serverRequestTypes.MORE_ONE_TIME_KEYS, ), keys: t.list(t.String), }), tShape({ type: t.irreducible( 'serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB', x => x === serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB, ), signedIdentityKeysBlob: signedIdentityKeysBlobValidator, }), tShape({ type: t.irreducible( 'serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE', x => x === serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE, ), initialNotificationsEncryptedMessage: t.String, }), ]); type StateCheckStatus = | { status: 'state_validated' } | { status: 'state_invalid', invalidKeys: $ReadOnlyArray } | { status: 'state_check' }; type ProcessClientResponsesResult = { serverRequests: ServerServerRequest[], stateCheckStatus: ?StateCheckStatus, activityUpdateResult: ?UpdateActivityResult, }; async function processClientResponses( viewer: Viewer, clientResponses: $ReadOnlyArray, ): Promise { let viewerMissingPlatform = !viewer.platform; const { platformDetails } = viewer; let viewerMissingPlatformDetails = !platformDetails || (isDeviceType(viewer.platform) && (platformDetails.codeVersion === null || platformDetails.codeVersion === undefined || platformDetails.stateVersion === null || platformDetails.stateVersion === undefined)); const promises = []; let activityUpdates: Array = []; let stateCheckStatus = null; const clientSentPlatformDetails = clientResponses.some( response => response.type === serverRequestTypes.PLATFORM_DETAILS, ); for (const clientResponse of clientResponses) { if ( clientResponse.type === serverRequestTypes.PLATFORM && !clientSentPlatformDetails ) { promises.push(setCookiePlatform(viewer, clientResponse.platform)); viewerMissingPlatform = false; if (!isDeviceType(clientResponse.platform)) { viewerMissingPlatformDetails = false; } } else if ( clientResponse.type === serverRequestTypes.THREAD_INCONSISTENCY ) { promises.push(recordThreadInconsistency(viewer, clientResponse)); } else if (clientResponse.type === serverRequestTypes.ENTRY_INCONSISTENCY) { promises.push(recordEntryInconsistency(viewer, clientResponse)); } else if (clientResponse.type === serverRequestTypes.PLATFORM_DETAILS) { promises.push( setCookiePlatformDetails(viewer, clientResponse.platformDetails), ); viewerMissingPlatform = false; viewerMissingPlatformDetails = false; } else if ( clientResponse.type === serverRequestTypes.INITIAL_ACTIVITY_UPDATES ) { activityUpdates = [...activityUpdates, ...clientResponse.activityUpdates]; } else if (clientResponse.type === serverRequestTypes.CHECK_STATE) { const invalidKeys = []; for (const key in clientResponse.hashResults) { const result = clientResponse.hashResults[key]; if (!result) { invalidKeys.push(key); } } stateCheckStatus = invalidKeys.length > 0 ? { status: 'state_invalid', invalidKeys } : { status: 'state_validated' }; } else if (clientResponse.type === serverRequestTypes.MORE_ONE_TIME_KEYS) { invariant(clientResponse.keys, 'keys expected in client response'); - handleAsyncPromise(saveOneTimeKeys(viewer, clientResponse.keys)); + ignorePromiseRejections(saveOneTimeKeys(viewer, clientResponse.keys)); } else if ( clientResponse.type === serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB ) { invariant( clientResponse.signedIdentityKeysBlob, 'signedIdentityKeysBlob expected in client response', ); const { signedIdentityKeysBlob } = clientResponse; const identityKeys: IdentityKeysBlob = JSON.parse( signedIdentityKeysBlob.payload, ); const olmUtil = getOlmUtility(); try { olmUtil.ed25519_verify( identityKeys.primaryIdentityPublicKeys.ed25519, signedIdentityKeysBlob.payload, signedIdentityKeysBlob.signature, ); - handleAsyncPromise( + ignorePromiseRejections( setCookieSignedIdentityKeysBlob( viewer.cookieID, signedIdentityKeysBlob, ), ); } catch (e) { continue; } } else if ( clientResponse.type === serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE ) { invariant( t.String.is(clientResponse.initialNotificationsEncryptedMessage), 'initialNotificationsEncryptedMessage expected in client response', ); const { initialNotificationsEncryptedMessage } = clientResponse; try { await createOlmSession( initialNotificationsEncryptedMessage, 'notifications', viewer.cookieID, ); } catch (e) { continue; } } } const activityUpdatePromise: Promise = (async () => { if (activityUpdates.length === 0) { return undefined; } return await activityUpdater(viewer, { updates: activityUpdates }); })(); const serverRequests: Array = []; const checkOneTimeKeysPromise = (async () => { if (!viewer.loggedIn) { return; } const enoughOneTimeKeys = await checkIfSessionHasEnoughOneTimeKeys( viewer.session, ); if (!enoughOneTimeKeys) { serverRequests.push({ type: serverRequestTypes.MORE_ONE_TIME_KEYS }); } })(); const { activityUpdateResult } = await promiseAll({ all: Promise.all(promises), activityUpdateResult: activityUpdatePromise, checkOneTimeKeysPromise, }); if ( !stateCheckStatus && viewer.loggedIn && viewer.sessionLastValidated + sessionCheckFrequency < Date.now() ) { stateCheckStatus = { status: 'state_check' }; } if (viewerMissingPlatform) { serverRequests.push({ type: serverRequestTypes.PLATFORM }); } if (viewerMissingPlatformDetails) { serverRequests.push({ type: serverRequestTypes.PLATFORM_DETAILS }); } return { serverRequests, stateCheckStatus, activityUpdateResult }; } async function recordThreadInconsistency( viewer: Viewer, response: ThreadInconsistencyClientResponse, ): Promise { const { type, ...rest } = response; const reportCreationRequest = ({ ...rest, type: reportTypes.THREAD_INCONSISTENCY, }: ThreadInconsistencyReportCreationRequest); await createReport(viewer, reportCreationRequest); } async function recordEntryInconsistency( viewer: Viewer, response: EntryInconsistencyClientResponse, ): Promise { const { type, ...rest } = response; const reportCreationRequest = ({ ...rest, type: reportTypes.ENTRY_INCONSISTENCY, }: EntryInconsistencyReportCreationRequest); await createReport(viewer, reportCreationRequest); } type SessionInitializationResult = | { sessionContinued: false } | { sessionContinued: true, deltaEntryInfoResult: DeltaEntryInfosResponse, sessionUpdate: SessionUpdate, }; async function initializeSession( viewer: Viewer, calendarQuery: CalendarQuery, oldLastUpdate: number, ): Promise { if (!viewer.loggedIn) { return { sessionContinued: false }; } if (!viewer.hasSessionInfo) { // If the viewer has no session info but is logged in, that is indicative // of an expired / invalidated session and we should generate a new one await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } if (oldLastUpdate < viewer.sessionLastUpdated) { // If the client has an older last_update than the server is tracking for // that client, then the client either had some issue persisting its store, // or the user restored the client app from a backup. Either way, we should // invalidate the existing session, since the server has assumed that the // checkpoint is further along than it is on the client, and might not still // have all of the updates necessary to do an incremental update await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } let comparisonResult = null; try { comparisonResult = compareNewCalendarQuery(viewer, calendarQuery); } catch (e) { if (e.message !== 'unknown_error') { throw e; } } if (comparisonResult) { const { difference, oldCalendarQuery } = comparisonResult; const sessionUpdate = { ...comparisonResult.sessionUpdate, lastUpdate: oldLastUpdate, }; const deltaEntryInfoResult = await fetchEntriesForSession( viewer, difference, oldCalendarQuery, ); return { sessionContinued: true, deltaEntryInfoResult, sessionUpdate }; } else { await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } } type StateCheckResult = { sessionUpdate?: SessionUpdate, checkStateRequest?: ServerCheckStateServerRequest, }; async function checkState( viewer: Viewer, status: StateCheckStatus, ): Promise { if (status.status === 'state_validated') { return { sessionUpdate: { lastValidated: Date.now() } }; } else if (status.status === 'state_check') { const promises = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, (async () => { if ( !hasMinCodeVersion(viewer.platformDetails, { native: 267, web: 32, }) ) { const data = await spec.fetch(viewer); return hash(data); } const infosHash = await spec.fetchServerInfosHash(viewer); return infosHash; })(), ]), ); const hashesToCheck = await promiseAll(promises); const checkStateRequest = { type: serverRequestTypes.CHECK_STATE, hashesToCheck, }; return { checkStateRequest }; } const invalidKeys = new Set(status.invalidKeys); const shouldFetchAll = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, invalidKeys.has(spec.hashKey), ]), ); const idsToFetch = Object.fromEntries( values(serverStateSyncSpecs) .filter(spec => spec.innerHashSpec?.hashKey) .map(spec => [spec.innerHashSpec?.hashKey, new Set()]), ); for (const key of invalidKeys) { const [innerHashKey, id] = key.split('|'); if (innerHashKey && id) { idsToFetch[innerHashKey]?.add(id); } } const fetchPromises: { [string]: Promise } = {}; for (const spec of values(serverStateSyncSpecs)) { if (shouldFetchAll[spec.hashKey]) { fetchPromises[spec.hashKey] = spec.fetch(viewer); } else if (idsToFetch[spec.innerHashSpec?.hashKey]?.size > 0) { fetchPromises[spec.hashKey] = spec.fetch( viewer, idsToFetch[spec.innerHashSpec?.hashKey], ); } } const fetchedData = await promiseAll(fetchPromises); const specPerHashKey = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [spec.hashKey, spec]), ); const specPerInnerHashKey = Object.fromEntries( values(serverStateSyncSpecs) .filter(spec => spec.innerHashSpec?.hashKey) .map(spec => [spec.innerHashSpec?.hashKey, spec]), ); const hashesToCheck: { [string]: number } = {}, failUnmentioned: { [string]: boolean } = {}, stateChanges: { [string]: mixed } = {}; for (const key of invalidKeys) { const spec = specPerHashKey[key]; const innerHashKey = spec?.innerHashSpec?.hashKey; const isTopLevelKey = !!spec; if (isTopLevelKey && innerHashKey) { // Instead of returning all the infos, we want to narrow down and figure // out which infos don't match first const infos = fetchedData[key]; // We have a type error here because in fact the relationship between // Infos and Info is not guaranteed to be like this. In particular, // currentUserStateSyncSpec does not match this pattern. But this code // doesn't fire for it because no innerHashSpec is defined const iterableInfos: { +[string]: mixed } = (infos: any); for (const infoID in iterableInfos) { let hashValue; if ( hasMinCodeVersion(viewer.platformDetails, { native: 267, web: 32, }) ) { // We have a type error here because Flow has no way to determine that // spec and infos are matched up hashValue = spec.getServerInfoHash((iterableInfos[infoID]: any)); } else { hashValue = hash(iterableInfos[infoID]); } hashesToCheck[`${innerHashKey}|${infoID}`] = hashValue; } failUnmentioned[key] = true; } else if (isTopLevelKey) { stateChanges[key] = fetchedData[key]; } else { const [keyPrefix, id] = key.split('|'); const innerSpec = specPerInnerHashKey[keyPrefix]; const innerHashSpec = innerSpec?.innerHashSpec; if (!innerHashSpec || !id) { continue; } const infos = fetchedData[innerSpec.hashKey]; // We have a type error here because in fact the relationship between // Infos and Info is not guaranteed to be like this. In particular, // currentUserStateSyncSpec does not match this pattern. But this code // doesn't fire for it because no innerHashSpec is defined const iterableInfos: { +[string]: mixed } = (infos: any); const info = iterableInfos[id]; // We have a type error here because Flow wants us to type iterableInfos // in this file, but we don't have access to the parameterization of // innerHashSpec here if (!info || innerHashSpec.additionalDeleteCondition?.((info: any))) { if (!stateChanges[innerHashSpec.deleteKey]) { stateChanges[innerHashSpec.deleteKey] = [id]; } else { // We have a type error here because in fact stateChanges values // aren't always arrays. In particular, currentUserStateSyncSpec does // not match this pattern. But this code doesn't fire for it because // no innerHashSpec is defined const curDeleteKeyChanges: Array = (stateChanges[ innerHashSpec.deleteKey ]: any); curDeleteKeyChanges.push(id); } continue; } if (!stateChanges[innerHashSpec.rawInfosKey]) { stateChanges[innerHashSpec.rawInfosKey] = [info]; } else { // We have a type error here because in fact stateChanges values aren't // always arrays. In particular, currentUserStateSyncSpec does not match // this pattern. But this code doesn't fire for it because no // innerHashSpec is defined const curRawInfosKeyChanges: Array = (stateChanges[ innerHashSpec.rawInfosKey ]: any); curRawInfosKeyChanges.push(info); } } } // We have a type error here because the keys that get set on some of these // collections aren't statically typed when they're set. Rather, they are set // as arbitrary strings const checkStateRequest: ServerCheckStateServerRequest = ({ type: serverRequestTypes.CHECK_STATE, hashesToCheck, failUnmentioned, stateChanges, }: any); if (Object.keys(hashesToCheck).length === 0) { return { checkStateRequest, sessionUpdate: { lastValidated: Date.now() } }; } else { return { checkStateRequest }; } } export { clientResponseInputValidator, processClientResponses, initializeSession, checkState, }; diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js index c6e413533..422572b45 100644 --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -1,885 +1,884 @@ // @flow import type { $Request } from 'express'; import invariant from 'invariant'; import _debounce from 'lodash/debounce.js'; import t from 'tcomb'; import type { TUnion } from 'tcomb'; import WebSocket from 'ws'; import { baseLegalPolicies } from 'lib/facts/policies.js'; import { mostRecentMessageTimestamp } from 'lib/shared/message-utils.js'; import { isStaff } from 'lib/shared/staff-utils.js'; import { serverRequestSocketTimeout, serverResponseTimeout, } from 'lib/shared/timeouts.js'; import { mostRecentUpdateTimestamp } from 'lib/shared/update-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import { endpointIsSocketSafe } from 'lib/types/endpoints.js'; import type { RawEntryInfo } from 'lib/types/entry-types.js'; import { defaultNumberPerThread } from 'lib/types/message-types.js'; import { redisMessageTypes, type RedisMessage } from 'lib/types/redis-types.js'; import { serverRequestTypes } from 'lib/types/request-types.js'; import { sessionCheckFrequency, stateCheckInactivityActivationInterval, } from 'lib/types/session-types.js'; import { type ClientSocketMessage, type InitialClientSocketMessage, type ResponsesClientSocketMessage, type ServerStateSyncFullSocketPayload, type ServerServerSocketMessage, type ErrorServerSocketMessage, type AuthErrorServerSocketMessage, type PingClientSocketMessage, type AckUpdatesClientSocketMessage, type APIRequestClientSocketMessage, clientSocketMessageTypes, stateSyncPayloadTypes, serverSocketMessageTypes, serverServerSocketMessageValidator, } from 'lib/types/socket-types.js'; import type { LegacyRawThreadInfos } from 'lib/types/thread-types.js'; import type { UserInfo, CurrentUserInfo } from 'lib/types/user-types.js'; import { ServerError } from 'lib/utils/errors.js'; import { values } from 'lib/utils/objects.js'; -import { promiseAll } from 'lib/utils/promises.js'; +import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import SequentialPromiseResolver from 'lib/utils/sequential-promise-resolver.js'; import sleep from 'lib/utils/sleep.js'; import { tShape, tCookie } from 'lib/utils/validation-utils.js'; import { RedisSubscriber } from './redis.js'; import { clientResponseInputValidator, processClientResponses, initializeSession, checkState, } from './session-utils.js'; import { fetchUpdateInfosWithRawUpdateInfos } from '../creators/update-creator.js'; import { deleteActivityForViewerSession } from '../deleters/activity-deleters.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import { deleteUpdatesBeforeTimeTargetingSession } from '../deleters/update-deleters.js'; import { jsonEndpoints } from '../endpoints.js'; import { fetchMessageInfosSince, getMessageFetchResultFromRedisMessages, } from '../fetchers/message-fetchers.js'; import { fetchUpdateInfos } from '../fetchers/update-fetchers.js'; import { newEntryQueryInputValidator, verifyCalendarQueryThreadIDs, } from '../responders/entry-responders.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import { fetchViewerForSocket, updateCookie, isCookieMissingSignedIdentityKeysBlob, isCookieMissingOlmNotificationsSession, createNewAnonymousCookie, } from '../session/cookies.js'; import { Viewer } from '../session/viewer.js'; import type { AnonymousViewerData } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { commitSessionUpdate } from '../updaters/session-updaters.js'; import { compressMessage } from '../utils/compress.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { checkInputValidator, checkClientSupported, policiesValidator, validateOutput, } from '../utils/validation-utils.js'; const clientSocketMessageInputValidator: TUnion = t.union([ tShape({ type: t.irreducible( 'clientSocketMessageTypes.INITIAL', x => x === clientSocketMessageTypes.INITIAL, ), id: t.Number, payload: tShape({ sessionIdentification: tShape({ cookie: t.maybe(tCookie), sessionID: t.maybe(t.String), }), sessionState: tShape({ calendarQuery: newEntryQueryInputValidator, messagesCurrentAsOf: t.Number, updatesCurrentAsOf: t.Number, watchedIDs: t.list(t.String), }), clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.RESPONSES', x => x === clientSocketMessageTypes.RESPONSES, ), id: t.Number, payload: tShape({ clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.PING', x => x === clientSocketMessageTypes.PING, ), id: t.Number, }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.ACK_UPDATES', x => x === clientSocketMessageTypes.ACK_UPDATES, ), id: t.Number, payload: tShape({ currentAsOf: t.Number, }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.API_REQUEST', x => x === clientSocketMessageTypes.API_REQUEST, ), id: t.Number, payload: tShape({ endpoint: t.String, input: t.maybe(t.Object), }), }), ]); function onConnection(ws: WebSocket, req: $Request) { assertSecureRequest(req); new Socket(ws, req); } type StateCheckConditions = { activityRecentlyOccurred: boolean, stateCheckOngoing: boolean, }; const minVersionsForCompression = { native: 265, web: 30, }; class Socket { ws: WebSocket; httpRequest: $Request; viewer: ?Viewer; redis: ?RedisSubscriber; redisPromiseResolver: SequentialPromiseResolver; stateCheckConditions: StateCheckConditions = { activityRecentlyOccurred: true, stateCheckOngoing: false, }; stateCheckTimeoutID: ?TimeoutID; constructor(ws: WebSocket, httpRequest: $Request) { this.ws = ws; this.httpRequest = httpRequest; ws.on('message', this.onMessage); ws.on('close', this.onClose); this.resetTimeout(); this.redisPromiseResolver = new SequentialPromiseResolver(this.sendMessage); } onMessage = async ( messageString: string | Buffer | ArrayBuffer | Array, ): Promise => { invariant(typeof messageString === 'string', 'message should be string'); let clientSocketMessage: ?ClientSocketMessage; try { this.resetTimeout(); const messageObject = JSON.parse(messageString); clientSocketMessage = checkInputValidator( clientSocketMessageInputValidator, messageObject, ); if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { if (this.viewer) { // This indicates that the user sent multiple INITIAL messages. throw new ServerError('socket_already_initialized'); } this.viewer = await fetchViewerForSocket( this.httpRequest, clientSocketMessage, ); } const { viewer } = this; if (!viewer) { // This indicates a non-INITIAL message was sent by the client before // the INITIAL message. throw new ServerError('socket_uninitialized'); } if (viewer.sessionChanged) { // This indicates that the cookie was invalid, and we've assigned a new // anonymous one. throw new ServerError('socket_deauthorized'); } if (!viewer.loggedIn) { // This indicates that the specified cookie was an anonymous one. throw new ServerError('not_logged_in'); } await checkClientSupported( viewer, clientSocketMessageInputValidator, clientSocketMessage, ); await policiesValidator(viewer, baseLegalPolicies); const serverResponses = await this.handleClientSocketMessage( clientSocketMessage, ); if (!this.redis) { this.redis = new RedisSubscriber( { userID: viewer.userID, sessionID: viewer.session }, this.onRedisMessage, ); } if (viewer.sessionChanged) { // This indicates that something has caused the session to change, which // shouldn't happen from inside a WebSocket since we can't handle cookie // invalidation. throw new ServerError('session_mutated_from_socket'); } if (clientSocketMessage.type !== clientSocketMessageTypes.PING) { - handleAsyncPromise(updateCookie(viewer)); + ignorePromiseRejections(updateCookie(viewer)); } for (const response of serverResponses) { // Normally it's an anti-pattern to await in sequence like this. But in // this case, we have a requirement that this array of serverResponses // is delivered in order. See here: // https://github.com/CommE2E/comm/blob/101eb34481deb49c609bfd2c785f375886e52666/keyserver/src/socket/socket.js#L566-L568 await this.sendMessage(response); } if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { this.onSuccessfulConnection(); } } catch (error) { console.warn(error); if (!(error instanceof ServerError)) { const errorMessage: ErrorServerSocketMessage = { type: serverSocketMessageTypes.ERROR, message: error.message, }; const responseTo = clientSocketMessage ? clientSocketMessage.id : null; if (responseTo !== null) { errorMessage.responseTo = responseTo; } this.markActivityOccurred(); await this.sendMessage(errorMessage); return; } invariant(clientSocketMessage, 'should be set'); const responseTo = clientSocketMessage.id; if (error.message === 'socket_deauthorized') { invariant(this.viewer, 'should be set'); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, sessionChange: { cookie: this.viewer.cookiePairString, currentUserInfo: { anonymous: true, }, }, }; await this.sendMessage(authErrorMessage); this.ws.close(4100, error.message); return; } else if (error.message === 'client_version_unsupported') { const { viewer } = this; invariant(viewer, 'should be set'); const anonymousViewerDataPromise: Promise = createNewAnonymousCookie({ platformDetails: error.platformDetails, deviceToken: viewer.deviceToken, }); const deleteCookiePromise = deleteCookie(viewer.cookieID); const [anonymousViewerData] = await Promise.all([ anonymousViewerDataPromise, deleteCookiePromise, ]); // It is normally not safe to pass the result of // createNewAnonymousCookie to the Viewer constructor. That is because // createNewAnonymousCookie leaves several fields of // AnonymousViewerData unset, and consequently Viewer will throw when // access is attempted. It is only safe here because we can guarantee // that only cookiePairString and cookieID are accessed on anonViewer // below. const anonViewer = new Viewer(anonymousViewerData); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, sessionChange: { cookie: anonViewer.cookiePairString, currentUserInfo: { anonymous: true, }, }, }; await this.sendMessage(authErrorMessage); this.ws.close(4101, error.message); return; } if (error.payload) { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, payload: error.payload, }); } else { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, }); } if (error.message === 'not_logged_in') { this.ws.close(4102, error.message); } else if (error.message === 'session_mutated_from_socket') { this.ws.close(4103, error.message); } else { this.markActivityOccurred(); } } }; onClose = async () => { this.clearStateCheckTimeout(); this.resetTimeout.cancel(); this.debouncedAfterActivity.cancel(); if (this.viewer && this.viewer.hasSessionInfo) { await deleteActivityForViewerSession(this.viewer); } if (this.redis) { this.redis.quit(); this.redis = null; } }; sendMessage = async (message: ServerServerSocketMessage) => { invariant( this.ws.readyState > 0, "shouldn't send message until connection established", ); if (this.ws.readyState !== 1) { return; } const { viewer } = this; const validatedMessage = validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, message, ); const stringMessage = JSON.stringify(validatedMessage); if ( !viewer?.platformDetails || !hasMinCodeVersion(viewer.platformDetails, minVersionsForCompression) || !isStaff(viewer.id) ) { this.ws.send(stringMessage); return; } const compressionResult = await compressMessage(stringMessage); if (this.ws.readyState !== 1) { return; } if (!compressionResult.compressed) { this.ws.send(stringMessage); return; } const compressedMessage = { type: serverSocketMessageTypes.COMPRESSED_MESSAGE, payload: compressionResult.result, }; const validatedCompressedMessage = validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, compressedMessage, ); const stringCompressedMessage = JSON.stringify(validatedCompressedMessage); this.ws.send(stringCompressedMessage); }; async handleClientSocketMessage( message: ClientSocketMessage, ): Promise { const resultPromise: Promise = (async () => { if (message.type === clientSocketMessageTypes.INITIAL) { this.markActivityOccurred(); return await this.handleInitialClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.RESPONSES) { this.markActivityOccurred(); return await this.handleResponsesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.PING) { return this.handlePingClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.ACK_UPDATES) { this.markActivityOccurred(); return await this.handleAckUpdatesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.API_REQUEST) { this.markActivityOccurred(); return await this.handleAPIRequestClientSocketMessage(message); } return []; })(); const timeoutPromise: Promise = (async () => { await sleep(serverResponseTimeout); throw new ServerError('socket_response_timeout'); })(); return await Promise.race([resultPromise, timeoutPromise]); } async handleInitialClientSocketMessage( message: InitialClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const responses: Array = []; const { sessionState, clientResponses } = message.payload; const { calendarQuery, updatesCurrentAsOf: oldUpdatesCurrentAsOf, messagesCurrentAsOf: oldMessagesCurrentAsOf, watchedIDs, } = sessionState; await verifyCalendarQueryThreadIDs(calendarQuery); const sessionInitializationResult = await initializeSession( viewer, calendarQuery, oldUpdatesCurrentAsOf, ); const threadCursors: { [string]: null } = {}; for (const watchedThreadID of watchedIDs) { threadCursors[watchedThreadID] = null; } const messageSelectionCriteria = { threadCursors, joinedThreads: true, newerThan: oldMessagesCurrentAsOf, }; const [fetchMessagesResult, { serverRequests, activityUpdateResult }] = await Promise.all([ fetchMessageInfosSince( viewer, messageSelectionCriteria, defaultNumberPerThread, ), processClientResponses(viewer, clientResponses), ]); const messagesResult = { rawMessageInfos: fetchMessagesResult.rawMessageInfos, truncationStatuses: fetchMessagesResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( fetchMessagesResult.rawMessageInfos, oldMessagesCurrentAsOf, ), }; const isCookieMissingSignedIdentityKeysBlobPromise = isCookieMissingSignedIdentityKeysBlob(viewer.cookieID); const isCookieMissingOlmNotificationsSessionPromise = isCookieMissingOlmNotificationsSession(viewer); if (!sessionInitializationResult.sessionContinued) { const promises: { +[string]: Promise } = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, spec.fetchFullSocketSyncPayload(viewer, [calendarQuery]), ]), ); // We have a type error here because Flow doesn't know spec.hashKey const castPromises: { +threadInfos: Promise, +currentUserInfo: Promise, +entryInfos: Promise<$ReadOnlyArray>, +userInfos: Promise<$ReadOnlyArray>, } = (promises: any); const results = await promiseAll(castPromises); const payload: ServerStateSyncFullSocketPayload = { type: stateSyncPayloadTypes.FULL, messagesResult, threadInfos: results.threadInfos, currentUserInfo: results.currentUserInfo, rawEntryInfos: results.entryInfos, userInfos: results.userInfos, updatesCurrentAsOf: oldUpdatesCurrentAsOf, }; if (viewer.sessionChanged) { // If initializeSession encounters, // sessionIdentifierTypes.BODY_SESSION_ID but the session // is unspecified or expired, // it will set a new sessionID and specify viewer.sessionChanged const { sessionID } = viewer; invariant( sessionID !== null && sessionID !== undefined, 'should be set', ); payload.sessionID = sessionID; viewer.sessionChanged = false; } responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload, }); } else { const { sessionUpdate, deltaEntryInfoResult } = sessionInitializationResult; const deleteExpiredUpdatesPromise = deleteUpdatesBeforeTimeTargetingSession(viewer, oldUpdatesCurrentAsOf); const fetchUpdateResultPromise = fetchUpdateInfos( viewer, oldUpdatesCurrentAsOf, calendarQuery, ); const sessionUpdatePromise = commitSessionUpdate(viewer, sessionUpdate); const [fetchUpdateResult] = await Promise.all([ fetchUpdateResultPromise, deleteExpiredUpdatesPromise, sessionUpdatePromise, ]); const { updateInfos, userInfos } = fetchUpdateResult; const newUpdatesCurrentAsOf = mostRecentUpdateTimestamp( [...updateInfos], oldUpdatesCurrentAsOf, ); const updatesResult = { newUpdates: updateInfos, currentAsOf: newUpdatesCurrentAsOf, }; responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload: { type: stateSyncPayloadTypes.INCREMENTAL, messagesResult, updatesResult, deltaEntryInfos: deltaEntryInfoResult.rawEntryInfos, deletedEntryIDs: deltaEntryInfoResult.deletedEntryIDs, userInfos: values(userInfos), }, }); } const [signedIdentityKeysBlobMissing, olmNotificationsSessionMissing] = await Promise.all([ isCookieMissingSignedIdentityKeysBlobPromise, isCookieMissingOlmNotificationsSessionPromise, ]); if (signedIdentityKeysBlobMissing) { serverRequests.push({ type: serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB, }); } if (olmNotificationsSessionMissing) { serverRequests.push({ type: serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE, }); } if (serverRequests.length > 0 || clientResponses.length > 0) { // We send this message first since the STATE_SYNC triggers the client's // connection status to shift to "connected", and we want to make sure the // client responses are cleared from Redux before that happens responses.unshift({ type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }); } if (activityUpdateResult) { // Same reason for unshifting as above responses.unshift({ type: serverSocketMessageTypes.ACTIVITY_UPDATE_RESPONSE, responseTo: message.id, payload: activityUpdateResult, }); } return responses; } async handleResponsesClientSocketMessage( message: ResponsesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { clientResponses } = message.payload; const { stateCheckStatus } = await processClientResponses( viewer, clientResponses, ); const serverRequests = []; if (stateCheckStatus && stateCheckStatus.status !== 'state_check') { const { sessionUpdate, checkStateRequest } = await checkState( viewer, stateCheckStatus, ); if (sessionUpdate) { await commitSessionUpdate(viewer, sessionUpdate); this.setStateCheckConditions({ stateCheckOngoing: false }); } if (checkStateRequest) { serverRequests.push(checkStateRequest); } } // We send a response message regardless of whether we have any requests, // since we need to ack the client's responses return [ { type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }, ]; } handlePingClientSocketMessage( message: PingClientSocketMessage, ): ServerServerSocketMessage[] { return [ { type: serverSocketMessageTypes.PONG, responseTo: message.id, }, ]; } async handleAckUpdatesClientSocketMessage( message: AckUpdatesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { currentAsOf } = message.payload; await Promise.all([ deleteUpdatesBeforeTimeTargetingSession(viewer, currentAsOf), commitSessionUpdate(viewer, { lastUpdate: currentAsOf }), ]); return []; } async handleAPIRequestClientSocketMessage( message: APIRequestClientSocketMessage, ): Promise { if (!endpointIsSocketSafe(message.payload.endpoint)) { throw new ServerError('endpoint_unsafe_for_socket'); } const { viewer } = this; invariant(viewer, 'should be set'); const responder = jsonEndpoints[message.payload.endpoint]; await policiesValidator(viewer, responder.requiredPolicies); const response = await responder.responder(viewer, message.payload.input); return [ { type: serverSocketMessageTypes.API_RESPONSE, responseTo: message.id, payload: response, }, ]; } onRedisMessage = async (message: RedisMessage) => { try { await this.processRedisMessage(message); } catch (e) { console.warn(e); } }; async processRedisMessage(message: RedisMessage) { if (message.type === redisMessageTypes.START_SUBSCRIPTION) { this.ws.terminate(); } else if (message.type === redisMessageTypes.NEW_UPDATES) { const { viewer } = this; invariant(viewer, 'should be set'); if (message.ignoreSession && message.ignoreSession === viewer.session) { return; } const rawUpdateInfos = message.updates; this.redisPromiseResolver.add( (async () => { const { updateInfos, userInfos } = await fetchUpdateInfosWithRawUpdateInfos(rawUpdateInfos, { viewer, }); if (updateInfos.length === 0) { console.warn( 'could not get any UpdateInfos from redisMessageTypes.NEW_UPDATES', ); return null; } this.markActivityOccurred(); return { type: serverSocketMessageTypes.UPDATES, payload: { updatesResult: { currentAsOf: mostRecentUpdateTimestamp([...updateInfos], 0), newUpdates: updateInfos, }, userInfos: values(userInfos), }, }; })(), ); } else if (message.type === redisMessageTypes.NEW_MESSAGES) { const { viewer } = this; invariant(viewer, 'should be set'); const rawMessageInfos = message.messages; const messageFetchResult = getMessageFetchResultFromRedisMessages( viewer, rawMessageInfos, ); if (messageFetchResult.rawMessageInfos.length === 0) { console.warn( 'could not get any rawMessageInfos from ' + 'redisMessageTypes.NEW_MESSAGES', ); return; } this.redisPromiseResolver.add( (async () => { this.markActivityOccurred(); return { type: serverSocketMessageTypes.MESSAGES, payload: { messagesResult: { rawMessageInfos: messageFetchResult.rawMessageInfos, truncationStatuses: messageFetchResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( messageFetchResult.rawMessageInfos, 0, ), }, }, }; })(), ); } } onSuccessfulConnection() { if (this.ws.readyState !== 1) { return; } this.handleStateCheckConditionsUpdate(); } // The Socket will timeout by calling this.ws.terminate() // serverRequestSocketTimeout milliseconds after the last // time resetTimeout is called resetTimeout: { +cancel: () => void } & (() => void) = _debounce( () => this.ws.terminate(), serverRequestSocketTimeout, ); debouncedAfterActivity: { +cancel: () => void } & (() => void) = _debounce( () => this.setStateCheckConditions({ activityRecentlyOccurred: false }), stateCheckInactivityActivationInterval, ); markActivityOccurred = () => { if (this.ws.readyState !== 1) { return; } this.setStateCheckConditions({ activityRecentlyOccurred: true }); this.debouncedAfterActivity(); }; clearStateCheckTimeout() { const { stateCheckTimeoutID } = this; if (stateCheckTimeoutID) { clearTimeout(stateCheckTimeoutID); this.stateCheckTimeoutID = null; } } setStateCheckConditions(newConditions: Partial) { this.stateCheckConditions = { ...this.stateCheckConditions, ...newConditions, }; this.handleStateCheckConditionsUpdate(); } get stateCheckCanStart(): boolean { return Object.values(this.stateCheckConditions).every(cond => !cond); } handleStateCheckConditionsUpdate() { if (!this.stateCheckCanStart) { this.clearStateCheckTimeout(); return; } if (this.stateCheckTimeoutID) { return; } const { viewer } = this; if (!viewer) { return; } const timeUntilStateCheck = viewer.sessionLastValidated + sessionCheckFrequency - Date.now(); if (timeUntilStateCheck <= 0) { this.initiateStateCheck(); } else { this.stateCheckTimeoutID = setTimeout( this.initiateStateCheck, timeUntilStateCheck, ); } } initiateStateCheck = async () => { this.setStateCheckConditions({ stateCheckOngoing: true }); const { viewer } = this; invariant(viewer, 'should be set'); const { checkStateRequest } = await checkState(viewer, { status: 'state_check', }); invariant(checkStateRequest, 'should be set'); await this.sendMessage({ type: serverSocketMessageTypes.REQUESTS, payload: { serverRequests: [checkStateRequest] }, }); }; } export { onConnection }; diff --git a/keyserver/src/updaters/thread-updaters.js b/keyserver/src/updaters/thread-updaters.js index 9abf5211d..92ba2aac7 100644 --- a/keyserver/src/updaters/thread-updaters.js +++ b/keyserver/src/updaters/thread-updaters.js @@ -1,982 +1,981 @@ // @flow import { specialRoles } from 'lib/permissions/special-roles.js'; import { getRolePermissionBlobs } from 'lib/permissions/thread-permissions.js'; import { filteredThreadIDs } from 'lib/selectors/calendar-filter-selectors.js'; import { getPinnedContentFromMessage } from 'lib/shared/message-utils.js'; import { threadHasAdminRole, roleIsAdminRole, viewerIsMember, getThreadTypeParentRequirement, } from 'lib/shared/thread-utils.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import type { RawMessageInfo, MessageData } from 'lib/types/message-types.js'; import { threadPermissions } from 'lib/types/thread-permission-types.js'; import { threadTypes } from 'lib/types/thread-types-enum.js'; import { type RoleChangeRequest, type ChangeThreadSettingsResult, type RemoveMembersRequest, type LeaveThreadRequest, type LeaveThreadResult, type UpdateThreadRequest, type ServerThreadJoinRequest, type ThreadJoinResult, type ToggleMessagePinRequest, type ToggleMessagePinResult, } from 'lib/types/thread-types.js'; import { updateTypes } from 'lib/types/update-types-enum.js'; import { ServerError } from 'lib/utils/errors.js'; -import { promiseAll } from 'lib/utils/promises.js'; +import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import { firstLine } from 'lib/utils/string-utils.js'; import { canToggleMessagePin } from 'lib/utils/toggle-pin-utils.js'; import { validChatNameRegex } from 'lib/utils/validation-utils.js'; import { reportLinkUsage } from './link-updaters.js'; import { updateRoles } from './role-updaters.js'; import { changeRole, recalculateThreadPermissions, commitMembershipChangeset, type MembershipChangeset, type MembershipRow, } from './thread-permission-updaters.js'; import createMessages from '../creators/message-creator.js'; import { createUpdates } from '../creators/update-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { checkIfInviteLinkIsValid } from '../fetchers/link-fetchers.js'; import { fetchMessageInfoByID } from '../fetchers/message-fetchers.js'; import { fetchThreadInfos, fetchServerThreadInfos, determineThreadAncestry, rawThreadInfosFromServerThreadInfos, } from '../fetchers/thread-fetchers.js'; import { checkThreadPermission, viewerIsMember as fetchViewerIsMember, checkThread, validateCandidateMembers, } from '../fetchers/thread-permission-fetchers.js'; import { verifyUserIDs, verifyUserOrCookieIDs, } from '../fetchers/user-fetchers.js'; -import { handleAsyncPromise } from '../responders/handlers.js'; import type { Viewer } from '../session/viewer.js'; import RelationshipChangeset from '../utils/relationship-changeset.js'; type UpdateRoleOptions = { +silenceNewMessages?: boolean, +forcePermissionRecalculation?: boolean, }; async function updateRole( viewer: Viewer, request: RoleChangeRequest, options?: UpdateRoleOptions, ): Promise { const silenceNewMessages = options?.silenceNewMessages; const forcePermissionRecalculation = options?.forcePermissionRecalculation; if (!viewer.loggedIn) { throw new ServerError('not_logged_in'); } const [memberIDs, hasPermission, fetchThreadResult] = await Promise.all([ verifyUserIDs(request.memberIDs), checkThreadPermission( viewer, request.threadID, threadPermissions.CHANGE_ROLE, ), fetchThreadInfos(viewer, { threadID: request.threadID }), ]); if (memberIDs.length === 0) { throw new ServerError('invalid_parameters'); } if (!hasPermission) { throw new ServerError('invalid_credentials'); } const threadInfo = fetchThreadResult.threadInfos[request.threadID]; if (!threadInfo) { throw new ServerError('invalid_parameters'); } const adminRoleID = Object.keys(threadInfo.roles).find( roleID => threadInfo.roles[roleID].name === 'Admins', ); // Ensure that there will always still be at least one admin in a community if (adminRoleID) { const memberRoles = memberIDs.map( memberID => threadInfo.members.find(member => member.id === memberID)?.role, ); const communityAdminsCount = threadInfo.members.filter( member => member.role === adminRoleID, ).length; const changedAdminsCount = memberRoles.filter( memberRole => memberRole === adminRoleID, ).length; if (changedAdminsCount >= communityAdminsCount) { throw new ServerError('invalid_parameters'); } } const query = SQL` SELECT user, role FROM memberships WHERE user IN (${memberIDs}) AND thread = ${request.threadID} `; const [result] = await dbQuery(query); let nonMemberUser = false; let numResults = 0; for (const row of result) { if (row.role <= 0) { nonMemberUser = true; break; } numResults++; } if (nonMemberUser || numResults < memberIDs.length) { throw new ServerError('invalid_parameters'); } const changeset = await changeRole( request.threadID, memberIDs, request.role, { forcePermissionRecalculation: !!forcePermissionRecalculation, }, ); const { viewerUpdates } = await commitMembershipChangeset( viewer, changeset, forcePermissionRecalculation ? { changedThreadIDs: new Set([request.threadID]) } : undefined, ); let newMessageInfos: Array = []; if (!silenceNewMessages) { const messageData = { type: messageTypes.CHANGE_ROLE, threadID: request.threadID, creatorID: viewer.userID, time: Date.now(), userIDs: memberIDs, newRole: request.role, roleName: threadInfo.roles[request.role].name, }; newMessageInfos = await createMessages(viewer, [messageData]); } return { updatesResult: { newUpdates: viewerUpdates }, newMessageInfos }; } async function removeMembers( viewer: Viewer, request: RemoveMembersRequest, ): Promise { const viewerID = viewer.userID; if (request.memberIDs.includes(viewerID)) { throw new ServerError('invalid_parameters'); } const [memberIDs, hasPermission] = await Promise.all([ verifyUserOrCookieIDs(request.memberIDs), checkThreadPermission( viewer, request.threadID, threadPermissions.REMOVE_MEMBERS, ), ]); if (memberIDs.length === 0) { throw new ServerError('invalid_parameters'); } if (!hasPermission) { throw new ServerError('invalid_credentials'); } const query = SQL` SELECT m.user, m.role, r.id AS default_role FROM memberships m LEFT JOIN roles r ON r.special_role = ${specialRoles.DEFAULT_ROLE} AND r.thread = ${request.threadID} WHERE m.user IN (${memberIDs}) AND m.thread = ${request.threadID} `; const [result] = await dbQuery(query); let nonDefaultRoleUser = false; const actualMemberIDs = []; for (const row of result) { if (row.role <= 0) { continue; } actualMemberIDs.push(row.user.toString()); if (row.role !== row.default_role) { nonDefaultRoleUser = true; } } if (nonDefaultRoleUser) { const hasChangeRolePermission = await checkThreadPermission( viewer, request.threadID, threadPermissions.CHANGE_ROLE, ); if (!hasChangeRolePermission) { throw new ServerError('invalid_credentials'); } } const changeset = await changeRole(request.threadID, actualMemberIDs, 0); const { viewerUpdates } = await commitMembershipChangeset(viewer, changeset); const newMessageInfos = await (async () => { if (actualMemberIDs.length === 0) { return []; } const messageData = { type: messageTypes.REMOVE_MEMBERS, threadID: request.threadID, creatorID: viewerID, time: Date.now(), removedUserIDs: actualMemberIDs, }; return await createMessages(viewer, [messageData]); })(); return { updatesResult: { newUpdates: viewerUpdates }, newMessageInfos }; } async function leaveThread( viewer: Viewer, request: LeaveThreadRequest, ): Promise { if (!viewer.loggedIn) { throw new ServerError('not_logged_in'); } const [fetchThreadResult, hasPermission] = await Promise.all([ fetchThreadInfos(viewer, { threadID: request.threadID }), checkThreadPermission( viewer, request.threadID, threadPermissions.LEAVE_THREAD, ), ]); const threadInfo = fetchThreadResult.threadInfos[request.threadID]; if (!viewerIsMember(threadInfo)) { return { updatesResult: { newUpdates: [] }, }; } if (!hasPermission) { throw new ServerError('invalid_parameters'); } const viewerID = viewer.userID; if (threadHasAdminRole(threadInfo)) { let otherUsersExist = false; let otherAdminsExist = false; for (const member of threadInfo.members) { const role = member.role; if (!role || member.id === viewerID) { continue; } otherUsersExist = true; if (roleIsAdminRole(threadInfo.roles[role])) { otherAdminsExist = true; break; } } if (otherUsersExist && !otherAdminsExist) { throw new ServerError('invalid_parameters'); } } const changeset = await changeRole(request.threadID, [viewerID], 0); const { viewerUpdates } = await commitMembershipChangeset(viewer, changeset); const messageData = { type: messageTypes.LEAVE_THREAD, threadID: request.threadID, creatorID: viewerID, time: Date.now(), }; await createMessages(viewer, [messageData]); return { updatesResult: { newUpdates: viewerUpdates } }; } type UpdateThreadOptions = Partial<{ +forceAddMembers: boolean, +forceUpdateRoot: boolean, +silenceMessages: boolean, +ignorePermissions: boolean, }>; async function updateThread( viewer: Viewer, request: UpdateThreadRequest, options?: UpdateThreadOptions, ): Promise { if (!viewer.loggedIn) { throw new ServerError('not_logged_in'); } const forceAddMembers = options?.forceAddMembers ?? false; const forceUpdateRoot = options?.forceUpdateRoot ?? false; const silenceMessages = options?.silenceMessages ?? false; const ignorePermissions = (options?.ignorePermissions && viewer.isScriptViewer) ?? false; const changedFields: { [string]: string | number } = {}; const sqlUpdate: { [string]: ?string | number } = {}; const untrimmedName = request.changes.name; if (untrimmedName !== undefined && untrimmedName !== null) { const name = firstLine(untrimmedName); if (name.search(validChatNameRegex) === -1) { throw new ServerError('invalid_chat_name'); } changedFields.name = name; sqlUpdate.name = name; } const { description } = request.changes; if (description !== undefined && description !== null) { changedFields.description = description; sqlUpdate.description = description; } if (request.changes.color) { const color = request.changes.color.toLowerCase(); changedFields.color = color; sqlUpdate.color = color; } const { parentThreadID } = request.changes; if (parentThreadID !== undefined) { // TODO some sort of message when this changes sqlUpdate.parent_thread_id = parentThreadID; } const { avatar } = request.changes; if (avatar) { changedFields.avatar = avatar.type !== 'remove' ? JSON.stringify(avatar) : ''; sqlUpdate.avatar = avatar.type !== 'remove' ? JSON.stringify(avatar) : null; } const threadType = request.changes.type; if (threadType !== null && threadType !== undefined) { changedFields.type = threadType; sqlUpdate.type = threadType; } if ( !ignorePermissions && threadType !== null && threadType !== undefined && threadType !== threadTypes.COMMUNITY_OPEN_SUBTHREAD && threadType !== threadTypes.COMMUNITY_SECRET_SUBTHREAD ) { throw new ServerError('invalid_parameters'); } const newMemberIDs = request.changes.newMemberIDs && request.changes.newMemberIDs.length > 0 ? [...new Set(request.changes.newMemberIDs)] : null; if ( Object.keys(sqlUpdate).length === 0 && !newMemberIDs && !forceUpdateRoot ) { throw new ServerError('invalid_parameters'); } const serverThreadInfosPromise = fetchServerThreadInfos({ threadID: request.threadID, }); const hasNecessaryPermissionsPromise = (async () => { if (ignorePermissions) { return; } const checks = []; if (sqlUpdate.name !== undefined) { checks.push({ check: 'permission', permission: threadPermissions.EDIT_THREAD_NAME, }); } if (sqlUpdate.description !== undefined) { checks.push({ check: 'permission', permission: threadPermissions.EDIT_THREAD_DESCRIPTION, }); } if (sqlUpdate.color !== undefined) { checks.push({ check: 'permission', permission: threadPermissions.EDIT_THREAD_COLOR, }); } if (sqlUpdate.avatar !== undefined) { checks.push({ check: 'permission', permission: threadPermissions.EDIT_THREAD_AVATAR, }); } if (parentThreadID !== undefined || sqlUpdate.type !== undefined) { checks.push({ check: 'permission', permission: threadPermissions.EDIT_PERMISSIONS, }); } if (newMemberIDs) { checks.push({ check: 'permission', permission: threadPermissions.ADD_MEMBERS, }); } const hasNecessaryPermissions = await checkThread( viewer, request.threadID, checks, ); if (!hasNecessaryPermissions) { throw new ServerError('invalid_credentials'); } })(); const [serverThreadInfos] = await Promise.all([ serverThreadInfosPromise, hasNecessaryPermissionsPromise, ]); const serverThreadInfo = serverThreadInfos.threadInfos[request.threadID]; if (!serverThreadInfo) { throw new ServerError('internal_error'); } // Threads with source message should be visible to everyone, but we can't // guarantee it for COMMUNITY_SECRET_SUBTHREAD threads so we forbid it for // now. In the future, if we want to support this, we would need to unlink the // source message. if ( threadType !== null && threadType !== undefined && threadType !== threadTypes.SIDEBAR && threadType !== threadTypes.COMMUNITY_OPEN_SUBTHREAD && serverThreadInfo.sourceMessageID ) { throw new ServerError('invalid_parameters'); } // You can't change the parent thread of a current or former SIDEBAR if (parentThreadID !== undefined && serverThreadInfo.sourceMessageID) { throw new ServerError('invalid_parameters'); } const oldThreadType = serverThreadInfo.type; const oldParentThreadID = serverThreadInfo.parentThreadID; const oldContainingThreadID = serverThreadInfo.containingThreadID; const oldCommunity = serverThreadInfo.community; const oldDepth = serverThreadInfo.depth; const nextThreadType = threadType !== null && threadType !== undefined ? threadType : oldThreadType; let nextParentThreadID = parentThreadID !== undefined ? parentThreadID : oldParentThreadID; // Does the new thread type preclude a parent? if ( threadType !== undefined && threadType !== null && getThreadTypeParentRequirement(threadType) === 'disabled' && nextParentThreadID !== null ) { nextParentThreadID = null; sqlUpdate.parent_thread_id = null; } // Does the new thread type require a parent? if ( threadType !== undefined && threadType !== null && getThreadTypeParentRequirement(threadType) === 'required' && nextParentThreadID === null ) { throw new ServerError('no_parent_thread_specified'); } const determineThreadAncestryPromise = determineThreadAncestry( nextParentThreadID, nextThreadType, ); const confirmParentPermissionPromise = (async () => { if (ignorePermissions || !nextParentThreadID) { return; } if ( nextParentThreadID === oldParentThreadID && (nextThreadType === threadTypes.SIDEBAR) === (oldThreadType === threadTypes.SIDEBAR) ) { return; } const hasParentPermission = await checkThreadPermission( viewer, nextParentThreadID, nextThreadType === threadTypes.SIDEBAR ? threadPermissions.CREATE_SIDEBARS : threadPermissions.CREATE_SUBCHANNELS, ); if (!hasParentPermission) { throw new ServerError('invalid_parameters'); } })(); const rolesNeedUpdate = forceUpdateRoot || nextThreadType !== oldThreadType; const validateNewMembersPromise = (async () => { if (!newMemberIDs || ignorePermissions) { return; } const defaultRolePermissionsPromise = (async () => { let rolePermissions; if (!rolesNeedUpdate) { const rolePermissionsQuery = SQL` SELECT r.permissions FROM threads t LEFT JOIN roles r ON r.special_role = ${specialRoles.DEFAULT_ROLE} AND r.thread = ${request.threadID} WHERE t.id = ${request.threadID} `; const [result] = await dbQuery(rolePermissionsQuery); if (result.length > 0) { rolePermissions = JSON.parse(result[0].permissions); } } if (!rolePermissions) { rolePermissions = getRolePermissionBlobs(nextThreadType).Members; } return rolePermissions; })(); const [defaultRolePermissions, nextThreadAncestry] = await Promise.all([ defaultRolePermissionsPromise, determineThreadAncestryPromise, ]); const { newMemberIDs: validatedIDs } = await validateCandidateMembers( viewer, { newMemberIDs }, { threadType: nextThreadType, parentThreadID: nextParentThreadID, containingThreadID: nextThreadAncestry.containingThreadID, defaultRolePermissions, }, { requireRelationship: !forceAddMembers }, ); if ( validatedIDs && Number(validatedIDs?.length) < Number(newMemberIDs?.length) ) { throw new ServerError('invalid_credentials'); } })(); const { nextThreadAncestry } = await promiseAll({ nextThreadAncestry: determineThreadAncestryPromise, confirmParentPermissionPromise, validateNewMembersPromise, }); if (nextThreadAncestry.containingThreadID !== oldContainingThreadID) { sqlUpdate.containing_thread_id = nextThreadAncestry.containingThreadID; } if (nextThreadAncestry.community !== oldCommunity) { if (!ignorePermissions) { throw new ServerError('invalid_parameters'); } sqlUpdate.community = nextThreadAncestry.community; } if (nextThreadAncestry.depth !== oldDepth) { sqlUpdate.depth = nextThreadAncestry.depth; } const updateQueryPromise = (async () => { if (Object.keys(sqlUpdate).length === 0) { return; } const { avatar: avatarUpdate, ...nonAvatarUpdates } = sqlUpdate; const updatePromises = []; if (Object.keys(nonAvatarUpdates).length > 0) { const nonAvatarUpdateQuery = SQL` UPDATE threads SET ${nonAvatarUpdates} WHERE id = ${request.threadID} `; updatePromises.push(dbQuery(nonAvatarUpdateQuery)); } if (avatarUpdate !== undefined) { const avatarUploadID = avatar && (avatar.type === 'image' || avatar.type === 'encrypted_image') ? avatar.uploadID : null; const avatarUpdateQuery = SQL` START TRANSACTION; UPDATE uploads SET container = NULL WHERE container = ${request.threadID} AND ( ${avatarUploadID} IS NULL OR EXISTS ( SELECT 1 FROM uploads WHERE id = ${avatarUploadID} AND ${avatarUploadID} IS NOT NULL AND uploader = ${viewer.userID} AND container IS NULL AND thread IS NULL ) ); UPDATE uploads SET container = ${request.threadID} WHERE id = ${avatarUploadID} AND ${avatarUploadID} IS NOT NULL AND uploader = ${viewer.userID} AND container IS NULL AND thread IS NULL; UPDATE threads SET avatar = ${avatarUpdate} WHERE id = ${request.threadID} AND ( ${avatarUploadID} IS NULL OR EXISTS ( SELECT 1 FROM uploads WHERE id = ${avatarUploadID} AND ${avatarUploadID} IS NOT NULL AND uploader = ${viewer.userID} AND container = ${request.threadID} AND thread IS NULL ) ); COMMIT; `; updatePromises.push( dbQuery(avatarUpdateQuery, { multipleStatements: true }), ); } await Promise.all(updatePromises); })(); const updateRolesPromise = (async () => { if (rolesNeedUpdate) { await updateRoles(viewer, request.threadID, nextThreadType); } })(); const addMembersChangesetPromise: Promise = (async () => { if (!newMemberIDs) { return undefined; } await Promise.all([updateQueryPromise, updateRolesPromise]); return await changeRole(request.threadID, newMemberIDs, null, { setNewMembersToUnread: true, }); })(); const recalculatePermissionsChangesetPromise: Promise = (async () => { const threadRootChanged = rolesNeedUpdate || nextParentThreadID !== oldParentThreadID; if (!threadRootChanged) { return undefined; } await Promise.all([updateQueryPromise, updateRolesPromise]); return await recalculateThreadPermissions(request.threadID); })(); const [addMembersChangeset, recalculatePermissionsChangeset] = await Promise.all([ addMembersChangesetPromise, recalculatePermissionsChangesetPromise, updateQueryPromise, updateRolesPromise, ]); const membershipRows: Array = []; const relationshipChangeset = new RelationshipChangeset(); if (recalculatePermissionsChangeset) { const { membershipRows: recalculateMembershipRows, relationshipChangeset: recalculateRelationshipChangeset, } = recalculatePermissionsChangeset; membershipRows.push(...recalculateMembershipRows); relationshipChangeset.addAll(recalculateRelationshipChangeset); } let addedMemberIDs; if (addMembersChangeset) { const { membershipRows: addMembersMembershipRows, relationshipChangeset: addMembersRelationshipChangeset, } = addMembersChangeset; addedMemberIDs = addMembersMembershipRows .filter( row => row.operation === 'save' && row.threadID === request.threadID && Number(row.role) > 0, ) .map(row => row.userID); membershipRows.push(...addMembersMembershipRows); relationshipChangeset.addAll(addMembersRelationshipChangeset); } const changeset = { membershipRows, relationshipChangeset }; const { viewerUpdates } = await commitMembershipChangeset(viewer, changeset, { // This forces an update for this thread, // regardless of whether any membership rows are changed changedThreadIDs: Object.keys(sqlUpdate).length > 0 ? new Set([request.threadID]) : new Set(), // last_message will be updated automatically if we send a message, // so we only need to handle it here when we silence new messages updateMembershipsLastMessage: silenceMessages, }); let newMessageInfos: Array = []; if (!silenceMessages) { const time = Date.now(); const messageDatas: Array = []; for (const fieldName in changedFields) { const newValue = changedFields[fieldName]; messageDatas.push({ type: messageTypes.CHANGE_SETTINGS, threadID: request.threadID, creatorID: viewer.userID, time, field: fieldName, value: newValue, }); } if (addedMemberIDs && addedMemberIDs.length > 0) { messageDatas.push({ type: messageTypes.ADD_MEMBERS, threadID: request.threadID, creatorID: viewer.userID, time, addedUserIDs: addedMemberIDs, }); } newMessageInfos = await createMessages(viewer, messageDatas); } return { updatesResult: { newUpdates: viewerUpdates }, newMessageInfos }; } async function joinThread( viewer: Viewer, request: ServerThreadJoinRequest, ): Promise { if (!viewer.loggedIn) { throw new ServerError('not_logged_in'); } const permissionCheck = request.inviteLinkSecret ? checkIfInviteLinkIsValid(request.inviteLinkSecret, request.threadID) : checkThreadPermission( viewer, request.threadID, threadPermissions.JOIN_THREAD, ); const [isMember, hasPermission] = await Promise.all([ fetchViewerIsMember(viewer, request.threadID), permissionCheck, ]); if (!hasPermission) { throw new ServerError('invalid_parameters'); } const { calendarQuery } = request; if (isMember) { const response: ThreadJoinResult = { rawMessageInfos: [], truncationStatuses: {}, userInfos: {}, updatesResult: { newUpdates: [], }, }; return response; } if (calendarQuery) { const threadFilterIDs = filteredThreadIDs(calendarQuery.filters); if ( !threadFilterIDs || threadFilterIDs.size !== 1 || threadFilterIDs.values().next().value !== request.threadID ) { throw new ServerError('invalid_parameters'); } } const changeset = await changeRole(request.threadID, [viewer.userID], null); const membershipResult = await commitMembershipChangeset(viewer, changeset, { calendarQuery, }); if (request.inviteLinkSecret) { - handleAsyncPromise(reportLinkUsage(request.inviteLinkSecret)); + ignorePromiseRejections(reportLinkUsage(request.inviteLinkSecret)); } const messageData = { type: messageTypes.JOIN_THREAD, threadID: request.threadID, creatorID: viewer.userID, time: Date.now(), }; const newMessages = await createMessages(viewer, [messageData]); return { rawMessageInfos: newMessages, truncationStatuses: {}, userInfos: membershipResult.userInfos, updatesResult: { newUpdates: membershipResult.viewerUpdates, }, }; } async function toggleMessagePinForThread( viewer: Viewer, request: ToggleMessagePinRequest, ): Promise { const { messageID, action } = request; const targetMessage = await fetchMessageInfoByID(viewer, messageID); if (!targetMessage) { throw new ServerError('invalid_parameters'); } const { threadID } = targetMessage; const fetchServerThreadInfosResult = await fetchServerThreadInfos({ threadID, }); const { threadInfos: rawThreadInfos } = rawThreadInfosFromServerThreadInfos( viewer, fetchServerThreadInfosResult, ); const rawThreadInfo = rawThreadInfos[threadID]; const canTogglePin = canToggleMessagePin(targetMessage, rawThreadInfo); if (!canTogglePin) { throw new ServerError('invalid_parameters'); } const pinnedValue = action === 'pin' ? 1 : 0; const pinTimeValue = action === 'pin' ? Date.now() : null; const pinnedCountValue = action === 'pin' ? 1 : -1; const query = SQL` UPDATE messages AS m, threads AS t SET m.pinned = ${pinnedValue}, m.pin_time = ${pinTimeValue}, t.pinned_count = t.pinned_count + ${pinnedCountValue} WHERE m.id = ${messageID} AND m.thread = ${threadID} AND t.id = ${threadID} AND m.pinned != ${pinnedValue} `; const [result] = await dbQuery(query); if (result.affectedRows === 0) { return { newMessageInfos: [], threadID, }; } const createMessagesAsync = async () => { const messageData = { type: messageTypes.TOGGLE_PIN, threadID, targetMessageID: messageID, action, pinnedContent: getPinnedContentFromMessage(targetMessage), creatorID: viewer.userID, time: Date.now(), }; const newMessageInfos = await createMessages(viewer, [messageData]); return newMessageInfos; }; const createUpdatesAsync = async () => { const { threadInfos: serverThreadInfos } = fetchServerThreadInfosResult; const time = Date.now(); const updates = []; for (const member of serverThreadInfos[threadID].members) { updates.push({ userID: member.id, time, threadID, type: updateTypes.UPDATE_THREAD, }); } await createUpdates(updates); }; const [newMessageInfos] = await Promise.all([ createMessagesAsync(), createUpdatesAsync(), ]); return { newMessageInfos, threadID, }; } export { updateRole, removeMembers, leaveThread, updateThread, joinThread, toggleMessagePinForThread, }; diff --git a/lib/utils/promises.js b/lib/utils/promises.js index a673060ad..7bd6df564 100644 --- a/lib/utils/promises.js +++ b/lib/utils/promises.js @@ -1,32 +1,40 @@ // @flow type Promisable = Promise | T; async function promiseAll }>( input: T, ): Promise<$ObjMap> { const promises = []; const keys = Object.keys(input); for (let i = 0; i < keys.length; i++) { const key = keys[i]; const promise = input[key]; promises.push(promise); } const results = await Promise.all(promises); const byName: { [string]: mixed } = {}; for (let i = 0; i < keys.length; i++) { const key = keys[i]; byName[key] = results[i]; } return byName; } async function promiseFilter( input: $ReadOnlyArray, filterFunc: (value: T) => Promise, ): Promise { const filterResults = await Promise.all(input.map(filterFunc)); return input.filter((value, index) => filterResults[index]); } -export { promiseAll, promiseFilter }; +async function ignorePromiseRejections(promise: Promise) { + try { + await promise; + } catch (error) { + console.warn(error); + } +} + +export { promiseAll, promiseFilter, ignorePromiseRejections };