diff --git a/keyserver/src/push/rescind.js b/keyserver/src/push/rescind.js index a2b8855ea..b33853c81 100644 --- a/keyserver/src/push/rescind.js +++ b/keyserver/src/push/rescind.js @@ -1,379 +1,379 @@ // @flow import apn from '@parse/node-apn'; import type { ResponseFailure } from '@parse/node-apn'; import type { FirebaseError } from 'firebase-admin'; import invariant from 'invariant'; import type { PlatformDetails } from 'lib/types/device-types.js'; import { threadSubscriptions } from 'lib/types/subscription-types.js'; import { threadPermissions } from 'lib/types/thread-permission-types.js'; import { promiseAll } from 'lib/utils/promises.js'; import { tID } from 'lib/utils/validation-utils.js'; import { prepareEncryptedAndroidNotificationRescinds, prepareEncryptedIOSNotificationRescind, } from './crypto.js'; import { getAPNsNotificationTopic } from './providers.js'; import type { NotificationTargetDevice, TargetedAndroidNotification, TargetedAPNsNotification, } from './types.js'; import { apnPush, fcmPush, type APNPushResult, type FCMPushResult, } from './utils.js'; import createIDs from '../creators/id-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import type { SQLStatementType } from '../database/types.js'; import { validateOutput } from '../utils/validation-utils.js'; type ParsedDelivery = { +platform: 'ios' | 'macos' | 'android', +codeVersion: ?number, +stateVersion: ?number, +notificationID: string, +deviceTokens: $ReadOnlyArray, }; type RescindDelivery = { source: 'rescind', rescindedID: string, errors?: | $ReadOnlyArray | $ReadOnlyArray, }; async function rescindPushNotifs( notifCondition: SQLStatementType, inputCountCondition?: SQLStatementType, ) { const notificationExtractString = `$.${threadSubscriptions.home}`; const visPermissionExtractString = `$.${threadPermissions.VISIBLE}.value`; const fetchQuery = SQL` SELECT n.id, n.user, n.thread, n.message, n.delivery, n.collapse_key, COUNT( `; fetchQuery.append(inputCountCondition ? inputCountCondition : SQL`m.thread`); fetchQuery.append(SQL` ) AS unread_count FROM notifications n LEFT JOIN memberships m ON m.user = n.user AND m.last_message > m.last_read_message AND m.role > 0 AND JSON_EXTRACT(subscription, ${notificationExtractString}) AND JSON_EXTRACT(permissions, ${visPermissionExtractString}) WHERE n.rescinded = 0 AND `); fetchQuery.append(notifCondition); fetchQuery.append(SQL` GROUP BY n.id, m.user`); const [fetchResult] = await dbQuery(fetchQuery); const allDeviceTokens = new Set(); const parsedDeliveries: { [string]: $ReadOnlyArray } = {}; for (const row of fetchResult) { const rawDelivery = JSON.parse(row.delivery); const deliveries = Array.isArray(rawDelivery) ? rawDelivery : [rawDelivery]; const id = row.id.toString(); const rowParsedDeliveries = []; for (const delivery of deliveries) { if ( delivery.iosID || delivery.deviceType === 'ios' || delivery.deviceType === 'macos' ) { const deviceTokens = delivery.iosDeviceTokens ?? delivery.deviceTokens; rowParsedDeliveries.push({ notificationID: delivery.iosID, codeVersion: delivery.codeVersion, stateVersion: delivery.stateVersion, platform: delivery.deviceType ?? 'ios', deviceTokens, }); deviceTokens.forEach(deviceToken => allDeviceTokens.add(deviceToken)); } else if (delivery.androidID || delivery.deviceType === 'android') { const deviceTokens = delivery.androidDeviceTokens ?? delivery.deviceTokens; rowParsedDeliveries.push({ notificationID: row.collapse_key ? row.collapse_key : id, codeVersion: delivery.codeVersion, stateVersion: delivery.stateVersion, platform: 'android', deviceTokens, }); deviceTokens.forEach(deviceToken => allDeviceTokens.add(deviceToken)); } } parsedDeliveries[id] = rowParsedDeliveries; } const deviceTokenToCookieID = await getDeviceTokenToCookieID(allDeviceTokens); const deliveryPromises: { [string]: Promise | Promise, } = {}; const notifInfo = {}; const rescindedIDs = []; for (const row of fetchResult) { const id = row.id.toString(); const threadID = row.thread.toString(); notifInfo[id] = { userID: row.user.toString(), threadID, messageID: row.message.toString(), }; for (const delivery of parsedDeliveries[id]) { let platformDetails: PlatformDetails = { platform: delivery.platform }; if (delivery.codeVersion) { platformDetails = { ...platformDetails, codeVersion: delivery.codeVersion, }; } if (delivery.stateVersion) { platformDetails = { ...platformDetails, stateVersion: delivery.stateVersion, }; } if (delivery.platform === 'ios') { const devices = delivery.deviceTokens.map(deviceToken => ({ deviceToken, cookieID: deviceTokenToCookieID[deviceToken], })); const deliveryPromise = (async () => { const targetedNotifications = await prepareIOSNotification( delivery.notificationID, row.unread_count, threadID, platformDetails, devices, ); return await apnPush({ targetedNotifications, platformDetails: { platform: 'ios', codeVersion: delivery.codeVersion, }, }); })(); deliveryPromises[id] = deliveryPromise; } else if (delivery.platform === 'android') { const devices = delivery.deviceTokens.map(deviceToken => ({ deviceToken, cookieID: deviceTokenToCookieID[deviceToken], })); const deliveryPromise = (async () => { const targetedNotifications = await prepareAndroidNotification( delivery.notificationID, row.unread_count, threadID, platformDetails, devices, ); return await fcmPush({ targetedNotifications, codeVersion: delivery.codeVersion, }); })(); deliveryPromises[id] = deliveryPromise; } } rescindedIDs.push(id); } const numRescinds = Object.keys(deliveryPromises).length; const dbIDsPromise: Promise> = (async () => { if (numRescinds === 0) { return undefined; } return await createIDs('notifications', numRescinds); })(); const rescindPromise: Promise = (async () => { if (rescindedIDs.length === 0) { return undefined; } const rescindQuery = SQL` UPDATE notifications SET rescinded = 1 WHERE id IN (${rescindedIDs}) `; return await dbQuery(rescindQuery); })(); const [deliveryResults, dbIDs] = await Promise.all([ promiseAll(deliveryPromises), dbIDsPromise, rescindPromise, ]); const newNotifRows = []; if (numRescinds > 0) { invariant(dbIDs, 'dbIDs should be set'); for (const rescindedID in deliveryResults) { const delivery: RescindDelivery = { source: 'rescind', rescindedID, }; const { errors } = deliveryResults[rescindedID]; if (errors) { delivery.errors = errors; } const dbID = dbIDs.shift(); const { userID, threadID, messageID } = notifInfo[rescindedID]; newNotifRows.push([ dbID, userID, threadID, messageID, null, JSON.stringify([delivery]), 1, ]); } } if (newNotifRows.length > 0) { const insertQuery = SQL` INSERT INTO notifications (id, user, thread, message, collapse_key, delivery, rescinded) VALUES ${newNotifRows} `; await dbQuery(insertQuery); } } async function getDeviceTokenToCookieID( deviceTokens: Set, ): Promise<{ +[string]: string }> { if (deviceTokens.size === 0) { return {}; } const deviceTokenToCookieID = {}; const fetchCookiesQuery = SQL` SELECT id, device_token FROM cookies WHERE device_token IN (${[...deviceTokens]}) `; const [fetchResult] = await dbQuery(fetchCookiesQuery); for (const row of fetchResult) { deviceTokenToCookieID[row.device_token.toString()] = row.id.toString(); } return deviceTokenToCookieID; } async function conditionallyEncryptNotification( notification: T, codeVersion: ?number, devices: $ReadOnlyArray, encryptCallback: ( devices: $ReadOnlyArray, notification: T, codeVersion?: ?number, ) => Promise< $ReadOnlyArray<{ +notification: T, +cookieID: string, +deviceToken: string, +encryptionOrder?: number, }>, >, ): Promise<$ReadOnlyArray<{ +deviceToken: string, +notification: T }>> { const shouldBeEncrypted = codeVersion && codeVersion >= 233; if (!shouldBeEncrypted) { return devices.map(({ deviceToken }) => ({ notification, deviceToken, })); } const notifications = await encryptCallback( devices, notification, codeVersion, ); return notifications.map(({ deviceToken, notification: notif }) => ({ deviceToken, notification: notif, })); } async function prepareIOSNotification( iosID: string, unreadCount: number, threadID: string, platformDetails: PlatformDetails, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - threadID = validateOutput(platformDetails, tID, threadID); + threadID = await validateOutput(platformDetails, tID, threadID); const { codeVersion } = platformDetails; const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic({ platform: 'ios', codeVersion, }); if (codeVersion && codeVersion > 198) { notification.mutableContent = true; notification.pushType = 'alert'; notification.badge = unreadCount; } else { notification.priority = 5; notification.contentAvailable = true; notification.pushType = 'background'; } notification.payload = codeVersion && codeVersion > 135 ? { backgroundNotifType: 'CLEAR', notificationId: iosID, setUnreadStatus: true, threadID, } : { managedAps: { action: 'CLEAR', notificationId: iosID, }, }; return await conditionallyEncryptNotification( notification, codeVersion, devices, prepareEncryptedIOSNotificationRescind, ); } async function prepareAndroidNotification( notifID: string, unreadCount: number, threadID: string, platformDetails: PlatformDetails, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - threadID = validateOutput(platformDetails, tID, threadID); + threadID = await validateOutput(platformDetails, tID, threadID); const { codeVersion } = platformDetails; const notification = { data: { badge: unreadCount.toString(), rescind: 'true', rescindID: notifID, setUnreadStatus: 'true', threadID, }, }; return await conditionallyEncryptNotification( notification, codeVersion, devices, prepareEncryptedAndroidNotificationRescinds, ); } export { rescindPushNotifs }; diff --git a/keyserver/src/push/send.js b/keyserver/src/push/send.js index 810ae8df2..5c3b5aca8 100644 --- a/keyserver/src/push/send.js +++ b/keyserver/src/push/send.js @@ -1,1810 +1,1810 @@ // @flow import type { ResponseFailure } from '@parse/node-apn'; import apn from '@parse/node-apn'; import invariant from 'invariant'; import _cloneDeep from 'lodash/fp/cloneDeep.js'; import _flow from 'lodash/fp/flow.js'; import _groupBy from 'lodash/fp/groupBy.js'; import _mapValues from 'lodash/fp/mapValues.js'; import _pickBy from 'lodash/fp/pickBy.js'; import type { QueryResults } from 'mysql'; import t from 'tcomb'; import uuidv4 from 'uuid/v4.js'; import { oldValidUsernameRegex } from 'lib/shared/account-utils.js'; import { isUserMentioned } from 'lib/shared/mention-utils.js'; import { createMessageInfo, shimUnsupportedRawMessageInfos, sortMessageInfoList, } from 'lib/shared/message-utils.js'; import { messageSpecs } from 'lib/shared/messages/message-specs.js'; import { notifTextsForMessageInfo } from 'lib/shared/notif-utils.js'; import { rawThreadInfoFromServerThreadInfo, threadInfoFromRawThreadInfo, } from 'lib/shared/thread-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { Platform, PlatformDetails } from 'lib/types/device-types.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import { type MessageData, type RawMessageInfo, rawMessageInfoValidator, } from 'lib/types/message-types.js'; import type { ThreadInfo } from 'lib/types/minimally-encoded-thread-permissions-types.js'; import type { ResolvedNotifTexts } from 'lib/types/notif-types.js'; import { resolvedNotifTextsValidator } from 'lib/types/notif-types.js'; import type { ServerThreadInfo } from 'lib/types/thread-types.js'; import { updateTypes } from 'lib/types/update-types-enum.js'; import { type GlobalUserInfo } from 'lib/types/user-types.js'; import { values } from 'lib/utils/objects.js'; import { tID, tPlatformDetails, tShape } from 'lib/utils/validation-utils.js'; import { prepareEncryptedAndroidNotifications, prepareEncryptedAPNsNotifications, prepareEncryptedWebNotifications, prepareEncryptedWNSNotifications, } from './crypto.js'; import { getAPNsNotificationTopic } from './providers.js'; import { rescindPushNotifs } from './rescind.js'; import type { AndroidNotification, NotificationTargetDevice, TargetedAndroidNotification, TargetedAPNsNotification, TargetedWebNotification, TargetedWNSNotification, } from './types.js'; import { apnMaxNotificationPayloadByteSize, apnPush, fcmMaxNotificationPayloadByteSize, fcmPush, getUnreadCounts, webPush, type WebPushError, wnsMaxNotificationPayloadByteSize, wnsPush, type WNSPushError, } from './utils.js'; import createIDs from '../creators/id-creator.js'; import { createUpdates } from '../creators/update-creator.js'; import { dbQuery, mergeOrConditions, SQL } from '../database/database.js'; import type { CollapsableNotifInfo } from '../fetchers/message-fetchers.js'; import { fetchCollapsableNotifs } from '../fetchers/message-fetchers.js'; import { fetchServerThreadInfos } from '../fetchers/thread-fetchers.js'; import { fetchUserInfos } from '../fetchers/user-fetchers.js'; import type { Viewer } from '../session/viewer.js'; import { getENSNames } from '../utils/ens-cache.js'; import { validateOutput } from '../utils/validation-utils.js'; export type Device = { +platform: Platform, +deviceToken: string, +cookieID: string, +codeVersion: ?number, +stateVersion: ?number, +majorDesktopVersion: ?number, }; export type PushUserInfo = { +devices: Device[], // messageInfos and messageDatas have the same key +messageInfos: RawMessageInfo[], +messageDatas: MessageData[], }; type Delivery = PushDelivery | { collapsedInto: string }; type NotificationRow = { +dbID: string, +userID: string, +threadID?: ?string, +messageID?: ?string, +collapseKey?: ?string, +deliveries: Delivery[], }; export type PushInfo = { [userID: string]: PushUserInfo }; async function sendPushNotifs(pushInfo: PushInfo) { if (Object.keys(pushInfo).length === 0) { return; } const [ unreadCounts, { usersToCollapsableNotifInfo, serverThreadInfos, userInfos }, dbIDs, ] = await Promise.all([ getUnreadCounts(Object.keys(pushInfo)), fetchInfos(pushInfo), createDBIDs(pushInfo), ]); const preparePromises: Array>> = []; const notifications: Map = new Map(); for (const userID in usersToCollapsableNotifInfo) { const threadInfos = _flow( _mapValues((serverThreadInfo: ServerThreadInfo) => { const rawThreadInfo = rawThreadInfoFromServerThreadInfo( serverThreadInfo, userID, { minimallyEncodePermissions: true }, ); if (!rawThreadInfo) { return null; } invariant( rawThreadInfo.minimallyEncoded, 'rawThreadInfo from rawThreadInfoFromServerThreadInfo must be ' + 'minimallyEncoded when minimallyEncodePermissions option is set', ); return threadInfoFromRawThreadInfo(rawThreadInfo, userID, userInfos); }), _pickBy(threadInfo => threadInfo), )(serverThreadInfos); for (const notifInfo of usersToCollapsableNotifInfo[userID]) { preparePromises.push( preparePushNotif({ notifInfo, userID, pushUserInfo: pushInfo[userID], unreadCount: unreadCounts[userID], threadInfos, userInfos, dbIDs, rowsToSave: notifications, }), ); } } const prepareResults = await Promise.all(preparePromises); const flattenedPrepareResults = prepareResults.filter(Boolean).flat(); const deliveryResults = await deliverPushNotifsInEncryptionOrder( flattenedPrepareResults, ); const cleanUpPromise = (async () => { if (dbIDs.length === 0) { return; } const query = SQL`DELETE FROM ids WHERE id IN (${dbIDs})`; await dbQuery(query); })(); await Promise.all([ cleanUpPromise, saveNotifResults(deliveryResults, notifications, true), ]); } type PreparePushResult = { +platform: Platform, +notificationInfo: NotificationInfo, +notification: | TargetedAPNsNotification | TargetedAndroidNotification | TargetedWebNotification | TargetedWNSNotification, }; async function preparePushNotif(input: { notifInfo: CollapsableNotifInfo, userID: string, pushUserInfo: PushUserInfo, unreadCount: number, threadInfos: { +[threadID: string]: ThreadInfo, }, userInfos: { +[userID: string]: GlobalUserInfo }, dbIDs: string[], // mutable rowsToSave: Map, // mutable }): Promise> { const { notifInfo, userID, pushUserInfo, unreadCount, threadInfos, userInfos, dbIDs, rowsToSave, } = input; const hydrateMessageInfo = (rawMessageInfo: RawMessageInfo) => createMessageInfo(rawMessageInfo, userID, userInfos, threadInfos); const newMessageInfos = []; const newRawMessageInfos = []; for (const newRawMessageInfo of notifInfo.newMessageInfos) { const newMessageInfo = hydrateMessageInfo(newRawMessageInfo); if (newMessageInfo) { newMessageInfos.push(newMessageInfo); newRawMessageInfos.push(newRawMessageInfo); } } if (newMessageInfos.length === 0) { return null; } const existingMessageInfos = notifInfo.existingMessageInfos .map(hydrateMessageInfo) .filter(Boolean); const allMessageInfos = sortMessageInfoList([ ...newMessageInfos, ...existingMessageInfos, ]); const [firstNewMessageInfo, ...remainingNewMessageInfos] = newMessageInfos; const { threadID } = firstNewMessageInfo; const threadInfo = threadInfos[threadID]; const parentThreadInfo = threadInfo.parentThreadID ? threadInfos[threadInfo.parentThreadID] : null; const updateBadge = threadInfo.currentUser.subscription.home; const displayBanner = threadInfo.currentUser.subscription.pushNotifs; const username = userInfos[userID] && userInfos[userID].username; let resolvedUsername; if (getENSNames) { const userInfosWithENSNames = await getENSNames([userInfos[userID]]); resolvedUsername = userInfosWithENSNames[0].username; } const userWasMentioned = username && threadInfo.currentUser.role && oldValidUsernameRegex.test(username) && newMessageInfos.some(newMessageInfo => { const unwrappedMessageInfo = newMessageInfo.type === messageTypes.SIDEBAR_SOURCE ? newMessageInfo.sourceMessage : newMessageInfo; return ( unwrappedMessageInfo.type === messageTypes.TEXT && (isUserMentioned(username, unwrappedMessageInfo.text) || (resolvedUsername && isUserMentioned(resolvedUsername, unwrappedMessageInfo.text))) ); }); if (!updateBadge && !displayBanner && !userWasMentioned) { return null; } const badgeOnly = !displayBanner && !userWasMentioned; const notifTargetUserInfo = { id: userID, username }; const notifTexts = await notifTextsForMessageInfo( allMessageInfos, threadInfo, parentThreadInfo, notifTargetUserInfo, getENSNames, ); if (!notifTexts) { return null; } const dbID = dbIDs.shift(); invariant(dbID, 'should have sufficient DB IDs'); const byPlatform = getDevicesByPlatform(pushUserInfo.devices); const firstMessageID = firstNewMessageInfo.id; invariant(firstMessageID, 'RawMessageInfo.id should be set on server'); const notificationInfo = { source: 'new_message', dbID, userID, threadID, messageID: firstMessageID, collapseKey: notifInfo.collapseKey, }; const preparePromises: Array>> = []; const iosVersionsToTokens = byPlatform.get('ios'); if (iosVersionsToTokens) { for (const [versionKey, devices] of iosVersionsToTokens) { const { codeVersion, stateVersion } = stringToVersionKey(versionKey); const platformDetails: PlatformDetails = { platform: 'ios', codeVersion, stateVersion, }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, platformDetails, ); const preparePromise: Promise<$ReadOnlyArray> = (async () => { const targetedNotifications = await prepareAPNsNotification( { notifTexts, newRawMessageInfos: shimmedNewRawMessageInfos, threadID: threadInfo.id, collapseKey: notifInfo.collapseKey, badgeOnly, unreadCount, platformDetails, }, devices, ); return targetedNotifications.map(notification => ({ notification, platform: 'ios', notificationInfo: { ...notificationInfo, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const androidVersionsToTokens = byPlatform.get('android'); if (androidVersionsToTokens) { for (const [versionKey, devices] of androidVersionsToTokens) { const { codeVersion, stateVersion } = stringToVersionKey(versionKey); const platformDetails = { platform: 'android', codeVersion, stateVersion, }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, platformDetails, ); const preparePromise: Promise<$ReadOnlyArray> = (async () => { const targetedNotifications = await prepareAndroidNotification( { notifTexts, newRawMessageInfos: shimmedNewRawMessageInfos, threadID: threadInfo.id, collapseKey: notifInfo.collapseKey, badgeOnly, unreadCount, platformDetails, dbID, }, devices, ); return targetedNotifications.map(notification => ({ notification, platform: 'android', notificationInfo: { ...notificationInfo, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const webVersionsToTokens = byPlatform.get('web'); if (webVersionsToTokens) { for (const [versionKey, devices] of webVersionsToTokens) { const { codeVersion, stateVersion } = stringToVersionKey(versionKey); const platformDetails = { platform: 'web', codeVersion, stateVersion, }; const preparePromise: Promise<$ReadOnlyArray> = (async () => { const targetedNotifications = await prepareWebNotification( { notifTexts, threadID: threadInfo.id, unreadCount, platformDetails, }, devices, ); return targetedNotifications.map(notification => ({ notification, platform: 'web', notificationInfo: { ...notificationInfo, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const macosVersionsToTokens = byPlatform.get('macos'); if (macosVersionsToTokens) { for (const [versionKey, devices] of macosVersionsToTokens) { const { codeVersion, stateVersion, majorDesktopVersion } = stringToVersionKey(versionKey); const platformDetails = { platform: 'macos', codeVersion, stateVersion, majorDesktopVersion, }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, platformDetails, ); const preparePromise: Promise<$ReadOnlyArray> = (async () => { const targetedNotifications = await prepareAPNsNotification( { notifTexts, newRawMessageInfos: shimmedNewRawMessageInfos, threadID: threadInfo.id, collapseKey: notifInfo.collapseKey, badgeOnly, unreadCount, platformDetails, }, devices, ); return targetedNotifications.map(notification => ({ notification, platform: 'macos', notificationInfo: { ...notificationInfo, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const windowsVersionsToTokens = byPlatform.get('windows'); if (windowsVersionsToTokens) { for (const [versionKey, devices] of windowsVersionsToTokens) { const { codeVersion, stateVersion, majorDesktopVersion } = stringToVersionKey(versionKey); const platformDetails = { platform: 'windows', codeVersion, stateVersion, majorDesktopVersion, }; const preparePromise: Promise<$ReadOnlyArray> = (async () => { const targetedNotifications = await prepareWNSNotification(devices, { notifTexts, threadID: threadInfo.id, unreadCount, platformDetails, }); return targetedNotifications.map(notification => ({ notification, platform: 'windows', notificationInfo: { ...notificationInfo, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } for (const newMessageInfo of remainingNewMessageInfos) { const newDBID = dbIDs.shift(); invariant(newDBID, 'should have sufficient DB IDs'); const messageID = newMessageInfo.id; invariant(messageID, 'RawMessageInfo.id should be set on server'); rowsToSave.set(newDBID, { dbID: newDBID, userID, threadID: newMessageInfo.threadID, messageID, collapseKey: notifInfo.collapseKey, deliveries: [{ collapsedInto: dbID }], }); } const prepareResults = await Promise.all(preparePromises); return prepareResults.flat(); } // For better readability we don't differentiate between // encrypted and unencrypted notifs and order them together function compareEncryptionOrder( pushNotif1: PreparePushResult, pushNotif2: PreparePushResult, ): number { const order1 = pushNotif1.notification.encryptionOrder ?? 0; const order2 = pushNotif2.notification.encryptionOrder ?? 0; return order1 - order2; } async function deliverPushNotifsInEncryptionOrder( preparedPushNotifs: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { const deliveryPromises: Array>> = []; const groupedByDevice = _groupBy( preparedPushNotif => preparedPushNotif.deviceToken, )(preparedPushNotifs); for (const preparedPushNotifsForDevice of values(groupedByDevice)) { const orderedPushNotifsForDevice = preparedPushNotifsForDevice.sort( compareEncryptionOrder, ); const deviceDeliveryPromise = (async () => { const deliveries = []; for (const preparedPushNotif of orderedPushNotifsForDevice) { const { platform, notification, notificationInfo } = preparedPushNotif; let delivery: PushResult; if (platform === 'ios' || platform === 'macos') { delivery = await sendAPNsNotification( platform, [notification], notificationInfo, ); } else if (platform === 'android') { delivery = await sendAndroidNotification( [notification], notificationInfo, ); } else if (platform === 'web') { delivery = await sendWebNotifications( [notification], notificationInfo, ); } else if (platform === 'windows') { delivery = await sendWNSNotification( [notification], notificationInfo, ); } if (delivery) { deliveries.push(delivery); } } return deliveries; })(); deliveryPromises.push(deviceDeliveryPromise); } const deliveryResults = await Promise.all(deliveryPromises); return deliveryResults.flat(); } async function sendRescindNotifs(rescindInfo: PushInfo) { if (Object.keys(rescindInfo).length === 0) { return; } const usersToCollapsableNotifInfo = await fetchCollapsableNotifs(rescindInfo); const promises = []; for (const userID in usersToCollapsableNotifInfo) { for (const notifInfo of usersToCollapsableNotifInfo[userID]) { for (const existingMessageInfo of notifInfo.existingMessageInfos) { const rescindCondition = SQL` n.user = ${userID} AND n.thread = ${existingMessageInfo.threadID} AND n.message = ${existingMessageInfo.id} `; promises.push(rescindPushNotifs(rescindCondition)); } } } await Promise.all(promises); } // The results in deliveryResults will be combined with the rows // in rowsToSave and then written to the notifications table async function saveNotifResults( deliveryResults: $ReadOnlyArray, inputRowsToSave: Map, rescindable: boolean, ) { const rowsToSave = new Map(inputRowsToSave); const allInvalidTokens = []; for (const deliveryResult of deliveryResults) { const { info, delivery, invalidTokens } = deliveryResult; const { dbID, userID } = info; const curNotifRow = rowsToSave.get(dbID); if (curNotifRow) { curNotifRow.deliveries.push(delivery); } else { // Ternary expressions for Flow const threadID = info.threadID ? info.threadID : null; const messageID = info.messageID ? info.messageID : null; const collapseKey = info.collapseKey ? info.collapseKey : null; rowsToSave.set(dbID, { dbID, userID, threadID, messageID, collapseKey, deliveries: [delivery], }); } if (invalidTokens) { allInvalidTokens.push({ userID, tokens: invalidTokens, }); } } const notificationRows = []; for (const notification of rowsToSave.values()) { notificationRows.push([ notification.dbID, notification.userID, notification.threadID, notification.messageID, notification.collapseKey, JSON.stringify(notification.deliveries), Number(!rescindable), ]); } const dbPromises: Array> = []; if (allInvalidTokens.length > 0) { dbPromises.push(removeInvalidTokens(allInvalidTokens)); } if (notificationRows.length > 0) { const query = SQL` INSERT INTO notifications (id, user, thread, message, collapse_key, delivery, rescinded) VALUES ${notificationRows} `; dbPromises.push(dbQuery(query)); } if (dbPromises.length > 0) { await Promise.all(dbPromises); } } async function fetchInfos(pushInfo: PushInfo) { const usersToCollapsableNotifInfo = await fetchCollapsableNotifs(pushInfo); const threadIDs = new Set(); const threadWithChangedNamesToMessages = new Map>(); const addThreadIDsFromMessageInfos = (rawMessageInfo: RawMessageInfo) => { const threadID = rawMessageInfo.threadID; threadIDs.add(threadID); const messageSpec = messageSpecs[rawMessageInfo.type]; if (messageSpec.threadIDs) { for (const id of messageSpec.threadIDs(rawMessageInfo)) { threadIDs.add(id); } } if ( rawMessageInfo.type === messageTypes.CHANGE_SETTINGS && rawMessageInfo.field === 'name' ) { const messages = threadWithChangedNamesToMessages.get(threadID); if (messages) { messages.push(rawMessageInfo.id); } else { threadWithChangedNamesToMessages.set(threadID, [rawMessageInfo.id]); } } }; for (const userID in usersToCollapsableNotifInfo) { for (const notifInfo of usersToCollapsableNotifInfo[userID]) { for (const rawMessageInfo of notifInfo.existingMessageInfos) { addThreadIDsFromMessageInfos(rawMessageInfo); } for (const rawMessageInfo of notifInfo.newMessageInfos) { addThreadIDsFromMessageInfos(rawMessageInfo); } } } // These threadInfos won't have currentUser set const threadPromise = fetchServerThreadInfos({ threadIDs }); const oldNamesPromise: Promise = (async () => { if (threadWithChangedNamesToMessages.size === 0) { return undefined; } const typesThatAffectName = [ messageTypes.CHANGE_SETTINGS, messageTypes.CREATE_THREAD, ]; const oldNameQuery = SQL` SELECT IF( JSON_TYPE(JSON_EXTRACT(m.content, "$.name")) = 'NULL', "", JSON_UNQUOTE(JSON_EXTRACT(m.content, "$.name")) ) AS name, m.thread FROM ( SELECT MAX(id) AS id FROM messages WHERE type IN (${typesThatAffectName}) AND JSON_EXTRACT(content, "$.name") IS NOT NULL AND`; const threadClauses = []; for (const [threadID, messages] of threadWithChangedNamesToMessages) { threadClauses.push( SQL`(thread = ${threadID} AND id NOT IN (${messages}))`, ); } oldNameQuery.append(mergeOrConditions(threadClauses)); oldNameQuery.append(SQL` GROUP BY thread ) x LEFT JOIN messages m ON m.id = x.id `); return await dbQuery(oldNameQuery); })(); const [threadResult, oldNames] = await Promise.all([ threadPromise, oldNamesPromise, ]); const serverThreadInfos = { ...threadResult.threadInfos }; if (oldNames) { const [result] = oldNames; for (const row of result) { const threadID = row.thread.toString(); serverThreadInfos[threadID] = { ...serverThreadInfos[threadID], name: row.name, }; } } const userInfos = await fetchNotifUserInfos( serverThreadInfos, usersToCollapsableNotifInfo, ); return { usersToCollapsableNotifInfo, serverThreadInfos, userInfos }; } async function fetchNotifUserInfos( serverThreadInfos: { +[threadID: string]: ServerThreadInfo }, usersToCollapsableNotifInfo: { +[userID: string]: CollapsableNotifInfo[] }, ) { const missingUserIDs = new Set(); for (const threadID in serverThreadInfos) { const serverThreadInfo = serverThreadInfos[threadID]; for (const member of serverThreadInfo.members) { missingUserIDs.add(member.id); } } const addUserIDsFromMessageInfos = (rawMessageInfo: RawMessageInfo) => { missingUserIDs.add(rawMessageInfo.creatorID); const userIDs = messageSpecs[rawMessageInfo.type].userIDs?.(rawMessageInfo) ?? []; for (const userID of userIDs) { missingUserIDs.add(userID); } }; for (const userID in usersToCollapsableNotifInfo) { missingUserIDs.add(userID); for (const notifInfo of usersToCollapsableNotifInfo[userID]) { for (const rawMessageInfo of notifInfo.existingMessageInfos) { addUserIDsFromMessageInfos(rawMessageInfo); } for (const rawMessageInfo of notifInfo.newMessageInfos) { addUserIDsFromMessageInfos(rawMessageInfo); } } } return await fetchUserInfos([...missingUserIDs]); } async function createDBIDs(pushInfo: PushInfo): Promise { let numIDsNeeded = 0; for (const userID in pushInfo) { numIDsNeeded += pushInfo[userID].messageInfos.length; } return await createIDs('notifications', numIDsNeeded); } type VersionKey = { +codeVersion: number, +stateVersion: number, +majorDesktopVersion?: number, }; const versionKeyRegex: RegExp = new RegExp(/^-?\d+\|-?\d+(\|-?\d+)?$/); function versionKeyToString(versionKey: VersionKey): string { const baseStringVersionKey = `${versionKey.codeVersion}|${versionKey.stateVersion}`; if (!versionKey.majorDesktopVersion) { return baseStringVersionKey; } return `${baseStringVersionKey}|${versionKey.majorDesktopVersion}`; } function stringToVersionKey(versionKeyString: string): VersionKey { invariant( versionKeyRegex.test(versionKeyString), 'should pass correct version key string', ); const [codeVersion, stateVersion, majorDesktopVersion] = versionKeyString .split('|') .map(Number); return { codeVersion, stateVersion, majorDesktopVersion }; } function getDevicesByPlatform( devices: $ReadOnlyArray, ): Map>> { const byPlatform = new Map< Platform, Map>, >(); for (const device of devices) { let innerMap = byPlatform.get(device.platform); if (!innerMap) { innerMap = new Map>(); byPlatform.set(device.platform, innerMap); } const codeVersion: number = device.codeVersion !== null && device.codeVersion !== undefined ? device.codeVersion : -1; const stateVersion: number = device.stateVersion ?? -1; let versionsObject = { codeVersion, stateVersion }; if (device.majorDesktopVersion) { versionsObject = { ...versionsObject, majorDesktopVersion: device.majorDesktopVersion, }; } const versionKey = versionKeyToString(versionsObject); let innerMostArrayTmp: ?Array = innerMap.get(versionKey); if (!innerMostArrayTmp) { innerMostArrayTmp = []; innerMap.set(versionKey, innerMostArrayTmp); } const innerMostArray = innerMostArrayTmp; innerMostArray.push({ cookieID: device.cookieID, deviceToken: device.deviceToken, }); } return byPlatform; } type APNsNotifInputData = { +notifTexts: ResolvedNotifTexts, +newRawMessageInfos: RawMessageInfo[], +threadID: string, +collapseKey: ?string, +badgeOnly: boolean, +unreadCount: number, +platformDetails: PlatformDetails, }; const apnsNotifInputDataValidator = tShape({ notifTexts: resolvedNotifTextsValidator, newRawMessageInfos: t.list(rawMessageInfoValidator), threadID: tID, collapseKey: t.maybe(t.String), badgeOnly: t.Boolean, unreadCount: t.Number, platformDetails: tPlatformDetails, }); async function prepareAPNsNotification( inputData: APNsNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, apnsNotifInputDataValidator, inputData, ); const { notifTexts, newRawMessageInfos, threadID, collapseKey, badgeOnly, unreadCount, platformDetails, } = convertedData; const canDecryptNonCollapsibleTextIOSNotifs = platformDetails.codeVersion && platformDetails.codeVersion > 222; const isNonCollapsibleTextNotification = newRawMessageInfos.every( newRawMessageInfo => newRawMessageInfo.type === messageTypes.TEXT, ) && !collapseKey; const canDecryptAllIOSNotifs = platformDetails.codeVersion && platformDetails.codeVersion >= 267; const canDecryptIOSNotif = platformDetails.platform === 'ios' && (canDecryptAllIOSNotifs || (isNonCollapsibleTextNotification && canDecryptNonCollapsibleTextIOSNotifs)); const canDecryptMacOSNotifs = platformDetails.platform === 'macos' && hasMinCodeVersion(platformDetails, { web: 47, majorDesktop: 9, }); const shouldBeEncrypted = canDecryptIOSNotif || canDecryptMacOSNotifs; const uniqueID = uuidv4(); const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic(platformDetails); const { merged, ...rest } = notifTexts; // We don't include alert's body on macos because we // handle displaying the notification ourselves and // we don't want macOS to display it automatically. if (!badgeOnly && platformDetails.platform !== 'macos') { notification.body = merged; notification.sound = 'default'; } notification.payload = { ...notification.payload, ...rest, }; notification.badge = unreadCount; notification.threadId = threadID; notification.id = uniqueID; notification.pushType = 'alert'; notification.payload.id = uniqueID; notification.payload.threadID = threadID; if (platformDetails.codeVersion && platformDetails.codeVersion > 198) { notification.mutableContent = true; } if (collapseKey && (canDecryptAllIOSNotifs || canDecryptMacOSNotifs)) { notification.payload.collapseID = collapseKey; } else if (collapseKey) { notification.collapseId = collapseKey; } const messageInfos = JSON.stringify(newRawMessageInfos); // We make a copy before checking notification's length, because calling // length compiles the notification and makes it immutable. Further // changes to its properties won't be reflected in the final plaintext // data that is sent. const copyWithMessageInfos = _cloneDeep(notification); copyWithMessageInfos.payload = { ...copyWithMessageInfos.payload, messageInfos, }; const notificationSizeValidator = (notif: apn.Notification) => notif.length() <= apnMaxNotificationPayloadByteSize; if (!shouldBeEncrypted) { const notificationToSend = notificationSizeValidator( _cloneDeep(copyWithMessageInfos), ) ? copyWithMessageInfos : notification; return devices.map(({ deviceToken }) => ({ notification: notificationToSend, deviceToken, })); } const notifsWithMessageInfos = await prepareEncryptedAPNsNotifications( devices, copyWithMessageInfos, platformDetails.codeVersion, notificationSizeValidator, ); const devicesWithExcessiveSize = notifsWithMessageInfos .filter(({ payloadSizeExceeded }) => payloadSizeExceeded) .map(({ deviceToken, cookieID }) => ({ deviceToken, cookieID })); if (devicesWithExcessiveSize.length === 0) { return notifsWithMessageInfos.map( ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }) => ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }), ); } const notifsWithoutMessageInfos = await prepareEncryptedAPNsNotifications( devicesWithExcessiveSize, notification, platformDetails.codeVersion, ); const targetedNotifsWithMessageInfos = notifsWithMessageInfos .filter(({ payloadSizeExceeded }) => !payloadSizeExceeded) .map( ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }) => ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }), ); const targetedNotifsWithoutMessageInfos = notifsWithoutMessageInfos.map( ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }) => ({ notification: notif, deviceToken, encryptedPayloadHash, encryptionOrder, }), ); return [ ...targetedNotifsWithMessageInfos, ...targetedNotifsWithoutMessageInfos, ]; } type AndroidNotifInputData = { ...APNsNotifInputData, +dbID: string, }; const androidNotifInputDataValidator = tShape({ ...apnsNotifInputDataValidator.meta.props, dbID: t.String, }); async function prepareAndroidNotification( inputData: AndroidNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, androidNotifInputDataValidator, inputData, ); const { notifTexts, newRawMessageInfos, threadID, collapseKey, badgeOnly, unreadCount, platformDetails: { codeVersion }, dbID, } = convertedData; const canDecryptNonCollapsibleTextNotifs = codeVersion && codeVersion > 228; const isNonCollapsibleTextNotif = newRawMessageInfos.every( newRawMessageInfo => newRawMessageInfo.type === messageTypes.TEXT, ) && !collapseKey; const canDecryptAllNotifTypes = codeVersion && codeVersion >= 267; const shouldBeEncrypted = canDecryptAllNotifTypes || (canDecryptNonCollapsibleTextNotifs && isNonCollapsibleTextNotif); const { merged, ...rest } = notifTexts; const notification = { data: { badge: unreadCount.toString(), ...rest, threadID, }, }; let notifID; if (collapseKey && canDecryptAllNotifTypes) { notifID = dbID; notification.data = { ...notification.data, collapseKey, }; } else if (collapseKey) { notifID = collapseKey; } else { notifID = dbID; } // The reason we only include `badgeOnly` for newer clients is because older // clients don't know how to parse it. The reason we only include `id` for // newer clients is that if the older clients see that field, they assume // the notif has a full payload, and then crash when trying to parse it. // By skipping `id` we allow old clients to still handle in-app notifs and // badge updating. if (!badgeOnly || (codeVersion && codeVersion >= 69)) { notification.data = { ...notification.data, id: notifID, badgeOnly: badgeOnly ? '1' : '0', }; } const messageInfos = JSON.stringify(newRawMessageInfos); const copyWithMessageInfos = { ...notification, data: { ...notification.data, messageInfos }, }; if (!shouldBeEncrypted) { const notificationToSend = Buffer.byteLength(JSON.stringify(copyWithMessageInfos)) <= fcmMaxNotificationPayloadByteSize ? copyWithMessageInfos : notification; return devices.map(({ deviceToken }) => ({ notification: notificationToSend, deviceToken, })); } const notificationsSizeValidator = (notif: AndroidNotification) => { const serializedNotif = JSON.stringify(notif); return ( !serializedNotif || Buffer.byteLength(serializedNotif) <= fcmMaxNotificationPayloadByteSize ); }; const notifsWithMessageInfos = await prepareEncryptedAndroidNotifications( devices, copyWithMessageInfos, notificationsSizeValidator, ); const devicesWithExcessiveSize = notifsWithMessageInfos .filter(({ payloadSizeExceeded }) => payloadSizeExceeded) .map(({ cookieID, deviceToken }) => ({ cookieID, deviceToken })); if (devicesWithExcessiveSize.length === 0) { return notifsWithMessageInfos.map( ({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, }), ); } const notifsWithoutMessageInfos = await prepareEncryptedAndroidNotifications( devicesWithExcessiveSize, notification, ); const targetedNotifsWithMessageInfos = notifsWithMessageInfos .filter(({ payloadSizeExceeded }) => !payloadSizeExceeded) .map(({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, })); const targetedNotifsWithoutMessageInfos = notifsWithoutMessageInfos.map( ({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, }), ); return [ ...targetedNotifsWithMessageInfos, ...targetedNotifsWithoutMessageInfos, ]; } type WebNotifInputData = { +notifTexts: ResolvedNotifTexts, +threadID: string, +unreadCount: number, +platformDetails: PlatformDetails, }; const webNotifInputDataValidator = tShape({ notifTexts: resolvedNotifTextsValidator, threadID: tID, unreadCount: t.Number, platformDetails: tPlatformDetails, }); async function prepareWebNotification( inputData: WebNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, webNotifInputDataValidator, inputData, ); const { notifTexts, threadID, unreadCount } = convertedData; const id = uuidv4(); const { merged, ...rest } = notifTexts; const notification = { ...rest, unreadCount, id, threadID, }; const shouldBeEncrypted = hasMinCodeVersion(convertedData.platformDetails, { web: 43, }); if (!shouldBeEncrypted) { return devices.map(({ deviceToken }) => ({ deviceToken, notification })); } return prepareEncryptedWebNotifications(devices, notification); } type WNSNotifInputData = { +notifTexts: ResolvedNotifTexts, +threadID: string, +unreadCount: number, +platformDetails: PlatformDetails, }; const wnsNotifInputDataValidator = tShape({ notifTexts: resolvedNotifTextsValidator, threadID: tID, unreadCount: t.Number, platformDetails: tPlatformDetails, }); async function prepareWNSNotification( devices: $ReadOnlyArray, inputData: WNSNotifInputData, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, wnsNotifInputDataValidator, inputData, ); const { notifTexts, threadID, unreadCount } = convertedData; const { merged, ...rest } = notifTexts; const notification = { ...rest, unreadCount, threadID, }; if ( Buffer.byteLength(JSON.stringify(notification)) > wnsMaxNotificationPayloadByteSize ) { console.warn('WNS notification exceeds size limit'); } const shouldBeEncrypted = hasMinCodeVersion(inputData.platformDetails, { majorDesktop: 10, }); if (!shouldBeEncrypted) { return devices.map(({ deviceToken }) => ({ deviceToken, notification, })); } return await prepareEncryptedWNSNotifications(devices, notification); } type NotificationInfo = | { +source: 'new_message', +dbID: string, +userID: string, +threadID: string, +messageID: string, +collapseKey: ?string, +codeVersion: number, +stateVersion: number, } | { +source: 'mark_as_unread' | 'mark_as_read' | 'activity_update', +dbID: string, +userID: string, +codeVersion: number, +stateVersion: number, }; type APNsDelivery = { +source: $PropertyType, +deviceType: 'ios' | 'macos', +iosID: string, +deviceTokens: $ReadOnlyArray, +codeVersion: number, +stateVersion: number, +errors?: $ReadOnlyArray, +encryptedPayloadHashes?: $ReadOnlyArray, +deviceTokensToPayloadHash?: { +[deviceToken: string]: string, }, }; type APNsResult = { info: NotificationInfo, delivery: APNsDelivery, invalidTokens?: $ReadOnlyArray, }; async function sendAPNsNotification( platform: 'ios' | 'macos', targetedNotifications: $ReadOnlyArray, notificationInfo: NotificationInfo, ): Promise { const { source, codeVersion, stateVersion } = notificationInfo; const response = await apnPush({ targetedNotifications, platformDetails: { platform, codeVersion }, }); invariant( new Set(targetedNotifications.map(({ notification }) => notification.id)) .size === 1, 'Encrypted versions of the same notification must share id value', ); const iosID = targetedNotifications[0].notification.id; const deviceTokens = targetedNotifications.map( ({ deviceToken }) => deviceToken, ); let delivery: APNsDelivery = { source, deviceType: platform, iosID, deviceTokens, codeVersion, stateVersion, }; if (response.errors) { delivery = { ...delivery, errors: response.errors, }; } const deviceTokensToPayloadHash: { [string]: string } = {}; for (const targetedNotification of targetedNotifications) { if (targetedNotification.encryptedPayloadHash) { deviceTokensToPayloadHash[targetedNotification.deviceToken] = targetedNotification.encryptedPayloadHash; } } if (Object.keys(deviceTokensToPayloadHash).length !== 0) { delivery = { ...delivery, deviceTokensToPayloadHash, }; } const result: APNsResult = { info: notificationInfo, delivery, }; if (response.invalidTokens) { result.invalidTokens = response.invalidTokens; } return result; } type PushResult = AndroidResult | APNsResult | WebResult | WNSResult; type PushDelivery = AndroidDelivery | APNsDelivery | WebDelivery | WNSDelivery; type AndroidDelivery = { source: $PropertyType, deviceType: 'android', androidIDs: $ReadOnlyArray, deviceTokens: $ReadOnlyArray, codeVersion: number, stateVersion: number, errors?: $ReadOnlyArray, }; type AndroidResult = { info: NotificationInfo, delivery: AndroidDelivery, invalidTokens?: $ReadOnlyArray, }; async function sendAndroidNotification( targetedNotifications: $ReadOnlyArray, notificationInfo: NotificationInfo, ): Promise { const collapseKey = notificationInfo.collapseKey ? notificationInfo.collapseKey : null; // for Flow... const { source, codeVersion, stateVersion } = notificationInfo; const response = await fcmPush({ targetedNotifications, collapseKey, codeVersion, }); const deviceTokens = targetedNotifications.map( ({ deviceToken }) => deviceToken, ); const androidIDs = response.fcmIDs ? response.fcmIDs : []; const delivery: AndroidDelivery = { source, deviceType: 'android', androidIDs, deviceTokens, codeVersion, stateVersion, }; if (response.errors) { delivery.errors = response.errors; } const result: AndroidResult = { info: notificationInfo, delivery, }; if (response.invalidTokens) { result.invalidTokens = response.invalidTokens; } return result; } type WebDelivery = { +source: $PropertyType, +deviceType: 'web', +deviceTokens: $ReadOnlyArray, +codeVersion?: number, +stateVersion: number, +errors?: $ReadOnlyArray, }; type WebResult = { +info: NotificationInfo, +delivery: WebDelivery, +invalidTokens?: $ReadOnlyArray, }; async function sendWebNotifications( targetedNotifications: $ReadOnlyArray, notificationInfo: NotificationInfo, ): Promise { const { source, codeVersion, stateVersion } = notificationInfo; const response = await webPush(targetedNotifications); const deviceTokens = targetedNotifications.map( ({ deviceToken }) => deviceToken, ); const delivery: WebDelivery = { source, deviceType: 'web', deviceTokens, codeVersion, errors: response.errors, stateVersion, }; const result: WebResult = { info: notificationInfo, delivery, invalidTokens: response.invalidTokens, }; return result; } type WNSDelivery = { +source: $PropertyType, +deviceType: 'windows', +wnsIDs: $ReadOnlyArray, +deviceTokens: $ReadOnlyArray, +codeVersion?: number, +stateVersion: number, +errors?: $ReadOnlyArray, }; type WNSResult = { +info: NotificationInfo, +delivery: WNSDelivery, +invalidTokens?: $ReadOnlyArray, }; async function sendWNSNotification( targetedNotifications: $ReadOnlyArray, notificationInfo: NotificationInfo, ): Promise { const { source, codeVersion, stateVersion } = notificationInfo; const response = await wnsPush(targetedNotifications); const deviceTokens = targetedNotifications.map( ({ deviceToken }) => deviceToken, ); const wnsIDs = response.wnsIDs ?? []; const delivery: WNSDelivery = { source, deviceType: 'windows', wnsIDs, deviceTokens, codeVersion, errors: response.errors, stateVersion, }; const result: WNSResult = { info: notificationInfo, delivery, invalidTokens: response.invalidTokens, }; return result; } type InvalidToken = { +userID: string, +tokens: $ReadOnlyArray, }; async function removeInvalidTokens( invalidTokens: $ReadOnlyArray, ): Promise { const sqlTuples = invalidTokens.map( invalidTokenUser => SQL`( user = ${invalidTokenUser.userID} AND device_token IN (${invalidTokenUser.tokens}) )`, ); const sqlCondition = mergeOrConditions(sqlTuples); const selectQuery = SQL` SELECT id, user, device_token FROM cookies WHERE `; selectQuery.append(sqlCondition); const [result] = await dbQuery(selectQuery); const userCookiePairsToInvalidDeviceTokens = new Map>(); for (const row of result) { const userCookiePair = `${row.user}|${row.id}`; const existing = userCookiePairsToInvalidDeviceTokens.get(userCookiePair); if (existing) { existing.add(row.device_token); } else { userCookiePairsToInvalidDeviceTokens.set( userCookiePair, new Set([row.device_token]), ); } } const time = Date.now(); const promises: Array> = []; for (const entry of userCookiePairsToInvalidDeviceTokens) { const [userCookiePair, deviceTokens] = entry; const [userID, cookieID] = userCookiePair.split('|'); const updateDatas = [...deviceTokens].map(deviceToken => ({ type: updateTypes.BAD_DEVICE_TOKEN, userID, time, deviceToken, targetCookie: cookieID, })); promises.push(createUpdates(updateDatas)); } const updateQuery = SQL` UPDATE cookies SET device_token = NULL WHERE `; updateQuery.append(sqlCondition); promises.push(dbQuery(updateQuery)); await Promise.all(promises); } async function updateBadgeCount( viewer: Viewer, source: 'mark_as_unread' | 'mark_as_read' | 'activity_update', ) { const { userID } = viewer; const deviceTokenQuery = SQL` SELECT platform, device_token, versions, id FROM cookies WHERE user = ${userID} AND device_token IS NOT NULL `; if (viewer.data.cookieID) { deviceTokenQuery.append(SQL`AND id != ${viewer.cookieID} `); } const [unreadCounts, [deviceTokenResult], [dbID]] = await Promise.all([ getUnreadCounts([userID]), dbQuery(deviceTokenQuery), createIDs('notifications', 1), ]); const unreadCount = unreadCounts[userID]; const devices = deviceTokenResult.map(row => { const versions = JSON.parse(row.versions); return { platform: row.platform, cookieID: row.id, deviceToken: row.device_token, codeVersion: versions?.codeVersion, stateVersion: versions?.stateVersion, }; }); const byPlatform = getDevicesByPlatform(devices); const preparePromises: Array>> = []; const iosVersionsToTokens = byPlatform.get('ios'); if (iosVersionsToTokens) { for (const [versionKey, deviceInfos] of iosVersionsToTokens) { const { codeVersion, stateVersion } = stringToVersionKey(versionKey); const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic({ platform: 'ios', codeVersion, stateVersion, }); notification.badge = unreadCount; notification.pushType = 'alert'; const preparePromise: Promise = (async () => { let targetedNotifications: $ReadOnlyArray; if (codeVersion > 222) { const notificationsArray = await prepareEncryptedAPNsNotifications( deviceInfos, notification, codeVersion, ); targetedNotifications = notificationsArray.map( ({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, }), ); } else { targetedNotifications = deviceInfos.map(({ deviceToken }) => ({ notification, deviceToken, })); } return targetedNotifications.map(targetedNotification => ({ notification: targetedNotification, platform: 'ios', notificationInfo: { source, dbID, userID, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const androidVersionsToTokens = byPlatform.get('android'); if (androidVersionsToTokens) { for (const [versionKey, deviceInfos] of androidVersionsToTokens) { const { codeVersion, stateVersion } = stringToVersionKey(versionKey); const notificationData = codeVersion < 69 ? { badge: unreadCount.toString() } : { badge: unreadCount.toString(), badgeOnly: '1' }; const notification = { data: notificationData }; const preparePromise: Promise = (async () => { let targetedNotifications: $ReadOnlyArray; if (codeVersion > 222) { const notificationsArray = await prepareEncryptedAndroidNotifications( deviceInfos, notification, ); targetedNotifications = notificationsArray.map( ({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, }), ); } else { targetedNotifications = deviceInfos.map(({ deviceToken }) => ({ deviceToken, notification, })); } return targetedNotifications.map(targetedNotification => ({ notification: targetedNotification, platform: 'android', notificationInfo: { source, dbID, userID, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const macosVersionsToTokens = byPlatform.get('macos'); if (macosVersionsToTokens) { for (const [versionKey, deviceInfos] of macosVersionsToTokens) { const { codeVersion, stateVersion, majorDesktopVersion } = stringToVersionKey(versionKey); const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic({ platform: 'macos', codeVersion, stateVersion, majorDesktopVersion, }); notification.badge = unreadCount; notification.pushType = 'alert'; const preparePromise: Promise = (async () => { const shouldBeEncrypted = hasMinCodeVersion(viewer.platformDetails, { web: 47, majorDesktop: 9, }); let targetedNotifications: $ReadOnlyArray; if (shouldBeEncrypted) { const notificationsArray = await prepareEncryptedAPNsNotifications( deviceInfos, notification, codeVersion, ); targetedNotifications = notificationsArray.map( ({ notification: notif, deviceToken, encryptionOrder }) => ({ notification: notif, deviceToken, encryptionOrder, }), ); } else { targetedNotifications = deviceInfos.map(({ deviceToken }) => ({ deviceToken, notification, })); } return targetedNotifications.map(targetedNotification => ({ notification: targetedNotification, platform: 'macos', notificationInfo: { source, dbID, userID, codeVersion, stateVersion, }, })); })(); preparePromises.push(preparePromise); } } const prepareResults = await Promise.all(preparePromises); const flattenedPrepareResults = prepareResults.filter(Boolean).flat(); const deliveryResults = await deliverPushNotifsInEncryptionOrder( flattenedPrepareResults, ); await saveNotifResults(deliveryResults, new Map(), false); } export { sendPushNotifs, sendRescindNotifs, updateBadgeCount }; diff --git a/keyserver/src/responders/handlers.js b/keyserver/src/responders/handlers.js index 2bbb67d27..72606fc28 100644 --- a/keyserver/src/responders/handlers.js +++ b/keyserver/src/responders/handlers.js @@ -1,237 +1,241 @@ // @flow import type { $Response, $Request } from 'express'; import type { TType } from 'tcomb'; import { ServerError } from 'lib/utils/errors.js'; import { assertWithValidator, tPlatformDetails, } from 'lib/utils/validation-utils.js'; import { getMessageForException } from './utils.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import type { PolicyType } from '../lib/facts/policies.js'; import { fetchViewerForJSONRequest, addCookieToJSONResponse, addCookieToHomeResponse, createNewAnonymousCookie, setCookiePlatformDetails, } from '../session/cookies.js'; import type { Viewer } from '../session/viewer.js'; import { getAppURLFactsFromRequestURL } from '../utils/urls.js'; import { policiesValidator, validateInput, validateOutput, } from '../utils/validation-utils.js'; type InnerJSONResponder = { responder: (viewer: Viewer, input: any) => Promise<*>, requiredPolicies: $ReadOnlyArray, }; export opaque type JSONResponder: InnerJSONResponder = InnerJSONResponder; function createJSONResponder( responder: (Viewer, input: I) => Promise, inputValidator: TType, outputValidator: TType, requiredPolicies: $ReadOnlyArray, ): JSONResponder { return { responder: async (viewer, input) => { const request = await validateInput(viewer, inputValidator, input); const result = await responder(viewer, request); - return validateOutput(viewer.platformDetails, outputValidator, result); + return await validateOutput( + viewer.platformDetails, + outputValidator, + result, + ); }, requiredPolicies, }; } export type DownloadResponder = ( viewer: Viewer, req: $Request, res: $Response, ) => Promise; export type HTTPGetResponder = DownloadResponder; export type HTMLResponder = (req: $Request, res: $Response) => Promise; function jsonHandler( responder: JSONResponder, expectCookieInvalidation: boolean, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { if (!req.body || typeof req.body !== 'object') { throw new ServerError('invalid_parameters'); } const { input, platformDetails } = req.body; viewer = await fetchViewerForJSONRequest(req); const promises = [policiesValidator(viewer, responder.requiredPolicies)]; if (platformDetails) { if (!tPlatformDetails.is(platformDetails)) { throw new ServerError('invalid_platform_details'); } promises.push( setCookiePlatformDetails( viewer, assertWithValidator(platformDetails, tPlatformDetails), ), ); } await Promise.all(promises); const responderResult = await responder.responder(viewer, input); if (res.headersSent) { return; } const result = { ...responderResult }; addCookieToJSONResponse(viewer, res, result, expectCookieInvalidation); res.json({ success: true, ...result }); } catch (e) { await handleException(e, res, viewer, expectCookieInvalidation); } }; } function httpGetHandler( responder: HTTPGetResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { viewer = await fetchViewerForJSONRequest(req); await responder(viewer, req, res); } catch (e) { await handleException(e, res, viewer); } }; } function downloadHandler( responder: DownloadResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { try { const viewer = await fetchViewerForJSONRequest(req); await responder(viewer, req, res); } catch (e) { // Passing viewer in only makes sense if we want to handle failures as // JSON. We don't, and presume all download handlers avoid ServerError. await handleException(e, res); } }; } async function handleException( error: Error, res: $Response, viewer?: ?Viewer, expectCookieInvalidation?: boolean, ) { console.warn(error); if (res.headersSent) { return; } if (!(error instanceof ServerError)) { res.status(500).send(getMessageForException(error)); return; } const result: Object = error.payload ? { error: error.message, payload: error.payload } : { error: error.message }; if (viewer) { if (error.message === 'client_version_unsupported' && viewer.loggedIn) { // If the client version is unsupported, log the user out const { platformDetails } = error; const [data] = await Promise.all([ createNewAnonymousCookie({ platformDetails, deviceToken: viewer.deviceToken, }), deleteCookie(viewer.cookieID), ]); viewer.setNewCookie(data); viewer.cookieInvalidated = true; } // This can mutate the result object addCookieToJSONResponse(viewer, res, result, !!expectCookieInvalidation); } res.json(result); } function htmlHandler( responder: HTMLResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { try { addCookieToHomeResponse( req, res, getAppURLFactsFromRequestURL(req.originalUrl), ); res.type('html'); await responder(req, res); } catch (e) { console.warn(e); if (!res.headersSent) { res.status(500).send(getMessageForException(e)); } } }; } type MulterFile = { fieldname: string, originalname: string, encoding: string, mimetype: string, buffer: Buffer, size: number, }; export type MulterRequest = $Request & { files?: $ReadOnlyArray, ... }; type UploadResponder = (viewer: Viewer, req: MulterRequest) => Promise; function uploadHandler( responder: UploadResponder, ): (req: $Request, res: $Response) => Promise { return async (req: $Request, res: $Response) => { let viewer; try { if (!req.body || typeof req.body !== 'object') { throw new ServerError('invalid_parameters'); } viewer = await fetchViewerForJSONRequest(req); const responderResult = await responder( viewer, ((req: any): MulterRequest), ); if (res.headersSent) { return; } const result = { ...responderResult }; addCookieToJSONResponse(viewer, res, result, false); res.json({ success: true, ...result }); } catch (e) { await handleException(e, res, viewer); } }; } export { createJSONResponder, jsonHandler, httpGetHandler, downloadHandler, htmlHandler, uploadHandler, }; diff --git a/keyserver/src/responders/redux-state-responders.js b/keyserver/src/responders/redux-state-responders.js index c34c1d089..a6928b80f 100644 --- a/keyserver/src/responders/redux-state-responders.js +++ b/keyserver/src/responders/redux-state-responders.js @@ -1,380 +1,383 @@ // @flow import _keyBy from 'lodash/fp/keyBy.js'; import t, { type TInterface } from 'tcomb'; import { baseLegalPolicies } from 'lib/facts/policies.js'; import { mixedRawThreadInfoValidator } from 'lib/permissions/minimally-encoded-raw-thread-info-validators.js'; import { daysToEntriesFromEntryInfos } from 'lib/reducers/entry-reducer.js'; import { freshMessageStore } from 'lib/reducers/message-reducer.js'; import { mostRecentlyReadThread } from 'lib/selectors/thread-selectors.js'; import { mostRecentMessageTimestamp } from 'lib/shared/message-utils.js'; import { threadHasPermission, threadIsPending, parsePendingThreadID, createPendingThread, } from 'lib/shared/thread-utils.js'; import { canUseDatabaseOnWeb } from 'lib/shared/web-database.js'; import { entryStoreValidator } from 'lib/types/entry-types.js'; import { defaultCalendarFilters } from 'lib/types/filter-types.js'; import { inviteLinksStoreValidator, type CommunityLinks, type InviteLinkWithHolder, } from 'lib/types/link-types.js'; import { defaultNumberPerThread, messageStoreValidator, type MessageStore, } from 'lib/types/message-types.js'; import { webNavInfoValidator } from 'lib/types/nav-types.js'; import type { WebInitialKeyserverInfo, ServerWebInitialReduxStateResponse, } from 'lib/types/redux-types.js'; import { threadPermissions } from 'lib/types/thread-permission-types.js'; import { threadTypes } from 'lib/types/thread-types-enum.js'; import { type ThreadStore } from 'lib/types/thread-types.js'; import { currentUserInfoValidator, userInfosValidator, type GlobalAccountUserInfo, } from 'lib/types/user-types.js'; import { currentDateInTimeZone } from 'lib/utils/date-utils.js'; import { ServerError } from 'lib/utils/errors.js'; import { promiseAll } from 'lib/utils/promises.js'; import { urlInfoValidator } from 'lib/utils/url-utils.js'; -import { tShape, ashoatKeyserverID, tID } from 'lib/utils/validation-utils.js'; +import { tShape, tID } from 'lib/utils/validation-utils.js'; import type { InitialReduxStateRequest, ExcludedData, } from 'web/types/redux-types.js'; import { navInfoFromURL } from 'web/url-utils.js'; import { fetchEntryInfos } from '../fetchers/entry-fetchers.js'; import { fetchPrimaryInviteLinks } from '../fetchers/link-fetchers.js'; import { fetchMessageInfos } from '../fetchers/message-fetchers.js'; import { hasAnyNotAcknowledgedPolicies } from '../fetchers/policy-acknowledgment-fetchers.js'; import { fetchThreadInfos } from '../fetchers/thread-fetchers.js'; import { fetchCurrentUserInfo, fetchKnownUserInfos, fetchUserInfos, } from '../fetchers/user-fetchers.js'; import { getWebPushConfig } from '../push/providers.js'; import { setNewSession } from '../session/cookies.js'; import { Viewer } from '../session/viewer.js'; +import { thisKeyserverID } from '../user/identity.js'; const excludedDataValidator: TInterface = tShape({ threadStore: t.maybe(t.Bool), }); export const initialReduxStateRequestValidator: TInterface = tShape({ urlInfo: urlInfoValidator, excludedData: excludedDataValidator, clientUpdatesCurrentAsOf: t.Number, }); const initialKeyserverInfoValidator = tShape({ sessionID: t.maybe(t.String), updatesCurrentAsOf: t.Number, }); export const threadStoreValidator: TInterface = tShape({ threadInfos: t.dict(tID, mixedRawThreadInfoValidator), }); export const initialReduxStateValidator: TInterface = tShape({ navInfo: webNavInfoValidator, currentUserInfo: currentUserInfoValidator, entryStore: entryStoreValidator, threadStore: threadStoreValidator, userInfos: userInfosValidator, messageStore: messageStoreValidator, pushApiPublicKey: t.maybe(t.String), commServicesAccessToken: t.Nil, inviteLinksStore: inviteLinksStoreValidator, keyserverInfo: initialKeyserverInfoValidator, }); async function getInitialReduxStateResponder( viewer: Viewer, request: InitialReduxStateRequest, ): Promise { const { urlInfo, excludedData, clientUpdatesCurrentAsOf } = request; const useDatabase = viewer.loggedIn && canUseDatabaseOnWeb(viewer.userID); const hasNotAcknowledgedPoliciesPromise = hasAnyNotAcknowledgedPolicies( viewer.id, baseLegalPolicies, ); const initialNavInfoPromise = (async () => { try { let backupInfo = { now: currentDateInTimeZone(viewer.timeZone), }; // Some user ids in selectedUserList might not exist in the userInfos // (e.g. they were included in the results of the user search endpoint) // Because of that we keep their userInfos inside the navInfo. if (urlInfo.selectedUserList) { const fetchedUserInfos = await fetchUserInfos(urlInfo.selectedUserList); const userInfos: { [string]: GlobalAccountUserInfo } = {}; for (const userID in fetchedUserInfos) { const userInfo = fetchedUserInfos[userID]; if (userInfo.username) { userInfos[userID] = { ...userInfo, username: userInfo.username, }; } } backupInfo = { userInfos, ...backupInfo }; } return navInfoFromURL(urlInfo, backupInfo); } catch (e) { throw new ServerError(e.message); } })(); const calendarQueryPromise = (async () => { const initialNavInfo = await initialNavInfoPromise; return { startDate: initialNavInfo.startDate, endDate: initialNavInfo.endDate, filters: defaultCalendarFilters, }; })(); const messageSelectionCriteria = { joinedThreads: true }; const serverUpdatesCurrentAsOf = useDatabase && clientUpdatesCurrentAsOf ? clientUpdatesCurrentAsOf : Date.now(); const threadInfoPromise = fetchThreadInfos(viewer); const messageInfoPromise = fetchMessageInfos( viewer, messageSelectionCriteria, defaultNumberPerThread, ); const entryInfoPromise = (async () => { const calendarQuery = await calendarQueryPromise; return await fetchEntryInfos(viewer, [calendarQuery]); })(); const currentUserInfoPromise = fetchCurrentUserInfo(viewer); const userInfoPromise = fetchKnownUserInfos(viewer); const sessionIDPromise = (async () => { const calendarQuery = await calendarQueryPromise; if (viewer.loggedIn) { await setNewSession(viewer, calendarQuery, serverUpdatesCurrentAsOf); } return viewer.sessionID; })(); const threadStorePromise = (async () => { if (excludedData.threadStore && useDatabase) { return { threadInfos: {} }; } const [{ threadInfos }, hasNotAcknowledgedPolicies] = await Promise.all([ threadInfoPromise, hasNotAcknowledgedPoliciesPromise, ]); return { threadInfos: hasNotAcknowledgedPolicies ? {} : threadInfos }; })(); const messageStorePromise: Promise = (async () => { const [ { threadInfos }, { rawMessageInfos, truncationStatuses }, hasNotAcknowledgedPolicies, + keyserverID, ] = await Promise.all([ threadInfoPromise, messageInfoPromise, hasNotAcknowledgedPoliciesPromise, + thisKeyserverID(), ]); if (hasNotAcknowledgedPolicies) { return { messages: {}, threads: {}, local: {}, - currentAsOf: { [ashoatKeyserverID]: 0 }, + currentAsOf: { [keyserverID]: 0 }, }; } const { messageStore: freshStore } = freshMessageStore( rawMessageInfos, truncationStatuses, { - [ashoatKeyserverID]: mostRecentMessageTimestamp( + [keyserverID]: mostRecentMessageTimestamp( rawMessageInfos, serverUpdatesCurrentAsOf, ), }, threadInfos, ); return freshStore; })(); const entryStorePromise = (async () => { const [{ rawEntryInfos }, hasNotAcknowledgedPolicies] = await Promise.all([ entryInfoPromise, hasNotAcknowledgedPoliciesPromise, ]); if (hasNotAcknowledgedPolicies) { return { entryInfos: {}, daysToEntries: {}, lastUserInteractionCalendar: 0, }; } return { entryInfos: _keyBy('id')(rawEntryInfos), daysToEntries: daysToEntriesFromEntryInfos(rawEntryInfos), lastUserInteractionCalendar: serverUpdatesCurrentAsOf, }; })(); const userInfosPromise = (async () => { const [userInfos, hasNotAcknowledgedPolicies] = await Promise.all([ userInfoPromise, hasNotAcknowledgedPoliciesPromise, ]); return hasNotAcknowledgedPolicies ? {} : userInfos; })(); const navInfoPromise = (async () => { const [ { threadInfos }, messageStore, currentUserInfo, userInfos, finalNavInfo, ] = await Promise.all([ threadInfoPromise, messageStorePromise, currentUserInfoPromise, userInfosPromise, initialNavInfoPromise, ]); const requestedActiveChatThreadID = finalNavInfo.activeChatThreadID; if ( requestedActiveChatThreadID && !threadIsPending(requestedActiveChatThreadID) && !threadHasPermission( threadInfos[requestedActiveChatThreadID], threadPermissions.VISIBLE, ) ) { finalNavInfo.activeChatThreadID = null; } if (!finalNavInfo.activeChatThreadID) { const mostRecentThread = mostRecentlyReadThread( messageStore, threadInfos, ); if (mostRecentThread) { finalNavInfo.activeChatThreadID = mostRecentThread; } } const curActiveChatThreadID = finalNavInfo.activeChatThreadID; if ( curActiveChatThreadID && threadIsPending(curActiveChatThreadID) && finalNavInfo.pendingThread?.id !== curActiveChatThreadID ) { const pendingThreadData = parsePendingThreadID(curActiveChatThreadID); if ( pendingThreadData && pendingThreadData.threadType !== threadTypes.SIDEBAR && currentUserInfo.id ) { const members = [...pendingThreadData.memberIDs, currentUserInfo.id] .map(id => { const userInfo = userInfos[id]; if (!userInfo || !userInfo.username) { return undefined; } const { username } = userInfo; return { id, username }; }) .filter(Boolean); const newPendingThread = createPendingThread({ viewerID: currentUserInfo.id, threadType: pendingThreadData.threadType, members, }); finalNavInfo.activeChatThreadID = newPendingThread.id; finalNavInfo.pendingThread = newPendingThread; } } return finalNavInfo; })(); const currentAsOfPromise = (async () => { const hasNotAcknowledgedPolicies = await hasNotAcknowledgedPoliciesPromise; return hasNotAcknowledgedPolicies ? 0 : serverUpdatesCurrentAsOf; })(); const pushApiPublicKeyPromise: Promise = (async () => { const pushConfig = await getWebPushConfig(); if (!pushConfig) { if (process.env.NODE_ENV !== 'development') { console.warn('keyserver/secrets/web_push_config.json should exist'); } return null; } return pushConfig.publicKey; })(); const inviteLinksStorePromise = (async () => { const primaryInviteLinks = await fetchPrimaryInviteLinks(viewer); const links: { [string]: CommunityLinks } = {}; for (const { blobHolder, ...link }: InviteLinkWithHolder of primaryInviteLinks) { if (link.primary) { links[link.communityID] = { primaryLink: link, }; } } return { links, }; })(); const keyserverInfoPromise = (async () => { const { sessionID, updatesCurrentAsOf } = await promiseAll({ sessionID: sessionIDPromise, updatesCurrentAsOf: currentAsOfPromise, }); return { sessionID, updatesCurrentAsOf, }; })(); const initialReduxState: ServerWebInitialReduxStateResponse = await promiseAll({ navInfo: navInfoPromise, currentUserInfo: currentUserInfoPromise, entryStore: entryStorePromise, threadStore: threadStorePromise, userInfos: userInfosPromise, messageStore: messageStorePromise, pushApiPublicKey: pushApiPublicKeyPromise, commServicesAccessToken: null, inviteLinksStore: inviteLinksStorePromise, keyserverInfo: keyserverInfoPromise, }); return initialReduxState; } export { getInitialReduxStateResponder }; diff --git a/keyserver/src/shared/state-sync/current-user-state-sync-spec.js b/keyserver/src/shared/state-sync/current-user-state-sync-spec.js index 813fc55ef..d90c34144 100644 --- a/keyserver/src/shared/state-sync/current-user-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/current-user-state-sync-spec.js @@ -1,38 +1,43 @@ // @flow import { currentUserStateSyncSpec as libSpec } from 'lib/shared/state-sync/current-user-state-sync-spec.js'; import type { CurrentUserInfo } from 'lib/types/user-types.js'; import { currentUserInfoValidator } from 'lib/types/user-types.js'; import { hash } from 'lib/utils/objects.js'; import type { ServerStateSyncSpec } from './state-sync-spec.js'; import { fetchCurrentUserInfo } from '../../fetchers/user-fetchers.js'; import type { Viewer } from '../../session/viewer.js'; import { validateOutput } from '../../utils/validation-utils.js'; export const currentUserStateSyncSpec: ServerStateSyncSpec< CurrentUserInfo, CurrentUserInfo, CurrentUserInfo, void, > = Object.freeze({ fetch, fetchFullSocketSyncPayload(viewer: Viewer) { return fetchCurrentUserInfo(viewer); }, async fetchServerInfosHash(viewer: Viewer) { const info = await fetch(viewer); - return getHash(info); + return await getHash(info); }, getServerInfosHash: getHash, getServerInfoHash: getHash, ...libSpec, }); function fetch(viewer: Viewer) { return fetchCurrentUserInfo(viewer); } -function getHash(currentUserInfo: CurrentUserInfo) { - return hash(validateOutput(null, currentUserInfoValidator, currentUserInfo)); +async function getHash(currentUserInfo: CurrentUserInfo) { + const output = await validateOutput( + null, + currentUserInfoValidator, + currentUserInfo, + ); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/entries-state-sync-spec.js b/keyserver/src/shared/state-sync/entries-state-sync-spec.js index 8a3436081..2a097502d 100644 --- a/keyserver/src/shared/state-sync/entries-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/entries-state-sync-spec.js @@ -1,60 +1,62 @@ // @flow import { serverEntryInfosObject } from 'lib/shared/entry-utils.js'; import { entriesStateSyncSpec as libSpec } from 'lib/shared/state-sync/entries-state-sync-spec.js'; import { type CalendarQuery, type RawEntryInfo, type RawEntryInfos, rawEntryInfoValidator, } from 'lib/types/entry-types.js'; import type { ClientEntryInconsistencyReportCreationRequest } from 'lib/types/report-types.js'; import { hash, combineUnorderedHashes, values } from 'lib/utils/objects.js'; import type { ServerStateSyncSpec } from './state-sync-spec.js'; import { fetchEntryInfos, fetchEntryInfosByID, } from '../../fetchers/entry-fetchers.js'; import type { Viewer } from '../../session/viewer.js'; import { validateOutput } from '../../utils/validation-utils.js'; export const entriesStateSyncSpec: ServerStateSyncSpec< RawEntryInfos, $ReadOnlyArray, RawEntryInfo, $ReadOnlyArray, > = Object.freeze({ fetch, async fetchFullSocketSyncPayload( viewer: Viewer, query: $ReadOnlyArray, ) { const result = await fetchEntryInfos(viewer, query); return result.rawEntryInfos; }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const info = await fetch(viewer, ids); - return getServerInfosHash(info); + return await getServerInfosHash(info); }, getServerInfosHash, getServerInfoHash, ...libSpec, }); async function fetch(viewer: Viewer, ids?: $ReadOnlySet) { if (ids) { return fetchEntryInfosByID(viewer, ids); } const query = [viewer.calendarQuery]; const entriesResult = await fetchEntryInfos(viewer, query); return serverEntryInfosObject(entriesResult.rawEntryInfos); } -function getServerInfosHash(infos: RawEntryInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: RawEntryInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: RawEntryInfo) { - return hash(validateOutput(null, rawEntryInfoValidator, info)); +async function getServerInfoHash(info: RawEntryInfo) { + const output = await validateOutput(null, rawEntryInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/state-sync-spec.js b/keyserver/src/shared/state-sync/state-sync-spec.js index af0ce2024..063d2fb2a 100644 --- a/keyserver/src/shared/state-sync/state-sync-spec.js +++ b/keyserver/src/shared/state-sync/state-sync-spec.js @@ -1,26 +1,26 @@ // @flow import type { StateSyncSpec } from 'lib/shared/state-sync/state-sync-spec.js'; import type { CalendarQuery } from 'lib/types/entry-types.js'; import type { Viewer } from '../../session/viewer.js'; export type ServerStateSyncSpec< Infos, FullSocketSyncPayload, Info, Inconsistencies, > = { +fetch: (viewer: Viewer, ids?: $ReadOnlySet) => Promise, +fetchFullSocketSyncPayload: ( viewer: Viewer, calendarQuery: $ReadOnlyArray, ) => Promise, +fetchServerInfosHash: ( viewer: Viewer, ids?: $ReadOnlySet, ) => Promise, - +getServerInfosHash: (infos: Infos) => number, - +getServerInfoHash: (info: Info) => number, + +getServerInfosHash: (infos: Infos) => Promise, + +getServerInfoHash: (info: Info) => Promise, ...StateSyncSpec, }; diff --git a/keyserver/src/shared/state-sync/threads-state-sync-spec.js b/keyserver/src/shared/state-sync/threads-state-sync-spec.js index 1ee243059..ae0faefe5 100644 --- a/keyserver/src/shared/state-sync/threads-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/threads-state-sync-spec.js @@ -1,50 +1,52 @@ // @flow import { mixedRawThreadInfoValidator } from 'lib/permissions/minimally-encoded-raw-thread-info-validators.js'; import { threadsStateSyncSpec as libSpec } from 'lib/shared/state-sync/threads-state-sync-spec.js'; import type { RawThreadInfo } from 'lib/types/minimally-encoded-thread-permissions-types.js'; import type { ClientThreadInconsistencyReportCreationRequest } from 'lib/types/report-types.js'; import { type MixedRawThreadInfos, type LegacyRawThreadInfo, } from 'lib/types/thread-types.js'; import { hash, combineUnorderedHashes, values } from 'lib/utils/objects.js'; import type { ServerStateSyncSpec } from './state-sync-spec.js'; import { fetchThreadInfos } from '../../fetchers/thread-fetchers.js'; import type { Viewer } from '../../session/viewer.js'; import { validateOutput } from '../../utils/validation-utils.js'; export const threadsStateSyncSpec: ServerStateSyncSpec< MixedRawThreadInfos, MixedRawThreadInfos, LegacyRawThreadInfo | RawThreadInfo, $ReadOnlyArray, > = Object.freeze({ fetch, async fetchFullSocketSyncPayload(viewer: Viewer) { const result = await fetchThreadInfos(viewer); return result.threadInfos; }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const infos = await fetch(viewer, ids); - return getServerInfosHash(infos); + return await getServerInfosHash(infos); }, getServerInfosHash, getServerInfoHash, ...libSpec, }); async function fetch(viewer: Viewer, ids?: $ReadOnlySet) { const filter = ids ? { threadIDs: ids } : undefined; const result = await fetchThreadInfos(viewer, filter); return result.threadInfos; } -function getServerInfosHash(infos: MixedRawThreadInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: MixedRawThreadInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: LegacyRawThreadInfo | RawThreadInfo) { - return hash(validateOutput(null, mixedRawThreadInfoValidator, info)); +async function getServerInfoHash(info: LegacyRawThreadInfo | RawThreadInfo) { + const output = await validateOutput(null, mixedRawThreadInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/users-state-sync-spec.js b/keyserver/src/shared/state-sync/users-state-sync-spec.js index 10db1efb3..0facd6cf1 100644 --- a/keyserver/src/shared/state-sync/users-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/users-state-sync-spec.js @@ -1,48 +1,50 @@ // @flow import { usersStateSyncSpec as libSpec } from 'lib/shared/state-sync/users-state-sync-spec.js'; import type { ClientUserInconsistencyReportCreationRequest } from 'lib/types/report-types.js'; import type { UserInfos, UserInfo } from 'lib/types/user-types.js'; import { userInfoValidator } from 'lib/types/user-types.js'; import { values, hash, combineUnorderedHashes } from 'lib/utils/objects.js'; import type { ServerStateSyncSpec } from './state-sync-spec.js'; import { fetchKnownUserInfos } from '../../fetchers/user-fetchers.js'; import type { Viewer } from '../../session/viewer.js'; import { validateOutput } from '../../utils/validation-utils.js'; export const usersStateSyncSpec: ServerStateSyncSpec< UserInfos, $ReadOnlyArray, UserInfo, $ReadOnlyArray, > = Object.freeze({ fetch, async fetchFullSocketSyncPayload(viewer: Viewer) { const result = await fetchKnownUserInfos(viewer); return values(result); }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const infos = await fetch(viewer, ids); - return getServerInfosHash(infos); + return await getServerInfosHash(infos); }, getServerInfosHash, getServerInfoHash, ...libSpec, }); function fetch(viewer: Viewer, ids?: $ReadOnlySet) { if (ids) { return fetchKnownUserInfos(viewer, [...ids]); } return fetchKnownUserInfos(viewer); } -function getServerInfosHash(infos: UserInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: UserInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: UserInfo) { - return hash(validateOutput(null, userInfoValidator, info)); +async function getServerInfoHash(info: UserInfo) { + const output = await validateOutput(null, userInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/socket/session-utils.js b/keyserver/src/socket/session-utils.js index 3a348f6dd..b30facbb7 100644 --- a/keyserver/src/socket/session-utils.js +++ b/keyserver/src/socket/session-utils.js @@ -1,558 +1,560 @@ // @flow import invariant from 'invariant'; import t from 'tcomb'; import type { TUnion } from 'tcomb'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { UpdateActivityResult, ActivityUpdate, } from 'lib/types/activity-types.js'; import type { IdentityKeysBlob } from 'lib/types/crypto-types.js'; import { isDeviceType } from 'lib/types/device-types.js'; import type { CalendarQuery, DeltaEntryInfosResponse, } from 'lib/types/entry-types.js'; import { reportTypes, type ThreadInconsistencyReportCreationRequest, type EntryInconsistencyReportCreationRequest, } from 'lib/types/report-types.js'; import { serverRequestTypes, type ThreadInconsistencyClientResponse, type EntryInconsistencyClientResponse, type ClientResponse, type ServerServerRequest, type ServerCheckStateServerRequest, } from 'lib/types/request-types.js'; import { sessionCheckFrequency } from 'lib/types/session-types.js'; import { signedIdentityKeysBlobValidator } from 'lib/utils/crypto-utils.js'; import { hash, values } from 'lib/utils/objects.js'; import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import { tShape, tPlatform, tPlatformDetails, } from 'lib/utils/validation-utils.js'; import { createAndPersistOlmSession } from '../creators/olm-session-creator.js'; import { saveOneTimeKeys } from '../creators/one-time-keys-creator.js'; import createReport from '../creators/report-creator.js'; import { fetchEntriesForSession } from '../fetchers/entry-fetchers.js'; import { checkIfSessionHasEnoughOneTimeKeys } from '../fetchers/key-fetchers.js'; import { activityUpdatesInputValidator } from '../responders/activity-responders.js'; import { threadInconsistencyReportValidatorShape, entryInconsistencyReportValidatorShape, } from '../responders/report-responders.js'; import { setNewSession, setCookiePlatform, setCookiePlatformDetails, setCookieSignedIdentityKeysBlob, } from '../session/cookies.js'; import type { Viewer } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { activityUpdater } from '../updaters/activity-updaters.js'; import { compareNewCalendarQuery } from '../updaters/entry-updaters.js'; import type { SessionUpdate } from '../updaters/session-updaters.js'; import { getOlmUtility } from '../utils/olm-utils.js'; const clientResponseInputValidator: TUnion = t.union([ tShape({ type: t.irreducible( 'serverRequestTypes.PLATFORM', x => x === serverRequestTypes.PLATFORM, ), platform: tPlatform, }), tShape({ ...threadInconsistencyReportValidatorShape, type: t.irreducible( 'serverRequestTypes.THREAD_INCONSISTENCY', x => x === serverRequestTypes.THREAD_INCONSISTENCY, ), }), tShape({ ...entryInconsistencyReportValidatorShape, type: t.irreducible( 'serverRequestTypes.ENTRY_INCONSISTENCY', x => x === serverRequestTypes.ENTRY_INCONSISTENCY, ), }), tShape({ type: t.irreducible( 'serverRequestTypes.PLATFORM_DETAILS', x => x === serverRequestTypes.PLATFORM_DETAILS, ), platformDetails: tPlatformDetails, }), tShape({ type: t.irreducible( 'serverRequestTypes.CHECK_STATE', x => x === serverRequestTypes.CHECK_STATE, ), hashResults: t.dict(t.String, t.Boolean), }), tShape({ type: t.irreducible( 'serverRequestTypes.INITIAL_ACTIVITY_UPDATES', x => x === serverRequestTypes.INITIAL_ACTIVITY_UPDATES, ), activityUpdates: activityUpdatesInputValidator, }), tShape({ type: t.irreducible( 'serverRequestTypes.MORE_ONE_TIME_KEYS', x => x === serverRequestTypes.MORE_ONE_TIME_KEYS, ), keys: t.list(t.String), }), tShape({ type: t.irreducible( 'serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB', x => x === serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB, ), signedIdentityKeysBlob: signedIdentityKeysBlobValidator, }), tShape({ type: t.irreducible( 'serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE', x => x === serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE, ), initialNotificationsEncryptedMessage: t.String, }), ]); type StateCheckStatus = | { status: 'state_validated' } | { status: 'state_invalid', invalidKeys: $ReadOnlyArray } | { status: 'state_check' }; type ProcessClientResponsesResult = { serverRequests: ServerServerRequest[], stateCheckStatus: ?StateCheckStatus, activityUpdateResult: ?UpdateActivityResult, }; async function processClientResponses( viewer: Viewer, clientResponses: $ReadOnlyArray, ): Promise { let viewerMissingPlatform = !viewer.platform; const { platformDetails } = viewer; let viewerMissingPlatformDetails = !platformDetails || (isDeviceType(viewer.platform) && (platformDetails.codeVersion === null || platformDetails.codeVersion === undefined || platformDetails.stateVersion === null || platformDetails.stateVersion === undefined)); const promises = []; let activityUpdates: Array = []; let stateCheckStatus = null; const clientSentPlatformDetails = clientResponses.some( response => response.type === serverRequestTypes.PLATFORM_DETAILS, ); for (const clientResponse of clientResponses) { if ( clientResponse.type === serverRequestTypes.PLATFORM && !clientSentPlatformDetails ) { promises.push(setCookiePlatform(viewer, clientResponse.platform)); viewerMissingPlatform = false; if (!isDeviceType(clientResponse.platform)) { viewerMissingPlatformDetails = false; } } else if ( clientResponse.type === serverRequestTypes.THREAD_INCONSISTENCY ) { promises.push(recordThreadInconsistency(viewer, clientResponse)); } else if (clientResponse.type === serverRequestTypes.ENTRY_INCONSISTENCY) { promises.push(recordEntryInconsistency(viewer, clientResponse)); } else if (clientResponse.type === serverRequestTypes.PLATFORM_DETAILS) { promises.push( setCookiePlatformDetails(viewer, clientResponse.platformDetails), ); viewerMissingPlatform = false; viewerMissingPlatformDetails = false; } else if ( clientResponse.type === serverRequestTypes.INITIAL_ACTIVITY_UPDATES ) { activityUpdates = [...activityUpdates, ...clientResponse.activityUpdates]; } else if (clientResponse.type === serverRequestTypes.CHECK_STATE) { const invalidKeys = []; for (const key in clientResponse.hashResults) { const result = clientResponse.hashResults[key]; if (!result) { invalidKeys.push(key); } } stateCheckStatus = invalidKeys.length > 0 ? { status: 'state_invalid', invalidKeys } : { status: 'state_validated' }; } else if (clientResponse.type === serverRequestTypes.MORE_ONE_TIME_KEYS) { invariant(clientResponse.keys, 'keys expected in client response'); ignorePromiseRejections(saveOneTimeKeys(viewer, clientResponse.keys)); } else if ( clientResponse.type === serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB ) { invariant( clientResponse.signedIdentityKeysBlob, 'signedIdentityKeysBlob expected in client response', ); const { signedIdentityKeysBlob } = clientResponse; const identityKeys: IdentityKeysBlob = JSON.parse( signedIdentityKeysBlob.payload, ); const olmUtil = getOlmUtility(); try { olmUtil.ed25519_verify( identityKeys.primaryIdentityPublicKeys.ed25519, signedIdentityKeysBlob.payload, signedIdentityKeysBlob.signature, ); ignorePromiseRejections( setCookieSignedIdentityKeysBlob( viewer.cookieID, signedIdentityKeysBlob, ), ); } catch (e) { continue; } } else if ( clientResponse.type === serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE ) { invariant( t.String.is(clientResponse.initialNotificationsEncryptedMessage), 'initialNotificationsEncryptedMessage expected in client response', ); const { initialNotificationsEncryptedMessage } = clientResponse; try { await createAndPersistOlmSession( initialNotificationsEncryptedMessage, 'notifications', viewer.cookieID, ); } catch (e) { continue; } } } const activityUpdatePromise: Promise = (async () => { if (activityUpdates.length === 0) { return undefined; } return await activityUpdater(viewer, { updates: activityUpdates }); })(); const serverRequests: Array = []; const checkOneTimeKeysPromise = (async () => { if (!viewer.loggedIn) { return; } const enoughOneTimeKeys = await checkIfSessionHasEnoughOneTimeKeys( viewer.session, ); if (!enoughOneTimeKeys) { serverRequests.push({ type: serverRequestTypes.MORE_ONE_TIME_KEYS }); } })(); const { activityUpdateResult } = await promiseAll({ all: Promise.all(promises), activityUpdateResult: activityUpdatePromise, checkOneTimeKeysPromise, }); if ( !stateCheckStatus && viewer.loggedIn && viewer.sessionLastValidated + sessionCheckFrequency < Date.now() ) { stateCheckStatus = { status: 'state_check' }; } if (viewerMissingPlatform) { serverRequests.push({ type: serverRequestTypes.PLATFORM }); } if (viewerMissingPlatformDetails) { serverRequests.push({ type: serverRequestTypes.PLATFORM_DETAILS }); } return { serverRequests, stateCheckStatus, activityUpdateResult }; } async function recordThreadInconsistency( viewer: Viewer, response: ThreadInconsistencyClientResponse, ): Promise { const { type, ...rest } = response; const reportCreationRequest = ({ ...rest, type: reportTypes.THREAD_INCONSISTENCY, }: ThreadInconsistencyReportCreationRequest); await createReport(viewer, reportCreationRequest); } async function recordEntryInconsistency( viewer: Viewer, response: EntryInconsistencyClientResponse, ): Promise { const { type, ...rest } = response; const reportCreationRequest = ({ ...rest, type: reportTypes.ENTRY_INCONSISTENCY, }: EntryInconsistencyReportCreationRequest); await createReport(viewer, reportCreationRequest); } type SessionInitializationResult = | { sessionContinued: false } | { sessionContinued: true, deltaEntryInfoResult: DeltaEntryInfosResponse, sessionUpdate: SessionUpdate, }; async function initializeSession( viewer: Viewer, calendarQuery: CalendarQuery, oldLastUpdate: number, ): Promise { if (!viewer.loggedIn) { return { sessionContinued: false }; } if (!viewer.hasSessionInfo) { // If the viewer has no session info but is logged in, that is indicative // of an expired / invalidated session and we should generate a new one await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } if (oldLastUpdate < viewer.sessionLastUpdated) { // If the client has an older last_update than the server is tracking for // that client, then the client either had some issue persisting its store, // or the user restored the client app from a backup. Either way, we should // invalidate the existing session, since the server has assumed that the // checkpoint is further along than it is on the client, and might not still // have all of the updates necessary to do an incremental update await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } let comparisonResult = null; try { comparisonResult = compareNewCalendarQuery(viewer, calendarQuery); } catch (e) { if (e.message !== 'unknown_error') { throw e; } } if (comparisonResult) { const { difference, oldCalendarQuery } = comparisonResult; const sessionUpdate = { ...comparisonResult.sessionUpdate, lastUpdate: oldLastUpdate, }; const deltaEntryInfoResult = await fetchEntriesForSession( viewer, difference, oldCalendarQuery, ); return { sessionContinued: true, deltaEntryInfoResult, sessionUpdate }; } else { await setNewSession(viewer, calendarQuery, oldLastUpdate); return { sessionContinued: false }; } } type StateCheckResult = { sessionUpdate?: SessionUpdate, checkStateRequest?: ServerCheckStateServerRequest, }; async function checkState( viewer: Viewer, status: StateCheckStatus, ): Promise { if (status.status === 'state_validated') { return { sessionUpdate: { lastValidated: Date.now() } }; } else if (status.status === 'state_check') { const promises = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, (async () => { if ( !hasMinCodeVersion(viewer.platformDetails, { native: 267, web: 32, }) ) { const data = await spec.fetch(viewer); return hash(data); } const infosHash = await spec.fetchServerInfosHash(viewer); return infosHash; })(), ]), ); const hashesToCheck = await promiseAll(promises); const checkStateRequest = { type: serverRequestTypes.CHECK_STATE, hashesToCheck, }; return { checkStateRequest }; } const invalidKeys = new Set(status.invalidKeys); const shouldFetchAll = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, invalidKeys.has(spec.hashKey), ]), ); const idsToFetch = Object.fromEntries( values(serverStateSyncSpecs) .filter(spec => spec.innerHashSpec?.hashKey) .map(spec => [spec.innerHashSpec?.hashKey, new Set()]), ); for (const key of invalidKeys) { const [innerHashKey, id] = key.split('|'); if (innerHashKey && id) { idsToFetch[innerHashKey]?.add(id); } } const fetchPromises: { [string]: Promise } = {}; for (const spec of values(serverStateSyncSpecs)) { if (shouldFetchAll[spec.hashKey]) { fetchPromises[spec.hashKey] = spec.fetch(viewer); } else if (idsToFetch[spec.innerHashSpec?.hashKey]?.size > 0) { fetchPromises[spec.hashKey] = spec.fetch( viewer, idsToFetch[spec.innerHashSpec?.hashKey], ); } } const fetchedData = await promiseAll(fetchPromises); const specPerHashKey = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [spec.hashKey, spec]), ); const specPerInnerHashKey = Object.fromEntries( values(serverStateSyncSpecs) .filter(spec => spec.innerHashSpec?.hashKey) .map(spec => [spec.innerHashSpec?.hashKey, spec]), ); const hashesToCheck: { [string]: number } = {}, failUnmentioned: { [string]: boolean } = {}, stateChanges: { [string]: mixed } = {}; for (const key of invalidKeys) { const spec = specPerHashKey[key]; const innerHashKey = spec?.innerHashSpec?.hashKey; const isTopLevelKey = !!spec; if (isTopLevelKey && innerHashKey) { // Instead of returning all the infos, we want to narrow down and figure // out which infos don't match first const infos = fetchedData[key]; // We have a type error here because in fact the relationship between // Infos and Info is not guaranteed to be like this. In particular, // currentUserStateSyncSpec does not match this pattern. But this code // doesn't fire for it because no innerHashSpec is defined const iterableInfos: { +[string]: mixed } = (infos: any); for (const infoID in iterableInfos) { let hashValue; if ( hasMinCodeVersion(viewer.platformDetails, { native: 267, web: 32, }) ) { // We have a type error here because Flow has no way to determine that // spec and infos are matched up - hashValue = spec.getServerInfoHash((iterableInfos[infoID]: any)); + hashValue = await spec.getServerInfoHash( + (iterableInfos[infoID]: any), + ); } else { hashValue = hash(iterableInfos[infoID]); } hashesToCheck[`${innerHashKey}|${infoID}`] = hashValue; } failUnmentioned[key] = true; } else if (isTopLevelKey) { stateChanges[key] = fetchedData[key]; } else { const [keyPrefix, id] = key.split('|'); const innerSpec = specPerInnerHashKey[keyPrefix]; const innerHashSpec = innerSpec?.innerHashSpec; if (!innerHashSpec || !id) { continue; } const infos = fetchedData[innerSpec.hashKey]; // We have a type error here because in fact the relationship between // Infos and Info is not guaranteed to be like this. In particular, // currentUserStateSyncSpec does not match this pattern. But this code // doesn't fire for it because no innerHashSpec is defined const iterableInfos: { +[string]: mixed } = (infos: any); const info = iterableInfos[id]; // We have a type error here because Flow wants us to type iterableInfos // in this file, but we don't have access to the parameterization of // innerHashSpec here if (!info || innerHashSpec.additionalDeleteCondition?.((info: any))) { if (!stateChanges[innerHashSpec.deleteKey]) { stateChanges[innerHashSpec.deleteKey] = [id]; } else { // We have a type error here because in fact stateChanges values // aren't always arrays. In particular, currentUserStateSyncSpec does // not match this pattern. But this code doesn't fire for it because // no innerHashSpec is defined const curDeleteKeyChanges: Array = (stateChanges[ innerHashSpec.deleteKey ]: any); curDeleteKeyChanges.push(id); } continue; } if (!stateChanges[innerHashSpec.rawInfosKey]) { stateChanges[innerHashSpec.rawInfosKey] = [info]; } else { // We have a type error here because in fact stateChanges values aren't // always arrays. In particular, currentUserStateSyncSpec does not match // this pattern. But this code doesn't fire for it because no // innerHashSpec is defined const curRawInfosKeyChanges: Array = (stateChanges[ innerHashSpec.rawInfosKey ]: any); curRawInfosKeyChanges.push(info); } } } // We have a type error here because the keys that get set on some of these // collections aren't statically typed when they're set. Rather, they are set // as arbitrary strings const checkStateRequest: ServerCheckStateServerRequest = ({ type: serverRequestTypes.CHECK_STATE, hashesToCheck, failUnmentioned, stateChanges, }: any); if (Object.keys(hashesToCheck).length === 0) { return { checkStateRequest, sessionUpdate: { lastValidated: Date.now() } }; } else { return { checkStateRequest }; } } export { clientResponseInputValidator, processClientResponses, initializeSession, checkState, }; diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js index 9a3a64c74..5ca0e925d 100644 --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -1,883 +1,883 @@ // @flow import type { $Request } from 'express'; import invariant from 'invariant'; import _debounce from 'lodash/debounce.js'; import t from 'tcomb'; import type { TUnion } from 'tcomb'; import WebSocket from 'ws'; import { baseLegalPolicies } from 'lib/facts/policies.js'; import { mostRecentMessageTimestamp } from 'lib/shared/message-utils.js'; import { isStaff } from 'lib/shared/staff-utils.js'; import { serverRequestSocketTimeout, serverResponseTimeout, } from 'lib/shared/timeouts.js'; import { mostRecentUpdateTimestamp } from 'lib/shared/update-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import { endpointIsSocketSafe } from 'lib/types/endpoints.js'; import type { RawEntryInfo } from 'lib/types/entry-types.js'; import { defaultNumberPerThread } from 'lib/types/message-types.js'; import { redisMessageTypes, type RedisMessage } from 'lib/types/redis-types.js'; import { serverRequestTypes } from 'lib/types/request-types.js'; import { sessionCheckFrequency, stateCheckInactivityActivationInterval, } from 'lib/types/session-types.js'; import { type ClientSocketMessage, type InitialClientSocketMessage, type ResponsesClientSocketMessage, type ServerStateSyncFullSocketPayload, type ServerServerSocketMessage, type ErrorServerSocketMessage, type AuthErrorServerSocketMessage, type PingClientSocketMessage, type AckUpdatesClientSocketMessage, type APIRequestClientSocketMessage, clientSocketMessageTypes, stateSyncPayloadTypes, serverSocketMessageTypes, serverServerSocketMessageValidator, } from 'lib/types/socket-types.js'; import type { LegacyRawThreadInfos } from 'lib/types/thread-types.js'; import type { UserInfo, CurrentUserInfo } from 'lib/types/user-types.js'; import { ServerError } from 'lib/utils/errors.js'; import { values } from 'lib/utils/objects.js'; import { promiseAll, ignorePromiseRejections } from 'lib/utils/promises.js'; import SequentialPromiseResolver from 'lib/utils/sequential-promise-resolver.js'; import sleep from 'lib/utils/sleep.js'; import { tShape, tCookie } from 'lib/utils/validation-utils.js'; import { RedisSubscriber } from './redis.js'; import { clientResponseInputValidator, processClientResponses, initializeSession, checkState, } from './session-utils.js'; import { fetchUpdateInfosWithRawUpdateInfos } from '../creators/update-creator.js'; import { deleteActivityForViewerSession } from '../deleters/activity-deleters.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import { deleteUpdatesBeforeTimeTargetingSession } from '../deleters/update-deleters.js'; import { jsonEndpoints } from '../endpoints.js'; import { fetchMessageInfosSince, getMessageFetchResultFromRedisMessages, } from '../fetchers/message-fetchers.js'; import { fetchUpdateInfos } from '../fetchers/update-fetchers.js'; import { newEntryQueryInputValidator, verifyCalendarQueryThreadIDs, } from '../responders/entry-responders.js'; import { fetchViewerForSocket, updateCookie, isCookieMissingSignedIdentityKeysBlob, isCookieMissingOlmNotificationsSession, createNewAnonymousCookie, } from '../session/cookies.js'; import { Viewer } from '../session/viewer.js'; import type { AnonymousViewerData } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { commitSessionUpdate } from '../updaters/session-updaters.js'; import { compressMessage } from '../utils/compress.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { checkInputValidator, checkClientSupported, policiesValidator, validateOutput, } from '../utils/validation-utils.js'; const clientSocketMessageInputValidator: TUnion = t.union([ tShape({ type: t.irreducible( 'clientSocketMessageTypes.INITIAL', x => x === clientSocketMessageTypes.INITIAL, ), id: t.Number, payload: tShape({ sessionIdentification: tShape({ cookie: t.maybe(tCookie), sessionID: t.maybe(t.String), }), sessionState: tShape({ calendarQuery: newEntryQueryInputValidator, messagesCurrentAsOf: t.Number, updatesCurrentAsOf: t.Number, watchedIDs: t.list(t.String), }), clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.RESPONSES', x => x === clientSocketMessageTypes.RESPONSES, ), id: t.Number, payload: tShape({ clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.PING', x => x === clientSocketMessageTypes.PING, ), id: t.Number, }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.ACK_UPDATES', x => x === clientSocketMessageTypes.ACK_UPDATES, ), id: t.Number, payload: tShape({ currentAsOf: t.Number, }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.API_REQUEST', x => x === clientSocketMessageTypes.API_REQUEST, ), id: t.Number, payload: tShape({ endpoint: t.String, input: t.maybe(t.Object), }), }), ]); function onConnection(ws: WebSocket, req: $Request) { assertSecureRequest(req); new Socket(ws, req); } type StateCheckConditions = { activityRecentlyOccurred: boolean, stateCheckOngoing: boolean, }; const minVersionsForCompression = { native: 265, web: 30, }; class Socket { ws: WebSocket; httpRequest: $Request; viewer: ?Viewer; redis: ?RedisSubscriber; redisPromiseResolver: SequentialPromiseResolver; stateCheckConditions: StateCheckConditions = { activityRecentlyOccurred: true, stateCheckOngoing: false, }; stateCheckTimeoutID: ?TimeoutID; constructor(ws: WebSocket, httpRequest: $Request) { this.ws = ws; this.httpRequest = httpRequest; ws.on('message', this.onMessage); ws.on('close', this.onClose); this.resetTimeout(); this.redisPromiseResolver = new SequentialPromiseResolver(this.sendMessage); } onMessage = async ( messageString: string | Buffer | ArrayBuffer | Array, ): Promise => { invariant(typeof messageString === 'string', 'message should be string'); let clientSocketMessage: ?ClientSocketMessage; try { this.resetTimeout(); const messageObject = JSON.parse(messageString); clientSocketMessage = checkInputValidator( clientSocketMessageInputValidator, messageObject, ); if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { if (this.viewer) { // This indicates that the user sent multiple INITIAL messages. throw new ServerError('socket_already_initialized'); } this.viewer = await fetchViewerForSocket( this.httpRequest, clientSocketMessage, ); } const { viewer } = this; if (!viewer) { // This indicates a non-INITIAL message was sent by the client before // the INITIAL message. throw new ServerError('socket_uninitialized'); } if (viewer.sessionChanged) { // This indicates that the cookie was invalid, and we've assigned a new // anonymous one. throw new ServerError('socket_deauthorized'); } if (!viewer.loggedIn) { // This indicates that the specified cookie was an anonymous one. throw new ServerError('not_logged_in'); } await checkClientSupported( viewer, clientSocketMessageInputValidator, clientSocketMessage, ); await policiesValidator(viewer, baseLegalPolicies); const serverResponses = await this.handleClientSocketMessage(clientSocketMessage); if (!this.redis) { this.redis = new RedisSubscriber( { userID: viewer.userID, sessionID: viewer.session }, this.onRedisMessage, ); } if (viewer.sessionChanged) { // This indicates that something has caused the session to change, which // shouldn't happen from inside a WebSocket since we can't handle cookie // invalidation. throw new ServerError('session_mutated_from_socket'); } if (clientSocketMessage.type !== clientSocketMessageTypes.PING) { ignorePromiseRejections(updateCookie(viewer)); } for (const response of serverResponses) { // Normally it's an anti-pattern to await in sequence like this. But in // this case, we have a requirement that this array of serverResponses // is delivered in order. See here: // https://github.com/CommE2E/comm/blob/101eb34481deb49c609bfd2c785f375886e52666/keyserver/src/socket/socket.js#L566-L568 await this.sendMessage(response); } if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { this.onSuccessfulConnection(); } } catch (error) { console.warn(error); if (!(error instanceof ServerError)) { const errorMessage: ErrorServerSocketMessage = { type: serverSocketMessageTypes.ERROR, message: error.message, }; const responseTo = clientSocketMessage ? clientSocketMessage.id : null; if (responseTo !== null) { errorMessage.responseTo = responseTo; } this.markActivityOccurred(); await this.sendMessage(errorMessage); return; } invariant(clientSocketMessage, 'should be set'); const responseTo = clientSocketMessage.id; if (error.message === 'socket_deauthorized') { invariant(this.viewer, 'should be set'); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, sessionChange: { cookie: this.viewer.cookiePairString, currentUserInfo: { anonymous: true, }, }, }; await this.sendMessage(authErrorMessage); this.ws.close(4100, error.message); return; } else if (error.message === 'client_version_unsupported') { const { viewer } = this; invariant(viewer, 'should be set'); const anonymousViewerDataPromise: Promise = createNewAnonymousCookie({ platformDetails: error.platformDetails, deviceToken: viewer.deviceToken, }); const deleteCookiePromise = deleteCookie(viewer.cookieID); const [anonymousViewerData] = await Promise.all([ anonymousViewerDataPromise, deleteCookiePromise, ]); // It is normally not safe to pass the result of // createNewAnonymousCookie to the Viewer constructor. That is because // createNewAnonymousCookie leaves several fields of // AnonymousViewerData unset, and consequently Viewer will throw when // access is attempted. It is only safe here because we can guarantee // that only cookiePairString and cookieID are accessed on anonViewer // below. const anonViewer = new Viewer(anonymousViewerData); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, sessionChange: { cookie: anonViewer.cookiePairString, currentUserInfo: { anonymous: true, }, }, }; await this.sendMessage(authErrorMessage); this.ws.close(4101, error.message); return; } if (error.payload) { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, payload: error.payload, }); } else { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, }); } if (error.message === 'not_logged_in') { this.ws.close(4102, error.message); } else if (error.message === 'session_mutated_from_socket') { this.ws.close(4103, error.message); } else { this.markActivityOccurred(); } } }; onClose = async () => { this.clearStateCheckTimeout(); this.resetTimeout.cancel(); this.debouncedAfterActivity.cancel(); if (this.viewer && this.viewer.hasSessionInfo) { await deleteActivityForViewerSession(this.viewer); } if (this.redis) { this.redis.quit(); this.redis = null; } }; sendMessage = async (message: ServerServerSocketMessage) => { invariant( this.ws.readyState > 0, "shouldn't send message until connection established", ); if (this.ws.readyState !== 1) { return; } const { viewer } = this; - const validatedMessage = validateOutput( + const validatedMessage = await validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, message, ); const stringMessage = JSON.stringify(validatedMessage); if ( !viewer?.platformDetails || !hasMinCodeVersion(viewer.platformDetails, minVersionsForCompression) || !isStaff(viewer.id) ) { this.ws.send(stringMessage); return; } const compressionResult = await compressMessage(stringMessage); if (this.ws.readyState !== 1) { return; } if (!compressionResult.compressed) { this.ws.send(stringMessage); return; } const compressedMessage = { type: serverSocketMessageTypes.COMPRESSED_MESSAGE, payload: compressionResult.result, }; - const validatedCompressedMessage = validateOutput( + const validatedCompressedMessage = await validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, compressedMessage, ); const stringCompressedMessage = JSON.stringify(validatedCompressedMessage); this.ws.send(stringCompressedMessage); }; async handleClientSocketMessage( message: ClientSocketMessage, ): Promise { const resultPromise: Promise = (async () => { if (message.type === clientSocketMessageTypes.INITIAL) { this.markActivityOccurred(); return await this.handleInitialClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.RESPONSES) { this.markActivityOccurred(); return await this.handleResponsesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.PING) { return this.handlePingClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.ACK_UPDATES) { this.markActivityOccurred(); return await this.handleAckUpdatesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.API_REQUEST) { this.markActivityOccurred(); return await this.handleAPIRequestClientSocketMessage(message); } return []; })(); const timeoutPromise: Promise = (async () => { await sleep(serverResponseTimeout); throw new ServerError('socket_response_timeout'); })(); return await Promise.race([resultPromise, timeoutPromise]); } async handleInitialClientSocketMessage( message: InitialClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const responses: Array = []; const { sessionState, clientResponses } = message.payload; const { calendarQuery, updatesCurrentAsOf: oldUpdatesCurrentAsOf, messagesCurrentAsOf: oldMessagesCurrentAsOf, watchedIDs, } = sessionState; await verifyCalendarQueryThreadIDs(calendarQuery); const sessionInitializationResult = await initializeSession( viewer, calendarQuery, oldUpdatesCurrentAsOf, ); const threadCursors: { [string]: null } = {}; for (const watchedThreadID of watchedIDs) { threadCursors[watchedThreadID] = null; } const messageSelectionCriteria = { threadCursors, joinedThreads: true, newerThan: oldMessagesCurrentAsOf, }; const [fetchMessagesResult, { serverRequests, activityUpdateResult }] = await Promise.all([ fetchMessageInfosSince( viewer, messageSelectionCriteria, defaultNumberPerThread, ), processClientResponses(viewer, clientResponses), ]); const messagesResult = { rawMessageInfos: fetchMessagesResult.rawMessageInfos, truncationStatuses: fetchMessagesResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( fetchMessagesResult.rawMessageInfos, oldMessagesCurrentAsOf, ), }; const isCookieMissingSignedIdentityKeysBlobPromise = isCookieMissingSignedIdentityKeysBlob(viewer.cookieID); const isCookieMissingOlmNotificationsSessionPromise = isCookieMissingOlmNotificationsSession(viewer); if (!sessionInitializationResult.sessionContinued) { const promises: { +[string]: Promise } = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, spec.fetchFullSocketSyncPayload(viewer, [calendarQuery]), ]), ); // We have a type error here because Flow doesn't know spec.hashKey const castPromises: { +threadInfos: Promise, +currentUserInfo: Promise, +entryInfos: Promise<$ReadOnlyArray>, +userInfos: Promise<$ReadOnlyArray>, } = (promises: any); const results = await promiseAll(castPromises); const payload: ServerStateSyncFullSocketPayload = { type: stateSyncPayloadTypes.FULL, messagesResult, threadInfos: results.threadInfos, currentUserInfo: results.currentUserInfo, rawEntryInfos: results.entryInfos, userInfos: results.userInfos, updatesCurrentAsOf: oldUpdatesCurrentAsOf, }; if (viewer.sessionChanged) { // If initializeSession encounters, // sessionIdentifierTypes.BODY_SESSION_ID but the session // is unspecified or expired, // it will set a new sessionID and specify viewer.sessionChanged const { sessionID } = viewer; invariant( sessionID !== null && sessionID !== undefined, 'should be set', ); payload.sessionID = sessionID; viewer.sessionChanged = false; } responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload, }); } else { const { sessionUpdate, deltaEntryInfoResult } = sessionInitializationResult; const deleteExpiredUpdatesPromise = deleteUpdatesBeforeTimeTargetingSession(viewer, oldUpdatesCurrentAsOf); const fetchUpdateResultPromise = fetchUpdateInfos( viewer, oldUpdatesCurrentAsOf, calendarQuery, ); const sessionUpdatePromise = commitSessionUpdate(viewer, sessionUpdate); const [fetchUpdateResult] = await Promise.all([ fetchUpdateResultPromise, deleteExpiredUpdatesPromise, sessionUpdatePromise, ]); const { updateInfos, userInfos } = fetchUpdateResult; const newUpdatesCurrentAsOf = mostRecentUpdateTimestamp( [...updateInfos], oldUpdatesCurrentAsOf, ); const updatesResult = { newUpdates: updateInfos, currentAsOf: newUpdatesCurrentAsOf, }; responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload: { type: stateSyncPayloadTypes.INCREMENTAL, messagesResult, updatesResult, deltaEntryInfos: deltaEntryInfoResult.rawEntryInfos, deletedEntryIDs: deltaEntryInfoResult.deletedEntryIDs, userInfos: values(userInfos), }, }); } const [signedIdentityKeysBlobMissing, olmNotificationsSessionMissing] = await Promise.all([ isCookieMissingSignedIdentityKeysBlobPromise, isCookieMissingOlmNotificationsSessionPromise, ]); if (signedIdentityKeysBlobMissing) { serverRequests.push({ type: serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB, }); } if (olmNotificationsSessionMissing) { serverRequests.push({ type: serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE, }); } if (serverRequests.length > 0 || clientResponses.length > 0) { // We send this message first since the STATE_SYNC triggers the client's // connection status to shift to "connected", and we want to make sure the // client responses are cleared from Redux before that happens responses.unshift({ type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }); } if (activityUpdateResult) { // Same reason for unshifting as above responses.unshift({ type: serverSocketMessageTypes.ACTIVITY_UPDATE_RESPONSE, responseTo: message.id, payload: activityUpdateResult, }); } return responses; } async handleResponsesClientSocketMessage( message: ResponsesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { clientResponses } = message.payload; const { stateCheckStatus } = await processClientResponses( viewer, clientResponses, ); const serverRequests = []; if (stateCheckStatus && stateCheckStatus.status !== 'state_check') { const { sessionUpdate, checkStateRequest } = await checkState( viewer, stateCheckStatus, ); if (sessionUpdate) { await commitSessionUpdate(viewer, sessionUpdate); this.setStateCheckConditions({ stateCheckOngoing: false }); } if (checkStateRequest) { serverRequests.push(checkStateRequest); } } // We send a response message regardless of whether we have any requests, // since we need to ack the client's responses return [ { type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }, ]; } handlePingClientSocketMessage( message: PingClientSocketMessage, ): ServerServerSocketMessage[] { return [ { type: serverSocketMessageTypes.PONG, responseTo: message.id, }, ]; } async handleAckUpdatesClientSocketMessage( message: AckUpdatesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { currentAsOf } = message.payload; await Promise.all([ deleteUpdatesBeforeTimeTargetingSession(viewer, currentAsOf), commitSessionUpdate(viewer, { lastUpdate: currentAsOf }), ]); return []; } async handleAPIRequestClientSocketMessage( message: APIRequestClientSocketMessage, ): Promise { if (!endpointIsSocketSafe(message.payload.endpoint)) { throw new ServerError('endpoint_unsafe_for_socket'); } const { viewer } = this; invariant(viewer, 'should be set'); const responder = jsonEndpoints[message.payload.endpoint]; await policiesValidator(viewer, responder.requiredPolicies); const response = await responder.responder(viewer, message.payload.input); return [ { type: serverSocketMessageTypes.API_RESPONSE, responseTo: message.id, payload: response, }, ]; } onRedisMessage = async (message: RedisMessage) => { try { await this.processRedisMessage(message); } catch (e) { console.warn(e); } }; async processRedisMessage(message: RedisMessage) { if (message.type === redisMessageTypes.START_SUBSCRIPTION) { this.ws.terminate(); } else if (message.type === redisMessageTypes.NEW_UPDATES) { const { viewer } = this; invariant(viewer, 'should be set'); if (message.ignoreSession && message.ignoreSession === viewer.session) { return; } const rawUpdateInfos = message.updates; this.redisPromiseResolver.add( (async () => { const { updateInfos, userInfos } = await fetchUpdateInfosWithRawUpdateInfos(rawUpdateInfos, { viewer, }); if (updateInfos.length === 0) { console.warn( 'could not get any UpdateInfos from redisMessageTypes.NEW_UPDATES', ); return null; } this.markActivityOccurred(); return { type: serverSocketMessageTypes.UPDATES, payload: { updatesResult: { currentAsOf: mostRecentUpdateTimestamp([...updateInfos], 0), newUpdates: updateInfos, }, userInfos: values(userInfos), }, }; })(), ); } else if (message.type === redisMessageTypes.NEW_MESSAGES) { const { viewer } = this; invariant(viewer, 'should be set'); const rawMessageInfos = message.messages; const messageFetchResult = getMessageFetchResultFromRedisMessages( viewer, rawMessageInfos, ); if (messageFetchResult.rawMessageInfos.length === 0) { console.warn( 'could not get any rawMessageInfos from ' + 'redisMessageTypes.NEW_MESSAGES', ); return; } this.redisPromiseResolver.add( (async () => { this.markActivityOccurred(); return { type: serverSocketMessageTypes.MESSAGES, payload: { messagesResult: { rawMessageInfos: messageFetchResult.rawMessageInfos, truncationStatuses: messageFetchResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( messageFetchResult.rawMessageInfos, 0, ), }, }, }; })(), ); } } onSuccessfulConnection() { if (this.ws.readyState !== 1) { return; } this.handleStateCheckConditionsUpdate(); } // The Socket will timeout by calling this.ws.terminate() // serverRequestSocketTimeout milliseconds after the last // time resetTimeout is called resetTimeout: { +cancel: () => void } & (() => void) = _debounce( () => this.ws.terminate(), serverRequestSocketTimeout, ); debouncedAfterActivity: { +cancel: () => void } & (() => void) = _debounce( () => this.setStateCheckConditions({ activityRecentlyOccurred: false }), stateCheckInactivityActivationInterval, ); markActivityOccurred = () => { if (this.ws.readyState !== 1) { return; } this.setStateCheckConditions({ activityRecentlyOccurred: true }); this.debouncedAfterActivity(); }; clearStateCheckTimeout() { const { stateCheckTimeoutID } = this; if (stateCheckTimeoutID) { clearTimeout(stateCheckTimeoutID); this.stateCheckTimeoutID = null; } } setStateCheckConditions(newConditions: Partial) { this.stateCheckConditions = { ...this.stateCheckConditions, ...newConditions, }; this.handleStateCheckConditionsUpdate(); } get stateCheckCanStart(): boolean { return Object.values(this.stateCheckConditions).every(cond => !cond); } handleStateCheckConditionsUpdate() { if (!this.stateCheckCanStart) { this.clearStateCheckTimeout(); return; } if (this.stateCheckTimeoutID) { return; } const { viewer } = this; if (!viewer) { return; } const timeUntilStateCheck = viewer.sessionLastValidated + sessionCheckFrequency - Date.now(); if (timeUntilStateCheck <= 0) { ignorePromiseRejections(this.initiateStateCheck()); } else { this.stateCheckTimeoutID = setTimeout( this.initiateStateCheck, timeUntilStateCheck, ); } } initiateStateCheck = async () => { this.setStateCheckConditions({ stateCheckOngoing: true }); const { viewer } = this; invariant(viewer, 'should be set'); const { checkStateRequest } = await checkState(viewer, { status: 'state_check', }); invariant(checkStateRequest, 'should be set'); await this.sendMessage({ type: serverSocketMessageTypes.REQUESTS, payload: { serverRequests: [checkStateRequest] }, }); }; } export { onConnection }; diff --git a/keyserver/src/uploads/uploads.js b/keyserver/src/uploads/uploads.js index 0228eef5c..ba267fe9c 100644 --- a/keyserver/src/uploads/uploads.js +++ b/keyserver/src/uploads/uploads.js @@ -1,232 +1,232 @@ // @flow import type { $Request, $Response, Middleware } from 'express'; import invariant from 'invariant'; import multer from 'multer'; import { Readable } from 'stream'; import t, { type TInterface } from 'tcomb'; import { type UploadMediaMetadataRequest, type UploadMultimediaResult, uploadMultimediaResultValidator, type UploadDeletionRequest, type Dimensions, } from 'lib/types/media-types.js'; import { ServerError } from 'lib/utils/errors.js'; import { tShape, tID } from 'lib/utils/validation-utils.js'; import { getMediaType, validateAndConvert } from './media-utils.js'; import type { UploadInput } from '../creators/upload-creator.js'; import createUploads from '../creators/upload-creator.js'; import { deleteUpload } from '../deleters/upload-deleters.js'; import { fetchUpload, fetchUploadChunk, getUploadSize, } from '../fetchers/upload-fetchers.js'; import type { MulterRequest } from '../responders/handlers.js'; import type { Viewer } from '../session/viewer.js'; import { validateOutput } from '../utils/validation-utils.js'; const upload = multer(); const multerProcessor: Middleware<> = upload.array('multimedia'); type MultimediaUploadResult = { results: UploadMultimediaResult[], }; const MultimediaUploadResultValidator = tShape({ results: t.list(uploadMultimediaResultValidator), }); async function multimediaUploadResponder( viewer: Viewer, req: MulterRequest, ): Promise { const { files, body } = req; if (!files || !body || typeof body !== 'object') { throw new ServerError('invalid_parameters'); } const overrideFilename = files.length === 1 && body.filename ? body.filename : null; if (overrideFilename && typeof overrideFilename !== 'string') { throw new ServerError('invalid_parameters'); } const inputHeight = files.length === 1 && body.height ? parseInt(body.height) : null; const inputWidth = files.length === 1 && body.width ? parseInt(body.width) : null; if (!!inputHeight !== !!inputWidth) { throw new ServerError('invalid_parameters'); } const inputDimensions: ?Dimensions = inputHeight && inputWidth ? { height: inputHeight, width: inputWidth } : null; const inputLoop = !!(files.length === 1 && body.loop); const inputEncryptionKey = files.length === 1 && body.encryptionKey ? body.encryptionKey : null; if (inputEncryptionKey && typeof inputEncryptionKey !== 'string') { throw new ServerError('invalid_parameters'); } const inputMimeType = files.length === 1 && body.mimeType ? body.mimeType : null; if (inputMimeType && typeof inputMimeType !== 'string') { throw new ServerError('invalid_parameters'); } const inputThumbHash = files.length === 1 && body.thumbHash ? body.thumbHash : null; if (inputThumbHash && typeof inputThumbHash !== 'string') { throw new ServerError('invalid_parameters'); } const validationResults = await Promise.all( files.map(({ buffer, size, originalname }) => validateAndConvert({ initialBuffer: buffer, initialName: overrideFilename ? overrideFilename : originalname, inputDimensions, inputLoop, inputEncryptionKey, inputMimeType, inputThumbHash, size, }), ), ); const uploadInfos = validationResults.filter(Boolean); if (uploadInfos.length === 0) { throw new ServerError('invalid_parameters'); } const results = await createUploads(viewer, uploadInfos); - return validateOutput( + return await validateOutput( viewer.platformDetails, MultimediaUploadResultValidator, { results, }, ); } export const uploadMediaMetadataInputValidator: TInterface = tShape({ filename: t.String, width: t.Number, height: t.Number, blobHolder: t.String, blobHash: t.String, encryptionKey: t.String, mimeType: t.String, loop: t.maybe(t.Boolean), thumbHash: t.maybe(t.String), }); async function uploadMediaMetadataResponder( viewer: Viewer, request: UploadMediaMetadataRequest, ): Promise { const mediaType = getMediaType(request.mimeType); if (!mediaType) { throw new ServerError('invalid_parameters'); } const { filename, blobHolder, blobHash, encryptionKey, mimeType, width, height, loop, } = request; const uploadInfo: UploadInput = { name: filename, mime: mimeType, mediaType, content: { storage: 'blob_service', blobHolder, blobHash }, encryptionKey, dimensions: { width, height }, loop: loop ?? false, thumbHash: request.thumbHash, }; const [result] = await createUploads(viewer, [uploadInfo]); return result; } async function uploadDownloadResponder( viewer: Viewer, req: $Request, res: $Response, ): Promise { const { uploadID, secret } = req.params; if (!uploadID || !secret) { throw new ServerError('invalid_parameters'); } if (!req.headers.range) { const { content, mime } = await fetchUpload(viewer, uploadID, secret); res.type(mime); res.set('Cache-Control', 'public, max-age=31557600, immutable'); res.send(content); } else { const totalUploadSize = await getUploadSize(uploadID, secret); const range = req.range(totalUploadSize); if (typeof range === 'number' && range < 0) { throw new ServerError( range === -1 ? 'unsatisfiable_range' : 'malformed_header_string', ); } invariant( Array.isArray(range), 'range should be Array in uploadDownloadResponder!', ); const { start, end } = range[0]; const respWidth = end - start + 1; const { content, mime } = await fetchUploadChunk( uploadID, secret, start, respWidth, ); const respRange = `${start}-${end}/${totalUploadSize}`; const respHeaders: { [key: string]: string } = { 'Accept-Ranges': 'bytes', 'Content-Range': `bytes ${respRange}`, 'Content-Type': mime, 'Content-Length': respWidth.toString(), }; // HTTP 206 Partial Content // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/206 res.writeHead(206, respHeaders); const stream = new Readable(); stream.push(content); stream.push(null); stream.pipe(res); } } export const UploadDeletionRequestInputValidator: TInterface = tShape({ id: tID, }); async function uploadDeletionResponder( viewer: Viewer, { id }: UploadDeletionRequest, ): Promise { await deleteUpload(viewer, id); } export { multerProcessor, multimediaUploadResponder, uploadDownloadResponder, uploadDeletionResponder, uploadMediaMetadataResponder, }; diff --git a/keyserver/src/user/checks.js b/keyserver/src/user/checks.js index e68c64b2b..8fbe260b1 100644 --- a/keyserver/src/user/checks.js +++ b/keyserver/src/user/checks.js @@ -1,31 +1,35 @@ // @flow import { getCommConfig } from 'lib/utils/comm-config.js'; -export type UserCredentials = { +username: string, +password: string }; +export type UserCredentials = { + +username: string, + +password: string, + +usingIdentityCredentials: boolean, +}; async function ensureUserCredentials() { const userCredentials = await getCommConfig({ folder: 'secrets', name: 'user_credentials', }); if (!userCredentials) { console.warn( 'User credentials for the keyserver owner must be specified. They can be ' + 'specified either by setting the ' + '`COMM_JSONCONFIG_secrets_user_credentials` environmental variable, or by ' + 'setting a file at keyserver/secrets/user_credentials.json. The contents ' + 'should be a JSON blob that looks like this:\n' + '{\n' + ' "username": ,\n' + ' "password": \n' + '}\n', ); // Since we don't want to apply the migration until there are credentials; // throw the error and force keyserver to be configured next restart throw new Error('missing_keyserver_owner_credentials'); } } export { ensureUserCredentials }; diff --git a/keyserver/src/user/identity.js b/keyserver/src/user/identity.js index dffb01280..1004eacfe 100644 --- a/keyserver/src/user/identity.js +++ b/keyserver/src/user/identity.js @@ -1,47 +1,69 @@ // @flow import type { QueryResults } from 'mysql'; +import { getCommConfig } from 'lib/utils/comm-config.js'; +import { ashoatKeyserverID } from 'lib/utils/validation-utils.js'; + +import type { UserCredentials } from './checks'; import { SQL, dbQuery } from '../database/database.js'; const userIDMetadataKey = 'user_id'; const accessTokenMetadataKey = 'access_token'; // This information is minted when registering with identity service // Naming should reflect the rust-bindings: userId -> user_id export type IdentityInfo = { +userId: string, +accessToken: string }; async function fetchIdentityInfo(): Promise { const versionQuery = SQL` SELECT name, data FROM metadata WHERE name IN (${userIDMetadataKey}, ${accessTokenMetadataKey}) `; const [result] = await dbQuery(versionQuery); let userID, accessToken; for (const row of result) { if (row.name === userIDMetadataKey) { userID = row.data; } else if (row.name === accessTokenMetadataKey) { accessToken = row.data; } } if (!userID || !accessToken) { return null; } return { userId: userID, accessToken }; } +let cachedKeyserverID: ?string; + +async function thisKeyserverID(): Promise { + if (cachedKeyserverID) { + return cachedKeyserverID; + } + const config = await getCommConfig({ + folder: 'secrets', + name: 'user_credentials', + }); + if (!config?.usingIdentityCredentials) { + return ashoatKeyserverID; + } + const identityInfo = await fetchIdentityInfo(); + cachedKeyserverID = identityInfo?.userId ?? ashoatKeyserverID; + return cachedKeyserverID; +} + function saveIdentityInfo(userInfo: IdentityInfo): Promise { const updateQuery = SQL` REPLACE INTO metadata (name, data) VALUES (${userIDMetadataKey}, ${userInfo.userId}), (${accessTokenMetadataKey}, ${userInfo.accessToken}) `; return dbQuery(updateQuery); } -export { fetchIdentityInfo, saveIdentityInfo }; +export { fetchIdentityInfo, thisKeyserverID, saveIdentityInfo }; diff --git a/keyserver/src/utils/validation-utils.js b/keyserver/src/utils/validation-utils.js index a0e052342..05aa49b9b 100644 --- a/keyserver/src/utils/validation-utils.js +++ b/keyserver/src/utils/validation-utils.js @@ -1,235 +1,235 @@ // @flow import type { TType } from 'tcomb'; import type { PolicyType } from 'lib/facts/policies.js'; import { hasMinCodeVersion, hasMinStateVersion, } from 'lib/shared/version-utils.js'; import { type PlatformDetails } from 'lib/types/device-types.js'; import { convertClientIDsToServerIDs, convertObject, convertServerIDsToClientIDs, } from 'lib/utils/conversion-utils.js'; import { ServerError } from 'lib/utils/errors.js'; import { tCookie, tPassword, tPlatform, tPlatformDetails, assertWithValidator, - ashoatKeyserverID, } from 'lib/utils/validation-utils.js'; import { fetchNotAcknowledgedPolicies } from '../fetchers/policy-acknowledgment-fetchers.js'; import { verifyClientSupported } from '../session/version.js'; import type { Viewer } from '../session/viewer.js'; +import { thisKeyserverID } from '../user/identity.js'; async function validateInput( viewer: Viewer, inputValidator: TType, input: mixed, ): Promise { if (!viewer.isSocket) { await checkClientSupported(viewer, inputValidator, input); } const convertedInput = checkInputValidator(inputValidator, input); + const keyserverID = await thisKeyserverID(); + if ( hasMinStateVersion(viewer.platformDetails, { native: 43, web: 3, }) ) { try { return convertClientIDsToServerIDs( - ashoatKeyserverID, + keyserverID, inputValidator, convertedInput, ); } catch (err) { throw new ServerError(err.message); } } return convertedInput; } -function validateOutput( +async function validateOutput( platformDetails: ?PlatformDetails, outputValidator: TType, data: T, -): T { +): Promise { if (!outputValidator.is(data)) { console.trace( 'Output validation failed, validator is:', outputValidator.displayName, ); return data; } + const keyserverID = await thisKeyserverID(); + if ( hasMinStateVersion(platformDetails, { native: 43, web: 3, }) ) { - return convertServerIDsToClientIDs( - ashoatKeyserverID, - outputValidator, - data, - ); + return convertServerIDsToClientIDs(keyserverID, outputValidator, data); } return data; } function checkInputValidator(inputValidator: TType, input: mixed): T { if (inputValidator.is(input)) { return assertWithValidator(input, inputValidator); } const error = new ServerError('invalid_parameters'); error.sanitizedInput = input ? sanitizeInput(inputValidator, input) : null; throw error; } async function checkClientSupported( viewer: Viewer, inputValidator: ?TType, input: T, ) { let platformDetails; if (inputValidator) { platformDetails = findFirstInputMatchingValidator( inputValidator, tPlatformDetails, input, ); } if (!platformDetails && inputValidator) { const platform = findFirstInputMatchingValidator( inputValidator, tPlatform, input, ); if (platform) { platformDetails = { platform }; } } if (!platformDetails) { ({ platformDetails } = viewer); } await verifyClientSupported(viewer, platformDetails); } const redactedString = '********'; const redactedTypes = [tPassword, tCookie]; function sanitizeInput(inputValidator: TType, input: T): T { return convertObject( inputValidator, input, redactedTypes, () => redactedString, ); } function findFirstInputMatchingValidator( wholeInputValidator: *, inputValidatorToMatch: *, input: *, ): any { if (!wholeInputValidator || input === null || input === undefined) { return null; } if ( wholeInputValidator === inputValidatorToMatch && wholeInputValidator.is(input) ) { return input; } if (wholeInputValidator.meta.kind === 'maybe') { return findFirstInputMatchingValidator( wholeInputValidator.meta.type, inputValidatorToMatch, input, ); } if ( wholeInputValidator.meta.kind === 'interface' && typeof input === 'object' ) { for (const key in input) { const value = input[key]; const validator = wholeInputValidator.meta.props[key]; const innerResult = findFirstInputMatchingValidator( validator, inputValidatorToMatch, value, ); if (innerResult) { return innerResult; } } } if (wholeInputValidator.meta.kind === 'union') { for (const validator of wholeInputValidator.meta.types) { if (validator.is(input)) { return findFirstInputMatchingValidator( validator, inputValidatorToMatch, input, ); } } } if (wholeInputValidator.meta.kind === 'list' && Array.isArray(input)) { const validator = wholeInputValidator.meta.type; for (const value of input) { const innerResult = findFirstInputMatchingValidator( validator, inputValidatorToMatch, value, ); if (innerResult) { return innerResult; } } } return null; } async function policiesValidator( viewer: Viewer, policies: $ReadOnlyArray, ) { if (!policies.length || !viewer.loggedIn) { return; } if (!hasMinCodeVersion(viewer.platformDetails, { native: 181 })) { return; } const notAcknowledgedPolicies = await fetchNotAcknowledgedPolicies( viewer.id, policies, ); if (notAcknowledgedPolicies.length) { throw new ServerError('policies_not_accepted', { notAcknowledgedPolicies, }); } } export { validateInput, validateOutput, checkInputValidator, redactedString, sanitizeInput, findFirstInputMatchingValidator, checkClientSupported, policiesValidator, };