diff --git a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp index 0b9508637..daf9e7cdb 100644 --- a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp +++ b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp @@ -1,28 +1,40 @@ #include #include namespace comm { int NotificationsCryptoModuleJNIHelper::olmEncryptedTypeMessage( facebook::jni::alias_ref jThis) { return NotificationsCryptoModule::olmEncryptedTypeMessage; } std::string NotificationsCryptoModuleJNIHelper::decrypt( facebook::jni::alias_ref jThis, std::string keyserverID, std::string data, int messageType) { std::string decryptedData = NotificationsCryptoModule::decrypt(keyserverID, data, messageType); return decryptedData; } +std::string NotificationsCryptoModuleJNIHelper::peerDecrypt( + facebook::jni::alias_ref jThis, + std::string deviceID, + std::string data, + int messageType) { + std::string decryptedData = + NotificationsCryptoModule::peerDecrypt(deviceID, data, messageType); + return decryptedData; +} + void NotificationsCryptoModuleJNIHelper::registerNatives() { javaClassStatic()->registerNatives({ makeNativeMethod( "olmEncryptedTypeMessage", NotificationsCryptoModuleJNIHelper::olmEncryptedTypeMessage), makeNativeMethod("decrypt", NotificationsCryptoModuleJNIHelper::decrypt), + makeNativeMethod( + "peerDecrypt", NotificationsCryptoModuleJNIHelper::peerDecrypt), }); } } // namespace comm diff --git a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java index ca84ee63e..02a17089f 100644 --- a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java +++ b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java @@ -1,7 +1,9 @@ package app.comm.android.fbjni; public class NotificationsCryptoModule { public static native int olmEncryptedTypeMessage(); public static native String decrypt(String keyserverID, String data, int messageType); + public static native String + peerDecrypt(String deviceID, String data, int messageType); } diff --git a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java index 554b18532..dbed6462d 100644 --- a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java +++ b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java @@ -1,657 +1,673 @@ package app.comm.android.notifications; import android.app.Notification; import android.app.NotificationManager; import android.app.PendingIntent; import android.content.Context; import android.content.Intent; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.service.notification.StatusBarNotification; import android.util.JsonReader; import android.util.Log; import androidx.core.app.NotificationCompat; import androidx.lifecycle.Lifecycle; import androidx.lifecycle.ProcessLifecycleOwner; import androidx.localbroadcastmanager.content.LocalBroadcastManager; import app.comm.android.ExpoUtils; import app.comm.android.MainActivity; import app.comm.android.R; import app.comm.android.aescrypto.AESCryptoModuleCompat; import app.comm.android.commservices.CommAndroidServicesClient; import app.comm.android.fbjni.CommMMKV; import app.comm.android.fbjni.CommSecureStore; import app.comm.android.fbjni.GlobalDBSingleton; import app.comm.android.fbjni.MessageOperationsUtilities; import app.comm.android.fbjni.NetworkModule; import app.comm.android.fbjni.NotificationsCryptoModule; import app.comm.android.fbjni.StaffUtils; import app.comm.android.fbjni.ThreadOperations; import com.google.firebase.messaging.FirebaseMessagingService; import com.google.firebase.messaging.RemoteMessage; import java.io.File; import java.io.IOException; import java.lang.OutOfMemoryError; import java.lang.StringBuilder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Map; import me.leolin.shortcutbadger.ShortcutBadger; import org.json.JSONException; import org.json.JSONObject; public class CommNotificationsHandler extends FirebaseMessagingService { private static final String BADGE_KEY = "badge"; private static final String BADGE_ONLY_KEY = "badgeOnly"; private static final String SET_UNREAD_STATUS_KEY = "setUnreadStatus"; private static final String NOTIF_ID_KEY = "id"; private static final String ENCRYPTED_PAYLOAD_KEY = "encryptedPayload"; private static final String ENCRYPTION_FAILED_KEY = "encryptionFailed"; private static final String BLOB_HASH_KEY = "blobHash"; private static final String BLOB_HOLDER_KEY = "blobHolder"; private static final String AES_ENCRYPTION_KEY_LABEL = "encryptionKey"; private static final String GROUP_NOTIF_IDS_KEY = "groupNotifIDs"; private static final String COLLAPSE_ID_KEY = "collapseKey"; private static final String KEYSERVER_ID_KEY = "keyserverID"; + private static final String SENDER_DEVICE_ID_KEY = "senderDeviceID"; + private static final String MESSAGE_TYPE_KEY = "type"; private static final String CHANNEL_ID = "default"; private static final long[] VIBRATION_SPEC = {500, 500}; private static final Map NOTIF_PRIORITY_VERBOSE = Map.of(0, "UNKNOWN", 1, "HIGH", 2, "NORMAL"); // Those and future MMKV-related constants should match // similar constants in NotificationService.mm private static final String MMKV_KEY_SEPARATOR = "."; private static final String MMKV_KEYSERVER_PREFIX = "KEYSERVER"; private static final String MMKV_UNREAD_COUNT_SUFFIX = "UNREAD_COUNT"; private Bitmap displayableNotificationLargeIcon; private NotificationManager notificationManager; private LocalBroadcastManager localBroadcastManager; private AESCryptoModuleCompat aesCryptoModule; public static final String RESCIND_KEY = "rescind"; public static final String RESCIND_ID_KEY = "rescindID"; public static final String TITLE_KEY = "title"; public static final String PREFIX_KEY = "prefix"; public static final String BODY_KEY = "body"; public static final String MESSAGE_INFOS_KEY = "messageInfos"; public static final String THREAD_ID_KEY = "threadID"; public static final String TOKEN_EVENT = "TOKEN_EVENT"; public static final String MESSAGE_EVENT = "MESSAGE_EVENT"; @Override public void onCreate() { super.onCreate(); CommSecureStore.getInstance().initialize( ExpoUtils.createExpoSecureStoreSupplier(this.getApplicationContext())); notificationManager = (NotificationManager)this.getSystemService( Context.NOTIFICATION_SERVICE); localBroadcastManager = LocalBroadcastManager.getInstance(this); displayableNotificationLargeIcon = BitmapFactory.decodeResource( this.getApplicationContext().getResources(), R.mipmap.ic_launcher); aesCryptoModule = new AESCryptoModuleCompat(); } @Override public void onNewToken(String token) { Intent intent = new Intent(TOKEN_EVENT); intent.putExtra("token", token); localBroadcastManager.sendBroadcast(intent); } @Override public void onMessageReceived(RemoteMessage message) { handleAlteredNotificationPriority(message); if (StaffUtils.isStaffRelease() && - message.getData().get(KEYSERVER_ID_KEY) == null) { + message.getData().get(KEYSERVER_ID_KEY) == null && + message.getData().get(SENDER_DEVICE_ID_KEY) == null) { displayErrorMessageNotification( - "Received notification without keyserver ID.", + "Received notification without keyserver ID nor sender device ID", "Missing keyserver ID.", null); return; } - String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); - if (message.getData().get(ENCRYPTED_PAYLOAD_KEY) != null) { try { - message = this.olmDecryptRemoteMessage(message, senderKeyserverID); + message = this.olmDecryptRemoteMessage(message); } catch (JSONException e) { Log.w("COMM", "Malformed notification JSON payload.", e); return; } catch (IllegalStateException e) { Log.w("COMM", "Android notification type violation.", e); return; } catch (Exception e) { Log.w("COMM", "Notification decryption failure.", e); return; } } if (StaffUtils.isStaffRelease() && "1".equals(message.getData().get(ENCRYPTION_FAILED_KEY))) { displayErrorMessageNotification( "Notification encryption failed on the keyserver. Please investigate", "Unencrypted notification", null); } + if ("1".equals(message.getData().get(ENCRYPTION_FAILED_KEY))) { Log.w("COMM", "Received erroneously unencrypted notification."); } String rescind = message.getData().get(RESCIND_KEY); if ("true".equals(rescind) && android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.M) { handleNotificationRescind(message); } try { handleUnreadCountUpdate(message); } catch (Exception e) { Log.w("COMM", "Unread count update failure.", e); } String badgeOnly = message.getData().get(BADGE_ONLY_KEY); if ("1".equals(badgeOnly)) { return; } if (message.getData().get(MESSAGE_INFOS_KEY) != null) { handleMessageInfosPersistence(message); } if (message.getData().get(BLOB_HASH_KEY) != null && message.getData().get(AES_ENCRYPTION_KEY_LABEL) != null && message.getData().get(BLOB_HOLDER_KEY) != null) { handleLargeNotification(message); } Intent intent = new Intent(MESSAGE_EVENT); intent.putExtra( "message", serializeMessageDataForIntentAttachment(message)); localBroadcastManager.sendBroadcast(intent); if (this.isAppInForeground()) { return; } this.displayNotification(message); } private void handleAlteredNotificationPriority(RemoteMessage message) { if (!StaffUtils.isStaffRelease()) { return; } int originalPriority = message.getOriginalPriority(); int priority = message.getPriority(); String priorityName = NOTIF_PRIORITY_VERBOSE.get(priority); String originalPriorityName = NOTIF_PRIORITY_VERBOSE.get(originalPriority); if (priorityName == null || originalPriorityName == null) { // Technically this will never happen as // it would violate FCM documentation return; } if (priority != originalPriority) { displayErrorMessageNotification( "System changed notification priority from " + priorityName + " to " + originalPriorityName, "Notification priority altered.", null); } } private boolean isAppInForeground() { return ProcessLifecycleOwner.get().getLifecycle().getCurrentState() == Lifecycle.State.RESUMED; } private boolean notificationGroupingSupported() { // Comm doesn't support notification grouping for clients running // Android versions older than 23 return android.os.Build.VERSION.SDK_INT > android.os.Build.VERSION_CODES.M; } private void handleNotificationRescind(RemoteMessage message) { String setUnreadStatus = message.getData().get(SET_UNREAD_STATUS_KEY); String threadID = message.getData().get(THREAD_ID_KEY); if ("true".equals(setUnreadStatus)) { File sqliteFile = this.getApplicationContext().getDatabasePath("comm.sqlite"); if (sqliteFile.exists()) { GlobalDBSingleton.scheduleOrRun(() -> { ThreadOperations.updateSQLiteUnreadStatus( sqliteFile.getPath(), threadID, false); }); } else { Log.w( "COMM", "Database not existing yet. Skipping thread status update."); } } String rescindID = message.getData().get(RESCIND_ID_KEY); boolean groupSummaryPresent = false; boolean threadGroupPresent = false; for (StatusBarNotification notification : notificationManager.getActiveNotifications()) { String tag = notification.getTag(); boolean isGroupMember = threadID.equals(notification.getNotification().getGroup()); boolean isGroupSummary = (notification.getNotification().flags & Notification.FLAG_GROUP_SUMMARY) == Notification.FLAG_GROUP_SUMMARY; if (tag != null && tag.equals(rescindID)) { notificationManager.cancel(notification.getTag(), notification.getId()); } else if ( isGroupMember && isGroupSummary && StaffUtils.isStaffRelease()) { groupSummaryPresent = true; removeNotificationFromGroupSummary(threadID, rescindID, notification); } else if (isGroupMember && isGroupSummary) { groupSummaryPresent = true; } else if (isGroupMember) { threadGroupPresent = true; } else if (isGroupSummary && StaffUtils.isStaffRelease()) { checkForUnmatchedRescind(threadID, rescindID, notification); } } if (groupSummaryPresent && !threadGroupPresent) { notificationManager.cancel(threadID, threadID.hashCode()); } } private void handleUnreadCountUpdate(RemoteMessage message) { + if (message.getData().get(KEYSERVER_ID_KEY) == null) { + return; + } + String badge = message.getData().get(BADGE_KEY); if (badge == null) { return; } - if (message.getData().get(KEYSERVER_ID_KEY) == null) { - throw new RuntimeException("Received badge update without keyserver ID."); - } String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); String senderKeyserverUnreadCountKey = String.join( MMKV_KEY_SEPARATOR, MMKV_KEYSERVER_PREFIX, senderKeyserverID, MMKV_UNREAD_COUNT_SUFFIX); int senderKeyserverUnreadCount; try { senderKeyserverUnreadCount = Integer.parseInt(badge); } catch (NumberFormatException e) { Log.w("COMM", "Invalid badge count", e); return; } CommMMKV.setInt(senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); int totalUnreadCount = 0; String[] allKeys = CommMMKV.getAllKeys(); for (String key : allKeys) { if (!key.startsWith(MMKV_KEYSERVER_PREFIX) || !key.endsWith(MMKV_UNREAD_COUNT_SUFFIX)) { continue; } Integer unreadCount = CommMMKV.getInt(key, -1); if (unreadCount == null) { continue; } totalUnreadCount += unreadCount; } if (totalUnreadCount > 0) { ShortcutBadger.applyCount(this, totalUnreadCount); } else { ShortcutBadger.removeCount(this); } } private void handleMessageInfosPersistence(RemoteMessage message) { String rawMessageInfosString = message.getData().get(MESSAGE_INFOS_KEY); File sqliteFile = this.getApplicationContext().getDatabasePath("comm.sqlite"); if (rawMessageInfosString != null && sqliteFile.exists()) { GlobalDBSingleton.scheduleOrRun(() -> { MessageOperationsUtilities.storeMessageInfos( sqliteFile.getPath(), rawMessageInfosString); }); } else if (rawMessageInfosString != null) { Log.w("COMM", "Database not existing yet. Skipping notification"); } } private void handleLargeNotification(RemoteMessage message) { String blobHash = message.getData().get(BLOB_HASH_KEY); String blobHolder = message.getData().get(BLOB_HOLDER_KEY); try { byte[] largePayload = CommAndroidServicesClient.getInstance().getBlobSync(blobHash); message = aesDecryptRemoteMessage(message, largePayload); handleMessageInfosPersistence(message); } catch (Exception e) { Log.w("COMM", "Failure when handling large notification.", e); } CommAndroidServicesClient.getInstance().scheduleDeferredBlobDeletion( blobHash, blobHolder, this.getApplicationContext()); } private void addToThreadGroupAndDisplay( String notificationID, NotificationCompat.Builder notificationBuilder, String threadID) { notificationBuilder.setGroup(threadID).setGroupAlertBehavior( NotificationCompat.GROUP_ALERT_CHILDREN); NotificationCompat.Builder groupSummaryNotificationBuilder = new NotificationCompat.Builder(this.getApplicationContext()) .setChannelId(CHANNEL_ID) .setSmallIcon(R.drawable.notif_icon) .setContentIntent( this.createStartMainActivityAction(threadID, threadID)) .setGroup(threadID) .setGroupSummary(true) .setGroupAlertBehavior(NotificationCompat.GROUP_ALERT_CHILDREN); if (StaffUtils.isStaffRelease()) { ArrayList groupNotifIDs = recordNotificationInGroupSummary(threadID, notificationID); String notificationSummaryBody = "Notif IDs: " + String.join(System.lineSeparator(), groupNotifIDs); Bundle data = new Bundle(); data.putStringArrayList(GROUP_NOTIF_IDS_KEY, groupNotifIDs); groupSummaryNotificationBuilder .setContentTitle("Summary for thread id " + threadID) .setExtras(data) .setStyle(new NotificationCompat.BigTextStyle().bigText( notificationSummaryBody)) .setAutoCancel(false); } else { groupSummaryNotificationBuilder.setAutoCancel(true); } notificationManager.notify( notificationID, notificationID.hashCode(), notificationBuilder.build()); notificationManager.notify( threadID, threadID.hashCode(), groupSummaryNotificationBuilder.build()); } private void displayNotification(RemoteMessage message) { if (message.getData().get(RESCIND_KEY) != null) { // don't attempt to display rescinds return; } String id = message.getData().get(NOTIF_ID_KEY); String collapseKey = message.getData().get(COLLAPSE_ID_KEY); String notificationID = id; if (collapseKey != null) { notificationID = collapseKey; } String title = message.getData().get(TITLE_KEY); String prefix = message.getData().get(PREFIX_KEY); String body = message.getData().get(BODY_KEY); String threadID = message.getData().get(THREAD_ID_KEY); if (prefix != null) { body = prefix + " " + body; } Bundle data = new Bundle(); data.putString(THREAD_ID_KEY, threadID); NotificationCompat.Builder notificationBuilder = new NotificationCompat.Builder(this.getApplicationContext()) .setDefaults(Notification.DEFAULT_ALL) .setContentText(body) .setExtras(data) .setChannelId(CHANNEL_ID) .setVibrate(VIBRATION_SPEC) .setSmallIcon(R.drawable.notif_icon) .setLargeIcon(displayableNotificationLargeIcon) .setAutoCancel(true); if (title != null) { notificationBuilder.setContentTitle(title); } if (threadID != null) { notificationBuilder.setContentIntent( this.createStartMainActivityAction(id, threadID)); } if (!this.notificationGroupingSupported() || threadID == null) { notificationManager.notify( notificationID, notificationID.hashCode(), notificationBuilder.build()); return; } this.addToThreadGroupAndDisplay( notificationID, notificationBuilder, threadID); } private PendingIntent createStartMainActivityAction(String notificationID, String threadID) { Intent intent = new Intent(this.getApplicationContext(), MainActivity.class); intent.addFlags(Intent.FLAG_ACTIVITY_SINGLE_TOP); intent.putExtra("threadID", threadID); return PendingIntent.getActivity( this.getApplicationContext(), notificationID.hashCode(), intent, PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_MUTABLE); } private RemoteMessage updateRemoteMessageWithDecryptedPayload( RemoteMessage message, String decryptedSerializedPayload) throws JSONException, IllegalStateException { JSONObject decryptedPayload = new JSONObject(decryptedSerializedPayload); ((Iterable)() -> decryptedPayload.keys()) .forEach(payloadFieldName -> { if (decryptedPayload.optJSONArray(payloadFieldName) != null || decryptedPayload.optJSONObject(payloadFieldName) != null) { throw new IllegalStateException( "Notification payload JSON is not {[string]: string} type."); } String payloadFieldValue = decryptedPayload.optString(payloadFieldName); message.getData().put(payloadFieldName, payloadFieldValue); }); return message; } - private RemoteMessage - olmDecryptRemoteMessage(RemoteMessage message, String senderKeyserverID) - throws JSONException, IllegalStateException { + private RemoteMessage olmDecryptRemoteMessage(RemoteMessage message) + throws JSONException, IllegalStateException, NumberFormatException { String encryptedSerializedPayload = message.getData().get(ENCRYPTED_PAYLOAD_KEY); - String decryptedSerializedPayload = NotificationsCryptoModule.decrypt( - senderKeyserverID, - encryptedSerializedPayload, - NotificationsCryptoModule.olmEncryptedTypeMessage()); + + String decryptedSerializedPayload; + if (message.getData().get(KEYSERVER_ID_KEY) != null) { + String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); + decryptedSerializedPayload = NotificationsCryptoModule.decrypt( + senderKeyserverID, + encryptedSerializedPayload, + NotificationsCryptoModule.olmEncryptedTypeMessage()); + } else if (message.getData().get(SENDER_DEVICE_ID_KEY) != null) { + String senderDeviceID = message.getData().get(SENDER_DEVICE_ID_KEY); + String messageTypeString = message.getData().get(MESSAGE_TYPE_KEY); + int messageType = Integer.parseInt(messageTypeString); + decryptedSerializedPayload = NotificationsCryptoModule.peerDecrypt( + senderDeviceID, encryptedSerializedPayload, messageType); + } else { + throw new RuntimeException( + "Received notification without keyserver ID nor sender device ID."); + } return updateRemoteMessageWithDecryptedPayload( message, decryptedSerializedPayload); } private RemoteMessage aesDecryptRemoteMessage(RemoteMessage message, byte[] blob) throws JSONException, IllegalStateException { String aesEncryptionKey = message.getData().get(AES_ENCRYPTION_KEY_LABEL); // On the keyserver AES key is generated as raw bytes // so to send it in JSON it is encoded to Base64 string. byte[] aesEncryptionKeyBytes = Base64.getDecoder().decode(aesEncryptionKey); // On the keyserver notification is a string so it is // first encoded into UTF8 bytes. Therefore bytes // obtained from blob decryption are correct UTF8 bytes. String decryptedSerializedPayload = new String( aesCryptoModule.decrypt(aesEncryptionKeyBytes, blob), StandardCharsets.UTF_8); return updateRemoteMessageWithDecryptedPayload( message, decryptedSerializedPayload); } private Bundle serializeMessageDataForIntentAttachment(RemoteMessage message) { Bundle bundle = new Bundle(); message.getData().forEach(bundle::putString); return bundle; } private void displayErrorMessageNotification( String errorMessage, String errorTitle, String largeErrorData) { NotificationCompat.Builder errorNotificationBuilder = new NotificationCompat.Builder(this.getApplicationContext()) .setDefaults(Notification.DEFAULT_ALL) .setChannelId(CHANNEL_ID) .setSmallIcon(R.drawable.notif_icon) .setLargeIcon(displayableNotificationLargeIcon); if (errorMessage != null) { errorNotificationBuilder.setContentText(errorMessage); } if (errorTitle != null) { errorNotificationBuilder.setContentTitle(errorTitle); } if (largeErrorData != null) { errorNotificationBuilder.setStyle( new NotificationCompat.BigTextStyle().bigText(largeErrorData)); } notificationManager.notify( errorMessage, errorMessage.hashCode(), errorNotificationBuilder.build()); } private boolean isGroupSummary(StatusBarNotification notification, String threadID) { boolean isAnySummary = (notification.getNotification().flags & Notification.FLAG_GROUP_SUMMARY) != 0; if (threadID == null) { return isAnySummary; } return isAnySummary && threadID.equals(notification.getNotification().getGroup()); } private ArrayList recordNotificationInGroupSummary(String threadID, String notificationID) { ArrayList groupNotifIDs = Arrays.stream(notificationManager.getActiveNotifications()) .filter(notif -> isGroupSummary(notif, threadID)) .findFirst() .map( notif -> notif.getNotification().extras.getStringArrayList( GROUP_NOTIF_IDS_KEY)) .orElse(new ArrayList<>()); groupNotifIDs.add(notificationID); return groupNotifIDs; } private void removeNotificationFromGroupSummary( String threadID, String notificationID, StatusBarNotification groupSummaryNotification) { ArrayList groupNotifIDs = groupSummaryNotification.getNotification().extras.getStringArrayList( GROUP_NOTIF_IDS_KEY); if (groupNotifIDs == null) { displayErrorMessageNotification( "Empty summary notif for thread ID " + threadID, "Empty Summary Notif", "Summary notification for thread ID " + threadID + " had empty body when rescinding " + notificationID); } boolean notificationRemoved = groupNotifIDs.removeIf(notifID -> notifID.equals(notificationID)); if (!notificationRemoved) { displayErrorMessageNotification( "Notif with ID " + notificationID + " not in " + threadID, "Unrecorded Notif", "Rescinded notification with id " + notificationID + " not found in group summary for thread id " + threadID); return; } String notificationSummaryBody = "Notif IDs: " + String.join(System.lineSeparator(), groupNotifIDs); Bundle data = new Bundle(); data.putStringArrayList(GROUP_NOTIF_IDS_KEY, groupNotifIDs); NotificationCompat.Builder groupSummaryNotificationBuilder = new NotificationCompat.Builder(this.getApplicationContext()) .setChannelId(CHANNEL_ID) .setSmallIcon(R.drawable.notif_icon) .setContentIntent( this.createStartMainActivityAction(threadID, threadID)) .setContentTitle("Summary for thread id " + threadID) .setExtras(data) .setStyle(new NotificationCompat.BigTextStyle().bigText( notificationSummaryBody)) .setGroup(threadID) .setGroupSummary(true) .setAutoCancel(false) .setGroupAlertBehavior(NotificationCompat.GROUP_ALERT_CHILDREN); notificationManager.notify( threadID, threadID.hashCode(), groupSummaryNotificationBuilder.build()); } private void checkForUnmatchedRescind( String threadID, String notificationID, StatusBarNotification anySummaryNotification) { ArrayList anyGroupNotifIDs = anySummaryNotification.getNotification().extras.getStringArrayList( GROUP_NOTIF_IDS_KEY); if (anyGroupNotifIDs == null) { return; } String groupID = anySummaryNotification.getNotification().getGroup(); for (String notifID : anyGroupNotifIDs) { if (!notificationID.equals(notifID)) { continue; } displayErrorMessageNotification( "Summary for thread id " + groupID + "has " + notifID, "Rescind Mismatch", "Summary notif for thread id " + groupID + " contains notif id " + notifID + " which was received in rescind with thread id " + threadID); } } } diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp index f40d94cfc..3b39543ad 100644 --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -1,214 +1,216 @@ #include "Session.h" #include "PlatformSpecificTools.h" #include #include namespace comm { namespace crypto { OlmSession *Session::getOlmSession() { return reinterpret_cast(this->olmSessionBuffer.data()); } std::unique_ptr Session::createSessionAsInitializer( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &idKeys, const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const std::optional &oneTimeKey) { std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); OlmBuffer randomBuffer; PlatformSpecificTools::generateSecureRandomBytes( randomBuffer, ::olm_create_outbound_session_random_length(session->getOlmSession())); if (oneTimeKey) { if (-1 == ::olm_create_outbound_session( session->getOlmSession(), account, idKeys.data() + ID_KEYS_PREFIX_OFFSET, KEYSIZE, idKeys.data() + SIGNING_KEYS_PREFIX_OFFSET, KEYSIZE, preKeys.data(), KEYSIZE, preKeySignature.data(), SIGNATURESIZE, oneTimeKey->data(), KEYSIZE, randomBuffer.data(), randomBuffer.size())) { throw std::runtime_error( "error createOutbound => " + std::string{::olm_session_last_error(session->getOlmSession())}); } return session; } if (-1 == ::olm_create_outbound_session_without_otk( session->getOlmSession(), account, idKeys.data() + ID_KEYS_PREFIX_OFFSET, KEYSIZE, idKeys.data() + SIGNING_KEYS_PREFIX_OFFSET, KEYSIZE, preKeys.data(), KEYSIZE, preKeySignature.data(), SIGNATURESIZE, randomBuffer.data(), randomBuffer.size())) { throw std::runtime_error( "error createOutbound => " + std::string{::olm_session_last_error(session->getOlmSession())}); } return session; } std::unique_ptr Session::createSessionAsResponder( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys) { std::unique_ptr session(new Session()); OlmBuffer tmpEncryptedMessage(encryptedMessage); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); if (-1 == - ::olm_create_inbound_session( + ::olm_create_inbound_session_from( session->getOlmSession(), account, + idKeys.data() + ID_KEYS_PREFIX_OFFSET, + KEYSIZE, tmpEncryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error( "error createInbound => " + std::string{::olm_session_last_error(session->getOlmSession())}); } if (-1 == ::olm_remove_one_time_keys(account, session->getOlmSession())) { throw std::runtime_error( "error createInbound (remove oneTimeKey) => " + std::string{::olm_session_last_error(session->getOlmSession())}); } return session; } OlmBuffer Session::storeAsB64(const std::string &secretKey) { size_t pickleLength = ::olm_pickle_session_length(this->getOlmSession()); OlmBuffer pickle(pickleLength); size_t res = ::olm_pickle_session( this->getOlmSession(), secretKey.data(), secretKey.size(), pickle.data(), pickleLength); if (pickleLength != res) { throw std::runtime_error("error pickleSession => ::olm_pickle_session"); } return pickle; } std::unique_ptr Session::restoreFromB64(const std::string &secretKey, OlmBuffer &b64) { std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); if (-1 == ::olm_unpickle_session( session->getOlmSession(), secretKey.data(), secretKey.size(), b64.data(), b64.size())) { throw std::runtime_error("error pickleSession => ::olm_unpickle_session"); } return session; } std::string Session::decrypt(EncryptedData &encryptedData) { OlmSession *session = this->getOlmSession(); OlmBuffer utilityBuffer(::olm_utility_size()); OlmUtility *olmUtility = ::olm_utility(utilityBuffer.data()); OlmBuffer messageHashBuffer(::olm_sha256_length(olmUtility)); ::olm_sha256( olmUtility, encryptedData.message.data(), encryptedData.message.size(), messageHashBuffer.data(), messageHashBuffer.size()); OlmBuffer tmpEncryptedMessage(encryptedData.message); size_t maxSize = ::olm_decrypt_max_plaintext_length( session, encryptedData.messageType, tmpEncryptedMessage.data(), tmpEncryptedMessage.size()); if (maxSize == -1) { throw std::runtime_error{ "error decrypt_max_plaintext_length => " + std::string{::olm_session_last_error(session)} + ". Hash: " + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; } OlmBuffer decryptedMessage(maxSize); size_t decryptedSize = ::olm_decrypt( session, encryptedData.messageType, encryptedData.message.data(), encryptedData.message.size(), decryptedMessage.data(), decryptedMessage.size()); if (decryptedSize == -1) { throw std::runtime_error{ "error decrypt => " + std::string{::olm_session_last_error(session)} + ". Hash: " + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; } return std::string{(char *)decryptedMessage.data(), decryptedSize}; } EncryptedData Session::encrypt(const std::string &content) { OlmSession *session = this->getOlmSession(); OlmBuffer encryptedMessage( ::olm_encrypt_message_length(session, content.size())); OlmBuffer messageRandom; PlatformSpecificTools::generateSecureRandomBytes( messageRandom, ::olm_encrypt_random_length(session)); size_t messageType = ::olm_encrypt_message_type(session); if (-1 == ::olm_encrypt( session, (uint8_t *)content.data(), content.size(), messageRandom.data(), messageRandom.size(), encryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error{ "error encrypt => " + std::string{::olm_session_last_error(session)}}; } return {encryptedMessage, messageType}; } int Session::getVersion() { return this->version; } void Session::setVersion(int newVersion) { this->version = newVersion; } } // namespace crypto } // namespace comm diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp index 9f33742f1..8e5e78498 100644 --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -1,557 +1,745 @@ #include "NotificationsCryptoModule.h" #include "../../CryptoTools/Persist.h" #include "../../CryptoTools/Tools.h" #include "../../Tools/CommMMKV.h" #include "../../Tools/CommSecureStore.h" -#include "../../Tools/Logger.h" #include "../../Tools/PlatformSpecificTools.h" +#include "NotificationsInboundKeysProvider.h" +#include "olm/session.hh" #include "Logger.h" #include #include #include #include #include #include #include #include #include namespace comm { const std::string NotificationsCryptoModule::secureStoreNotificationsAccountDataKey = "notificationsCryptoAccountDataKey"; const std::string NotificationsCryptoModule::notificationsCryptoAccountID = "notificationsCryptoAccountDataID"; const std::string NotificationsCryptoModule::keyserverHostedNotificationsID = "keyserverHostedNotificationsID"; const std::string NotificationsCryptoModule::initialEncryptedMessageContent = "{\"type\": \"init\"}"; const int NotificationsCryptoModule::olmEncryptedTypeMessage = 1; // This constant is only used to migrate the existing notifications // session with production keyserver from flat file to MMKV. This // migration will fire when user updates the app. It will also fire // on dev env provided old keyserver set up is used. Developers willing // to use new keyserver set up must log out before installing updated // app version. Do not introduce new usages of this constant in the code!!! const std::string ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage = "256"; const int temporaryFilePathRandomSuffixLength = 32; const std::string notificationsAccountKey = "NOTIFS.ACCOUNT"; std::unique_ptr NotificationsCryptoModule::deserializeCryptoModule( const std::string &path, const std::string &picklingKey) { std::ifstream pickledPersistStream(path, std::ifstream::in); if (!pickledPersistStream.good()) { throw std::runtime_error( "Attempt to deserialize non-existing notifications crypto account"); } std::stringstream pickledPersistStringStream; pickledPersistStringStream << pickledPersistStream.rdbuf(); pickledPersistStream.close(); std::string pickledPersist = pickledPersistStringStream.str(); folly::dynamic persistJSON; try { persistJSON = folly::parseJson(pickledPersist); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "Notifications crypto account JSON deserialization failed with " "reason: " + std::string(e.what())); } std::string accountString = persistJSON["account"].asString(); crypto::OlmBuffer account = std::vector(accountString.begin(), accountString.end()); std::unordered_map sessions; if (persistJSON["sessions"].isNull()) { return std::make_unique( notificationsCryptoAccountID, picklingKey, crypto::Persist({account, sessions})); } for (auto &sessionKeyValuePair : persistJSON["sessions"].items()) { std::string targetUserID = sessionKeyValuePair.first.asString(); std::string sessionData = sessionKeyValuePair.second.asString(); sessions[targetUserID] = { std::vector(sessionData.begin(), sessionData.end()), 1}; } return std::make_unique( notificationsCryptoAccountID, picklingKey, crypto::Persist({account, sessions})); } void NotificationsCryptoModule::serializeAndFlushCryptoModule( std::unique_ptr cryptoModule, const std::string &path, const std::string &picklingKey) { crypto::Persist persist = cryptoModule->storeAsB64(picklingKey); folly::dynamic sessions = folly::dynamic::object; for (auto &sessionKeyValuePair : persist.sessions) { std::string targetUserID = sessionKeyValuePair.first; crypto::OlmBuffer sessionData = sessionKeyValuePair.second.buffer; sessions[targetUserID] = std::string(sessionData.begin(), sessionData.end()); } std::string account = std::string(persist.account.begin(), persist.account.end()); folly::dynamic persistJSON = folly::dynamic::object("account", account)("sessions", sessions); std::string pickledPersist = folly::toJson(persistJSON); std::string temporaryFilePathRandomSuffix = crypto::Tools::generateRandomHexString( temporaryFilePathRandomSuffixLength); std::string temporaryPath = path + temporaryFilePathRandomSuffix; mode_t readWritePermissionsMode = 0666; int temporaryFD = open(temporaryPath.c_str(), O_CREAT | O_WRONLY, readWritePermissionsMode); if (temporaryFD == -1) { throw std::runtime_error( "Failed to create temporary file. Unable to atomically update " "notifications crypto account. Details: " + std::string(strerror(errno))); } ssize_t bytesWritten = write(temporaryFD, pickledPersist.c_str(), pickledPersist.length()); if (bytesWritten == -1 || bytesWritten != pickledPersist.length()) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to write all data to temporary file. Unable to atomically " "update notifications crypto account. Details: " + std::string(strerror(errno))); } if (fsync(temporaryFD) == -1) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to synchronize temporary file data with hardware storage. " "Unable to atomically update notifications crypto account. Details: " + std::string(strerror(errno))); }; close(temporaryFD); if (rename(temporaryPath.c_str(), path.c_str()) == -1) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to replace temporary file content with notifications crypto " "account. Unable to atomically update notifications crypto account. " "Details: " + std::string(strerror(errno))); } remove(temporaryPath.c_str()); } std::string NotificationsCryptoModule::getKeyserverNotificationsSessionKey( const std::string &keyserverID) { return "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; } std::string NotificationsCryptoModule::getDeviceNotificationsSessionKey( const std::string &deviceID) { return "DEVICE." + deviceID + ".NOTIFS_SESSION"; } std::string NotificationsCryptoModule::serializeNotificationsSession( std::shared_ptr session, std::string picklingKey) { crypto::OlmBuffer pickledSessionBytes = session->storeAsB64(picklingKey); std::string pickledSession = std::string{pickledSessionBytes.begin(), pickledSessionBytes.end()}; folly::dynamic serializedSessionJson = folly::dynamic::object( "session", pickledSession)("picklingKey", picklingKey); return folly::toJson(serializedSessionJson); } std::pair, std::string> NotificationsCryptoModule::deserializeNotificationsSession( const std::string &serializedSession) { folly::dynamic serializedSessionJson; try { serializedSessionJson = folly::parseJson(serializedSession); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "Notifications session deserialization failed with reason: " + std::string(e.what())); } std::string pickledSession = serializedSessionJson["session"].asString(); crypto::OlmBuffer pickledSessionBytes = crypto::OlmBuffer{pickledSession.begin(), pickledSession.end()}; std::string picklingKey = serializedSessionJson["picklingKey"].asString(); std::unique_ptr session = crypto::Session::restoreFromB64(picklingKey, pickledSessionBytes); return {std::move(session), picklingKey}; } void NotificationsCryptoModule::clearSensitiveData() { std::string notificationsCryptoAccountPath = PlatformSpecificTools::getNotificationsCryptoAccountPath(); if (remove(notificationsCryptoAccountPath.c_str()) == -1 && errno != ENOENT) { throw std::runtime_error( "Unable to remove notifications crypto account. Security requirements " "might be violated."); } } void NotificationsCryptoModule::persistNotificationsSessionInternal( bool isKeyserverSession, const std::string &senderID, const std::string &picklingKey, std::shared_ptr session) { std::string serializedSession = NotificationsCryptoModule::serializeNotificationsSession( session, picklingKey); std::string notificationsSessionKey; std::string persistenceErrorMessage; if (isKeyserverSession) { notificationsSessionKey = NotificationsCryptoModule::getKeyserverNotificationsSessionKey( senderID); persistenceErrorMessage = "Failed to persist to MMKV notifications session for keyserver: " + senderID; } else { notificationsSessionKey = NotificationsCryptoModule::getDeviceNotificationsSessionKey(senderID); persistenceErrorMessage = "Failed to persist to MMKV notifications session for device: " + senderID; } bool sessionStored = CommMMKV::setString(notificationsSessionKey, serializedSession); if (!sessionStored) { throw std::runtime_error(persistenceErrorMessage); } } std::optional, std::string>> NotificationsCryptoModule::fetchNotificationsSession( bool isKeyserverSession, const std::string &senderID) { std::string notificationsSessionKey; if (isKeyserverSession) { notificationsSessionKey = NotificationsCryptoModule::getKeyserverNotificationsSessionKey( senderID); } else { notificationsSessionKey = NotificationsCryptoModule::getDeviceNotificationsSessionKey(senderID); } std::optional serializedSession; try { serializedSession = CommMMKV::getString(notificationsSessionKey); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { serializedSession = std::nullopt; } if (!serializedSession.has_value() && isKeyserverSession && senderID != ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { throw std::runtime_error( "Missing notifications session for keyserver: " + senderID); } else if (!serializedSession.has_value()) { return std::nullopt; } return NotificationsCryptoModule::deserializeNotificationsSession( serializedSession.value()); } void NotificationsCryptoModule::persistNotificationsAccount( const std::shared_ptr cryptoModule, const std::string &picklingKey) { crypto::Persist serializedCryptoModule = cryptoModule->storeAsB64(picklingKey); crypto::OlmBuffer serializedAccount = serializedCryptoModule.account; std::string serializedAccountString{ serializedAccount.begin(), serializedAccount.end()}; folly::dynamic serializedAccountObject = folly::dynamic::object( "account", serializedAccountString)("picklingKey", picklingKey); std::string serializedAccountJson = folly::toJson(serializedAccountObject); bool accountPersisted = CommMMKV::setString(notificationsAccountKey, serializedAccountJson); if (!accountPersisted) { throw std::runtime_error("Failed to persist notifications crypto account."); } } std::optional, std::string>> NotificationsCryptoModule::fetchNotificationsAccount() { std::optional serializedAccountJson; try { serializedAccountJson = CommMMKV::getString(notificationsAccountKey); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { serializedAccountJson = std::nullopt; } if (!serializedAccountJson.has_value()) { return std::nullopt; } folly::dynamic serializedAccountObject; try { serializedAccountObject = folly::parseJson(serializedAccountJson.value()); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "Notifications account deserialization failed with reason: " + std::string(e.what())); } std::string picklingKey = serializedAccountObject["picklingKey"].asString(); std::string accountString = serializedAccountObject["account"].asString(); crypto::OlmBuffer account = crypto::OlmBuffer{accountString.begin(), accountString.end()}; crypto::Persist serializedCryptoModule{account, {}}; std::shared_ptr cryptoModule = std::make_shared( notificationsCryptoAccountID, picklingKey, serializedCryptoModule); return {{cryptoModule, picklingKey}}; } void NotificationsCryptoModule::persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession) { std::string picklingKey = crypto::Tools::generateRandomString(64); NotificationsCryptoModule::persistNotificationsSessionInternal( true, keyserverID, picklingKey, keyserverNotificationsSession); } void NotificationsCryptoModule::persistDeviceNotificationsSession( const std::string &deviceID, std::shared_ptr peerNotificationsSession) { std::string picklingKey = crypto::Tools::generateRandomString(64); NotificationsCryptoModule::persistNotificationsSessionInternal( false, deviceID, picklingKey, peerNotificationsSession); } bool NotificationsCryptoModule::isNotificationsSessionInitialized( const std::string &keyserverID) { std::string keyserverNotificationsSessionKey = getKeyserverNotificationsSessionKey(keyserverID); return CommMMKV::getString(keyserverNotificationsSessionKey).has_value(); } bool NotificationsCryptoModule::isDeviceNotificationsSessionInitialized( const std::string &deviceID) { std::string peerNotificationsSessionKey = getDeviceNotificationsSessionKey(deviceID); return CommMMKV::getString(peerNotificationsSessionKey).has_value(); } std::vector> NotificationsCryptoModule::isNotificationsSessionInitializedWithDevices( const std::vector &deviceIDs) { std::vector allKeys = CommMMKV::getAllKeys(); std::unordered_set allKeysSet(allKeys.begin(), allKeys.end()); std::vector> result; for (const auto &deviceID : deviceIDs) { std::string mmkvDeviceIDKey = NotificationsCryptoModule::getDeviceNotificationsSessionKey(deviceID); if (allKeysSet.find(mmkvDeviceIDKey) == allKeysSet.end()) { result.push_back({deviceID, false}); } else { result.push_back({deviceID, true}); } } return result; } // notifications account bool NotificationsCryptoModule::isNotificationsAccountInitialized() { return fetchNotificationsAccount().has_value(); } std::string NotificationsCryptoModule::getIdentityKeys() { auto cryptoModuleWithPicklingKey = NotificationsCryptoModule::fetchNotificationsAccount(); if (!cryptoModuleWithPicklingKey.has_value()) { throw std::runtime_error("Notifications crypto account not initialized."); } return cryptoModuleWithPicklingKey.value().first->getIdentityKeys(); } NotificationsCryptoModule::BaseStatefulDecryptResult::BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData) : picklingKey(picklingKey), decryptedData(decryptedData) { } std::string NotificationsCryptoModule::BaseStatefulDecryptResult::getDecryptedData() { return this->decryptedData; } NotificationsCryptoModule::StatefulDecryptResult::StatefulDecryptResult( std::unique_ptr session, std::string keyserverID, std::string picklingKey, std::string decryptedData) : NotificationsCryptoModule::BaseStatefulDecryptResult:: BaseStatefulDecryptResult(picklingKey, decryptedData), sessionState(std::move(session)), keyserverID(keyserverID) { } void NotificationsCryptoModule::StatefulDecryptResult::flushState() { NotificationsCryptoModule::persistNotificationsSessionInternal( true, this->keyserverID, this->picklingKey, std::move(this->sessionState)); } NotificationsCryptoModule::LegacyStatefulDecryptResult:: LegacyStatefulDecryptResult( std::unique_ptr cryptoModule, std::string path, std::string picklingKey, std::string decryptedData) : NotificationsCryptoModule::BaseStatefulDecryptResult:: BaseStatefulDecryptResult(picklingKey, decryptedData), path(path), cryptoModule(std::move(cryptoModule)) { } void NotificationsCryptoModule::LegacyStatefulDecryptResult::flushState() { std::shared_ptr legacyNotificationsSession = this->cryptoModule->getSessionByDeviceId(keyserverHostedNotificationsID); NotificationsCryptoModule::serializeAndFlushCryptoModule( std::move(this->cryptoModule), this->path, this->picklingKey); try { NotificationsCryptoModule::persistNotificationsSession( ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage, legacyNotificationsSession); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { return; } } +NotificationsCryptoModule::StatefulPeerInitDecryptResult:: + StatefulPeerInitDecryptResult( + std::shared_ptr session, + std::shared_ptr account, + std::string sessionPicklingKey, + std::string accountPicklingKey, + std::string deviceID, + std::string decryptedData) + : BaseStatefulDecryptResult(sessionPicklingKey, decryptedData), + sessionState(session), + accountState(account), + accountPicklingKey(accountPicklingKey), + deviceID(deviceID) { +} + +void NotificationsCryptoModule::StatefulPeerInitDecryptResult::flushState() { + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, this->deviceID, this->picklingKey, std::move(this->sessionState)); + NotificationsCryptoModule::persistNotificationsAccount( + std::move(this->accountState), this->accountPicklingKey); +} + +NotificationsCryptoModule::StatefulPeerDecryptResult::StatefulPeerDecryptResult( + std::unique_ptr session, + std::string deviceID, + std::string picklingKey, + std::string decryptedData) + : NotificationsCryptoModule::BaseStatefulDecryptResult:: + BaseStatefulDecryptResult(picklingKey, decryptedData), + sessionState(std::move(session)), + deviceID(deviceID) { +} + +void NotificationsCryptoModule::StatefulPeerDecryptResult::flushState() { + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, this->deviceID, this->picklingKey, std::move(this->sessionState)); +} + +NotificationsCryptoModule::StatefulPeerConflictDecryptResult:: + StatefulPeerConflictDecryptResult( + std::string picklingKey, + std::string decryptedData) + : NotificationsCryptoModule::BaseStatefulDecryptResult( + picklingKey, + decryptedData) { +} +void NotificationsCryptoModule::StatefulPeerConflictDecryptResult:: + flushState() { + return; +} + std::unique_ptr NotificationsCryptoModule::prepareLegacyDecryptedState( const std::string &data, const size_t messageType) { folly::Optional picklingKey = comm::CommSecureStore::get( NotificationsCryptoModule::secureStoreNotificationsAccountDataKey); if (!picklingKey.hasValue()) { throw std::runtime_error( "Legacy notifications session pickling key missing."); } std::string legacyNotificationsAccountPath = comm::PlatformSpecificTools::getNotificationsCryptoAccountPath(); crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; auto cryptoModule = NotificationsCryptoModule::deserializeCryptoModule( legacyNotificationsAccountPath, picklingKey.value()); std::string decryptedData = cryptoModule->decrypt( NotificationsCryptoModule::keyserverHostedNotificationsID, encryptedData); LegacyStatefulDecryptResult statefulDecryptResult( std::move(cryptoModule), legacyNotificationsAccountPath, picklingKey.value(), decryptedData); return std::make_unique( std::move(statefulDecryptResult)); } std::string NotificationsCryptoModule::decrypt( const std::string &keyserverID, const std::string &data, const size_t messageType) { auto sessionWithPicklingKey = NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { auto statefulDecryptResult = NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); statefulDecryptResult->flushState(); return statefulDecryptResult->getDecryptedData(); } std::unique_ptr session = std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; std::string decryptedData = session->decrypt(encryptedData); NotificationsCryptoModule::persistNotificationsSessionInternal( true, keyserverID, picklingKey, std::move(session)); return decryptedData; } crypto::EncryptedData NotificationsCryptoModule::encrypt( const std::string &deviceID, const std::string &payload) { auto sessionWithPicklingKey = NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); if (!sessionWithPicklingKey.has_value()) { throw std::runtime_error( "Session with deviceID: " + deviceID + " not initialized."); } std::unique_ptr session = std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData = session->encrypt(payload); NotificationsCryptoModule::persistNotificationsSessionInternal( false, deviceID, picklingKey, std::move(session)); return encryptedData; } std::unique_ptr NotificationsCryptoModule::statefulDecrypt( const std::string &keyserverID, const std::string &data, const size_t messageType) { auto sessionWithPicklingKey = NotificationsCryptoModule::fetchNotificationsSession(true, keyserverID); if (!sessionWithPicklingKey.has_value()) { return NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); } std::unique_ptr session = std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; std::string decryptedData = session->decrypt(encryptedData); StatefulDecryptResult statefulDecryptResult( std::move(session), keyserverID, picklingKey, decryptedData); return std::make_unique( std::move(statefulDecryptResult)); } +std::unique_ptr +NotificationsCryptoModule::statefulPeerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType) { + if (messageType != OLM_MESSAGE_TYPE_MESSAGE && + messageType != OLM_MESSAGE_TYPE_PRE_KEY) { + throw std::runtime_error( + "Received message of invalid type from device: " + deviceID); + } + + auto maybeSessionWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); + + if (!maybeSessionWithPicklingKey.has_value() && + messageType == OLM_MESSAGE_TYPE_MESSAGE) { + throw std::runtime_error( + "Received MESSAGE_TYPE_MESSAGE message from device: " + deviceID + + " but session not initialized."); + } + + crypto::EncryptedData encryptedData{ + std::vector(data.begin(), data.end()), messageType}; + + bool isSenderChainEmpty = true; + bool hasReceivedMessage = false; + bool sessionExists = maybeSessionWithPicklingKey.has_value(); + + if (sessionExists) { + ::olm::Session *olmSessionAsCppClass = reinterpret_cast<::olm::Session *>( + maybeSessionWithPicklingKey.value().first->getOlmSession()); + isSenderChainEmpty = olmSessionAsCppClass->ratchet.sender_chain.empty(); + hasReceivedMessage = olmSessionAsCppClass->received_message; + } + + // regular message + bool isRegularMessage = + sessionExists && messageType == OLM_MESSAGE_TYPE_MESSAGE; + + bool isRegularPrekeyMessage = sessionExists && + messageType == OLM_MESSAGE_TYPE_PRE_KEY && isSenderChainEmpty && + hasReceivedMessage; + + if (isRegularMessage || isRegularPrekeyMessage) { + std::string decryptedData = + maybeSessionWithPicklingKey.value().first->decrypt(encryptedData); + StatefulPeerDecryptResult decryptResult = StatefulPeerDecryptResult( + std::move(maybeSessionWithPicklingKey.value().first), + deviceID, + maybeSessionWithPicklingKey.value().second, + decryptedData); + return std::make_unique( + std::move(decryptResult)); + } + + // At this point we either face race condition or session reset attempt or + // session initialization attempt. For each of this scenario new inbound + // session must be created in order to decrypt message + std::string notifInboundKeys = + NotificationsInboundKeysProvider::getNotifsInboundKeysForDeviceID( + deviceID); + auto maybeAccountWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); + + if (!maybeAccountWithPicklingKey.has_value()) { + throw std::runtime_error("Notifications account not initialized."); + } + + auto accountWithPicklingKey = maybeAccountWithPicklingKey.value(); + accountWithPicklingKey.first->initializeInboundForReceivingSession( + deviceID, + {data.begin(), data.end()}, + {notifInboundKeys.begin(), notifInboundKeys.end()}, + // The argument below is relevant for content only + 0, + true); + std::shared_ptr newInboundSession = + accountWithPicklingKey.first->getSessionByDeviceId(deviceID); + accountWithPicklingKey.first->removeSessionByDeviceId(deviceID); + std::string decryptedData = newInboundSession->decrypt(encryptedData); + + // session reset attempt or session initialization - handled the same + bool sessionResetAttempt = + sessionExists && !isSenderChainEmpty && hasReceivedMessage; + + // race condition + bool raceCondition = + sessionExists && !isSenderChainEmpty && !hasReceivedMessage; + + // device ID comparison + folly::Optional maybeOurDeviceID = + CommSecureStore::get(CommSecureStore::deviceID); + if (!maybeOurDeviceID.hasValue()) { + throw std::runtime_error("Session creation attempt but no device id"); + } + std::string ourDeviceID = maybeOurDeviceID.value(); + bool thisDeviceWinsRaceCondition = ourDeviceID > deviceID; + + // If there is no session or there is session reset attempt or + // there is a race condition but we loos device id comparison + // we end up creating new session as inbound + if (!sessionExists || sessionResetAttempt || + (raceCondition && !thisDeviceWinsRaceCondition)) { + std::string sessionPicklingKey = crypto::Tools::generateRandomString(64); + StatefulPeerInitDecryptResult decryptResult = StatefulPeerInitDecryptResult( + newInboundSession, + accountWithPicklingKey.first, + sessionPicklingKey, + accountWithPicklingKey.second, + deviceID, + decryptedData); + return std::make_unique( + std::move(decryptResult)); + } + + // If there is a race condition but we win device id comparison + // we return object that carries decrypted data but won't persist + // any session state + StatefulPeerConflictDecryptResult decryptResult = + StatefulPeerConflictDecryptResult( + maybeSessionWithPicklingKey.value().second, decryptedData); + return std::make_unique( + std::move(decryptResult)); +} + +std::string NotificationsCryptoModule::peerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType) { + auto statefulDecryptResult = NotificationsCryptoModule::statefulPeerDecrypt( + deviceID, data, messageType); + std::string decryptedData = statefulDecryptResult->decryptedData; + statefulDecryptResult->flushState(); + return decryptedData; +} + void NotificationsCryptoModule::flushState( std::unique_ptr baseStatefulDecryptResult) { baseStatefulDecryptResult->flushState(); } } // namespace comm diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h index 4270cfccd..7db22105d 100644 --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -1,140 +1,193 @@ #pragma once #include "../../CryptoTools/CryptoModule.h" #include #include namespace comm { class NotificationsCryptoModule { const static std::string notificationsCryptoAccountID; // Used for handling of legacy notifications sessions const static std::string secureStoreNotificationsAccountDataKey; const static std::string keyserverHostedNotificationsID; static std::unique_ptr deserializeCryptoModule( const std::string &path, const std::string &picklingKey); static void serializeAndFlushCryptoModule( std::unique_ptr cryptoModule, const std::string &path, const std::string &picklingKey); static std::string getKeyserverNotificationsSessionKey(const std::string &keyserverID); static std::string getDeviceNotificationsSessionKey(const std::string &deviceID); static std::string serializeNotificationsSession( std::shared_ptr session, std::string picklingKey); static std::pair, std::string> deserializeNotificationsSession(const std::string &serializedSession); static void persistNotificationsSessionInternal( bool isKeyserverSession, const std::string &senderID, const std::string &picklingKey, std::shared_ptr session); static std::optional, std::string>> fetchNotificationsSession( bool isKeyserverSession, const std::string &senderID); public: const static std::string initialEncryptedMessageContent; const static int olmEncryptedTypeMessage; static void clearSensitiveData(); // notifications sessions static void persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession); static void persistDeviceNotificationsSession( const std::string &deviceID, std::shared_ptr peerNotificationsSession); static bool isNotificationsSessionInitialized(const std::string &keyserverID); static bool isDeviceNotificationsSessionInitialized(const std::string &deviceID); static std::vector> isNotificationsSessionInitializedWithDevices( const std::vector &deviceIDs); // notifications account static void persistNotificationsAccount( const std::shared_ptr cryptoModule, const std::string &picklingKey); static std::optional< std::pair, std::string>> fetchNotificationsAccount(); static bool isNotificationsAccountInitialized(); static std::string getIdentityKeys(); class BaseStatefulDecryptResult { BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData); std::string picklingKey; std::string decryptedData; friend NotificationsCryptoModule; public: std::string getDecryptedData(); virtual void flushState() = 0; virtual ~BaseStatefulDecryptResult() = default; }; class StatefulDecryptResult : public BaseStatefulDecryptResult { StatefulDecryptResult( std::unique_ptr session, std::string keyserverID, std::string picklingKey, std::string decryptedData); std::unique_ptr sessionState; std::string keyserverID; friend NotificationsCryptoModule; public: void flushState() override; }; class LegacyStatefulDecryptResult : public BaseStatefulDecryptResult { LegacyStatefulDecryptResult( std::unique_ptr cryptoModule, std::string path, std::string picklingKey, std::string decryptedData); std::unique_ptr cryptoModule; std::string path; friend NotificationsCryptoModule; public: void flushState() override; }; + class StatefulPeerInitDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerInitDecryptResult( + std::shared_ptr session, + std::shared_ptr account, + std::string sessionPicklingKey, + std::string accountPicklingKey, + std::string deviceID, + std::string decryptedData); + std::shared_ptr sessionState; + std::shared_ptr accountState; + std::string accountPicklingKey; + std::string deviceID; + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + + class StatefulPeerDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerDecryptResult( + std::unique_ptr session, + std::string deviceID, + std::string picklingKey, + std::string decryptedData); + + std::unique_ptr sessionState; + std::string deviceID; + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + + class StatefulPeerConflictDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerConflictDecryptResult( + std::string picklingKey, + std::string decryptedData); + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + private: static std::unique_ptr prepareLegacyDecryptedState( const std::string &data, const size_t messageType); public: static std::string decrypt( const std::string &keyserverID, const std::string &data, const size_t messageType); static std::unique_ptr statefulDecrypt( const std::string &keyserverID, const std::string &data, const size_t messageType); static crypto::EncryptedData encrypt(const std::string &deviceID, const std::string &payload); + static std::unique_ptr statefulPeerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType); + + static std::string peerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType); + static void flushState(std::unique_ptr statefulDecryptResult); }; } // namespace comm diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h index 8f3524a23..4d52a234e 100644 --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h @@ -1,23 +1,29 @@ #pragma once #include namespace comm { class NotificationsCryptoModuleJNIHelper : public facebook::jni::JavaClass { public: static auto constexpr kJavaDescriptor = "Lapp/comm/android/fbjni/NotificationsCryptoModule;"; static int olmEncryptedTypeMessage( facebook::jni::alias_ref jThis); static std::string decrypt( facebook::jni::alias_ref jThis, std::string keyserverID, std::string data, int messageType); + static std::string peerDecrypt( + facebook::jni::alias_ref jThis, + std::string deviceID, + std::string data, + int messageType); + static void registerNatives(); }; } // namespace comm diff --git a/native/ios/NotificationService/NotificationService.mm b/native/ios/NotificationService/NotificationService.mm index 1eb356cf1..3e1712127 100644 --- a/native/ios/NotificationService/NotificationService.mm +++ b/native/ios/NotificationService/NotificationService.mm @@ -1,916 +1,927 @@ #import "NotificationService.h" #import "AESCryptoModuleObjCCompat.h" #import "CommIOSServicesClient.h" #import "CommMMKV.h" #import "Logger.h" #import "NotificationsCryptoModule.h" #import "StaffUtils.h" #import "TemporaryMessageStorage.h" #import #include #include NSString *const backgroundNotificationTypeKey = @"backgroundNotifType"; NSString *const messageInfosKey = @"messageInfos"; NSString *const encryptedPayloadKey = @"encryptedPayload"; NSString *const encryptionFailedKey = @"encryptionFailed"; NSString *const collapseIDKey = @"collapseID"; NSString *const keyserverIDKey = @"keyserverID"; +NSString *const senderDeviceIDKey = @"senderDeviceID"; +NSString *const messageTypeKey = @"type"; NSString *const blobHashKey = @"blobHash"; NSString *const blobHolderKey = @"blobHolder"; NSString *const encryptionKeyLabel = @"encryptionKey"; NSString *const needsSilentBadgeUpdateKey = @"needsSilentBadgeUpdate"; // Those and future MMKV-related constants should match // similar constants in CommNotificationsHandler.java const std::string mmkvKeySeparator = "."; const std::string mmkvKeyserverPrefix = "KEYSERVER"; const std::string mmkvUnreadCountSuffix = "UNREAD_COUNT"; // The context for this constant can be found here: // https://linear.app/comm/issue/ENG-3074#comment-bd2f5e28 int64_t const notificationRemovalDelay = (int64_t)(0.1 * NSEC_PER_SEC); // Apple gives us about 30 seconds to process single notification, // se we let any semaphore wait for at most 20 seconds int64_t const semaphoreAwaitTimeLimit = (int64_t)(20 * NSEC_PER_SEC); CFStringRef newMessageInfosDarwinNotification = CFSTR("app.comm.darwin_new_message_infos"); // Implementation below was inspired by the // following discussion with Apple staff member: // https://developer.apple.com/forums/thread/105088 size_t getMemoryUsageInBytes() { task_vm_info_data_t vmInfo; mach_msg_type_number_t count = TASK_VM_INFO_COUNT; kern_return_t result = task_info(mach_task_self(), TASK_VM_INFO, (task_info_t)&vmInfo, &count); if (result != KERN_SUCCESS) { return -1; } size_t memory_usage = static_cast(vmInfo.phys_footprint); return memory_usage; } std::string joinStrings( const std::string &separator, const std::vector &array) { std::ostringstream joinedStream; std::copy( array.begin(), array.end(), std::ostream_iterator(joinedStream, separator.c_str())); std::string joined = joinedStream.str(); return joined.empty() ? joined : joined.substr(0, joined.size() - 1); } @interface NotificationService () @property(strong) NSMutableDictionary *contentHandlers; @property(strong) NSMutableDictionary *contents; @end @implementation NotificationService - (void)didReceiveNotificationRequest:(UNNotificationRequest *)request withContentHandler: (void (^)(UNNotificationContent *_Nonnull)) contentHandler { // Set-up methods are idempotent [NotificationService setUpNSEProcess]; [self setUpNSEInstance]; NSString *contentHandlerKey = [request.identifier copy]; UNMutableNotificationContent *content = [request.content mutableCopy]; [self putContent:content withHandler:contentHandler forKey:contentHandlerKey]; UNNotificationContent *publicUserContent = content; // Step 1: notification decryption. std::unique_ptr statefulDecryptResultPtr; BOOL decryptionExecuted = NO; if ([self shouldBeDecrypted:content.userInfo]) { std::optional notifID; NSString *objcNotifID = content.userInfo[@"id"]; if (objcNotifID) { notifID = std::string([objcNotifID UTF8String]); } std::string decryptErrorMessage; try { @try { statefulDecryptResultPtr = [self decryptContentInPlace:content]; decryptionExecuted = YES; } @catch (NSException *e) { decryptErrorMessage = "NSE: Received Obj-C exception: " + std::string([e.name UTF8String]) + " during notification decryption."; if (notifID.has_value()) { decryptErrorMessage += " Notif ID: " + notifID.value(); } } } catch (const std::exception &e) { decryptErrorMessage = "NSE: Received C++ exception: " + std::string(e.what()) + " during notification decryption."; if (notifID.has_value()) { decryptErrorMessage += " Notif ID: " + notifID.value(); } } if (decryptErrorMessage.size()) { NSString *errorMessage = [NSString stringWithUTF8String:decryptErrorMessage.c_str()]; if (notifID.has_value() && [self isAppShowingNotificationWith: [NSString stringWithCString:notifID.value().c_str() encoding:NSUTF8StringEncoding]]) { errorMessage = [errorMessage stringByAppendingString:@" App shows notif with this ID."]; } [self callContentHandlerForKey:contentHandlerKey onErrorMessage:errorMessage withPublicUserContent:[[UNNotificationContent alloc] init]]; return; } } NSMutableArray *errorMessages = [[NSMutableArray alloc] init]; if (comm::StaffUtils::isStaffRelease() && [self shouldAlertUnencryptedNotification:content.userInfo]) { [errorMessages addObject: @"Notification encryption failed on the keyserver. " @"Please investigate!"]; } if ([self shouldAlertUnencryptedNotification:content.userInfo]) { comm::Logger::log("NSE: Received erroneously unencrypted notification."); } // Step 2: notification persistence in a temporary storage std::string persistErrorMessage; try { @try { [self persistMessagePayload:content.userInfo]; } @catch (NSException *e) { persistErrorMessage = "Obj-C exception: " + std::string([e.name UTF8String]) + " during notification persistence."; } } catch (const std::exception &e) { persistErrorMessage = "C++ exception: " + std::string(e.what()) + " during notification persistence."; } if (persistErrorMessage.size()) { [errorMessages addObject:[NSString stringWithUTF8String:persistErrorMessage.c_str()]]; } // Step 3: Cumulative unread count calculation if (content.badge) { std::string unreadCountCalculationError; try { @try { [self calculateTotalUnreadCountInPlace:content]; } @catch (NSException *e) { unreadCountCalculationError = "Obj-C exception: " + std::string([e.name UTF8String]) + " during unread count calculation."; } } catch (const std::exception &e) { unreadCountCalculationError = "C++ exception: " + std::string(e.what()) + " during unread count calculation."; } if (unreadCountCalculationError.size() && comm::StaffUtils::isStaffRelease()) { [errorMessages addObject:[NSString stringWithUTF8String:unreadCountCalculationError .c_str()]]; } } // Step 4: (optional) rescind read notifications // Message payload persistence is a higher priority task, so it has // to happen prior to potential notification center clearing. if ([self isRescind:content.userInfo]) { std::string rescindErrorMessage; try { @try { [self removeNotificationsWithCondition:^BOOL( UNNotification *_Nonnull notif) { return [content.userInfo[@"notificationId"] isEqualToString:notif.request.content.userInfo[@"id"]]; }]; } @catch (NSException *e) { rescindErrorMessage = "Obj-C exception: " + std::string([e.name UTF8String]) + " during notification rescind."; } } catch (const std::exception &e) { rescindErrorMessage = "C++ exception: " + std::string(e.what()) + " during notification rescind."; } if (rescindErrorMessage.size()) { [errorMessages addObject:[NSString stringWithUTF8String:persistErrorMessage.c_str()]]; } publicUserContent = [[UNNotificationContent alloc] init]; } // Step 5: (optional) execute notification coalescing if ([self isCollapsible:content.userInfo]) { std::string coalescingErrorMessage; try { @try { [self displayLocalNotificationFromContent:content forCollapseKey:content .userInfo[collapseIDKey]]; } @catch (NSException *e) { coalescingErrorMessage = "Obj-C exception: " + std::string([e.name UTF8String]) + " during notification coalescing."; } } catch (const std::exception &e) { coalescingErrorMessage = "C++ exception: " + std::string(e.what()) + " during notification coalescing."; } if (coalescingErrorMessage.size()) { [errorMessages addObject:[NSString stringWithUTF8String:coalescingErrorMessage.c_str()]]; // Even if we fail to execute coalescing then public users // should still see the original message. publicUserContent = content; } else { publicUserContent = [[UNNotificationContent alloc] init]; } } // Step 6: (optional) create empty notification that // only provides badge count. // For notifs that only contain badge update the // server sets BODY to "ENCRYPTED" for internal // builds for debugging purposes. So instead of // letting such notif go through, we construct // another notif that doesn't have a body. if (content.userInfo[needsSilentBadgeUpdateKey]) { publicUserContent = [self getBadgeOnlyContentFor:content]; } // Step 7: (optional) download notification paylaod // from blob service in case it is large notification if ([self isLargeNotification:content.userInfo]) { std::string processLargeNotificationError; try { @try { [self fetchAndPersistLargeNotifPayload:content]; } @catch (NSException *e) { processLargeNotificationError = "Obj-C exception: " + std::string([e.name UTF8String]) + " during large notification processing."; } } catch (const std::exception &e) { processLargeNotificationError = "C++ exception: " + std::string(e.what()) + " during large notification processing."; } if (processLargeNotificationError.size()) { [errorMessages addObject:[NSString stringWithUTF8String:processLargeNotificationError .c_str()]]; } } // Step 8: notify main app that there is data // to transfer to SQLite and redux. [self sendNewMessageInfosNotification]; if (NSString *currentMemoryEventMessage = [NotificationService getAndSetMemoryEventMessage:nil]) { [errorMessages addObject:currentMemoryEventMessage]; } if (errorMessages.count) { NSString *cumulatedErrorMessage = [@"NSE: Received " stringByAppendingString:[errorMessages componentsJoinedByString:@" "]]; [self callContentHandlerForKey:contentHandlerKey onErrorMessage:cumulatedErrorMessage withPublicUserContent:publicUserContent]; return; } [self callContentHandlerForKey:contentHandlerKey withContent:publicUserContent]; if (decryptionExecuted) { comm::NotificationsCryptoModule::flushState( std::move(statefulDecryptResultPtr)); } } - (void)serviceExtensionTimeWillExpire { // Called just before the extension will be terminated by the system. // Use this as an opportunity to deliver your "best attempt" at modified // content, otherwise the original push payload will be used. NSMutableArray *allHandlers = [[NSMutableArray alloc] init]; NSMutableArray *allContents = [[NSMutableArray alloc] init]; @synchronized(self.contentHandlers) { for (NSString *key in self.contentHandlers) { [allHandlers addObject:self.contentHandlers[key]]; [allContents addObject:self.contents[key]]; } [self.contentHandlers removeAllObjects]; [self.contents removeAllObjects]; } for (int i = 0; i < allContents.count; i++) { UNNotificationContent *content = allContents[i]; void (^handler)(UNNotificationContent *_Nonnull) = allHandlers[i]; if ([self isRescind:content.userInfo]) { // If we get to this place it means we were unable to // remove relevant notification from notification center in // in time given to NSE to process notification. // It is an extremely unlikely to happen. if (!comm::StaffUtils::isStaffRelease()) { handler([[UNNotificationContent alloc] init]); continue; } NSString *errorMessage = @"NSE: Exceeded time limit to rescind a notification."; UNNotificationContent *errorContent = [self buildContentForError:errorMessage]; handler(errorContent); continue; } if ([self isCollapsible:content.userInfo]) { // If we get to this place it means we were unable to // execute notification coalescing with local notification // mechanism in time given to NSE to process notification. if (!comm::StaffUtils::isStaffRelease()) { handler(content); continue; } NSString *errorMessage = @"NSE: Exceeded time limit to collapse a notitication."; UNNotificationContent *errorContent = [self buildContentForError:errorMessage]; handler(errorContent); continue; } if ([self shouldBeDecrypted:content.userInfo] && !content.userInfo[@"successfullyDecrypted"]) { // If we get to this place it means we were unable to // decrypt encrypted notification content in time // given to NSE to process notification. if (!comm::StaffUtils::isStaffRelease()) { handler([[UNNotificationContent alloc] init]); continue; } NSString *errorMessage = @"NSE: Exceeded time limit to decrypt a notification."; UNNotificationContent *errorContent = [self buildContentForError:errorMessage]; handler(errorContent); continue; } // At this point we know that the content is at least // correctly decrypted so we can display it to the user. // Another operation, like persistence, had failed. if (content.userInfo[needsSilentBadgeUpdateKey]) { UNNotificationContent *badgeOnlyContent = [self getBadgeOnlyContentFor:content]; handler(badgeOnlyContent); continue; } handler(content); } } - (void)removeNotificationsWithCondition: (BOOL (^)(UNNotification *_Nonnull))condition { dispatch_semaphore_t semaphore = dispatch_semaphore_create(0); void (^delayedSemaphorePostCallback)() = ^() { dispatch_time_t timeToPostSemaphore = dispatch_time(DISPATCH_TIME_NOW, notificationRemovalDelay); dispatch_after(timeToPostSemaphore, dispatch_get_main_queue(), ^{ dispatch_semaphore_signal(semaphore); }); }; [UNUserNotificationCenter.currentNotificationCenter getDeliveredNotificationsWithCompletionHandler:^( NSArray *_Nonnull notifications) { NSMutableArray *notificationsToRemove = [[NSMutableArray alloc] init]; for (UNNotification *notif in notifications) { if (condition(notif)) { [notificationsToRemove addObject:notif.request.identifier]; } } [UNUserNotificationCenter.currentNotificationCenter removeDeliveredNotificationsWithIdentifiers:notificationsToRemove]; delayedSemaphorePostCallback(); }]; dispatch_semaphore_wait( semaphore, dispatch_time(DISPATCH_TIME_NOW, semaphoreAwaitTimeLimit)); } - (void)displayLocalNotificationFromContent:(UNNotificationContent *)content forCollapseKey:(NSString *)collapseKey { UNMutableNotificationContent *localNotifContent = [[UNMutableNotificationContent alloc] init]; localNotifContent.title = content.title; localNotifContent.body = content.body; localNotifContent.badge = content.badge; localNotifContent.userInfo = content.userInfo; UNNotificationRequest *localNotifRequest = [UNNotificationRequest requestWithIdentifier:collapseKey content:localNotifContent trigger:nil]; [self displayLocalNotificationFor:localNotifRequest]; } - (void)persistMessagePayload:(NSDictionary *)payload { if (payload[messageInfosKey]) { TemporaryMessageStorage *temporaryStorage = [[TemporaryMessageStorage alloc] init]; [temporaryStorage writeMessage:payload[messageInfosKey]]; return; } if (![self isRescind:payload]) { return; } NSError *jsonError = nil; NSData *binarySerializedRescindPayload = [NSJSONSerialization dataWithJSONObject:payload options:0 error:&jsonError]; if (jsonError) { comm::Logger::log( "NSE: Failed to serialize rescind payload. Details: " + std::string([jsonError.localizedDescription UTF8String])); return; } NSString *serializedRescindPayload = [[NSString alloc] initWithData:binarySerializedRescindPayload encoding:NSUTF8StringEncoding]; TemporaryMessageStorage *temporaryRescindsStorage = [[TemporaryMessageStorage alloc] initForRescinds]; [temporaryRescindsStorage writeMessage:serializedRescindPayload]; } - (BOOL)isRescind:(NSDictionary *)payload { return payload[backgroundNotificationTypeKey] && [payload[backgroundNotificationTypeKey] isEqualToString:@"CLEAR"]; } - (void)calculateTotalUnreadCountInPlace: (UNMutableNotificationContent *)content { if (!content.userInfo[keyserverIDKey]) { - throw std::runtime_error("Received badge update without keyserver ID."); + return; } std::string senderKeyserverID = std::string([content.userInfo[keyserverIDKey] UTF8String]); std::string senderKeyserverUnreadCountKey = joinStrings( mmkvKeySeparator, {mmkvKeyserverPrefix, senderKeyserverID, mmkvUnreadCountSuffix}); int senderKeyserverUnreadCount = [content.badge intValue]; comm::CommMMKV::setInt( senderKeyserverUnreadCountKey, senderKeyserverUnreadCount); int totalUnreadCount = 0; std::vector allKeys = comm::CommMMKV::getAllKeys(); for (const auto &key : allKeys) { if (key.size() < mmkvKeyserverPrefix.size() + mmkvUnreadCountSuffix.size() || key.compare(0, mmkvKeyserverPrefix.size(), mmkvKeyserverPrefix) || key.compare( key.size() - mmkvUnreadCountSuffix.size(), mmkvUnreadCountSuffix.size(), mmkvUnreadCountSuffix)) { continue; } std::optional unreadCount = comm::CommMMKV::getInt(key, -1); if (!unreadCount.has_value()) { continue; } totalUnreadCount += unreadCount.value(); } content.badge = @(totalUnreadCount); } - (void)fetchAndPersistLargeNotifPayload: (UNMutableNotificationContent *)content { NSString *blobHash = content.userInfo[blobHashKey]; NSData *encryptionKey = [[NSData alloc] initWithBase64EncodedString:content.userInfo[encryptionKeyLabel] options:0]; __block NSError *fetchError = nil; NSData *largePayloadBinary = [CommIOSServicesClient.sharedInstance getBlobSync:blobHash orSetError:&fetchError]; if (fetchError) { comm::Logger::log( "Failed to fetch notif payload from blob service. Details: " + std::string([fetchError.localizedDescription UTF8String])); return; } NSDictionary *largePayload = [NotificationService aesDecryptAndParse:largePayloadBinary withKey:encryptionKey]; [self persistMessagePayload:largePayload]; [CommIOSServicesClient.sharedInstance storeBlobForDeletionWithHash:blobHash andHolder:content.userInfo[blobHolderKey]]; } - (BOOL)isCollapsible:(NSDictionary *)payload { return payload[collapseIDKey]; } - (BOOL)isLargeNotification:(NSDictionary *)payload { return payload[blobHashKey] && payload[encryptionKeyLabel] && payload[blobHolderKey]; } - (UNNotificationContent *)getBadgeOnlyContentFor: (UNNotificationContent *)content { UNMutableNotificationContent *badgeOnlyContent = [[UNMutableNotificationContent alloc] init]; badgeOnlyContent.badge = content.badge; return badgeOnlyContent; } - (void)sendNewMessageInfosNotification { CFNotificationCenterPostNotification( CFNotificationCenterGetDarwinNotifyCenter(), newMessageInfosDarwinNotification, (__bridge const void *)(self), nil, TRUE); } - (BOOL)shouldBeDecrypted:(NSDictionary *)payload { return payload[encryptedPayloadKey]; } - (BOOL)shouldAlertUnencryptedNotification:(NSDictionary *)payload { return payload[encryptionFailedKey] && [payload[encryptionFailedKey] isEqualToString:@"1"]; } - (std::unique_ptr) decryptContentInPlace:(UNMutableNotificationContent *)content { std::string encryptedData = std::string([content.userInfo[encryptedPayloadKey] UTF8String]); - if (!content.userInfo[keyserverIDKey]) { + std::unique_ptr + decryptResult; + if (content.userInfo[keyserverIDKey]) { + std::string senderKeyserverID = + std::string([content.userInfo[keyserverIDKey] UTF8String]); + decryptResult = comm::NotificationsCryptoModule::statefulDecrypt( + senderKeyserverID, + encryptedData, + comm::NotificationsCryptoModule::olmEncryptedTypeMessage); + } else if ( + content.userInfo[senderDeviceIDKey] && content.userInfo[messageTypeKey]) { + std::string senderDeviceID = + std::string([content.userInfo[senderDeviceIDKey] UTF8String]); + size_t messageType = [content.userInfo[messageTypeKey] intValue]; + decryptResult = comm::NotificationsCryptoModule::statefulPeerDecrypt( + senderDeviceID, encryptedData, messageType); + } else { throw std::runtime_error( - "Received encrypted notification without keyserverID."); + "Received notification without keyserver ID nor sender device ID."); } - std::string senderKeyserverID = - std::string([content.userInfo[keyserverIDKey] UTF8String]); - - auto decryptResult = comm::NotificationsCryptoModule::statefulDecrypt( - senderKeyserverID, - encryptedData, - comm::NotificationsCryptoModule::olmEncryptedTypeMessage); NSString *decryptedSerializedPayload = [NSString stringWithUTF8String:decryptResult->getDecryptedData().c_str()]; NSDictionary *decryptedPayload = [NSJSONSerialization JSONObjectWithData:[decryptedSerializedPayload dataUsingEncoding:NSUTF8StringEncoding] options:0 error:nil]; NSMutableDictionary *mutableUserInfo = [content.userInfo mutableCopy]; NSMutableDictionary *mutableAps = nil; if (mutableUserInfo[@"aps"]) { mutableAps = [mutableUserInfo[@"aps"] mutableCopy]; } NSString *body = decryptedPayload[@"merged"]; if (body) { content.body = body; if (mutableAps && mutableAps[@"alert"]) { mutableAps[@"alert"] = body; } } else { mutableUserInfo[needsSilentBadgeUpdateKey] = @(YES); } NSString *threadID = decryptedPayload[@"threadID"]; if (threadID) { content.threadIdentifier = threadID; mutableUserInfo[@"threadID"] = threadID; if (mutableAps) { mutableAps[@"thread-id"] = threadID; } } NSString *badgeStr = decryptedPayload[@"badge"]; if (badgeStr) { NSNumber *badge = @([badgeStr intValue]); content.badge = badge; if (mutableAps) { mutableAps[@"badge"] = badge; } } // The rest have been already decrypted and handled. static NSArray *handledKeys = @[ @"merged", @"badge", @"threadID" ]; for (NSString *payloadKey in decryptedPayload) { if ([handledKeys containsObject:payloadKey]) { continue; } mutableUserInfo[payloadKey] = decryptedPayload[payloadKey]; } if (mutableAps) { mutableUserInfo[@"aps"] = mutableAps; } [mutableUserInfo removeObjectForKey:encryptedPayloadKey]; mutableUserInfo[@"successfullyDecrypted"] = @(YES); content.userInfo = mutableUserInfo; return decryptResult; } // Apple documentation for NSE does not explicitly state // that single NSE instance will be used by only one thread // at a time. Even though UNNotificationServiceExtension API // suggests that it could be the case we don't trust it // and keep a synchronized collection of handlers and contents. // We keep reports of events that strongly suggest there is // parallelism in notifications processing. In particular we // have see notifications not being decrypted when access // to encryption keys had not been correctly implemented. // Similar behaviour is adopted by other apps such as Signal, // Telegram or Element. - (void)setUpNSEInstance { @synchronized(self) { if (self.contentHandlers) { return; } self.contentHandlers = [[NSMutableDictionary alloc] init]; self.contents = [[NSMutableDictionary alloc] init]; } } - (void)putContent:(UNNotificationContent *)content withHandler:(void (^)(UNNotificationContent *_Nonnull))handler forKey:(NSString *)key { @synchronized(self.contentHandlers) { [self.contentHandlers setObject:handler forKey:key]; [self.contents setObject:content forKey:key]; } } - (void)callContentHandlerForKey:(NSString *)key withContent:(UNNotificationContent *)content { void (^handler)(UNNotificationContent *_Nonnull); @synchronized(self.contentHandlers) { handler = [self.contentHandlers objectForKey:key]; [self.contentHandlers removeObjectForKey:key]; [self.contents removeObjectForKey:key]; } if (!handler) { return; } handler(content); } - (UNNotificationContent *)buildContentForError:(NSString *)error { UNMutableNotificationContent *content = [[UNMutableNotificationContent alloc] init]; content.body = error; return content; } - (void)callContentHandlerForKey:(NSString *)key onErrorMessage:(NSString *)errorMessage withPublicUserContent:(UNNotificationContent *)publicUserContent { comm::Logger::log(std::string([errorMessage UTF8String])); if (comm::StaffUtils::isStaffRelease()) { NSString *errorNotifId = [@"error_for_" stringByAppendingString:key]; UNNotificationContent *content = [self buildContentForError:errorMessage]; UNNotificationRequest *localNotifRequest = [UNNotificationRequest requestWithIdentifier:errorNotifId content:content trigger:nil]; [self displayLocalNotificationFor:localNotifRequest]; } [self callContentHandlerForKey:key withContent:publicUserContent]; } - (void)displayLocalNotificationFor:(UNNotificationRequest *)localNotifRequest { // We must wait until local notif display completion // handler returns. Context: // https://developer.apple.com/forums/thread/108340?answerId=331640022#331640022 dispatch_semaphore_t localNotifDisplaySemaphore = dispatch_semaphore_create(0); __block NSError *localNotifDisplayError = nil; [UNUserNotificationCenter.currentNotificationCenter addNotificationRequest:localNotifRequest withCompletionHandler:^(NSError *_Nullable error) { if (error) { localNotifDisplayError = error; } dispatch_semaphore_signal(localNotifDisplaySemaphore); }]; dispatch_semaphore_wait( localNotifDisplaySemaphore, dispatch_time(DISPATCH_TIME_NOW, semaphoreAwaitTimeLimit)); if (localNotifDisplayError) { throw std::runtime_error( std::string([localNotifDisplayError.localizedDescription UTF8String])); } } - (BOOL)isAppShowingNotificationWith:(NSString *)identifier { dispatch_semaphore_t getAllDeliveredNotifsSemaphore = dispatch_semaphore_create(0); __block BOOL foundNotification = NO; [UNUserNotificationCenter.currentNotificationCenter getDeliveredNotificationsWithCompletionHandler:^( NSArray *_Nonnull notifications) { for (UNNotification *notif in notifications) { if (notif.request.content.userInfo[@"id"] && [notif.request.content.userInfo[@"id"] isEqualToString:identifier]) { foundNotification = YES; break; } } dispatch_semaphore_signal(getAllDeliveredNotifsSemaphore); }]; dispatch_semaphore_wait( getAllDeliveredNotifsSemaphore, dispatch_time(DISPATCH_TIME_NOW, semaphoreAwaitTimeLimit)); return foundNotification; } // Monitor memory usage + (NSString *)getAndSetMemoryEventMessage:(NSString *)message { static NSString *memoryEventMessage = nil; static NSLock *memoryEventLock = [[NSLock alloc] init]; @try { if (![memoryEventLock tryLock]) { return nil; } NSString *currentMemoryEventMessage = memoryEventMessage ? [memoryEventMessage copy] : nil; memoryEventMessage = [message copy]; return currentMemoryEventMessage; } @finally { [memoryEventLock unlock]; } } + (dispatch_source_t)registerForMemoryEvents { dispatch_source_t memorySource = dispatch_source_create( DISPATCH_SOURCE_TYPE_MEMORYPRESSURE, 0L, DISPATCH_MEMORYPRESSURE_CRITICAL, dispatch_get_main_queue()); dispatch_block_t eventHandler = ^{ NSString *criticalMemoryEventMessage = [NSString stringWithFormat: @"NSE: Received CRITICAL memory event. Memory usage: %ld bytes", getMemoryUsageInBytes()]; comm::Logger::log(std::string([criticalMemoryEventMessage UTF8String])); if (!comm::StaffUtils::isStaffRelease()) { // If it is not a staff release we don't set // memoryEventMessage variable since it will // not be displayed to the client anyway return; } [NotificationService getAndSetMemoryEventMessage:criticalMemoryEventMessage]; }; dispatch_source_set_event_handler(memorySource, eventHandler); dispatch_activate(memorySource); return memorySource; } // AES Cryptography static AESCryptoModuleObjCCompat *_aesCryptoModule = nil; + (AESCryptoModuleObjCCompat *)processLocalAESCryptoModule { return _aesCryptoModule; } + (NSDictionary *)aesDecryptAndParse:(NSData *)sealedData withKey:(NSData *)key { NSError *decryptError = nil; NSInteger destinationLength = [[NotificationService processLocalAESCryptoModule] decryptedLength:sealedData]; NSMutableData *destination = [NSMutableData dataWithLength:destinationLength]; [[NotificationService processLocalAESCryptoModule] decryptWithKey:key sealedData:sealedData destination:destination withError:&decryptError]; if (decryptError) { comm::Logger::log( "NSE: Notification aes decryption failure. Details: " + std::string([decryptError.localizedDescription UTF8String])); return nil; } NSString *decryptedSerializedPayload = [[NSString alloc] initWithData:destination encoding:NSUTF8StringEncoding]; return [NSJSONSerialization JSONObjectWithData:[decryptedSerializedPayload dataUsingEncoding:NSUTF8StringEncoding] options:0 error:nil]; } // Process-local initialization code NSE may use different threads and instances // of this class to process notifs, but it usually keeps the same process for // extended period of time. Objects that can be initialized once and reused on // each notif should be declared in a method below to avoid wasting resources + (void)setUpNSEProcess { static dispatch_source_t memoryEventSource; static dispatch_once_t onceToken; dispatch_once(&onceToken, ^{ _aesCryptoModule = [[AESCryptoModuleObjCCompat alloc] init]; memoryEventSource = [NotificationService registerForMemoryEvents]; }); } @end