Page MenuHomePhabricator

D12750.diff
No OneTemporary

D12750.diff

diff --git a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp
--- a/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp
+++ b/native/android/app/src/cpp/NotificationsCryptoModuleJNIHelper.cpp
@@ -17,12 +17,24 @@
return decryptedData;
}
+std::string NotificationsCryptoModuleJNIHelper::peerDecrypt(
+ facebook::jni::alias_ref<NotificationsCryptoModuleJNIHelper> jThis,
+ std::string deviceID,
+ std::string data,
+ int messageType) {
+ std::string decryptedData =
+ NotificationsCryptoModule::peerDecrypt(deviceID, data, messageType);
+ return decryptedData;
+}
+
void NotificationsCryptoModuleJNIHelper::registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"olmEncryptedTypeMessage",
NotificationsCryptoModuleJNIHelper::olmEncryptedTypeMessage),
makeNativeMethod("decrypt", NotificationsCryptoModuleJNIHelper::decrypt),
+ makeNativeMethod(
+ "peerDecrypt", NotificationsCryptoModuleJNIHelper::peerDecrypt),
});
}
} // namespace comm
diff --git a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java
--- a/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java
+++ b/native/android/app/src/main/java/app/comm/android/fbjni/NotificationsCryptoModule.java
@@ -4,4 +4,6 @@
public static native int olmEncryptedTypeMessage();
public static native String
decrypt(String keyserverID, String data, int messageType);
+ public static native String
+ peerDecrypt(String deviceID, String data, int messageType);
}
diff --git a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java
--- a/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java
+++ b/native/android/app/src/main/java/app/comm/android/notifications/CommNotificationsHandler.java
@@ -56,6 +56,8 @@
private static final String GROUP_NOTIF_IDS_KEY = "groupNotifIDs";
private static final String COLLAPSE_ID_KEY = "collapseKey";
private static final String KEYSERVER_ID_KEY = "keyserverID";
+ private static final String SENDER_DEVICE_ID_KEY = "senderDeviceID";
+ private static final String MESSAGE_TYPE_KEY = "type";
private static final String CHANNEL_ID = "default";
private static final long[] VIBRATION_SPEC = {500, 500};
private static final Map<Integer, String> NOTIF_PRIORITY_VERBOSE =
@@ -107,19 +109,18 @@
handleAlteredNotificationPriority(message);
if (StaffUtils.isStaffRelease() &&
- message.getData().get(KEYSERVER_ID_KEY) == null) {
+ message.getData().get(KEYSERVER_ID_KEY) == null &&
+ message.getData().get(SENDER_DEVICE_ID_KEY) == null) {
displayErrorMessageNotification(
- "Received notification without keyserver ID.",
+ "Received notification without keyserver ID nor sender device ID",
"Missing keyserver ID.",
null);
return;
}
- String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY);
-
if (message.getData().get(ENCRYPTED_PAYLOAD_KEY) != null) {
try {
- message = this.olmDecryptRemoteMessage(message, senderKeyserverID);
+ message = this.olmDecryptRemoteMessage(message);
} catch (JSONException e) {
Log.w("COMM", "Malformed notification JSON payload.", e);
return;
@@ -139,6 +140,7 @@
"Unencrypted notification",
null);
}
+
if ("1".equals(message.getData().get(ENCRYPTION_FAILED_KEY))) {
Log.w("COMM", "Received erroneously unencrypted notification.");
}
@@ -269,14 +271,15 @@
}
private void handleUnreadCountUpdate(RemoteMessage message) {
+ if (message.getData().get(KEYSERVER_ID_KEY) == null) {
+ return;
+ }
+
String badge = message.getData().get(BADGE_KEY);
if (badge == null) {
return;
}
- if (message.getData().get(KEYSERVER_ID_KEY) == null) {
- throw new RuntimeException("Received badge update without keyserver ID.");
- }
String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY);
String senderKeyserverUnreadCountKey = String.join(
MMKV_KEY_SEPARATOR,
@@ -480,15 +483,28 @@
return message;
}
- private RemoteMessage
- olmDecryptRemoteMessage(RemoteMessage message, String senderKeyserverID)
- throws JSONException, IllegalStateException {
+ private RemoteMessage olmDecryptRemoteMessage(RemoteMessage message)
+ throws JSONException, IllegalStateException, NumberFormatException {
String encryptedSerializedPayload =
message.getData().get(ENCRYPTED_PAYLOAD_KEY);
- String decryptedSerializedPayload = NotificationsCryptoModule.decrypt(
- senderKeyserverID,
- encryptedSerializedPayload,
- NotificationsCryptoModule.olmEncryptedTypeMessage());
+
+ String decryptedSerializedPayload;
+ if (message.getData().get(KEYSERVER_ID_KEY) != null) {
+ String senderKeyserverID = message.getData().get(KEYSERVER_ID_KEY);
+ decryptedSerializedPayload = NotificationsCryptoModule.decrypt(
+ senderKeyserverID,
+ encryptedSerializedPayload,
+ NotificationsCryptoModule.olmEncryptedTypeMessage());
+ } else if (message.getData().get(SENDER_DEVICE_ID_KEY) != null) {
+ String senderDeviceID = message.getData().get(SENDER_DEVICE_ID_KEY);
+ String messageTypeString = message.getData().get(MESSAGE_TYPE_KEY);
+ int messageType = Integer.parseInt(messageTypeString);
+ decryptedSerializedPayload = NotificationsCryptoModule.peerDecrypt(
+ senderDeviceID, encryptedSerializedPayload, messageType);
+ } else {
+ throw new RuntimeException(
+ "Received notification without keyserver ID nor sender device ID.");
+ }
return updateRemoteMessageWithDecryptedPayload(
message, decryptedSerializedPayload);
diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp
--- a/native/cpp/CommonCpp/CryptoTools/Session.cpp
+++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp
@@ -83,9 +83,11 @@
session->olmSessionBuffer.resize(::olm_session_size());
::olm_session(session->olmSessionBuffer.data());
if (-1 ==
- ::olm_create_inbound_session(
+ ::olm_create_inbound_session_from(
session->getOlmSession(),
account,
+ idKeys.data() + ID_KEYS_PREFIX_OFFSET,
+ KEYSIZE,
tmpEncryptedMessage.data(),
encryptedMessage.size())) {
throw std::runtime_error(
diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h
--- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h
+++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.h
@@ -114,6 +114,49 @@
void flushState() override;
};
+ class StatefulPeerInitDecryptResult : public BaseStatefulDecryptResult {
+ StatefulPeerInitDecryptResult(
+ std::shared_ptr<crypto::Session> session,
+ std::shared_ptr<crypto::CryptoModule> account,
+ std::string sessionPicklingKey,
+ std::string accountPicklingKey,
+ std::string deviceID,
+ std::string decryptedData);
+ std::shared_ptr<crypto::Session> sessionState;
+ std::shared_ptr<crypto::CryptoModule> accountState;
+ std::string accountPicklingKey;
+ std::string deviceID;
+ friend NotificationsCryptoModule;
+
+ public:
+ void flushState() override;
+ };
+
+ class StatefulPeerDecryptResult : public BaseStatefulDecryptResult {
+ StatefulPeerDecryptResult(
+ std::unique_ptr<crypto::Session> session,
+ std::string deviceID,
+ std::string picklingKey,
+ std::string decryptedData);
+
+ std::unique_ptr<crypto::Session> sessionState;
+ std::string deviceID;
+ friend NotificationsCryptoModule;
+
+ public:
+ void flushState() override;
+ };
+
+ class StatefulPeerConflictDecryptResult : public BaseStatefulDecryptResult {
+ StatefulPeerConflictDecryptResult(
+ std::string picklingKey,
+ std::string decryptedData);
+ friend NotificationsCryptoModule;
+
+ public:
+ void flushState() override;
+ };
+
private:
static std::unique_ptr<NotificationsCryptoModule::BaseStatefulDecryptResult>
prepareLegacyDecryptedState(
@@ -134,6 +177,16 @@
static crypto::EncryptedData
encrypt(const std::string &deviceID, const std::string &payload);
+ static std::unique_ptr<BaseStatefulDecryptResult> statefulPeerDecrypt(
+ const std::string &deviceID,
+ const std::string &data,
+ const size_t messageType);
+
+ static std::string peerDecrypt(
+ const std::string &deviceID,
+ const std::string &data,
+ const size_t messageType);
+
static void
flushState(std::unique_ptr<BaseStatefulDecryptResult> statefulDecryptResult);
};
diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp
--- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp
+++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp
@@ -3,8 +3,9 @@
#include "../../CryptoTools/Tools.h"
#include "../../Tools/CommMMKV.h"
#include "../../Tools/CommSecureStore.h"
-#include "../../Tools/Logger.h"
#include "../../Tools/PlatformSpecificTools.h"
+#include "NotificationsInboundKeysProvider.h"
+#include "olm/session.hh"
#include "Logger.h"
#include <fcntl.h>
@@ -445,6 +446,57 @@
}
}
+NotificationsCryptoModule::StatefulPeerInitDecryptResult::
+ StatefulPeerInitDecryptResult(
+ std::shared_ptr<crypto::Session> session,
+ std::shared_ptr<crypto::CryptoModule> account,
+ std::string sessionPicklingKey,
+ std::string accountPicklingKey,
+ std::string deviceID,
+ std::string decryptedData)
+ : BaseStatefulDecryptResult(sessionPicklingKey, decryptedData),
+ sessionState(session),
+ accountState(account),
+ accountPicklingKey(accountPicklingKey),
+ deviceID(deviceID) {
+}
+
+void NotificationsCryptoModule::StatefulPeerInitDecryptResult::flushState() {
+ NotificationsCryptoModule::persistNotificationsSessionInternal(
+ false, this->deviceID, this->picklingKey, std::move(this->sessionState));
+ NotificationsCryptoModule::persistNotificationsAccount(
+ std::move(this->accountState), this->accountPicklingKey);
+}
+
+NotificationsCryptoModule::StatefulPeerDecryptResult::StatefulPeerDecryptResult(
+ std::unique_ptr<crypto::Session> session,
+ std::string deviceID,
+ std::string picklingKey,
+ std::string decryptedData)
+ : NotificationsCryptoModule::BaseStatefulDecryptResult::
+ BaseStatefulDecryptResult(picklingKey, decryptedData),
+ sessionState(std::move(session)),
+ deviceID(deviceID) {
+}
+
+void NotificationsCryptoModule::StatefulPeerDecryptResult::flushState() {
+ NotificationsCryptoModule::persistNotificationsSessionInternal(
+ false, this->deviceID, this->picklingKey, std::move(this->sessionState));
+}
+
+NotificationsCryptoModule::StatefulPeerConflictDecryptResult::
+ StatefulPeerConflictDecryptResult(
+ std::string picklingKey,
+ std::string decryptedData)
+ : NotificationsCryptoModule::BaseStatefulDecryptResult(
+ picklingKey,
+ decryptedData) {
+}
+void NotificationsCryptoModule::StatefulPeerConflictDecryptResult::
+ flushState() {
+ return;
+}
+
std::unique_ptr<NotificationsCryptoModule::BaseStatefulDecryptResult>
NotificationsCryptoModule::prepareLegacyDecryptedState(
const std::string &data,
@@ -550,6 +602,142 @@
std::move(statefulDecryptResult));
}
+std::unique_ptr<NotificationsCryptoModule::BaseStatefulDecryptResult>
+NotificationsCryptoModule::statefulPeerDecrypt(
+ const std::string &deviceID,
+ const std::string &data,
+ const size_t messageType) {
+ if (messageType != OLM_MESSAGE_TYPE_MESSAGE &&
+ messageType != OLM_MESSAGE_TYPE_PRE_KEY) {
+ throw std::runtime_error(
+ "Received message of invalid type from device: " + deviceID);
+ }
+
+ auto maybeSessionWithPicklingKey =
+ NotificationsCryptoModule::fetchNotificationsSession(false, deviceID);
+
+ if (!maybeSessionWithPicklingKey.has_value() &&
+ messageType == OLM_MESSAGE_TYPE_MESSAGE) {
+ throw std::runtime_error(
+ "Received MESSAGE_TYPE_MESSAGE message from device: " + deviceID +
+ " but session not initialized.");
+ }
+
+ crypto::EncryptedData encryptedData{
+ std::vector<uint8_t>(data.begin(), data.end()), messageType};
+
+ bool isSenderChainEmpty = true;
+ bool hasReceivedMessage = false;
+ bool sessionExists = maybeSessionWithPicklingKey.has_value();
+
+ if (sessionExists) {
+ ::olm::Session *olmSessionAsCppClass = reinterpret_cast<::olm::Session *>(
+ maybeSessionWithPicklingKey.value().first->getOlmSession());
+ isSenderChainEmpty = olmSessionAsCppClass->ratchet.sender_chain.empty();
+ hasReceivedMessage = olmSessionAsCppClass->received_message;
+ }
+
+ // regular message
+ bool isRegularMessage =
+ sessionExists && messageType == OLM_MESSAGE_TYPE_MESSAGE;
+
+ bool isRegularPrekeyMessage = sessionExists &&
+ messageType == OLM_MESSAGE_TYPE_PRE_KEY && isSenderChainEmpty &&
+ hasReceivedMessage;
+
+ if (isRegularMessage || isRegularPrekeyMessage) {
+ std::string decryptedData =
+ maybeSessionWithPicklingKey.value().first->decrypt(encryptedData);
+ StatefulPeerDecryptResult decryptResult = StatefulPeerDecryptResult(
+ std::move(maybeSessionWithPicklingKey.value().first),
+ deviceID,
+ maybeSessionWithPicklingKey.value().second,
+ decryptedData);
+ return std::make_unique<StatefulPeerDecryptResult>(
+ std::move(decryptResult));
+ }
+
+ // At this point we either face race condition or session reset attempt or
+ // session initialization attempt. For each of this scenario new inbound
+ // session must be created in order to decrypt message
+ std::string notifInboundKeys =
+ NotificationsInboundKeysProvider::getNotifsInboundKeysForDeviceID(
+ deviceID);
+ auto maybeAccountWithPicklingKey =
+ NotificationsCryptoModule::fetchNotificationsAccount();
+
+ if (!maybeAccountWithPicklingKey.has_value()) {
+ throw std::runtime_error("Notifications account not initialized.");
+ }
+
+ auto accountWithPicklingKey = maybeAccountWithPicklingKey.value();
+ accountWithPicklingKey.first->initializeInboundForReceivingSession(
+ deviceID,
+ {data.begin(), data.end()},
+ {notifInboundKeys.begin(), notifInboundKeys.end()},
+ // The argument below is relevant for content only
+ 0,
+ true);
+ std::shared_ptr<crypto::Session> newInboundSession =
+ accountWithPicklingKey.first->getSessionByDeviceId(deviceID);
+ accountWithPicklingKey.first->removeSessionByDeviceId(deviceID);
+ std::string decryptedData = newInboundSession->decrypt(encryptedData);
+
+ // session reset attempt or session initialization - handled the same
+ bool sessionResetAttempt =
+ sessionExists && !isSenderChainEmpty && hasReceivedMessage;
+
+ // race condition
+ bool raceCondition =
+ sessionExists && !isSenderChainEmpty && !hasReceivedMessage;
+
+ // device ID comparison
+ folly::Optional<std::string> maybeOurDeviceID =
+ CommSecureStore::get(CommSecureStore::deviceID);
+ if (!maybeOurDeviceID.hasValue()) {
+ throw std::runtime_error("Session creation attempt but no device id");
+ }
+ std::string ourDeviceID = maybeOurDeviceID.value();
+ bool thisDeviceWinsRaceCondition = ourDeviceID > deviceID;
+
+ // If there is no session or there is session reset attempt or
+ // there is a race condition but we loos device id comparison
+ // we end up creating new session as inbound
+ if (!sessionExists || sessionResetAttempt ||
+ (raceCondition && !thisDeviceWinsRaceCondition)) {
+ std::string sessionPicklingKey = crypto::Tools::generateRandomString(64);
+ StatefulPeerInitDecryptResult decryptResult = StatefulPeerInitDecryptResult(
+ newInboundSession,
+ accountWithPicklingKey.first,
+ sessionPicklingKey,
+ accountWithPicklingKey.second,
+ deviceID,
+ decryptedData);
+ return std::make_unique<StatefulPeerInitDecryptResult>(
+ std::move(decryptResult));
+ }
+
+ // If there is a race condition but we win device id comparison
+ // we return object that carries decrypted data but won't persist
+ // any session state
+ StatefulPeerConflictDecryptResult decryptResult =
+ StatefulPeerConflictDecryptResult(
+ maybeSessionWithPicklingKey.value().second, decryptedData);
+ return std::make_unique<StatefulPeerConflictDecryptResult>(
+ std::move(decryptResult));
+}
+
+std::string NotificationsCryptoModule::peerDecrypt(
+ const std::string &deviceID,
+ const std::string &data,
+ const size_t messageType) {
+ auto statefulDecryptResult = NotificationsCryptoModule::statefulPeerDecrypt(
+ deviceID, data, messageType);
+ std::string decryptedData = statefulDecryptResult->decryptedData;
+ statefulDecryptResult->flushState();
+ return decryptedData;
+}
+
void NotificationsCryptoModule::flushState(
std::unique_ptr<BaseStatefulDecryptResult> baseStatefulDecryptResult) {
baseStatefulDecryptResult->flushState();
diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h
--- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h
+++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModuleJNIHelper.h
@@ -18,6 +18,12 @@
std::string data,
int messageType);
+ static std::string peerDecrypt(
+ facebook::jni::alias_ref<NotificationsCryptoModuleJNIHelper> jThis,
+ std::string deviceID,
+ std::string data,
+ int messageType);
+
static void registerNatives();
};
} // namespace comm
diff --git a/native/ios/NotificationService/NotificationService.mm b/native/ios/NotificationService/NotificationService.mm
--- a/native/ios/NotificationService/NotificationService.mm
+++ b/native/ios/NotificationService/NotificationService.mm
@@ -16,6 +16,8 @@
NSString *const encryptionFailedKey = @"encryptionFailed";
NSString *const collapseIDKey = @"collapseID";
NSString *const keyserverIDKey = @"keyserverID";
+NSString *const senderDeviceIDKey = @"senderDeviceID";
+NSString *const messageTypeKey = @"type";
NSString *const blobHashKey = @"blobHash";
NSString *const blobHolderKey = @"blobHolder";
NSString *const encryptionKeyLabel = @"encryptionKey";
@@ -504,7 +506,7 @@
- (void)calculateTotalUnreadCountInPlace:
(UNMutableNotificationContent *)content {
if (!content.userInfo[keyserverIDKey]) {
- throw std::runtime_error("Received badge update without keyserver ID.");
+ return;
}
std::string senderKeyserverID =
std::string([content.userInfo[keyserverIDKey] UTF8String]);
@@ -609,17 +611,26 @@
std::string encryptedData =
std::string([content.userInfo[encryptedPayloadKey] UTF8String]);
- if (!content.userInfo[keyserverIDKey]) {
+ std::unique_ptr<comm::NotificationsCryptoModule::BaseStatefulDecryptResult>
+ decryptResult;
+ if (content.userInfo[keyserverIDKey]) {
+ std::string senderKeyserverID =
+ std::string([content.userInfo[keyserverIDKey] UTF8String]);
+ decryptResult = comm::NotificationsCryptoModule::statefulDecrypt(
+ senderKeyserverID,
+ encryptedData,
+ comm::NotificationsCryptoModule::olmEncryptedTypeMessage);
+ } else if (
+ content.userInfo[senderDeviceIDKey] && content.userInfo[messageTypeKey]) {
+ std::string senderDeviceID =
+ std::string([content.userInfo[senderDeviceIDKey] UTF8String]);
+ size_t messageType = [content.userInfo[messageTypeKey] intValue];
+ decryptResult = comm::NotificationsCryptoModule::statefulPeerDecrypt(
+ senderDeviceID, encryptedData, messageType);
+ } else {
throw std::runtime_error(
- "Received encrypted notification without keyserverID.");
+ "Received notification without keyserver ID nor sender device ID.");
}
- std::string senderKeyserverID =
- std::string([content.userInfo[keyserverIDKey] UTF8String]);
-
- auto decryptResult = comm::NotificationsCryptoModule::statefulDecrypt(
- senderKeyserverID,
- encryptedData,
- comm::NotificationsCryptoModule::olmEncryptedTypeMessage);
NSString *decryptedSerializedPayload =
[NSString stringWithUTF8String:decryptResult->getDecryptedData().c_str()];

File Metadata

Mime Type
text/plain
Expires
Fri, Nov 22, 7:25 PM (17 h, 54 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2562477
Default Alt Text
D12750.diff (20 KB)

Event Timeline