diff --git a/lib/ops/keyserver-store-ops.js b/lib/ops/keyserver-store-ops.js --- a/lib/ops/keyserver-store-ops.js +++ b/lib/ops/keyserver-store-ops.js @@ -130,7 +130,24 @@ }, }; +function getKeyserversToRemoveFromNotifsStore( + ops: $ReadOnlyArray, +): $ReadOnlyArray { + const idsToRemove: Set = new Set(); + for (const op of ops) { + if (op.type !== 'remove_keyservers') { + continue; + } + for (const id of op.payload.ids) { + idsToRemove.add(id); + } + } + + return [...idsToRemove]; +} + export { keyserverStoreOpsHandlers, convertKeyserverInfoToClientDBKeyserverInfo, + getKeyserversToRemoveFromNotifsStore, }; diff --git a/lib/selectors/keyserver-selectors.js b/lib/selectors/keyserver-selectors.js --- a/lib/selectors/keyserver-selectors.js +++ b/lib/selectors/keyserver-selectors.js @@ -1,5 +1,6 @@ // @flow +import _mapValues from 'lodash/fp/mapValues.js'; import _memoize from 'lodash/memoize.js'; import { createSelector } from 'reselect'; @@ -104,6 +105,13 @@ keyserverID: string, ) => (state: AppState) => ?ConnectionInfo = _memoize(baseConnectionSelector); +const allConnectionInfosSelector: (state: AppState) => { + +[keyserverID: string]: ?ConnectionInfo, +} = createSelector( + (state: AppState) => state.keyserverStore.keyserverInfos, + (infos: KeyserverInfos) => _mapValues(info => info.connection)(infos), +); + const baseLastCommunicatedPlatformDetailsSelector: ( keyserverID: string, ) => (state: AppState) => ?PlatformDetails = keyserverID => (state: AppState) => @@ -183,4 +191,5 @@ deviceTokenSelector, selectedKeyserversSelector, allUpdatesCurrentAsOfSelector, + allConnectionInfosSelector, }; diff --git a/lib/selectors/thread-selectors.js b/lib/selectors/thread-selectors.js --- a/lib/selectors/thread-selectors.js +++ b/lib/selectors/thread-selectors.js @@ -3,6 +3,7 @@ import _compact from 'lodash/fp/compact.js'; import _filter from 'lodash/fp/filter.js'; import _flow from 'lodash/fp/flow.js'; +import _groupBy from 'lodash/fp/groupBy.js'; import _map from 'lodash/fp/map.js'; import _mapValues from 'lodash/fp/mapValues.js'; import _orderBy from 'lodash/fp/orderBy.js'; @@ -18,6 +19,7 @@ } from './calendar-filter-selectors.js'; import { relativeMemberInfoSelectorForMembersOfThread } from './user-selectors.js'; import genesis from '../facts/genesis.js'; +import { extractKeyserverIDFromID } from '../keyserver-conn/keyserver-call-utils.js'; import { getAvatarForThread, getRandomDefaultEmojiAvatar, @@ -288,6 +290,31 @@ ).length, ); +const allUnreadCounts: (state: BaseAppState<>) => { + +[keyserverID: string]: number, +} = createSelector( + (state: BaseAppState<>) => state.threadStore.threadInfos, + (threadInfos: RawThreadInfos): { +[keyserverID: string]: number } => { + const keyserverToThreads = _groupBy(threadInfo => + extractKeyserverIDFromID(threadInfo.id), + )( + values(threadInfos).filter(threadInfo => + threadInHomeChatList(threadInfo), + ), + ); + + const keyserverUnreadCountPairs = Object.entries(keyserverToThreads).map( + ([keyserverID, keyserverThreadInfos]) => [ + keyserverID, + keyserverThreadInfos.filter(threadInfo => threadInfo.currentUser.unread) + .length, + ], + ); + + return Object.fromEntries(keyserverUnreadCountPairs); + }, +); + const unreadBackgroundCount: (state: BaseAppState<>) => number = createSelector( (state: BaseAppState<>) => state.threadStore.threadInfos, (threadInfos: RawThreadInfos): number => @@ -547,6 +574,7 @@ childThreadInfos, containedThreadInfos, unreadCount, + allUnreadCounts, unreadBackgroundCount, unreadCountSelectorForCommunity, otherUsersButNoOtherAdmins, diff --git a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java --- a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java +++ b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java @@ -18,6 +18,7 @@ import app.comm.android.ExpoUtils; import app.comm.android.MainActivity; import app.comm.android.R; +import app.comm.android.fbjni.CommMMKV; import app.comm.android.fbjni.CommSecureStore; import app.comm.android.fbjni.GlobalDBSingleton; import app.comm.android.fbjni.MessageOperationsUtilities; @@ -45,8 +46,15 @@ private static final String ENCRYPTION_FAILED_KEY = "encryptionFailed"; private static final String GROUP_NOTIF_IDS_KEY = "groupNotifIDs"; private static final String COLLAPSE_ID_KEY = "collapseKey"; + private static final String KEYSERVER_ID_KEY = "keyserverID"; private static final String CHANNEL_ID = "default"; private static final long[] VIBRATION_SPEC = {500, 500}; + + // Those and future MMKV-related constants should match + // similar constants in NotificationService.mm + private static final String MMKV_KEY_SEPARATOR = "."; + private static final String MMKV_KEYSERVER_PREFIX = "KEYSERVER"; + private static final String MMKV_UNREAD_COUNT_SUFFIX = "UNREAD_COUNT"; private Bitmap displayableNotificationLargeIcon; private NotificationManager notificationManager; private LocalBroadcastManager localBroadcastManager; @@ -108,18 +116,10 @@ handleNotificationRescind(message); } - String badge = message.getData().get(BADGE_KEY); - if (badge != null) { - try { - int badgeCount = Integer.parseInt(badge); - if (badgeCount > 0) { - ShortcutBadger.applyCount(this, badgeCount); - } else { - ShortcutBadger.removeCount(this); - } - } catch (NumberFormatException e) { - Log.w("COMM", "Invalid badge count", e); - } + try { + handleUnreadCountUpdate(message); + } catch (Exception e) { + Log.w("COMM", "Unread count update failure.", e); } String badgeOnly = message.getData().get(BADGE_ONLY_KEY); @@ -214,6 +214,55 @@ } } + private void handleUnreadCountUpdate(RemoteMessage message) { + String badge = message.getData().get(BADGE_KEY); + if (badge == null) { + return; + } + + if (message.getData().get(KEYSERVER_ID_KEY) == null) { + throw new RuntimeException("Received badge update without keyserver ID."); + } + String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); + String senderKeyserverUnreadCountKey = String.join( + MMKV_KEY_SEPARATOR, + MMKV_KEYSERVER_PREFIX, + senderKeyserverID, + MMKV_UNREAD_COUNT_SUFFIX); + + int senderKeyserverUnreadCount; + try { + senderKeyserverUnreadCount = Integer.parseInt(badge); + } catch (NumberFormatException e) { + Log.w("COMM", "Invalid badge count", e); + return; + } + CommMMKV.setInt(senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); + + int totalUnreadCount = 0; + String[] allKeys = CommMMKV.getAllKeys(); + for (String key : allKeys) { + + if (!key.startsWith(MMKV_KEYSERVER_PREFIX) || + !key.endsWith(MMKV_UNREAD_COUNT_SUFFIX)) { + continue; + } + + Integer unreadCount = CommMMKV.getInt(key, -1); + if (unreadCount == null) { + continue; + } + + totalUnreadCount += unreadCount; + } + + if (totalUnreadCount > 0) { + ShortcutBadger.applyCount(this, totalUnreadCount); + } else { + ShortcutBadger.removeCount(this); + } + } + private void addToThreadGroupAndDisplay( String notificationID, NotificationCompat.Builder notificationBuilder, diff --git a/native/ios/NotificationService/NotificationService.mm b/native/ios/NotificationService/NotificationService.mm --- a/native/ios/NotificationService/NotificationService.mm +++ b/native/ios/NotificationService/NotificationService.mm @@ -1,15 +1,25 @@ #import "NotificationService.h" +#import "CommMMKV.h" #import "Logger.h" #import "NotificationsCryptoModule.h" #import "StaffUtils.h" #import "TemporaryMessageStorage.h" #import +#include +#include NSString *const backgroundNotificationTypeKey = @"backgroundNotifType"; NSString *const messageInfosKey = @"messageInfos"; NSString *const encryptedPayloadKey = @"encryptedPayload"; NSString *const encryptionFailureKey = @"encryptionFailure"; NSString *const collapseIDKey = @"collapseID"; +NSString *const keyserverIDKey = @"keyserverID"; + +// Those and future MMKV-related constants should match +// similar constants in CommNotificationsHandler.java +const std::string mmkvKeySeparator = "."; +const std::string mmkvKeyserverPrefix = "KEYSERVER"; +const std::string mmkvUnreadCountSuffix = "UNREAD_COUNT"; const std::string callingProcessName = "NSE"; // The context for this constant can be found here: // https://linear.app/comm/issue/ENG-3074#comment-bd2f5e28 @@ -39,6 +49,18 @@ return memory_usage; } +std::string joinStrings( + const std::string &separator, + const std::vector &array) { + std::ostringstream joinedStream; + std::copy( + array.begin(), + array.end(), + std::ostream_iterator(joinedStream, separator.c_str())); + std::string joined = joinedStream.str(); + return joined.empty() ? joined : joined.substr(0, joined.size() - 1); +} + @interface NotificationService () @property(strong) NSMutableDictionary *contentHandlers; @@ -140,7 +162,30 @@ addObject:[NSString stringWithUTF8String:persistErrorMessage.c_str()]]; } - // Step 3: (optional) rescind read notifications + // Step 3: Cumulative unread count calculation + if (content.badge) { + std::string unreadCountCalculationError; + try { + @try { + [self calculateTotalUnreadCountInPlace:content]; + } @catch (NSException *e) { + unreadCountCalculationError = + "Obj-C exception: " + std::string([e.name UTF8String]) + + " during unread count calculation."; + } + } catch (const std::exception &e) { + unreadCountCalculationError = "C++ exception: " + std::string(e.what()) + + " during unread count calculation."; + } + + if (unreadCountCalculationError.size()) { + [errorMessages + addObject:[NSString stringWithUTF8String:unreadCountCalculationError + .c_str()]]; + } + } + + // Step 4: (optional) rescind read notifications // Message payload persistence is a higher priority task, so it has // to happen prior to potential notification center clearing. @@ -172,7 +217,7 @@ publicUserContent = [[UNNotificationContent alloc] init]; } - // Step 4: (optional) execute notification coalescing + // Step 5: (optional) execute notification coalescing if ([self isCollapsible:content.userInfo]) { std::string coalescingErrorMessage; try { @@ -202,7 +247,7 @@ } } - // Step 5: (optional) create empty notification that + // Step 6: (optional) create empty notification that // only provides badge count. if ([self needsSilentBadgeUpdate:content.userInfo]) { UNMutableNotificationContent *badgeOnlyContent = @@ -211,7 +256,7 @@ publicUserContent = badgeOnlyContent; } - // Step 5: notify main app that there is data + // Step 7: notify main app that there is data // to transfer to SQLite and redux. [self sendNewMessageInfosNotification]; @@ -415,6 +460,45 @@ [payload[backgroundNotificationTypeKey] isEqualToString:@"CLEAR"]; } +- (void)calculateTotalUnreadCountInPlace: + (UNMutableNotificationContent *)content { + if (!content.userInfo[keyserverIDKey]) { + throw std::runtime_error("Received badge update without keyserver ID."); + } + std::string senderKeyserverID = + std::string([content.userInfo[keyserverIDKey] UTF8String]); + + std::string senderKeyserverUnreadCountKey = joinStrings( + mmkvKeySeparator, + {mmkvKeyserverPrefix, senderKeyserverID, mmkvUnreadCountSuffix}); + + int senderKeyserverUnreadCount = [content.badge intValue]; + comm::CommMMKV::setInt( + senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); + + int totalUnreadCount = 0; + std::vector allKeys = comm::CommMMKV::getAllKeys(); + for (const auto &key : allKeys) { + if (key.size() < + mmkvKeyserverPrefix.size() + mmkvUnreadCountSuffix.size() || + key.compare(0, mmkvKeyserverPrefix.size(), mmkvKeyserverPrefix) || + key.compare( + key.size() - mmkvUnreadCountSuffix.size(), + mmkvUnreadCountSuffix.size(), + mmkvUnreadCountSuffix)) { + continue; + } + + std::optional unreadCount = comm::CommMMKV::getInt(key, -1); + if (!unreadCount.has_value()) { + continue; + } + totalUnreadCount += unreadCount.value(); + } + + content.badge = @(totalUnreadCount); +} + - (BOOL)needsSilentBadgeUpdate:(NSDictionary *)payload { // TODO: refactor this check by introducing // badgeOnly property in iOS notification payload diff --git a/native/push/push-handler.react.js b/native/push/push-handler.react.js --- a/native/push/push-handler.react.js +++ b/native/push/push-handler.react.js @@ -1,7 +1,7 @@ // @flow import * as Haptics from 'expo-haptics'; -import invariant from 'invariant'; +import _groupBy from 'lodash/fp/groupBy.js'; import * as React from 'react'; import { LogBox, Platform } from 'react-native'; import { Notification as InAppNotification } from 'react-native-in-app-message'; @@ -16,21 +16,22 @@ useSetDeviceTokenFanout, } from 'lib/actions/device-actions.js'; import { saveMessagesActionType } from 'lib/actions/message-actions.js'; +import { extractKeyserverIDFromID } from 'lib/keyserver-conn/keyserver-call-utils.js'; import { - connectionSelector, deviceTokensSelector, - updatesCurrentAsOfSelector, + allUpdatesCurrentAsOfSelector, + allConnectionInfosSelector, } from 'lib/selectors/keyserver-selectors.js'; import { threadInfoSelector, - unreadCount, + allUnreadCounts, } from 'lib/selectors/thread-selectors.js'; import { isLoggedIn } from 'lib/selectors/user-selectors.js'; import { mergePrefixIntoBody } from 'lib/shared/notif-utils.js'; import type { RawMessageInfo } from 'lib/types/message-types.js'; import type { ThreadInfo } from 'lib/types/minimally-encoded-thread-permissions-types.js'; import type { Dispatch } from 'lib/types/redux-types.js'; -import { type ConnectionInfo } from 'lib/types/socket-types.js'; +import type { ConnectionInfo } from 'lib/types/socket-types.js'; import type { GlobalTheme } from 'lib/types/theme-types.js'; import { convertNonPendingIDToNewSchema, @@ -78,6 +79,7 @@ addLifecycleListener, getCurrentLifecycleState, } from '../lifecycle/lifecycle.js'; +import { commCoreModule } from '../native-modules.js'; import { replaceWithThreadActionType } from '../navigation/action-types.js'; import { activeMessageListSelector } from '../navigation/nav-selectors.js'; import { NavContext } from '../navigation/navigation-context.js'; @@ -100,7 +102,8 @@ // Navigation state +activeThread: ?string, // Redux state - +unreadCount: number, + +unreadCount: { +[keyserverID: string]: number }, + +connection: { +[keyserverID: string]: ?ConnectionInfo }, +deviceTokens: { +[keyserverID: string]: ?string, }, @@ -108,8 +111,9 @@ +[id: string]: ThreadInfo, }, +notifPermissionAlertInfo: NotifPermissionAlertInfo, - +connection: ConnectionInfo, - +updatesCurrentAsOf: number, + +allUpdatesCurrentAsOf: { + +[keyserverID: string]: number, + }, +activeTheme: ?GlobalTheme, +loggedIn: boolean, +navigateToThread: (params: MessageListParams) => void, @@ -211,9 +215,7 @@ ); } - if (this.props.connection.status === 'connected') { - this.updateBadgeCount(); - } + void this.updateBadgeCount(); } componentWillUnmount() { @@ -267,14 +269,7 @@ if (this.props.activeThread !== prevProps.activeThread) { this.clearNotifsOfThread(); } - - if ( - this.props.connection.status === 'connected' && - (prevProps.connection.status !== 'connected' || - this.props.unreadCount !== prevProps.unreadCount) - ) { - this.updateBadgeCount(); - } + void this.updateBadgeCount(); for (const threadID of this.openThreadOnceReceived) { const threadInfo = this.props.threadInfos[threadID]; @@ -300,7 +295,7 @@ if (!this.props.loggedIn && prevProps.loggedIn) { this.clearAllNotifs(); - this.resetBadgeCount(); + void this.resetBadgeCount(); } if ( @@ -312,16 +307,81 @@ } } - updateBadgeCount() { - const curUnreadCount = this.props.unreadCount; + async updateBadgeCount() { + const curUnreadCounts = this.props.unreadCount; + const curConnections = this.props.connection; + + const notifStorageUpdates: Array<{ + +id: string, + +unreadCount: number, + }> = []; + const notifsStorageQueries: Array = []; + + for (const keyserverID in curUnreadCounts) { + if (curConnections[keyserverID]?.status !== 'connected') { + notifsStorageQueries.push(keyserverID); + continue; + } + + notifStorageUpdates.push({ + id: keyserverID, + unreadCount: curUnreadCounts[keyserverID], + }); + } + + let queriedKeyserverData: $ReadOnlyArray<{ + +id: string, + +unreadCount: number, + }> = []; + + try { + [queriedKeyserverData] = await Promise.all([ + commCoreModule.getKeyserverDataFromNotifStorage(notifsStorageQueries), + commCoreModule.updateKeyserverDataInNotifStorage(notifStorageUpdates), + ]); + } catch (e) { + if (__DEV__) { + Alert.alert( + 'MMKV error', + 'Failed to update keyserver data in MMKV.' + e.message, + ); + } + console.log(e); + return; + } + + let totalUnreadCount = 0; + for (const keyserverData of notifStorageUpdates) { + totalUnreadCount += keyserverData.unreadCount; + } + for (const keyserverData of queriedKeyserverData) { + totalUnreadCount += keyserverData.unreadCount; + } + if (Platform.OS === 'ios') { - CommIOSNotifications.setBadgesCount(curUnreadCount); + CommIOSNotifications.setBadgesCount(totalUnreadCount); } else if (Platform.OS === 'android') { - CommAndroidNotifications.setBadge(curUnreadCount); + CommAndroidNotifications.setBadge(totalUnreadCount); } } - resetBadgeCount() { + async resetBadgeCount() { + const keyserversDataToRemove = Object.keys(this.props.unreadCount); + try { + await commCoreModule.removeKeyserverDataFromNotifStorage( + keyserversDataToRemove, + ); + } catch (e) { + if (__DEV__) { + Alert.alert( + 'MMKV error', + 'Failed to remove keyserver from MMKV.' + e.message, + ); + } + console.log(e); + return; + } + if (Platform.OS === 'ios') { CommIOSNotifications.setBadgesCount(0); } else if (Platform.OS === 'android') { @@ -547,11 +607,23 @@ if (!rawMessageInfos) { return; } - const { updatesCurrentAsOf } = this.props; - this.props.dispatch({ - type: saveMessagesActionType, - payload: { rawMessageInfos, updatesCurrentAsOf }, - }); + + const keyserverIDToMessageInfos = _groupBy(messageInfos => + extractKeyserverIDFromID(messageInfos.threadID), + )(rawMessageInfos); + + for (const keyserverID in keyserverIDToMessageInfos) { + const updatesCurrentAsOf = this.props.allUpdatesCurrentAsOf[keyserverID]; + const messageInfos = keyserverIDToMessageInfos[keyserverID]; + if (!updatesCurrentAsOf) { + continue; + } + + this.props.dispatch({ + type: saveMessagesActionType, + payload: { rawMessageInfos: messageInfos, updatesCurrentAsOf }, + }); + } } iosForegroundNotificationReceived = ( @@ -667,9 +739,12 @@ const { messageInfos } = parsedMessage; this.saveMessageInfos(messageInfos); + const keyserverID = extractKeyserverIDFromID(message.threadID); + const updateCurrentAsOf = this.props.allUpdatesCurrentAsOf[keyserverID]; + handleAndroidMessage( parsedMessage, - this.props.updatesCurrentAsOf, + updateCurrentAsOf, this.handleAndroidNotificationIfActive, ); }; @@ -699,19 +774,14 @@ React.memo(function ConnectedPushHandler(props: BaseProps) { const navContext = React.useContext(NavContext); const activeThread = activeMessageListSelector(navContext); - const boundUnreadCount = useSelector(unreadCount); + const boundUnreadCount = useSelector(allUnreadCounts); + const boundConnection = useSelector(allConnectionInfosSelector); const deviceTokens = useSelector(deviceTokensSelector); const threadInfos = useSelector(threadInfoSelector); const notifPermissionAlertInfo = useSelector( state => state.notifPermissionAlertInfo, ); - const connection = useSelector( - connectionSelector(authoritativeKeyserverID), - ); - invariant(connection, 'keyserver missing from keyserverStore'); - const updatesCurrentAsOf = useSelector( - updatesCurrentAsOfSelector(authoritativeKeyserverID), - ); + const allUpdatesCurrentAsOf = useSelector(allUpdatesCurrentAsOfSelector); const activeTheme = useSelector(state => state.globalThemeInfo.activeTheme); const loggedIn = useSelector(isLoggedIn); const navigateToThread = useNavigateToThread(); @@ -725,11 +795,11 @@ {...props} activeThread={activeThread} unreadCount={boundUnreadCount} + connection={boundConnection} deviceTokens={deviceTokens} threadInfos={threadInfos} notifPermissionAlertInfo={notifPermissionAlertInfo} - connection={connection} - updatesCurrentAsOf={updatesCurrentAsOf} + allUpdatesCurrentAsOf={allUpdatesCurrentAsOf} activeTheme={activeTheme} loggedIn={loggedIn} navigateToThread={navigateToThread} diff --git a/native/redux/redux-utils.js b/native/redux/redux-utils.js --- a/native/redux/redux-utils.js +++ b/native/redux/redux-utils.js @@ -3,7 +3,10 @@ import { useSelector as reactReduxUseSelector } from 'react-redux'; import { communityStoreOpsHandlers } from 'lib/ops/community-store-ops.js'; -import { keyserverStoreOpsHandlers } from 'lib/ops/keyserver-store-ops.js'; +import { + keyserverStoreOpsHandlers, + getKeyserversToRemoveFromNotifsStore, +} from 'lib/ops/keyserver-store-ops.js'; import { messageStoreOpsHandlers } from 'lib/ops/message-store-ops.js'; import { reportStoreOpsHandlers } from 'lib/ops/report-store-ops.js'; import { threadStoreOpsHandlers } from 'lib/ops/thread-store-ops.js'; @@ -46,6 +49,8 @@ keyserverStoreOpsHandlers.convertOpsToClientDBOps(keyserverStoreOperations); const convertedCommunityStoreOperations = communityStoreOpsHandlers.convertOpsToClientDBOps(communityStoreOperations); + const keyserversToRemoveFromNotifsStore = + getKeyserversToRemoveFromNotifsStore(keyserverStoreOperations); try { const promises = []; @@ -94,6 +99,13 @@ ), ); } + if (keyserversToRemoveFromNotifsStore.length > 0) { + promises.push( + commCoreModule.removeKeyserverDataFromNotifStorage( + keyserversToRemoveFromNotifsStore, + ), + ); + } await Promise.all(promises); } catch (e) { if (isTaskCancelledError(e)) {