diff --git a/native/android/app/CMakeLists.txt b/native/android/app/CMakeLists.txt --- a/native/android/app/CMakeLists.txt +++ b/native/android/app/CMakeLists.txt @@ -217,6 +217,8 @@ -DSQLITE_HAS_CODEC -DSQLITE_TEMP_STORE=2 -DSQLCIPHER_CRYPTO_OPENSSL + -DSQLITE_ENABLE_SESSION + -DSQLITE_ENABLE_PREUPDATE_HOOK ) target_link_libraries( diff --git a/native/android/app/src/cpp/AESCrypto.cpp b/native/android/app/src/cpp/AESCrypto.cpp --- a/native/android/app/src/cpp/AESCrypto.cpp +++ b/native/android/app/src/cpp/AESCrypto.cpp @@ -5,26 +5,27 @@ using namespace facebook::jni; -class AESCryptoJavaClass : public JavaClass { +template +class AESCryptoJavaClass : public JavaClass> { public: // app.comm.android.aescrypto.AESCryptoModuleCompat static auto constexpr kJavaDescriptor = "Lapp/comm/android/aescrypto/AESCryptoModuleCompat;"; - static void generateKey(rust::Slice buffer) { + using JavaClass>::javaClassStatic; + + static void generateKey(T buffer) { local_ref byteBuffer = JByteBuffer::wrapBytes(buffer.data(), buffer.size()); static const auto cls = javaClassStatic(); static auto method = - cls->getStaticMethod)>("generateKey"); + cls->template getStaticMethod)>( + "generateKey"); method(cls, byteBuffer); } - static void encrypt( - rust::Slice key, - rust::Slice plaintext, - rust::Slice sealedData) { + static void encrypt(T key, T plaintext, T sealedData) { local_ref keyBuffer = JByteBuffer::wrapBytes(key.data(), key.size()); local_ref plaintextBuffer = @@ -32,17 +33,14 @@ local_ref sealedDataBuffer = JByteBuffer::wrapBytes(sealedData.data(), sealedData.size()); static const auto cls = javaClassStatic(); - static auto method = cls->getStaticMethodtemplate getStaticMethod, local_ref, local_ref)>("encrypt"); method(cls, keyBuffer, plaintextBuffer, sealedDataBuffer); } - static void decrypt( - rust::Slice key, - rust::Slice sealedData, - rust::Slice plaintext) { + static void decrypt(T key, T sealedData, T plaintext) { local_ref keyBuffer = JByteBuffer::wrapBytes(key.data(), key.size()); local_ref sealedDataBuffer = @@ -50,7 +48,7 @@ local_ref plaintextBuffer = JByteBuffer::wrapBytes(plaintext.data(), plaintext.size()); static const auto cls = javaClassStatic(); - static auto method = cls->getStaticMethodtemplate getStaticMethod, local_ref, local_ref)>("decrypt"); @@ -60,25 +58,24 @@ namespace comm { -void AESCrypto::generateKey(rust::Slice buffer) { +template void AESCrypto::generateKey(T buffer) { NativeAndroidAccessProvider::runTask( - [&]() { AESCryptoJavaClass::generateKey(buffer); }); + [&]() { AESCryptoJavaClass::generateKey(buffer); }); } -void AESCrypto::encrypt( - rust::Slice key, - rust::Slice plaintext, - rust::Slice sealedData) { +template +void AESCrypto::encrypt(T key, T plaintext, T sealedData) { NativeAndroidAccessProvider::runTask( - [&]() { AESCryptoJavaClass::encrypt(key, plaintext, sealedData); }); + [&]() { AESCryptoJavaClass::encrypt(key, plaintext, sealedData); }); } -void AESCrypto::decrypt( - rust::Slice key, - rust::Slice sealedData, - rust::Slice plaintext) { +template +void AESCrypto::decrypt(T key, T sealedData, T plaintext) { NativeAndroidAccessProvider::runTask( - [&]() { AESCryptoJavaClass::decrypt(key, sealedData, plaintext); }); + [&]() { AESCryptoJavaClass::decrypt(key, sealedData, plaintext); }); } +template class AESCrypto>; +template class AESCrypto &>; + } // namespace comm diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.h b/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.h --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.h +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.h @@ -19,7 +19,8 @@ std::string backupID, std::string logID, std::uint8_t *changesetPtr, - int changesetSize); + int changesetSize, + std::string encryptionKey); std::vector getAttachmentsFromLog(std::uint8_t *changesetPtr, int changesetSize); @@ -32,6 +33,9 @@ void closeConnection(); ~SQLiteConnectionManager(); bool shouldIncrementLogID(std::string backupID, std::string logID); - bool captureLogs(std::string backupID, std::string logID); + bool captureLogs( + std::string backupID, + std::string logID, + std::string encryptionKey); }; } // namespace comm diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.cpp --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.cpp +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteConnectionManager.cpp @@ -1,4 +1,5 @@ #include "SQLiteConnectionManager.h" +#include "AESCrypto.h" #include "PlatformSpecificTools.h" #include @@ -12,6 +13,9 @@ const int MAX_LOGS_BUFFER_SIZE = 5; const std::string BLOB_SERVICE_PREFIX = "comm-blob-service://"; +const int IV_LENGTH = 12; +const int TAG_LENGTH = 16; + void SQLiteConnectionManager::attachSession() { int sessionCreationResult = sqlite3session_create(dbConnection, "main", &backupLogsSession); @@ -35,7 +39,8 @@ std::string backupID, std::string logID, std::uint8_t *changesetPtr, - int changesetSize) { + int changesetSize, + std::string encryptionKey) { std::string finalFilePath = PlatformSpecificTools::getBackupLogFilePath(backupID, logID, false); std::string tempFilePath = finalFilePath + "_tmp"; @@ -46,8 +51,22 @@ if (!tempFile.is_open()) { throw std::runtime_error("Failed to open temporary log file."); } - tempFile << std::string(changesetPtr, changesetPtr + changesetSize); + + std::vector logBytes( + changesetPtr, changesetPtr + changesetSize); + + std::vector encryptedLog; + encryptedLog.resize(logBytes.size() + IV_LENGTH + TAG_LENGTH); + + std::vector encryptionKeyBytes( + encryptionKey.begin(), encryptionKey.end()); + + AESCrypto &>::encrypt( + encryptionKeyBytes, logBytes, encryptedLog); + + tempFile << std::string(encryptedLog.begin(), encryptedLog.end()); tempFile.close(); + if (std::rename(tempFilePath.c_str(), finalFilePath.c_str())) { throw std::runtime_error( "Failed to rename complete log file from temporary path to target " @@ -238,7 +257,8 @@ bool SQLiteConnectionManager::captureLogs( std::string backupID, - std::string logID) { + std::string logID, + std::string encryptionKey) { int changesetSize; std::uint8_t *changesetPtr; int getChangesetResult = sqlite3session_patchset( @@ -250,11 +270,11 @@ std::string(sqlite3_errstr(getChangesetResult))); } - if (changesetSize == 0 && !changesetPtr) { + if (changesetSize == 0 || !changesetPtr) { return false; } - persistLog(backupID, logID, changesetPtr, changesetSize); + persistLog(backupID, logID, changesetPtr, changesetSize, encryptionKey); sqlite3_free(changesetPtr); logsCount++; diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h @@ -19,6 +19,9 @@ static std::once_flag initialized; static int sqlcipherEncryptionKeySize; static std::string secureStoreEncryptionKeyID; + static int backupLogsEncryptionKeySize; + static std::string secureStoreBackupLogsEncryptionKeyID; + static std::string backupLogsEncryptionKey; #ifndef EMSCRIPTEN static SQLiteConnectionManager connectionManager; diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp @@ -23,6 +23,10 @@ int SQLiteQueryExecutor::sqlcipherEncryptionKeySize = 64; std::string SQLiteQueryExecutor::secureStoreEncryptionKeyID = "comm.encryptionKey"; +int SQLiteQueryExecutor::backupLogsEncryptionKeySize = 32; +std::string SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID = + "comm.backupLogsEncryptionKey"; +std::string SQLiteQueryExecutor::backupLogsEncryptionKey; #ifndef EMSCRIPTEN SQLiteConnectionManager SQLiteQueryExecutor::connectionManager; @@ -1591,9 +1595,15 @@ SQLiteQueryExecutor::sqliteFilePath = databasePath; folly::Optional maybeEncryptionKey = CommSecureStore::get(SQLiteQueryExecutor::secureStoreEncryptionKeyID); + folly::Optional maybeBackupLogsEncryptionKey = + CommSecureStore::get( + SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID); - if (file_exists(databasePath) && maybeEncryptionKey) { + if (file_exists(databasePath) && maybeEncryptionKey && + maybeBackupLogsEncryptionKey) { SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value(); + SQLiteQueryExecutor::backupLogsEncryptionKey = + maybeBackupLogsEncryptionKey.value(); return; } SQLiteQueryExecutor::assign_encryption_key(); @@ -1694,9 +1704,16 @@ void SQLiteQueryExecutor::assign_encryption_key() { std::string encryptionKey = comm::crypto::Tools::generateRandomHexString( SQLiteQueryExecutor::sqlcipherEncryptionKeySize); + std::string backupLogsEncryptionKey = + comm::crypto::Tools::generateRandomHexString( + SQLiteQueryExecutor::backupLogsEncryptionKeySize); CommSecureStore::set( SQLiteQueryExecutor::secureStoreEncryptionKeyID, encryptionKey); + CommSecureStore::set( + SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID, + backupLogsEncryptionKey); SQLiteQueryExecutor::encryptionKey = encryptionKey; + SQLiteQueryExecutor::backupLogsEncryptionKey = backupLogsEncryptionKey; } void SQLiteQueryExecutor::captureBackupLogs() const { @@ -1717,7 +1734,8 @@ } bool shouldIncrementLogID = - SQLiteQueryExecutor::connectionManager.captureLogs(backupID, logID); + SQLiteQueryExecutor::connectionManager.captureLogs( + backupID, logID, SQLiteQueryExecutor::backupLogsEncryptionKey); if (shouldIncrementLogID) { this->setMetadata("logID", std::to_string(std::stoi(logID) + 1)); } diff --git a/native/cpp/CommonCpp/Tools/AESCrypto.h b/native/cpp/CommonCpp/Tools/AESCrypto.h --- a/native/cpp/CommonCpp/Tools/AESCrypto.h +++ b/native/cpp/CommonCpp/Tools/AESCrypto.h @@ -4,17 +4,11 @@ namespace comm { -class AESCrypto { +template class AESCrypto { public: - static void generateKey(rust::Slice buffer); - static void encrypt( - rust::Slice key, - rust::Slice plaintext, - rust::Slice sealedData); - static void decrypt( - rust::Slice key, - rust::Slice sealedData, - rust::Slice plaintext); + static void generateKey(T buffer); + static void encrypt(T key, T plaintext, T sealedData); + static void decrypt(T key, T sealedData, T plaintext); }; } // namespace comm diff --git a/native/ios/Comm/AESCrypto.mm b/native/ios/Comm/AESCrypto.mm --- a/native/ios/Comm/AESCrypto.mm +++ b/native/ios/Comm/AESCrypto.mm @@ -4,7 +4,7 @@ namespace comm { -void AESCrypto::generateKey(rust::Slice buffer) { +template void AESCrypto::generateKey(T buffer) { NSError *keyGenerationError = nil; [AESCryptoModuleObjCCompat generateKey:buffer.data() destinationLength:buffer.size() @@ -15,10 +15,8 @@ } } -void AESCrypto::encrypt( - rust::Slice key, - rust::Slice plaintext, - rust::Slice sealedData) { +template +void AESCrypto::encrypt(T key, T plaintext, T sealedData) { NSData *keyBuffer = [NSData dataWithBytesNoCopy:key.data() length:key.size() freeWhenDone:NO]; @@ -37,10 +35,8 @@ } } -void AESCrypto::decrypt( - rust::Slice key, - rust::Slice sealedData, - rust::Slice plaintext) { +template +void AESCrypto::decrypt(T key, T sealedData, T plaintext) { NSData *keyBuffer = [NSData dataWithBytesNoCopy:key.data() length:key.size() freeWhenDone:NO]; @@ -59,4 +55,7 @@ } } +template class AESCrypto>; +template class AESCrypto &>; + } // namespace comm diff --git a/native/native_rust_library/RustAESCrypto.cpp b/native/native_rust_library/RustAESCrypto.cpp --- a/native/native_rust_library/RustAESCrypto.cpp +++ b/native/native_rust_library/RustAESCrypto.cpp @@ -4,21 +4,21 @@ namespace comm { void aesGenerateKey(rust::Slice buffer) { - AESCrypto::generateKey(buffer); + AESCrypto>::generateKey(buffer); } void aesEncrypt( rust::Slice key, rust::Slice plaintext, rust::Slice sealedData) { - AESCrypto::encrypt(key, plaintext, sealedData); + AESCrypto>::encrypt(key, plaintext, sealedData); } void aesDecrypt( rust::Slice key, rust::Slice sealedData, rust::Slice plaintext) { - AESCrypto::decrypt(key, sealedData, plaintext); + AESCrypto>::decrypt(key, sealedData, plaintext); } } // namespace comm