diff --git a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp --- a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp +++ b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp @@ -17,12 +17,24 @@ return decryptedData; } +std::string NotificationsCryptoModuleJNIHelper::peerDecrypt( + facebook::jni::alias_ref jThis, + std::string deviceID, + std::string data, + int messageType) { + std::string decryptedData = + NotificationsCryptoModule::peerDecrypt(deviceID, data, messageType); + return decryptedData; +} + void NotificationsCryptoModuleJNIHelper::registerNatives() { javaClassStatic()->registerNatives({ makeNativeMethod( "olmEncryptedTypeMessage", NotificationsCryptoModuleJNIHelper::olmEncryptedTypeMessage), makeNativeMethod("decrypt", NotificationsCryptoModuleJNIHelper::decrypt), + makeNativeMethod( + "peerDecrypt", NotificationsCryptoModuleJNIHelper::peerDecrypt), }); } } // namespace comm diff --git a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java --- a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java +++ b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java @@ -4,4 +4,6 @@ public static native int olmEncryptedTypeMessage(); public static native String decrypt(String keyserverID, String data, int messageType); + public static native String + peerDecrypt(String deviceID, String data, int messageType); } diff --git a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java --- a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java +++ b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java @@ -56,6 +56,8 @@ private static final String GROUP_NOTIF_IDS_KEY = "groupNotifIDs"; private static final String COLLAPSE_ID_KEY = "collapseKey"; private static final String KEYSERVER_ID_KEY = "keyserverID"; + private static final String SENDER_DEVICE_ID_KEY = "senderDeviceID"; + private static final String MESSAGE_TYPE_KEY = "type"; private static final String CHANNEL_ID = "default"; private static final long[] VIBRATION_SPEC = {500, 500}; private static final Map NOTIF_PRIORITY_VERBOSE = @@ -107,19 +109,18 @@ handleAlteredNotificationPriority(message); if (StaffUtils.isStaffRelease() && - message.getData().get(KEYSERVER_ID_KEY) == null) { + message.getData().get(KEYSERVER_ID_KEY) == null && + message.getData().get(SENDER_DEVICE_ID_KEY) == null) { displayErrorMessageNotification( - "Received notification without keyserver ID.", + "Received notification without keyserver ID nor sender device ID", "Missing keyserver ID.", null); return; } - String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); - if (message.getData().get(ENCRYPTED_PAYLOAD_KEY) != null) { try { - message = this.olmDecryptRemoteMessage(message, senderKeyserverID); + message = this.olmDecryptRemoteMessage(message); } catch (JSONException e) { Log.w("COMM", "Malformed notification JSON payload.", e); return; @@ -139,6 +140,7 @@ "Unencrypted notification", null); } + if ("1".equals(message.getData().get(ENCRYPTION_FAILED_KEY))) { Log.w("COMM", "Received erroneously unencrypted notification."); } @@ -269,14 +271,15 @@ } private void handleUnreadCountUpdate(RemoteMessage message) { + if (message.getData().get(KEYSERVER_ID_KEY) == null) { + return; + } + String badge = message.getData().get(BADGE_KEY); if (badge == null) { return; } - if (message.getData().get(KEYSERVER_ID_KEY) == null) { - throw new RuntimeException("Received badge update without keyserver ID."); - } String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); String senderKeyserverUnreadCountKey = String.join( MMKV_KEY_SEPARATOR, @@ -480,15 +483,28 @@ return message; } - private RemoteMessage - olmDecryptRemoteMessage(RemoteMessage message, String senderKeyserverID) - throws JSONException, IllegalStateException { + private RemoteMessage olmDecryptRemoteMessage(RemoteMessage message) + throws JSONException, IllegalStateException, NumberFormatException { String encryptedSerializedPayload = message.getData().get(ENCRYPTED_PAYLOAD_KEY); - String decryptedSerializedPayload = NotificationsCryptoModule.decrypt( - senderKeyserverID, - encryptedSerializedPayload, - NotificationsCryptoModule.olmEncryptedTypeMessage()); + + String decryptedSerializedPayload; + if (message.getData().get(KEYSERVER_ID_KEY) != null) { + String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY); + decryptedSerializedPayload = NotificationsCryptoModule.decrypt( + senderKeyserverID, + encryptedSerializedPayload, + NotificationsCryptoModule.olmEncryptedTypeMessage()); + } else if (message.getData().get(SENDER_DEVICE_ID_KEY) != null) { + String senderDeviceID = message.getData().get(SENDER_DEVICE_ID_KEY); + String messageTypeString = message.getData().get(MESSAGE_TYPE_KEY); + int messageType = Integer.parseInt(messageTypeString); + decryptedSerializedPayload = NotificationsCryptoModule.peerDecrypt( + senderDeviceID, encryptedSerializedPayload, messageType); + } else { + throw new RuntimeException( + "Received notification without keyserver ID nor sender device ID."); + } return updateRemoteMessageWithDecryptedPayload( message, decryptedSerializedPayload); diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -83,9 +83,11 @@ session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); if (-1 == - ::olm_create_inbound_session( + ::olm_create_inbound_session_from( session->getOlmSession(), account, + idKeys.data() + ID_KEYS_PREFIX_OFFSET, + KEYSIZE, tmpEncryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error( diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h @@ -111,6 +111,49 @@ void flushState() override; }; + class StatefulPeerInitDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerInitDecryptResult( + std::shared_ptr session, + std::shared_ptr account, + std::string sessionPicklingKey, + std::string accountPicklingKey, + std::string deviceID, + std::string decryptedData); + std::shared_ptr sessionState; + std::shared_ptr accountState; + std::string accountPicklingKey; + std::string deviceID; + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + + class StatefulPeerDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerDecryptResult( + std::unique_ptr session, + std::string deviceID, + std::string picklingKey, + std::string decryptedData); + + std::unique_ptr sessionState; + std::string deviceID; + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + + class StatefulPeerConflictDecryptResult : public BaseStatefulDecryptResult { + StatefulPeerConflictDecryptResult( + std::string picklingKey, + std::string decryptedData); + friend NotificationsCryptoModule; + + public: + void flushState() override; + }; + private: static std::unique_ptr prepareLegacyDecryptedState( @@ -131,6 +174,16 @@ static crypto::EncryptedData encryptNotification(const std::string &deviceID, const std::string &payload); + static std::unique_ptr statefulPeerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType); + + static std::string peerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType); + static void flushState(std::unique_ptr statefulDecryptResult); }; diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -3,8 +3,9 @@ #include "../../CryptoTools/Tools.h" #include "../../Tools/CommMMKV.h" #include "../../Tools/CommSecureStore.h" -#include "../../Tools/Logger.h" #include "../../Tools/PlatformSpecificTools.h" +#include "NotificationsInboundKeysProvider.h" +#include "olm/session.hh" #include #include @@ -424,6 +425,57 @@ } } +NotificationsCryptoModule::StatefulPeerInitDecryptResult:: + StatefulPeerInitDecryptResult( + std::shared_ptr session, + std::shared_ptr account, + std::string sessionPicklingKey, + std::string accountPicklingKey, + std::string deviceID, + std::string decryptedData) + : BaseStatefulDecryptResult(sessionPicklingKey, decryptedData), + sessionState(session), + accountState(account), + accountPicklingKey(accountPicklingKey), + deviceID(deviceID) { +} + +void NotificationsCryptoModule::StatefulPeerInitDecryptResult::flushState() { + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, this->deviceID, this->picklingKey, std::move(this->sessionState)); + NotificationsCryptoModule::persistNotificationsAccount( + std::move(this->accountState), this->accountPicklingKey); +} + +NotificationsCryptoModule::StatefulPeerDecryptResult::StatefulPeerDecryptResult( + std::unique_ptr session, + std::string deviceID, + std::string picklingKey, + std::string decryptedData) + : NotificationsCryptoModule::BaseStatefulDecryptResult:: + BaseStatefulDecryptResult(picklingKey, decryptedData), + sessionState(std::move(session)), + deviceID(deviceID) { +} + +void NotificationsCryptoModule::StatefulPeerDecryptResult::flushState() { + NotificationsCryptoModule::persistNotificationsSessionInternal( + false, this->deviceID, this->picklingKey, std::move(this->sessionState)); +} + +NotificationsCryptoModule::StatefulPeerConflictDecryptResult:: + StatefulPeerConflictDecryptResult( + std::string picklingKey, + std::string decryptedData) + : NotificationsCryptoModule::BaseStatefulDecryptResult( + picklingKey, + decryptedData) { +} +void NotificationsCryptoModule::StatefulPeerConflictDecryptResult:: + flushState() { + return; +} + std::unique_ptr NotificationsCryptoModule::prepareLegacyDecryptedState( const std::string &data, @@ -529,6 +581,142 @@ std::move(statefulDecryptResult)); } +std::unique_ptr +NotificationsCryptoModule::statefulPeerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType) { + if (messageType != OLM_MESSAGE_TYPE_MESSAGE && + messageType != OLM_MESSAGE_TYPE_PRE_KEY) { + throw std::runtime_error( + "Received message of invalid type from device: " + deviceID); + } + + auto maybeSessionWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsSession(false, deviceID); + + if (!maybeSessionWithPicklingKey.has_value() && + messageType == OLM_MESSAGE_TYPE_MESSAGE) { + throw std::runtime_error( + "Received MESSAGE_TYPE_MESSAGE message from device: " + deviceID + + " but session not initialized."); + } + + crypto::EncryptedData encryptedData{ + std::vector(data.begin(), data.end()), messageType}; + + bool isSenderChainEmpty = true; + bool hasReceivedMessage = false; + bool sessionExists = maybeSessionWithPicklingKey.has_value(); + + if (sessionExists) { + ::olm::Session *olmSessionAsCppClass = reinterpret_cast<::olm::Session *>( + maybeSessionWithPicklingKey.value().first->getOlmSession()); + isSenderChainEmpty = olmSessionAsCppClass->ratchet.sender_chain.empty(); + hasReceivedMessage = olmSessionAsCppClass->received_message; + } + + // regular message + bool isRegularMessage = + sessionExists && messageType == OLM_MESSAGE_TYPE_MESSAGE; + + bool isRegularPrekeyMessage = sessionExists && + messageType == OLM_MESSAGE_TYPE_PRE_KEY && isSenderChainEmpty && + hasReceivedMessage; + + if (isRegularMessage || isRegularPrekeyMessage) { + std::string decryptedData = + maybeSessionWithPicklingKey.value().first->decrypt(encryptedData); + StatefulPeerDecryptResult decryptResult = StatefulPeerDecryptResult( + std::move(maybeSessionWithPicklingKey.value().first), + deviceID, + maybeSessionWithPicklingKey.value().second, + decryptedData); + return std::make_unique( + std::move(decryptResult)); + } + + // At this point we either face race condition or session reset attempt or + // session initialization attempt. For each of this scenario new inbound + // session must be created in order to decrypt message + std::string notifInboundKeys = + NotificationsInboundKeysProvider::getNotifsInboundKeysForDeviceID( + deviceID); + auto maybeAccountWithPicklingKey = + NotificationsCryptoModule::fetchNotificationsAccount(); + + if (!maybeAccountWithPicklingKey.has_value()) { + throw std::runtime_error("Notifications account not initialized."); + } + + auto accountWithPicklingKey = maybeAccountWithPicklingKey.value(); + accountWithPicklingKey.first->initializeInboundForReceivingSession( + deviceID, + {data.begin(), data.end()}, + {notifInboundKeys.begin(), notifInboundKeys.end()}, + // The argument below is relevant for content only + 0, + true); + std::shared_ptr newInboundSession = + accountWithPicklingKey.first->getSessionByDeviceId(deviceID); + accountWithPicklingKey.first->removeSessionByDeviceId(deviceID); + std::string decryptedData = newInboundSession->decrypt(encryptedData); + + // session reset attempt or session initialization - handled the same + bool sessionResetAttempt = + sessionExists && !isSenderChainEmpty && hasReceivedMessage; + + // race condition + bool raceCondition = + sessionExists && !isSenderChainEmpty && !hasReceivedMessage; + + // device ID comparison + folly::Optional maybeOurDeviceID = + CommSecureStore::get(CommSecureStore::deviceID); + if (!maybeOurDeviceID.hasValue()) { + throw std::runtime_error("Session creation attempt but no device id"); + } + std::string ourDeviceID = maybeOurDeviceID.value(); + bool thisDeviceWinsRaceCondition = ourDeviceID > deviceID; + + // If there is no session or there is session reset attempt or + // there is a race condition but we loos device id comparison + // we end up creating new session as inbound + if (!sessionExists || sessionResetAttempt || + (raceCondition && !thisDeviceWinsRaceCondition)) { + std::string sessionPicklingKey = crypto::Tools::generateRandomString(64); + StatefulPeerInitDecryptResult decryptResult = StatefulPeerInitDecryptResult( + newInboundSession, + accountWithPicklingKey.first, + sessionPicklingKey, + accountWithPicklingKey.second, + deviceID, + decryptedData); + return std::make_unique( + std::move(decryptResult)); + } + + // If there is a race condition but we win device id comparison + // we return object that carries decrypted data but won't persist + // any session state + StatefulPeerConflictDecryptResult decryptResult = + StatefulPeerConflictDecryptResult( + maybeSessionWithPicklingKey.value().second, decryptedData); + return std::make_unique( + std::move(decryptResult)); +} + +std::string NotificationsCryptoModule::peerDecrypt( + const std::string &deviceID, + const std::string &data, + const size_t messageType) { + auto statefulDecryptResult = NotificationsCryptoModule::statefulPeerDecrypt( + deviceID, data, messageType); + std::string decryptedData = statefulDecryptResult->decryptedData; + statefulDecryptResult->flushState(); + return decryptedData; +} + void NotificationsCryptoModule::flushState( std::unique_ptr baseStatefulDecryptResult) { baseStatefulDecryptResult->flushState(); diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h @@ -18,6 +18,12 @@ std::string data, int messageType); + static std::string peerDecrypt( + facebook::jni::alias_ref jThis, + std::string deviceID, + std::string data, + int messageType); + static void registerNatives(); }; } // namespace comm diff --git a/native/ios/NotificationService/NotificationService.mm b/native/ios/NotificationService/NotificationService.mm --- a/native/ios/NotificationService/NotificationService.mm +++ b/native/ios/NotificationService/NotificationService.mm @@ -16,6 +16,8 @@ NSString *const encryptionFailedKey = @"encryptionFailed"; NSString *const collapseIDKey = @"collapseID"; NSString *const keyserverIDKey = @"keyserverID"; +NSString *const senderDeviceIDKey = @"senderDeviceID"; +NSString *const messageTypeKey = @"type"; NSString *const blobHashKey = @"blobHash"; NSString *const blobHolderKey = @"blobHolder"; NSString *const encryptionKeyLabel = @"encryptionKey"; @@ -504,7 +506,7 @@ - (void)calculateTotalUnreadCountInPlace: (UNMutableNotificationContent *)content { if (!content.userInfo[keyserverIDKey]) { - throw std::runtime_error("Received badge update without keyserver ID."); + return; } std::string senderKeyserverID = std::string([content.userInfo[keyserverIDKey] UTF8String]); @@ -609,17 +611,26 @@ std::string encryptedData = std::string([content.userInfo[encryptedPayloadKey] UTF8String]); - if (!content.userInfo[keyserverIDKey]) { + std::unique_ptr + decryptResult; + if (content.userInfo[keyserverIDKey]) { + std::string senderKeyserverID = + std::string([content.userInfo[keyserverIDKey] UTF8String]); + decryptResult = comm::NotificationsCryptoModule::statefulDecrypt( + senderKeyserverID, + encryptedData, + comm::NotificationsCryptoModule::olmEncryptedTypeMessage); + } else if ( + content.userInfo[senderDeviceIDKey] && content.userInfo[messageTypeKey]) { + std::string senderDeviceID = + std::string([content.userInfo[senderDeviceIDKey] UTF8String]); + size_t messageType = [content.userInfo[messageTypeKey] intValue]; + decryptResult = comm::NotificationsCryptoModule::statefulPeerDecrypt( + senderDeviceID, encryptedData, messageType); + } else { throw std::runtime_error( - "Received encrypted notification without keyserverID."); + "Received notification without keyserver ID nor sender device ID."); } - std::string senderKeyserverID = - std::string([content.userInfo[keyserverIDKey] UTF8String]); - - auto decryptResult = comm::NotificationsCryptoModule::statefulDecrypt( - senderKeyserverID, - encryptedData, - comm::NotificationsCryptoModule::olmEncryptedTypeMessage); NSString *decryptedSerializedPayload = [NSString stringWithUTF8String:decryptResult->getDecryptedData().c_str()];