diff --git a/keyserver/src/creators/message-creator.js b/keyserver/src/creators/message-creator.js --- a/keyserver/src/creators/message-creator.js +++ b/keyserver/src/creators/message-creator.js @@ -52,6 +52,7 @@ { +platform: string, +deviceToken: string, + +cookieID: string, +codeVersion: ?string, }, >, @@ -323,7 +324,7 @@ const time = earliestFocusedTimeConsideredExpired(); const visibleExtractString = `$.${threadPermissions.VISIBLE}.value`; const query = SQL` - SELECT m.user, m.thread, c.platform, c.device_token, c.versions, + SELECT m.user, m.thread, c.platform, c.device_token, c.versions, c.id, f.user AS focused_user `; query.append(subthreadSelects); @@ -349,6 +350,7 @@ const focusedUser = !!row.focused_user; const { platform } = row; const versions = JSON.parse(row.versions); + const cookieID = row.id; let thisUserInfo = perUserInfo.get(userID); if (!thisUserInfo) { thisUserInfo = { @@ -384,10 +386,11 @@ } } } - if (deviceToken) { + if (deviceToken && cookieID) { thisUserInfo.devices.set(deviceToken, { platform, deviceToken, + cookieID: cookieID.toString(), codeVersion: versions ? versions.codeVersion : null, }); } diff --git a/keyserver/src/push/rescind.js b/keyserver/src/push/rescind.js --- a/keyserver/src/push/rescind.js +++ b/keyserver/src/push/rescind.js @@ -58,9 +58,11 @@ row.unread_count, threadID, ); + const targetedNotifications = delivery.iosDeviceTokens.map( + deviceToken => ({ deviceToken, notification }), + ); deliveryPromises[id] = apnPush({ - notification, - deviceTokens: delivery.iosDeviceTokens, + targetedNotifications, platformDetails: { platform: 'ios' }, }); } else if (delivery.androidID) { @@ -84,9 +86,12 @@ threadID, codeVersion, ); - deliveryPromises[id] = apnPush({ + const targetedNotifications = deviceTokens.map(deviceToken => ({ + deviceToken, notification, - deviceTokens, + })); + deliveryPromises[id] = apnPush({ + targetedNotifications, platformDetails: { platform: 'ios', codeVersion }, }); } else if (delivery.deviceType === 'android') { diff --git a/keyserver/src/push/send.js b/keyserver/src/push/send.js --- a/keyserver/src/push/send.js +++ b/keyserver/src/push/send.js @@ -23,6 +23,7 @@ rawThreadInfoFromServerThreadInfo, threadInfoFromRawThreadInfo, } from 'lib/shared/thread-utils.js'; +import { FUTURE_CODE_VERSION } from 'lib/shared/version-utils.js'; import type { Platform, PlatformDetails } from 'lib/types/device-types.js'; import { messageTypes } from 'lib/types/message-types-enum.js'; import { @@ -41,6 +42,7 @@ import { promiseAll } from 'lib/utils/promises.js'; import { tID, tPlatformDetails, tShape } from 'lib/utils/validation-utils.js'; +import { prepareEncryptedIOSNotifications } from './crypto.js'; import { getAPNsNotificationTopic } from './providers.js'; import { rescindPushNotifs } from './rescind.js'; import { @@ -69,6 +71,7 @@ type Device = { +platform: Platform, +deviceToken: string, + +cookieID: string, +codeVersion: ?number, }; type PushUserInfo = { @@ -195,38 +198,43 @@ const iosVersionsToTokens = byPlatform.get('ios'); if (iosVersionsToTokens) { - for (const [codeVersion, deviceTokens] of iosVersionsToTokens) { + for (const [codeVersion, devices] of iosVersionsToTokens) { const platformDetails = { platform: 'ios', codeVersion }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, platformDetails, ); + const cookieIDs = devices.map(({ cookieID }) => cookieID); const deliveryPromise = (async () => { - const notification = await prepareAPNsNotification({ - notifTexts, - newRawMessageInfos: shimmedNewRawMessageInfos, - threadID: threadInfo.id, - collapseKey: notifInfo.collapseKey, - badgeOnly, - unreadCount: unreadCounts[userID], - platformDetails, - }); - return await sendAPNsNotification( - 'ios', - notification, - [...deviceTokens], + const notificationsArray = await prepareAPNsNotification( { - ...notificationInfo, - codeVersion, + notifTexts, + newRawMessageInfos: shimmedNewRawMessageInfos, + threadID: threadInfo.id, + collapseKey: notifInfo.collapseKey, + badgeOnly, + unreadCount: unreadCounts[userID], + platformDetails, }, + cookieIDs, + ); + const targetedNotifications = devices.map( + ({ deviceToken }, idx) => ({ + deviceToken, + notification: notificationsArray[idx], + }), ); + return await sendAPNsNotification('ios', targetedNotifications, { + ...notificationInfo, + codeVersion, + }); })(); deliveryPromises.push(deliveryPromise); } } const androidVersionsToTokens = byPlatform.get('android'); if (androidVersionsToTokens) { - for (const [codeVersion, deviceTokens] of androidVersionsToTokens) { + for (const [codeVersion, devices] of androidVersionsToTokens) { const platformDetails = { platform: 'android', codeVersion }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, @@ -243,21 +251,18 @@ platformDetails, dbID, }); - return await sendAndroidNotification( - notification, - [...deviceTokens], - { - ...notificationInfo, - codeVersion, - }, - ); + const deviceTokens = devices.map(({ deviceToken }) => deviceToken); + return await sendAndroidNotification(notification, deviceTokens, { + ...notificationInfo, + codeVersion, + }); })(); deliveryPromises.push(deliveryPromise); } } const webVersionsToTokens = byPlatform.get('web'); if (webVersionsToTokens) { - for (const [codeVersion, deviceTokens] of webVersionsToTokens) { + for (const [codeVersion, devices] of webVersionsToTokens) { const platformDetails = { platform: 'web', codeVersion }; const deliveryPromise = (async () => { const notification = await prepareWebNotification({ @@ -266,7 +271,8 @@ unreadCount: unreadCounts[userID], platformDetails, }); - return await sendWebNotification(notification, [...deviceTokens], { + const deviceTokens = devices.map(({ deviceToken }) => deviceToken); + return await sendWebNotification(notification, deviceTokens, { ...notificationInfo, codeVersion, }); @@ -276,38 +282,43 @@ } const macosVersionsToTokens = byPlatform.get('macos'); if (macosVersionsToTokens) { - for (const [codeVersion, deviceTokens] of macosVersionsToTokens) { + for (const [codeVersion, devices] of macosVersionsToTokens) { const platformDetails = { platform: 'macos', codeVersion }; const shimmedNewRawMessageInfos = shimUnsupportedRawMessageInfos( newRawMessageInfos, platformDetails, ); + const cookieIDs = devices.map(({ cookieID }) => cookieID); const deliveryPromise = (async () => { - const notification = await prepareAPNsNotification({ - notifTexts, - newRawMessageInfos: shimmedNewRawMessageInfos, - threadID: threadInfo.id, - collapseKey: notifInfo.collapseKey, - badgeOnly, - unreadCount: unreadCounts[userID], - platformDetails, - }); - return await sendAPNsNotification( - 'macos', - notification, - [...deviceTokens], + const notificationsArray = await prepareAPNsNotification( { - ...notificationInfo, - codeVersion, + notifTexts, + newRawMessageInfos: shimmedNewRawMessageInfos, + threadID: threadInfo.id, + collapseKey: notifInfo.collapseKey, + badgeOnly, + unreadCount: unreadCounts[userID], + platformDetails, }, + cookieIDs, + ); + const targetedNotifications = devices.map( + ({ deviceToken }, idx) => ({ + deviceToken, + notification: notificationsArray[idx], + }), ); + return await sendAPNsNotification('macos', targetedNotifications, { + ...notificationInfo, + codeVersion, + }); })(); deliveryPromises.push(deliveryPromise); } } const windowsVersionsToTokens = byPlatform.get('windows'); if (windowsVersionsToTokens) { - for (const [codeVersion, deviceTokens] of windowsVersionsToTokens) { + for (const [codeVersion, devices] of windowsVersionsToTokens) { const platformDetails = { platform: 'windows', codeVersion }; const deliveryPromise = (async () => { const notification = await prepareWNSNotification({ @@ -316,7 +327,8 @@ unreadCount: unreadCounts[userID], platformDetails, }); - return await sendWNSNotification(notification, [...deviceTokens], { + const deviceTokens = devices.map(({ deviceToken }) => deviceToken); + return await sendWNSNotification(notification, deviceTokens, { ...notificationInfo, codeVersion, }); @@ -586,8 +598,11 @@ } function getDevicesByPlatform( - devices: Device[], -): Map>> { + devices: $ReadOnlyArray, +): Map< + Platform, + Map>, +> { const byPlatform = new Map(); for (const device of devices) { let innerMap = byPlatform.get(device.platform); @@ -599,12 +614,16 @@ device.codeVersion !== null && device.codeVersion !== undefined ? device.codeVersion : -1; - let innerMostSet = innerMap.get(codeVersion); - if (!innerMostSet) { - innerMostSet = new Set(); - innerMap.set(codeVersion, innerMostSet); + let innerMostArray = innerMap.get(codeVersion); + if (!innerMostArray) { + innerMostArray = []; + innerMap.set(codeVersion, innerMostArray); } - innerMostSet.add(device.deviceToken); + + innerMostArray.push({ + cookieID: device.cookieID, + deviceToken: device.deviceToken, + }); } return byPlatform; } @@ -629,7 +648,8 @@ }); async function prepareAPNsNotification( inputData: APNsNotifInputData, -): Promise { + cookieIDs: $ReadOnlyArray, +): Promise> { const convertedData = validateOutput( inputData.platformDetails, apnsNotifInputDataValidator, @@ -645,6 +665,12 @@ platformDetails, } = convertedData; + const isTextNotification = newRawMessageInfos.every( + newRawMessageInfo => newRawMessageInfo.type === messageTypes.TEXT, + ); + const shouldBeEncrypted = + platformDetails.platform === 'ios' && !collapseKey && isTextNotification; + const uniqueID = uuidv4(); const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic(platformDetails); @@ -685,18 +711,42 @@ ...copyWithMessageInfos.payload, messageInfos, }; - if (copyWithMessageInfos.length() <= apnMaxNotificationPayloadByteSize) { - notification.payload.messageInfos = messageInfos; - return notification; - } - const notificationCopy = _cloneDeep(notification); - if (notificationCopy.length() > apnMaxNotificationPayloadByteSize) { - console.warn( - `${platformDetails.platform} notification ${uniqueID} ` + - `exceeds size limit, even with messageInfos omitted`, + + const evaluateAndSelectNotifPayload = (notif, notifWithMessageInfos) => { + const notifWithMessageInfosCopy = _cloneDeep(notifWithMessageInfos); + if ( + notifWithMessageInfosCopy.length() <= apnMaxNotificationPayloadByteSize + ) { + return notifWithMessageInfos; + } + const notifCopy = _cloneDeep(notif); + if (notifCopy.length() > apnMaxNotificationPayloadByteSize) { + console.warn( + `${platformDetails.platform} notification ${uniqueID} ` + + `exceeds size limit, even with messageInfos omitted`, + ); + } + return notif; + }; + + if ( + shouldBeEncrypted && + platformDetails.codeVersion && + platformDetails.codeVersion > FUTURE_CODE_VERSION + ) { + const [notifications, notificationsWithMessageInfos] = await Promise.all([ + prepareEncryptedIOSNotifications(cookieIDs, notification), + prepareEncryptedIOSNotifications(cookieIDs, copyWithMessageInfos), + ]); + return notificationsWithMessageInfos.map((notif, idx) => + evaluateAndSelectNotifPayload(notifications[idx], notif), ); } - return notification; + const notificationToSend = evaluateAndSelectNotifPayload( + notification, + copyWithMessageInfos, + ); + return cookieIDs.map(() => notificationToSend); } type AndroidNotifInputData = { @@ -875,20 +925,35 @@ }; async function sendAPNsNotification( platform: 'ios' | 'macos', - notification: apn.Notification, - deviceTokens: $ReadOnlyArray, + targetedNotifications: $ReadOnlyArray<{ + +notification: apn.Notification, + +deviceToken: string, + }>, notificationInfo: NotificationInfo, ): Promise { const { source, codeVersion } = notificationInfo; + const response = await apnPush({ - notification, - deviceTokens, + targetedNotifications, platformDetails: { platform, codeVersion }, }); + invariant( + new Set(targetedNotifications.map(({ notification }) => notification.id)) + .size === 1, + 'Encrypted versions of the same notification must share id value', + ); + const [ + { + notification: { id }, + }, + ] = targetedNotifications; + const deviceTokens = targetedNotifications.map( + ({ deviceToken }) => deviceToken, + ); const delivery: APNsDelivery = { source, deviceType: platform, - iosID: notification.id, + iosID: id, deviceTokens, codeVersion, }; @@ -1108,7 +1173,7 @@ const { userID } = viewer; const deviceTokenQuery = SQL` - SELECT platform, device_token, versions + SELECT platform, device_token, versions, id FROM cookies WHERE user = ${userID} AND device_token IS NOT NULL @@ -1125,6 +1190,7 @@ const devices = deviceTokenResult.map(row => ({ platform: row.platform, + cookieID: row.id, deviceToken: row.device_token, codeVersion: JSON.parse(row.versions)?.codeVersion, })); @@ -1134,7 +1200,7 @@ const iosVersionsToTokens = byPlatform.get('ios'); if (iosVersionsToTokens) { - for (const [codeVersion, deviceTokens] of iosVersionsToTokens) { + for (const [codeVersion, deviceInfos] of iosVersionsToTokens) { const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic({ platform: 'ios', @@ -1142,27 +1208,46 @@ }); notification.badge = unreadCount; notification.pushType = 'alert'; - deliveryPromises.push( - sendAPNsNotification('ios', notification, [...deviceTokens], { + const deliveryPromise = (async () => { + const cookieIDs = deviceInfos.map(({ cookieID }) => cookieID); + let notificationsArray; + if (codeVersion > FUTURE_CODE_VERSION) { + notificationsArray = await prepareEncryptedIOSNotifications( + cookieIDs, + notification, + ); + } else { + notificationsArray = cookieIDs.map(() => notification); + } + const targetedNotifications = deviceInfos.map( + ({ deviceToken }, idx) => ({ + deviceToken, + notification: notificationsArray[idx], + }), + ); + return await sendAPNsNotification('ios', targetedNotifications, { source, dbID, userID, codeVersion, - }), - ); + }); + })(); + + deliveryPromises.push(deliveryPromise); } } const androidVersionsToTokens = byPlatform.get('android'); if (androidVersionsToTokens) { - for (const [codeVersion, deviceTokens] of androidVersionsToTokens) { + for (const [codeVersion, deviceInfos] of androidVersionsToTokens) { const notificationData = codeVersion < 69 ? { badge: unreadCount.toString() } : { badge: unreadCount.toString(), badgeOnly: '1' }; const notification = { data: notificationData }; + const deviceTokens = deviceInfos.map(({ deviceToken }) => deviceToken); deliveryPromises.push( - sendAndroidNotification(notification, [...deviceTokens], { + sendAndroidNotification(notification, deviceTokens, { source, dbID, userID, @@ -1174,7 +1259,7 @@ const macosVersionsToTokens = byPlatform.get('macos'); if (macosVersionsToTokens) { - for (const [codeVersion, deviceTokens] of macosVersionsToTokens) { + for (const [codeVersion, deviceInfos] of macosVersionsToTokens) { const notification = new apn.Notification(); notification.topic = getAPNsNotificationTopic({ platform: 'macos', @@ -1182,8 +1267,12 @@ }); notification.badge = unreadCount; notification.pushType = 'alert'; + const targetedNotifications = deviceInfos.map(({ deviceToken }) => ({ + deviceToken, + notification, + })); deliveryPromises.push( - sendAPNsNotification('macos', notification, [...deviceTokens], { + sendAPNsNotification('macos', targetedNotifications, { source, dbID, userID, diff --git a/keyserver/src/push/utils.js b/keyserver/src/push/utils.js --- a/keyserver/src/push/utils.js +++ b/keyserver/src/push/utils.js @@ -46,12 +46,13 @@ +invalidTokens?: $ReadOnlyArray, }; async function apnPush({ - notification, - deviceTokens, + targetedNotifications, platformDetails, }: { - +notification: apn.Notification, - +deviceTokens: $ReadOnlyArray, + +targetedNotifications: $ReadOnlyArray<{ + +notification: apn.Notification, + +deviceToken: string, + }>, +platformDetails: PlatformDetails, }): Promise { const pushProfile = getAPNPushProfileForCodeVersion(platformDetails); @@ -61,10 +62,22 @@ return { success: true }; } invariant(apnProvider, `keyserver/secrets/${pushProfile}.json should exist`); - const result = await apnProvider.send(notification, deviceTokens); + + const results = await Promise.all( + targetedNotifications.map(({ notification, deviceToken }) => { + return apnProvider.send(notification, deviceToken); + }), + ); + + const mergedResults = { sent: [], failed: [] }; + for (const result of results) { + mergedResults.sent.push(...result.sent); + mergedResults.failed.push(...result.failed); + } + const errors = []; const invalidTokens = []; - for (const error of result.failed) { + for (const error of mergedResults.failed) { errors.push(error); /* eslint-disable eqeqeq */ if (