diff --git a/lib/ops/keyserver-store-ops.js b/lib/ops/keyserver-store-ops.js --- a/lib/ops/keyserver-store-ops.js +++ b/lib/ops/keyserver-store-ops.js @@ -130,7 +130,23 @@ }, }; +function getKeyserversToRemoveFromNotifsStore( + ops: $ReadOnlyArray, +): $ReadOnlyArray { + const removeKeyserversOperations: Array = []; + for (const op of ops) { + if (op.type === 'remove_keyservers') { + removeKeyserversOperations.push(op); + } + } + + return removeKeyserversOperations + .map(operation => operation.payload.ids) + .flat(); +} + export { keyserverStoreOpsHandlers, convertKeyserverInfoToClientDBKeyserverInfo, + getKeyserversToRemoveFromNotifsStore, }; diff --git a/lib/selectors/thread-selectors.js b/lib/selectors/thread-selectors.js --- a/lib/selectors/thread-selectors.js +++ b/lib/selectors/thread-selectors.js @@ -3,6 +3,7 @@ import _compact from 'lodash/fp/compact.js'; import _filter from 'lodash/fp/filter.js'; import _flow from 'lodash/fp/flow.js'; +import _groupBy from 'lodash/fp/groupBy.js'; import _map from 'lodash/fp/map.js'; import _mapValues from 'lodash/fp/mapValues.js'; import _orderBy from 'lodash/fp/orderBy.js'; @@ -18,6 +19,7 @@ } from './calendar-filter-selectors.js'; import { relativeMemberInfoSelectorForMembersOfThread } from './user-selectors.js'; import genesis from '../facts/genesis.js'; +import { extractKeyserverIDFromID } from '../keyserver-conn/keyserver-call-utils.js'; import { getAvatarForThread, getRandomDefaultEmojiAvatar, @@ -288,6 +290,31 @@ ).length, ); +const extendedUnreadCount: (state: BaseAppState<>) => { + +[keyserverID: string]: number, +} = createSelector( + (state: BaseAppState<>) => state.threadStore.threadInfos, + (threadInfos: RawThreadInfos): { +[keyserverID: string]: number } => { + const keyserverToThreads = _groupBy(threadInfo => + extractKeyserverIDFromID(threadInfo.id), + )( + values(threadInfos).filter(threadInfo => + threadInHomeChatList(threadInfo), + ), + ); + + const keyserverUnreadCountPairs = Object.entries(keyserverToThreads).map( + ([keyserverID, keyserverThreadInfos]) => [ + keyserverID, + keyserverThreadInfos.filter(threadInfo => threadInfo.currentUser.unread) + .length, + ], + ); + + return Object.fromEntries(keyserverUnreadCountPairs); + }, +); + const unreadBackgroundCount: (state: BaseAppState<>) => number = createSelector( (state: BaseAppState<>) => state.threadStore.threadInfos, (threadInfos: RawThreadInfos): number => @@ -547,6 +574,7 @@ childThreadInfos, containedThreadInfos, unreadCount, + extendedUnreadCount, unreadBackgroundCount, unreadCountSelectorForCommunity, otherUsersButNoOtherAdmins, diff --git a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java --- a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java +++ b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java @@ -18,6 +18,7 @@ import app.comm.android.ExpoUtils; import app.comm.android.MainActivity; import app.comm.android.R; +import app.comm.android.fbjni.CommMMKV; import app.comm.android.fbjni.CommSecureStore; import app.comm.android.fbjni.GlobalDBSingleton; import app.comm.android.fbjni.MessageOperationsUtilities; @@ -45,8 +46,13 @@ private static final String ENCRYPTION_FAILED_KEY = "encryptionFailed"; private static final String GROUP_NOTIF_IDS_KEY = "groupNotifIDs"; private static final String COLLAPSE_ID_KEY = "collapseKey"; + private static final String KEYSERVER_ID_KEY = "keyserverID"; private static final String CHANNEL_ID = "default"; private static final long[] VIBRATION_SPEC = {500, 500}; + // Introduced temporarily + private static final String ASHOAT_KEYSERVER_ID = "256"; + private static final String MMKV_KEYSERVER_PREFIX = "KEYSERVER."; + private static final String MMKV_UNREAD_COUNT_SUFFIX = ".UNREAD_COUNT"; private Bitmap displayableNotificationLargeIcon; private NotificationManager notificationManager; private LocalBroadcastManager localBroadcastManager; @@ -108,18 +114,10 @@ handleNotificationRescind(message); } - String badge = message.getData().get(BADGE_KEY); - if (badge != null) { - try { - int badgeCount = Integer.parseInt(badge); - if (badgeCount > 0) { - ShortcutBadger.applyCount(this, badgeCount); - } else { - ShortcutBadger.removeCount(this); - } - } catch (NumberFormatException e) { - Log.w("COMM", "Invalid badge count", e); - } + try { + handleUnreadCountUpdate(message); + } catch (Exception e) { + Log.w("COMM", "Unread count update failure.", e); } String badgeOnly = message.getData().get(BADGE_ONLY_KEY); @@ -214,6 +212,55 @@ } } + private void handleUnreadCountUpdate(RemoteMessage message) { + String badge = message.getData().get(BADGE_KEY); + if (badge == null) { + return; + } + + String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); + if (senderKeyserverID == null) { + senderKeyserverID = ASHOAT_KEYSERVER_ID; + } + String senderKeyserverUnreadCountKey = + MMKV_KEYSERVER_PREFIX + senderKeyserverID + MMKV_UNREAD_COUNT_SUFFIX; + + int senderKeyserverUnreadCount; + try { + senderKeyserverUnreadCount = Integer.parseInt(badge); + } catch (NumberFormatException e) { + Log.w("COMM", "Invalid badge count", e); + return; + } + CommMMKV.setInt(senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); + + int totalUnreadCount = senderKeyserverUnreadCount; + String[] allKeys = CommMMKV.getAllKeys(); + for (String key : allKeys) { + if (key.equals(senderKeyserverUnreadCountKey)) { + continue; + } + + if (!key.startsWith(MMKV_KEYSERVER_PREFIX) || + !key.endsWith(MMKV_UNREAD_COUNT_SUFFIX)) { + continue; + } + + Integer unreadCount = CommMMKV.getInt(key, -1); + if (unreadCount == null) { + continue; + } + + totalUnreadCount += unreadCount; + } + + if (totalUnreadCount > 0) { + ShortcutBadger.applyCount(this, totalUnreadCount); + } else { + ShortcutBadger.removeCount(this); + } + } + private void addToThreadGroupAndDisplay( String notificationID, NotificationCompat.Builder notificationBuilder, diff --git a/native/ios/NotificationService/NotificationService.mm b/native/ios/NotificationService/NotificationService.mm --- a/native/ios/NotificationService/NotificationService.mm +++ b/native/ios/NotificationService/NotificationService.mm @@ -1,4 +1,5 @@ #import "NotificationService.h" +#import "CommMMKV.h" #import "Logger.h" #import "NotificationsCryptoModule.h" #import "StaffUtils.h" @@ -10,6 +11,7 @@ NSString *const encryptedPayloadKey = @"encryptedPayload"; NSString *const encryptionFailureKey = @"encryptionFailure"; NSString *const collapseIDKey = @"collapseID"; +NSString *const keyserverIDKey = @"keyserverID"; const std::string callingProcessName = "NSE"; // The context for this constant can be found here: // https://linear.app/comm/issue/ENG-3074#comment-bd2f5e28 @@ -20,6 +22,8 @@ CFStringRef newMessageInfosDarwinNotification = CFSTR("app.comm.darwin_new_message_infos"); +// Introduced temporarily +const std::string ashoatKeyserverID = "256"; // Implementation below was inspired by the // following discussion with Apple staff member: @@ -140,7 +144,30 @@ addObject:[NSString stringWithUTF8String:persistErrorMessage.c_str()]]; } - // Step 3: (optional) rescind read notifications + // Step 3: Cumulative unread count calculation + if (content.badge) { + std::string unreadCountCalculationError; + try { + @try { + [self calculateTotalUnreadCountInPlace:content]; + } @catch (NSException *e) { + unreadCountCalculationError = + "Obj-C exception: " + std::string([e.name UTF8String]) + + " during unread count calculation."; + } + } catch (const std::exception &e) { + unreadCountCalculationError = "C++ exception: " + std::string(e.what()) + + " during unread count calculation."; + } + + if (unreadCountCalculationError.size()) { + [errorMessages + addObject:[NSString stringWithUTF8String:unreadCountCalculationError + .c_str()]]; + } + } + + // Step 4: (optional) rescind read notifications // Message payload persistence is a higher priority task, so it has // to happen prior to potential notification center clearing. @@ -172,7 +199,7 @@ publicUserContent = [[UNNotificationContent alloc] init]; } - // Step 4: (optional) execute notification coalescing + // Step 5: (optional) execute notification coalescing if ([self isCollapsible:content.userInfo]) { std::string coalescingErrorMessage; try { @@ -202,7 +229,7 @@ } } - // Step 5: (optional) create empty notification that + // Step 6: (optional) create empty notification that // only provides badge count. if ([self needsSilentBadgeUpdate:content.userInfo]) { UNMutableNotificationContent *badgeOnlyContent = @@ -211,7 +238,7 @@ publicUserContent = badgeOnlyContent; } - // Step 5: notify main app that there is data + // Step 7: notify main app that there is data // to transfer to SQLite and redux. [self sendNewMessageInfosNotification]; @@ -415,6 +442,49 @@ [payload[backgroundNotificationTypeKey] isEqualToString:@"CLEAR"]; } +- (void)calculateTotalUnreadCountInPlace: + (UNMutableNotificationContent *)content { + std::string senderKeyserverID = ashoatKeyserverID; + if (content.userInfo[keyserverIDKey]) { + senderKeyserverID = + std::string([content.userInfo[keyserverIDKey] UTF8String]); + } + + static const std::string keyserverPrefix = "KEYSERVER."; + static const std::string unreadCountSuffix = ".UNREAD_COUNT"; + std::string senderKeyserverUnreadCountKey = + keyserverPrefix + senderKeyserverID + unreadCountSuffix; + + int senderKeyserverUnreadCount = [content.badge intValue]; + comm::CommMMKV::setInt( + senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); + + int totalUnreadCount = senderKeyserverUnreadCount; + std::vector allKeys = comm::CommMMKV::getAllKeys(); + for (const auto &key : allKeys) { + if (key == senderKeyserverUnreadCountKey) { + continue; + } + + if (key.size() < keyserverPrefix.size() + unreadCountSuffix.size() || + key.compare(0, keyserverPrefix.size(), keyserverPrefix) || + key.compare( + key.size() - unreadCountSuffix.size(), + unreadCountSuffix.size(), + unreadCountSuffix)) { + continue; + } + + std::optional unreadCount = comm::CommMMKV::getInt(key, -1); + if (!unreadCount.has_value()) { + continue; + } + totalUnreadCount += unreadCount.value(); + } + + content.badge = @(totalUnreadCount); +} + - (BOOL)needsSilentBadgeUpdate:(NSDictionary *)payload { // TODO: refactor this check by introducing // badgeOnly property in iOS notification payload diff --git a/native/push/push-handler.react.js b/native/push/push-handler.react.js --- a/native/push/push-handler.react.js +++ b/native/push/push-handler.react.js @@ -1,7 +1,6 @@ // @flow import * as Haptics from 'expo-haptics'; -import invariant from 'invariant'; import * as React from 'react'; import { LogBox, Platform } from 'react-native'; import { Notification as InAppNotification } from 'react-native-in-app-message'; @@ -17,20 +16,18 @@ } from 'lib/actions/device-actions.js'; import { saveMessagesActionType } from 'lib/actions/message-actions.js'; import { - connectionSelector, deviceTokensSelector, updatesCurrentAsOfSelector, } from 'lib/selectors/keyserver-selectors.js'; import { threadInfoSelector, - unreadCount, + extendedUnreadCount, } from 'lib/selectors/thread-selectors.js'; import { isLoggedIn } from 'lib/selectors/user-selectors.js'; import { mergePrefixIntoBody } from 'lib/shared/notif-utils.js'; import type { RawMessageInfo } from 'lib/types/message-types.js'; import type { ThreadInfo } from 'lib/types/minimally-encoded-thread-permissions-types.js'; import type { Dispatch } from 'lib/types/redux-types.js'; -import { type ConnectionInfo } from 'lib/types/socket-types.js'; import type { GlobalTheme } from 'lib/types/theme-types.js'; import { convertNonPendingIDToNewSchema, @@ -78,6 +75,7 @@ addLifecycleListener, getCurrentLifecycleState, } from '../lifecycle/lifecycle.js'; +import { commCoreModule } from '../native-modules.js'; import { replaceWithThreadActionType } from '../navigation/action-types.js'; import { activeMessageListSelector } from '../navigation/nav-selectors.js'; import { NavContext } from '../navigation/navigation-context.js'; @@ -100,7 +98,7 @@ // Navigation state +activeThread: ?string, // Redux state - +unreadCount: number, + +unreadCount: { +[keyserverID: string]: number }, +deviceTokens: { +[keyserverID: string]: ?string, }, @@ -108,7 +106,6 @@ +[id: string]: ThreadInfo, }, +notifPermissionAlertInfo: NotifPermissionAlertInfo, - +connection: ConnectionInfo, +updatesCurrentAsOf: number, +activeTheme: ?GlobalTheme, +loggedIn: boolean, @@ -211,9 +208,7 @@ ); } - if (this.props.connection.status === 'connected') { - this.updateBadgeCount(); - } + void this.updateBadgeCount(); } componentWillUnmount() { @@ -267,14 +262,7 @@ if (this.props.activeThread !== prevProps.activeThread) { this.clearNotifsOfThread(); } - - if ( - this.props.connection.status === 'connected' && - (prevProps.connection.status !== 'connected' || - this.props.unreadCount !== prevProps.unreadCount) - ) { - this.updateBadgeCount(); - } + void this.updateBadgeCount(); for (const threadID of this.openThreadOnceReceived) { const threadInfo = this.props.threadInfos[threadID]; @@ -312,12 +300,41 @@ } } - updateBadgeCount() { - const curUnreadCount = this.props.unreadCount; + async updateBadgeCount() { + const curUnreadCounts = this.props.unreadCount; + + let totalUnreadCount = 0; + const notifStorageUnreadCounts: Array<{ + +id: string, + unreadCount: number, + }> = []; + + for (const keyserverID in curUnreadCounts) { + totalUnreadCount += curUnreadCounts[keyserverID]; + notifStorageUnreadCounts.push({ + id: keyserverID, + unreadCount: curUnreadCounts[keyserverID], + }); + } + if (Platform.OS === 'ios') { - CommIOSNotifications.setBadgesCount(curUnreadCount); + CommIOSNotifications.setBadgesCount(totalUnreadCount); } else if (Platform.OS === 'android') { - CommAndroidNotifications.setBadge(curUnreadCount); + CommAndroidNotifications.setBadge(totalUnreadCount); + } + + try { + await commCoreModule.updateKeyserverDataInNotifStorage( + notifStorageUnreadCounts, + ); + } catch (e) { + if (__DEV__) { + Alert.alert( + 'MMKV error', + 'Failed to update keyserver data in MMKV.' + e.message, + ); + } + console.log(e); } } @@ -699,14 +716,12 @@ React.memo(function ConnectedPushHandler(props: BaseProps) { const navContext = React.useContext(NavContext); const activeThread = activeMessageListSelector(navContext); - const boundUnreadCount = useSelector(unreadCount); + const boundUnreadCount = useSelector(extendedUnreadCount); const deviceTokens = useSelector(deviceTokensSelector); const threadInfos = useSelector(threadInfoSelector); const notifPermissionAlertInfo = useSelector( state => state.notifPermissionAlertInfo, ); - const connection = useSelector(connectionSelector(ashoatKeyserverID)); - invariant(connection, 'keyserver missing from keyserverStore'); const updatesCurrentAsOf = useSelector( updatesCurrentAsOfSelector(ashoatKeyserverID), ); @@ -726,7 +741,6 @@ deviceTokens={deviceTokens} threadInfos={threadInfos} notifPermissionAlertInfo={notifPermissionAlertInfo} - connection={connection} updatesCurrentAsOf={updatesCurrentAsOf} activeTheme={activeTheme} loggedIn={loggedIn} diff --git a/native/redux/redux-utils.js b/native/redux/redux-utils.js --- a/native/redux/redux-utils.js +++ b/native/redux/redux-utils.js @@ -2,7 +2,10 @@ import { useSelector as reactReduxUseSelector } from 'react-redux'; -import { keyserverStoreOpsHandlers } from 'lib/ops/keyserver-store-ops.js'; +import { + keyserverStoreOpsHandlers, + getKeyserversToRemoveFromNotifsStore, +} from 'lib/ops/keyserver-store-ops.js'; import { messageStoreOpsHandlers } from 'lib/ops/message-store-ops.js'; import { reportStoreOpsHandlers } from 'lib/ops/report-store-ops.js'; import { threadStoreOpsHandlers } from 'lib/ops/thread-store-ops.js'; @@ -42,6 +45,8 @@ userStoreOpsHandlers.convertOpsToClientDBOps(userStoreOperations); const convertedKeyserverStoreOperations = keyserverStoreOpsHandlers.convertOpsToClientDBOps(keyserverStoreOperations); + const keyserversToRemoveFromNotifsStore = + getKeyserversToRemoveFromNotifsStore(keyserverStoreOperations); try { const promises = []; @@ -83,6 +88,13 @@ ), ); } + if (keyserversToRemoveFromNotifsStore.length > 0) { + promises.push( + commCoreModule.removeKeyserverDataFromNotifStorage( + keyserversToRemoveFromNotifsStore, + ), + ); + } await Promise.all(promises); } catch (e) { if (isTaskCancelledError(e)) {