diff --git a/native/android/app/src/cpp/CommMMKV.cpp b/native/android/app/src/cpp/CommMMKV.cpp index b3eaeeeca..afd632e4c 100644 --- a/native/android/app/src/cpp/CommMMKV.cpp +++ b/native/android/app/src/cpp/CommMMKV.cpp @@ -1,140 +1,160 @@ #include "jniHelpers.h" #include #include using namespace facebook::jni; class CommMMKVJavaClass : public JavaClass { public: static auto constexpr kJavaDescriptor = "Lapp/comm/android/fbjni/CommMMKV;"; static void initialize() { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("initialize"); method(cls); } + static void lock() { + static const auto cls = javaClassStatic(); + static auto method = cls->getStaticMethod("lock"); + method(cls); + } + + static void unlock() { + static const auto cls = javaClassStatic(); + static auto method = cls->getStaticMethod("unlock"); + method(cls); + } + static void clearSensitiveData() { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("clearSensitiveData"); method(cls); } static bool setString(std::string key, std::string value) { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("setString"); return method(cls, key, value); } static std::optional getString(std::string key) { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("getString"); const auto result = method(cls, key); if (result) { return result->toStdString(); } return std::nullopt; } static bool setInt(std::string key, int value) { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("setInt"); return method(cls, key, value); } static std::optional getInt(std::string key, int noValue) { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod("getInt"); const auto result = method(cls, key, noValue); if (result) { return result->value(); } return std::nullopt; } static std::vector getAllKeys() { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod()>("getAllKeys"); auto methodResult = method(cls); std::vector result; for (int i = 0; i < methodResult->size(); i++) { result.push_back(methodResult->getElement(i)->toStdString()); } return result; } static void removeKeys(const std::vector &keys) { static const auto cls = javaClassStatic(); static auto method = cls->getStaticMethod>)>( "removeKeys"); local_ref> keysJava = JArrayClass::newArray(keys.size()); for (int i = 0; i < keys.size(); i++) { keysJava->setElement(i, *make_jstring(keys[i])); } method(cls, keysJava); } }; namespace comm { void CommMMKV::initialize() { NativeAndroidAccessProvider::runTask( []() { CommMMKVJavaClass::initialize(); }); } +CommMMKV::ScopedCommMMKVLock::ScopedCommMMKVLock() { + NativeAndroidAccessProvider::runTask([]() { CommMMKVJavaClass::lock(); }); +} + +CommMMKV::ScopedCommMMKVLock::~ScopedCommMMKVLock() { + NativeAndroidAccessProvider::runTask([]() { CommMMKVJavaClass::unlock(); }); +} + void CommMMKV::clearSensitiveData() { NativeAndroidAccessProvider::runTask( []() { CommMMKVJavaClass::clearSensitiveData(); }); } bool CommMMKV::setString(std::string key, std::string value) { bool result; NativeAndroidAccessProvider::runTask( [&]() { result = CommMMKVJavaClass::setString(key, value); }); return result; } std::optional CommMMKV::getString(std::string key) { std::optional result; NativeAndroidAccessProvider::runTask( [&]() { result = CommMMKVJavaClass::getString(key); }); return result; } bool CommMMKV::setInt(std::string key, int value) { bool result; NativeAndroidAccessProvider::runTask( [&]() { result = CommMMKVJavaClass::setInt(key, value); }); return result; } std::optional CommMMKV::getInt(std::string key, int noValue) { std::optional result; NativeAndroidAccessProvider::runTask( [&]() { result = CommMMKVJavaClass::getInt(key, noValue); }); return result; } std::vector CommMMKV::getAllKeys() { std::vector result; NativeAndroidAccessProvider::runTask( [&]() { result = CommMMKVJavaClass::getAllKeys(); }); return result; } void CommMMKV::removeKeys(const std::vector &keys) { NativeAndroidAccessProvider::runTask( [&]() { CommMMKVJavaClass::removeKeys(keys); }); } } // namespace comm diff --git a/native/android/app/src/main/java/app/comm/android/fbjni/CommMMKV.java b/native/android/app/src/main/java/app/comm/android/fbjni/CommMMKV.java index 23f81caca..74e05af9d 100644 --- a/native/android/app/src/main/java/app/comm/android/fbjni/CommMMKV.java +++ b/native/android/app/src/main/java/app/comm/android/fbjni/CommMMKV.java @@ -1,129 +1,138 @@ package app.comm.android.fbjni; import android.util.Log; import app.comm.android.MainApplication; import app.comm.android.fbjni.CommSecureStore; import app.comm.android.fbjni.PlatformSpecificTools; import com.tencent.mmkv.MMKV; import java.util.Base64; public class CommMMKV { private static final int MMKV_ENCRYPTION_KEY_SIZE = 16; private static final int MMKV_ID_SIZE = 8; private static final String SECURE_STORE_MMKV_ENCRYPTION_KEY_ID = "comm.mmkvEncryptionKey"; private static final String SECURE_STORE_MMKV_IDENTIFIER_KEY_ID = "comm.mmkvID"; private static String mmkvEncryptionKey; private static String mmkvIdentifier; private static MMKV getMMKVInstance(String mmkvID, String encryptionKey) { MMKV mmkv = MMKV.mmkvWithID(mmkvID, MMKV.MULTI_PROCESS_MODE, encryptionKey); if (mmkv == null) { throw new RuntimeException("Failed to instantiate MMKV object."); } return mmkv; } private static void assignInitializationData() { byte[] encryptionKeyBytes = PlatformSpecificTools.generateSecureRandomBytes( MMKV_ENCRYPTION_KEY_SIZE); byte[] identifierBytes = PlatformSpecificTools.generateSecureRandomBytes(MMKV_ID_SIZE); String encryptionKey = Base64.getEncoder() .encodeToString(encryptionKeyBytes) .substring(0, MMKV_ENCRYPTION_KEY_SIZE); String identifier = Base64.getEncoder() .encodeToString(identifierBytes) .substring(0, MMKV_ID_SIZE); CommSecureStore.set(SECURE_STORE_MMKV_ENCRYPTION_KEY_ID, encryptionKey); CommSecureStore.set(SECURE_STORE_MMKV_IDENTIFIER_KEY_ID, identifier); mmkvEncryptionKey = encryptionKey; mmkvIdentifier = identifier; } public static void initialize() { if (mmkvEncryptionKey != null && mmkvIdentifier != null) { return; } synchronized (CommMMKV.class) { if (mmkvEncryptionKey != null && mmkvIdentifier != null) { return; } String encryptionKey = null, identifier = null; try { encryptionKey = CommSecureStore.get(SECURE_STORE_MMKV_ENCRYPTION_KEY_ID); identifier = CommSecureStore.get(SECURE_STORE_MMKV_IDENTIFIER_KEY_ID); } catch (Exception e) { Log.w("COMM", "Failed to get MMKV keys from CommSecureStore", e); } if (encryptionKey == null || identifier == null) { assignInitializationData(); } else { mmkvEncryptionKey = encryptionKey; mmkvIdentifier = identifier; } MMKV.initialize(MainApplication.getMainApplicationContext()); getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); } } + public static void lock() { + initialize(); + getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey).lock(); + } + + public static void unlock() { + getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey).unlock(); + } + public static void clearSensitiveData() { initialize(); synchronized (mmkvEncryptionKey) { getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey).clearAll(); boolean storageRemoved = MMKV.removeStorage(mmkvIdentifier); if (!storageRemoved) { throw new RuntimeException("Failed to remove MMKV storage."); } assignInitializationData(); MMKV.initialize(MainApplication.getMainApplicationContext()); getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); } } public static boolean setString(String key, String value) { initialize(); return getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey) .encode(key, value); } public static String getString(String key) { initialize(); return getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey).decodeString(key); } public static boolean setInt(String key, int value) { initialize(); return getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey) .encode(key, value); } public static Integer getInt(String key, int noValue) { initialize(); int value = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey) .decodeInt(key, noValue); if (value == noValue) { return null; } return value; } public static String[] getAllKeys() { initialize(); return getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey).allKeys(); } public static void removeKeys(String[] keys) { initialize(); getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey) .removeValuesForKeys(keys); } } diff --git a/native/cpp/CommonCpp/Tools/CommMMKV.h b/native/cpp/CommonCpp/Tools/CommMMKV.h index 6faef14af..3c9acc4cd 100644 --- a/native/cpp/CommonCpp/Tools/CommMMKV.h +++ b/native/cpp/CommonCpp/Tools/CommMMKV.h @@ -1,35 +1,41 @@ #pragma once #include #include #include namespace comm { class CommMMKV { public: static void initialize(); static void clearSensitiveData(); static bool setString(std::string key, std::string value); static std::optional getString(std::string key); static bool setInt(std::string key, int value); // MMKV API can't return null when we try to get integer that // doesn't exist. It allows us to set default value that is // returned instead in case the integer isn't present. The // developer should pass as `noValue` the value that they // know should never be set under certain key. Implementation // will pass `noValue` as default value and return `std::nullopt` // in case MMKV returns default value. static std::optional getInt(std::string key, int noValue); static std::vector getAllKeys(); static void removeKeys(const std::vector &keys); class InitFromNSEForbiddenError : public std::runtime_error { public: using std::runtime_error::runtime_error; }; + + class ScopedCommMMKVLock { + public: + ScopedCommMMKVLock(); + ~ScopedCommMMKVLock(); + }; }; } // namespace comm diff --git a/native/ios/Comm/CommMMKV.mm b/native/ios/Comm/CommMMKV.mm index 4814bfe34..906d7444b 100644 --- a/native/ios/Comm/CommMMKV.mm +++ b/native/ios/Comm/CommMMKV.mm @@ -1,192 +1,215 @@ #import "CommMMKV.h" #import "../../cpp/CommonCpp/CryptoTools/Tools.h" #import "CommSecureStore.h" #import "Logger.h" #import "Tools.h" #import #import +#import + +// Core MMKV C++ implementation and Android wrapper have public `lock` and +// `unlock` methods while Obj-C wrapper doesn't. However Obj-C wrapper has +// private instance variable of type `mmkv::MMKV *`. Interface redeclaration +// below lets us access it. This pattern is hacky but it is well known in +// Obj-C and used in our codebase in CommSecureStore. +@interface MMKV () { +@public + mmkv::MMKV *m_mmkv; +} +@end namespace comm { const int mmkvEncryptionKeySize = 16; const int mmkvIDsize = 8; const std::string secureStoreMMKVEncryptionKeyID = "comm.mmkvEncryptionKey"; const std::string secureStoreMMKVIdentifierKeyID = "comm.mmkvID"; static NSString *mmkvEncryptionKey; static NSString *mmkvIdentifier; MMKV *getMMKVInstance(NSString *mmkvID, NSString *encryptionKey) { MMKV *mmkv = [MMKV mmkvWithID:mmkvID cryptKey:[encryptionKey dataUsingEncoding:NSUTF8StringEncoding] mode:MMKVMultiProcess]; if (!mmkv) { throw std::runtime_error("Failed to instantiate MMKV object."); } return mmkv; } +CommMMKV::ScopedCommMMKVLock::ScopedCommMMKVLock() { + CommMMKV::initialize(); + MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); + mmkv->m_mmkv->lock(); +} + +CommMMKV::ScopedCommMMKVLock::~ScopedCommMMKVLock() { + MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); + mmkv->m_mmkv->unlock(); +} + void assignInitializationData() { std::string encryptionKey = crypto::Tools::generateRandomString(mmkvEncryptionKeySize); std::string identifier = crypto::Tools::generateRandomString(mmkvIDsize); CommSecureStore::set(secureStoreMMKVEncryptionKeyID, encryptionKey); CommSecureStore::set(secureStoreMMKVIdentifierKeyID, identifier); mmkvEncryptionKey = [NSString stringWithCString:encryptionKey.c_str() encoding:NSUTF8StringEncoding]; mmkvIdentifier = [NSString stringWithCString:identifier.c_str() encoding:NSUTF8StringEncoding]; } void CommMMKV::initialize() { // This way of checking if we are running in app extension is // taken from MMKV implementation. See the code linked below: // https://github.com/Tencent/MMKV/blob/master/iOS/MMKV/MMKV/libMMKV.mm#L109 bool isRunningInAppExtension = [[[NSBundle mainBundle] bundlePath] hasSuffix:@".appex"]; void (^initializeBlock)(void) = ^{ auto maybeEncryptionKey = CommSecureStore::get(secureStoreMMKVEncryptionKeyID); auto maybeIdentifier = CommSecureStore::get(secureStoreMMKVIdentifierKeyID); if (maybeEncryptionKey.hasValue() && maybeIdentifier.hasValue()) { mmkvEncryptionKey = [NSString stringWithCString:maybeEncryptionKey.value().c_str() encoding:NSUTF8StringEncoding]; mmkvIdentifier = [NSString stringWithCString:maybeIdentifier.value().c_str() encoding:NSUTF8StringEncoding]; } else if (!isRunningInAppExtension) { assignInitializationData(); } else { throw CommMMKV::InitFromNSEForbiddenError( std::string("NSE can't initialize MMKV encryption key.")); } [MMKV initializeMMKV:nil groupDir:[Tools getAppGroupDirectoryPath] logLevel:MMKVLogNone]; getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); }; if (isRunningInAppExtension) { initializeBlock(); return; } static dispatch_once_t onceToken; dispatch_once(&onceToken, initializeBlock); } void CommMMKV::clearSensitiveData() { CommMMKV::initialize(); @synchronized(mmkvEncryptionKey) { MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); [mmkv clearAll]; BOOL storageRemoved = [MMKV removeStorage:mmkvIdentifier mode:MMKVMultiProcess]; if (!storageRemoved) { throw std::runtime_error("Failed to remove mmkv storage."); } assignInitializationData(); [MMKV initializeMMKV:nil groupDir:[Tools getAppGroupDirectoryPath] logLevel:MMKVLogNone]; getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); } } bool CommMMKV::setString(std::string key, std::string value) { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); BOOL result = [mmkv setString:[NSString stringWithCString:value.c_str() encoding:NSUTF8StringEncoding] forKey:[NSString stringWithCString:key.c_str() encoding:NSUTF8StringEncoding]]; if (!result) { Logger::log("Attempt to write in background or failure during write."); } return result; } std::optional CommMMKV::getString(std::string key) { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); NSString *value = [mmkv getStringForKey:[NSString stringWithCString:key.c_str() encoding:NSUTF8StringEncoding]]; if (!value) { return std::nullopt; } return std::string([value UTF8String]); } bool CommMMKV::setInt(std::string key, int value) { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); BOOL result = [mmkv setInt64:value forKey:[NSString stringWithCString:key.c_str() encoding:NSUTF8StringEncoding]]; if (!result) { Logger::log("Attempt to write in background or failure during write."); } return result; } std::optional CommMMKV::getInt(std::string key, int noValue) { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); int value = [mmkv getInt64ForKey:[NSString stringWithCString:key.c_str() encoding:NSUTF8StringEncoding] defaultValue:noValue hasValue:nil]; if (value == noValue) { return std::nullopt; } return value; } std::vector CommMMKV::getAllKeys() { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); NSArray *allKeys = [mmkv allKeys]; std::vector result; for (NSString *key in allKeys) { result.emplace_back(std::string([key UTF8String])); } return result; } void CommMMKV::removeKeys(const std::vector &keys) { CommMMKV::initialize(); MMKV *mmkv = getMMKVInstance(mmkvIdentifier, mmkvEncryptionKey); NSMutableArray *keysObjC = [[NSMutableArray alloc] init]; for (const auto &key : keys) { [keysObjC addObject:[NSString stringWithCString:key.c_str() encoding:NSUTF8StringEncoding]]; } [mmkv removeValuesForKeys:keysObjC]; } } // namespace comm