diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp index 59d8e197f..208f34cfe 100644 --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -1,463 +1,465 @@ #include "CryptoModule.h" #include "Logger.h" #include "PlatformSpecificTools.h" #include "olm/account.hh" #include "olm/session.hh" #include #include #include #include namespace comm { namespace crypto { CryptoModule::CryptoModule(std::string id) : id{id} { this->createAccount(); } CryptoModule::CryptoModule( std::string id, std::string secretKey, Persist persist) : id{id} { if (persist.isEmpty()) { this->createAccount(); } else { this->restoreFromB64(secretKey, persist); } } OlmAccount *CryptoModule::getOlmAccount() { return reinterpret_cast(this->accountBuffer.data()); } void CryptoModule::createAccount() { this->accountBuffer.resize(::olm_account_size()); ::olm_account(this->accountBuffer.data()); size_t randomSize = ::olm_create_account_random_length(this->getOlmAccount()); OlmBuffer randomBuffer; PlatformSpecificTools::generateSecureRandomBytes(randomBuffer, randomSize); if (-1 == ::olm_create_account( this->getOlmAccount(), randomBuffer.data(), randomSize)) { throw std::runtime_error{ "error createAccount => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; }; } void CryptoModule::exposePublicIdentityKeys() { size_t identityKeysSize = ::olm_account_identity_keys_length(this->getOlmAccount()); if (this->keys.identityKeys.size() == identityKeysSize) { return; } this->keys.identityKeys.resize( ::olm_account_identity_keys_length(this->getOlmAccount())); if (-1 == ::olm_account_identity_keys( this->getOlmAccount(), this->keys.identityKeys.data(), this->keys.identityKeys.size())) { throw std::runtime_error{ "error generateIdentityKeys => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } } void CryptoModule::generateOneTimeKeys(size_t oneTimeKeysAmount) { size_t numRandomBytesRequired = ::olm_account_generate_one_time_keys_random_length( this->getOlmAccount(), oneTimeKeysAmount); OlmBuffer random; PlatformSpecificTools::generateSecureRandomBytes( random, numRandomBytesRequired); if (-1 == ::olm_account_generate_one_time_keys( this->getOlmAccount(), oneTimeKeysAmount, random.data(), random.size())) { throw std::runtime_error{ "error generateOneTimeKeys => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } } // returns number of published keys size_t CryptoModule::publishOneTimeKeys() { this->keys.oneTimeKeys.resize( ::olm_account_one_time_keys_length(this->getOlmAccount())); if (-1 == ::olm_account_one_time_keys( this->getOlmAccount(), this->keys.oneTimeKeys.data(), this->keys.oneTimeKeys.size())) { throw std::runtime_error{ "error publishOneTimeKeys => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return ::olm_account_mark_keys_as_published(this->getOlmAccount()); } bool CryptoModule::prekeyExistsAndOlderThan(uint64_t threshold) { // Our fork of Olm only remembers two prekeys at a time. // If the new one hasn't been published, then the old one is still active. // In that scenario, we need to avoid rotating the prekey because it will // result in the old active prekey being discarded. if (this->getUnpublishedPrekey().has_value()) { return false; } uint64_t currentTime = std::time(nullptr); uint64_t lastPrekeyPublishTime = ::olm_account_get_last_prekey_publish_time(this->getOlmAccount()); return currentTime - lastPrekeyPublishTime >= threshold; } Keys CryptoModule::keysFromStrings( const std::string &identityKeys, const std::string &oneTimeKeys) { return { OlmBuffer(identityKeys.begin(), identityKeys.end()), OlmBuffer(oneTimeKeys.begin(), oneTimeKeys.end())}; } std::string CryptoModule::getIdentityKeys() { this->exposePublicIdentityKeys(); return std::string{ this->keys.identityKeys.begin(), this->keys.identityKeys.end()}; } std::string CryptoModule::getOneTimeKeysForPublishing(size_t oneTimeKeysAmount) { OlmBuffer unpublishedOneTimeKeys; unpublishedOneTimeKeys.resize( ::olm_account_one_time_keys_length(this->getOlmAccount())); if (-1 == ::olm_account_one_time_keys( this->getOlmAccount(), unpublishedOneTimeKeys.data(), unpublishedOneTimeKeys.size())) { throw std::runtime_error{ "error getOneTimeKeysForPublishing => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } std::string unpublishedKeysString = std::string{unpublishedOneTimeKeys.begin(), unpublishedOneTimeKeys.end()}; folly::dynamic parsedUnpublishedKeys = folly::parseJson(unpublishedKeysString); size_t numUnpublishedKeys = parsedUnpublishedKeys["curve25519"].size(); if (numUnpublishedKeys < oneTimeKeysAmount) { this->generateOneTimeKeys(oneTimeKeysAmount - numUnpublishedKeys); } this->publishOneTimeKeys(); return std::string{ this->keys.oneTimeKeys.begin(), this->keys.oneTimeKeys.end()}; } std::uint8_t CryptoModule::getNumPrekeys() { return reinterpret_cast(this->getOlmAccount())->num_prekeys; } std::string CryptoModule::getPrekey() { OlmBuffer prekey; prekey.resize(::olm_account_prekey_length(this->getOlmAccount())); if (-1 == ::olm_account_prekey( this->getOlmAccount(), prekey.data(), prekey.size())) { throw std::runtime_error{ "error getPrekey => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return std::string{std::string{prekey.begin(), prekey.end()}}; } std::string CryptoModule::getPrekeySignature() { size_t signatureSize = ::olm_account_signature_length(this->getOlmAccount()); OlmBuffer signatureBuffer; signatureBuffer.resize(signatureSize); if (-1 == ::olm_account_prekey_signature( this->getOlmAccount(), signatureBuffer.data())) { throw std::runtime_error{ "error getPrekeySignature => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return std::string{signatureBuffer.begin(), signatureBuffer.end()}; } std::optional CryptoModule::getUnpublishedPrekey() { OlmBuffer prekey; prekey.resize(::olm_account_prekey_length(this->getOlmAccount())); std::size_t retval = ::olm_account_unpublished_prekey( this->getOlmAccount(), prekey.data(), prekey.size()); if (0 == retval) { return std::nullopt; } else if (-1 == retval) { throw std::runtime_error{ "error getUnpublishedPrekey => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return std::string{prekey.begin(), prekey.end()}; } std::string CryptoModule::generateAndGetPrekey() { size_t prekeySize = ::olm_account_generate_prekey_random_length(this->getOlmAccount()); OlmBuffer random; PlatformSpecificTools::generateSecureRandomBytes(random, prekeySize); if (-1 == ::olm_account_generate_prekey( this->getOlmAccount(), random.data(), random.size())) { throw std::runtime_error{ "error generateAndGetPrekey => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } OlmBuffer prekey; prekey.resize(::olm_account_prekey_length(this->getOlmAccount())); if (-1 == ::olm_account_prekey( this->getOlmAccount(), prekey.data(), prekey.size())) { throw std::runtime_error{ "error generateAndGetPrekey => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return std::string{prekey.begin(), prekey.end()}; } void CryptoModule::markPrekeyAsPublished() { ::olm_account_mark_prekey_as_published(this->getOlmAccount()); } void CryptoModule::forgetOldPrekey() { ::olm_account_forget_old_prekey(this->getOlmAccount()); } void CryptoModule::initializeInboundForReceivingSession( const std::string &targetDeviceId, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys, const bool overwrite) { if (this->hasSessionFor(targetDeviceId)) { if (overwrite) { this->sessions.erase(this->sessions.find(targetDeviceId)); } else { throw std::runtime_error{ "error initializeInboundForReceivingSession => session already " "initialized"}; } } std::unique_ptr newSession = Session::createSessionAsResponder( this->getOlmAccount(), this->keys.identityKeys.data(), encryptedMessage, idKeys); this->sessions.insert(make_pair(targetDeviceId, std::move(newSession))); } void CryptoModule::initializeOutboundForSendingSession( const std::string &targetDeviceId, const OlmBuffer &idKeys, const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const OlmBuffer &oneTimeKey) { if (this->hasSessionFor(targetDeviceId)) { Logger::log( "olm session overwritten for the device with id: " + targetDeviceId); this->sessions.erase(this->sessions.find(targetDeviceId)); } std::unique_ptr newSession = Session::createSessionAsInitializer( this->getOlmAccount(), this->keys.identityKeys.data(), idKeys, preKeys, preKeySignature, oneTimeKey); this->sessions.insert(make_pair(targetDeviceId, std::move(newSession))); } bool CryptoModule::hasSessionFor(const std::string &targetDeviceId) { return (this->sessions.find(targetDeviceId) != this->sessions.end()); } std::shared_ptr CryptoModule::getSessionByDeviceId(const std::string &deviceId) { return this->sessions.at(deviceId); } void CryptoModule::removeSessionByDeviceId(const std::string &deviceId) { this->sessions.erase(deviceId); } Persist CryptoModule::storeAsB64(const std::string &secretKey) { Persist persist; size_t accountPickleLength = ::olm_pickle_account_length(this->getOlmAccount()); OlmBuffer accountPickleBuffer(accountPickleLength); if (accountPickleLength != ::olm_pickle_account( this->getOlmAccount(), secretKey.data(), secretKey.size(), accountPickleBuffer.data(), accountPickleLength)) { throw std::runtime_error{ "error storeAsB64 => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } persist.account = accountPickleBuffer; std::unordered_map>::iterator it; for (it = this->sessions.begin(); it != this->sessions.end(); ++it) { OlmBuffer buffer = it->second->storeAsB64(secretKey); - persist.sessions.insert(make_pair(it->first, buffer)); + SessionPersist sessionPersist{buffer, it->second->getVersion()}; + persist.sessions.insert(make_pair(it->first, sessionPersist)); } return persist; } void CryptoModule::restoreFromB64( const std::string &secretKey, Persist persist) { this->accountBuffer.resize(::olm_account_size()); ::olm_account(this->accountBuffer.data()); if (-1 == ::olm_unpickle_account( this->getOlmAccount(), secretKey.data(), secretKey.size(), persist.account.data(), persist.account.size())) { throw std::runtime_error{ "error restoreFromB64 => " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } - std::unordered_map::iterator it; + std::unordered_map::iterator it; for (it = persist.sessions.begin(); it != persist.sessions.end(); ++it) { std::unique_ptr session = - session->restoreFromB64(secretKey, it->second); + session->restoreFromB64(secretKey, it->second.buffer); + session->setVersion(it->second.version); this->sessions.insert(make_pair(it->first, move(session))); } } EncryptedData CryptoModule::encrypt( const std::string &targetDeviceId, const std::string &content) { if (!this->hasSessionFor(targetDeviceId)) { throw std::runtime_error{"error encrypt => uninitialized session"}; } OlmSession *session = this->sessions.at(targetDeviceId)->getOlmSession(); OlmBuffer encryptedMessage( ::olm_encrypt_message_length(session, content.size())); OlmBuffer messageRandom; PlatformSpecificTools::generateSecureRandomBytes( messageRandom, ::olm_encrypt_random_length(session)); size_t messageType = ::olm_encrypt_message_type(session); if (-1 == ::olm_encrypt( session, (uint8_t *)content.data(), content.size(), messageRandom.data(), messageRandom.size(), encryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error{ "error encrypt => " + std::string{::olm_session_last_error(session)}}; } return {encryptedMessage, messageType}; } std::string CryptoModule::decrypt( const std::string &targetDeviceId, EncryptedData &encryptedData) { if (!this->hasSessionFor(targetDeviceId)) { throw std::runtime_error{"error decrypt => uninitialized session"}; } return this->sessions.at(targetDeviceId)->decrypt(encryptedData); } std::string CryptoModule::signMessage(const std::string &message) { OlmBuffer signature; signature.resize(::olm_account_signature_length(this->getOlmAccount())); size_t signatureLength = ::olm_account_sign( this->getOlmAccount(), (uint8_t *)message.data(), message.length(), signature.data(), signature.size()); if (signatureLength == -1) { throw std::runtime_error{ "olm error: " + std::string{::olm_account_last_error(this->getOlmAccount())}}; } return std::string{(char *)signature.data(), signatureLength}; } void CryptoModule::verifySignature( const std::string &publicKey, const std::string &message, const std::string &signature) { OlmBuffer utilityBuffer; utilityBuffer.resize(::olm_utility_size()); OlmUtility *olmUtility = ::olm_utility(utilityBuffer.data()); ssize_t verificationResult = ::olm_ed25519_verify( olmUtility, (uint8_t *)publicKey.data(), publicKey.length(), (uint8_t *)message.data(), message.length(), (uint8_t *)signature.data(), signature.length()); if (verificationResult == -1) { throw std::runtime_error{ "olm error: " + std::string{::olm_utility_last_error(olmUtility)}}; } } std::optional CryptoModule::validatePrekey() { static const uint64_t maxPrekeyPublishTime = 10 * 60; static const uint64_t maxOldPrekeyAge = 2 * 60; std::optional maybeNewPrekey; bool shouldRotatePrekey = this->prekeyExistsAndOlderThan(maxPrekeyPublishTime); if (shouldRotatePrekey) { maybeNewPrekey = this->generateAndGetPrekey(); } bool shouldForgetPrekey = this->prekeyExistsAndOlderThan(maxOldPrekeyAge); if (shouldForgetPrekey) { this->forgetOldPrekey(); } return maybeNewPrekey; } } // namespace crypto } // namespace comm diff --git a/native/cpp/CommonCpp/CryptoTools/Persist.h b/native/cpp/CommonCpp/CryptoTools/Persist.h index db038bd70..70d87e91c 100644 --- a/native/cpp/CommonCpp/CryptoTools/Persist.h +++ b/native/cpp/CommonCpp/CryptoTools/Persist.h @@ -1,21 +1,26 @@ #pragma once #include #include #include "Tools.h" namespace comm { namespace crypto { +struct SessionPersist { + OlmBuffer buffer; + int version; +}; + struct Persist { OlmBuffer account; - std::unordered_map sessions; + std::unordered_map sessions; bool isEmpty() const { return (this->account.size() == 0); } }; } // namespace crypto } // namespace comm diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp index add51af74..0f0b3d94e 100644 --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -1,161 +1,169 @@ #include "Session.h" #include "PlatformSpecificTools.h" #include namespace comm { namespace crypto { OlmSession *Session::getOlmSession() { return reinterpret_cast(this->olmSessionBuffer.data()); } std::unique_ptr Session::createSessionAsInitializer( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &idKeys, const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const OlmBuffer &oneTimeKey) { std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); OlmBuffer randomBuffer; PlatformSpecificTools::generateSecureRandomBytes( randomBuffer, ::olm_create_outbound_session_random_length(session->getOlmSession())); if (-1 == ::olm_create_outbound_session( session->getOlmSession(), account, idKeys.data() + ID_KEYS_PREFIX_OFFSET, KEYSIZE, idKeys.data() + SIGNING_KEYS_PREFIX_OFFSET, KEYSIZE, preKeys.data(), KEYSIZE, preKeySignature.data(), SIGNATURESIZE, oneTimeKey.data(), KEYSIZE, randomBuffer.data(), randomBuffer.size())) { throw std::runtime_error( "error createOutbound => " + std::string{::olm_session_last_error(session->getOlmSession())}); } return session; } std::unique_ptr Session::createSessionAsResponder( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys) { std::unique_ptr session(new Session()); OlmBuffer tmpEncryptedMessage(encryptedMessage); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); if (-1 == ::olm_create_inbound_session( session->getOlmSession(), account, tmpEncryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error( "error createInbound => " + std::string{::olm_session_last_error(session->getOlmSession())}); } if (-1 == ::olm_remove_one_time_keys(account, session->getOlmSession())) { throw std::runtime_error( "error createInbound (remove oneTimeKey) => " + std::string{::olm_session_last_error(session->getOlmSession())}); } return session; } OlmBuffer Session::storeAsB64(const std::string &secretKey) { size_t pickleLength = ::olm_pickle_session_length(this->getOlmSession()); OlmBuffer pickle(pickleLength); size_t res = ::olm_pickle_session( this->getOlmSession(), secretKey.data(), secretKey.size(), pickle.data(), pickleLength); if (pickleLength != res) { throw std::runtime_error("error pickleSession => ::olm_pickle_session"); } return pickle; } std::unique_ptr Session::restoreFromB64(const std::string &secretKey, OlmBuffer &b64) { std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); if (-1 == ::olm_unpickle_session( session->getOlmSession(), secretKey.data(), secretKey.size(), b64.data(), b64.size())) { throw std::runtime_error("error pickleSession => ::olm_unpickle_session"); } return session; } std::string Session::decrypt(EncryptedData &encryptedData) { OlmSession *session = this->getOlmSession(); OlmBuffer utilityBuffer(::olm_utility_size()); OlmUtility *olmUtility = ::olm_utility(utilityBuffer.data()); OlmBuffer messageHashBuffer(::olm_sha256_length(olmUtility)); ::olm_sha256( olmUtility, encryptedData.message.data(), encryptedData.message.size(), messageHashBuffer.data(), messageHashBuffer.size()); OlmBuffer tmpEncryptedMessage(encryptedData.message); size_t maxSize = ::olm_decrypt_max_plaintext_length( session, encryptedData.messageType, tmpEncryptedMessage.data(), tmpEncryptedMessage.size()); if (maxSize == -1) { throw std::runtime_error{ "error decrypt_max_plaintext_length => " + std::string{::olm_session_last_error(session)} + ". Hash: " + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; } OlmBuffer decryptedMessage(maxSize); size_t decryptedSize = ::olm_decrypt( session, encryptedData.messageType, encryptedData.message.data(), encryptedData.message.size(), decryptedMessage.data(), decryptedMessage.size()); if (decryptedSize == -1) { throw std::runtime_error{ "error decrypt => " + std::string{::olm_session_last_error(session)} + ". Hash: " + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; } return std::string{(char *)decryptedMessage.data(), decryptedSize}; } +int Session::getVersion() { + return this->version; +} + +void Session::setVersion(int newVersion) { + this->version = newVersion; +} + } // namespace crypto } // namespace comm diff --git a/native/cpp/CommonCpp/CryptoTools/Session.h b/native/cpp/CommonCpp/CryptoTools/Session.h index 1c1f2ec22..2129e3f44 100644 --- a/native/cpp/CommonCpp/CryptoTools/Session.h +++ b/native/cpp/CommonCpp/CryptoTools/Session.h @@ -1,37 +1,40 @@ #pragma once #include #include #include "Tools.h" #include "olm/olm.h" namespace comm { namespace crypto { class Session { OlmBuffer olmSessionBuffer; + int version; public: static std::unique_ptr createSessionAsInitializer( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &idKeys, const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const OlmBuffer &oneTimeKey); static std::unique_ptr createSessionAsResponder( OlmAccount *account, std::uint8_t *ownerIdentityKeys, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys); OlmBuffer storeAsB64(const std::string &secretKey); static std::unique_ptr restoreFromB64(const std::string &secretKey, OlmBuffer &b64); OlmSession *getOlmSession(); std::string decrypt(EncryptedData &encryptedData); + int getVersion(); + void setVersion(int newVersion); }; } // namespace crypto } // namespace comm diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp index 7bf2df621..76ce39398 100644 --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp @@ -1,2425 +1,2428 @@ #include "SQLiteQueryExecutor.h" #include "Logger.h" #include "entities/CommunityInfo.h" #include "entities/EntityQueryHelpers.h" #include "entities/IntegrityThreadHash.h" #include "entities/KeyserverInfo.h" #include "entities/Metadata.h" #include "entities/SyncedMetadataEntry.h" #include "entities/UserInfo.h" #include #include #include #ifndef EMSCRIPTEN #include "../CryptoTools/CryptoModule.h" #include "CommSecureStore.h" #include "PlatformSpecificTools.h" #include "StaffUtils.h" #endif const int CONTENT_ACCOUNT_ID = 1; const int NOTIFS_ACCOUNT_ID = 2; namespace comm { std::string SQLiteQueryExecutor::sqliteFilePath; std::string SQLiteQueryExecutor::encryptionKey; std::once_flag SQLiteQueryExecutor::initialized; int SQLiteQueryExecutor::sqlcipherEncryptionKeySize = 64; // Should match constant defined in `native_rust_library/src/constants.rs` std::string SQLiteQueryExecutor::secureStoreEncryptionKeyID = "comm.encryptionKey"; int SQLiteQueryExecutor::backupLogsEncryptionKeySize = 32; std::string SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID = "comm.backupLogsEncryptionKey"; std::string SQLiteQueryExecutor::backupLogsEncryptionKey; #ifndef EMSCRIPTEN NativeSQLiteConnectionManager SQLiteQueryExecutor::connectionManager; std::unordered_set SQLiteQueryExecutor::backedUpTablesBlocklist = { "olm_persist_account", "olm_persist_sessions", "metadata", "messages_to_device", "integrity_store", "persist_storage", "keyservers", }; #else SQLiteConnectionManager SQLiteQueryExecutor::connectionManager; #endif bool create_table(sqlite3 *db, std::string query, std::string tableName) { char *error; sqlite3_exec(db, query.c_str(), nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating '" << tableName << "' table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_drafts_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS drafts (threadID TEXT UNIQUE PRIMARY KEY, " "text TEXT);"; return create_table(db, query, "drafts"); } bool rename_threadID_to_key(sqlite3 *db) { sqlite3_stmt *key_column_stmt; sqlite3_prepare_v2( db, "SELECT name AS col_name FROM pragma_table_xinfo ('drafts') WHERE " "col_name='key';", -1, &key_column_stmt, nullptr); sqlite3_step(key_column_stmt); auto num_bytes = sqlite3_column_bytes(key_column_stmt, 0); sqlite3_finalize(key_column_stmt); if (num_bytes) { return true; } char *error; sqlite3_exec( db, "ALTER TABLE drafts RENAME COLUMN `threadID` TO `key`;", nullptr, nullptr, &error); if (error) { std::ostringstream stringStream; stringStream << "Error occurred renaming threadID column in drafts table " << "to key: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } return true; } bool create_persist_account_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS olm_persist_account(" "id INTEGER UNIQUE PRIMARY KEY NOT NULL, " "account_data TEXT NOT NULL);"; return create_table(db, query, "olm_persist_account"); } bool create_persist_sessions_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS olm_persist_sessions(" "target_user_id TEXT UNIQUE PRIMARY KEY NOT NULL, " "session_data TEXT NOT NULL);"; return create_table(db, query, "olm_persist_sessions"); } bool drop_messages_table(sqlite3 *db) { char *error; sqlite3_exec(db, "DROP TABLE IF EXISTS messages;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error dropping 'messages' table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool recreate_messages_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS messages ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "local_id TEXT, " "thread TEXT NOT NULL, " "user TEXT NOT NULL, " "type INTEGER NOT NULL, " "future_type INTEGER, " "content TEXT, " "time INTEGER NOT NULL);"; return create_table(db, query, "messages"); } bool create_messages_idx_thread_time(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE INDEX IF NOT EXISTS messages_idx_thread_time " "ON messages (thread, time);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating (thread, time) index on messages table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_media_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS media ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "container TEXT NOT NULL, " "thread TEXT NOT NULL, " "uri TEXT NOT NULL, " "type TEXT NOT NULL, " "extras TEXT NOT NULL);"; return create_table(db, query, "media"); } bool create_media_idx_container(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE INDEX IF NOT EXISTS media_idx_container " "ON media (container);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating (container) index on media table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_threads_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS threads ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "type INTEGER NOT NULL, " "name TEXT, " "description TEXT, " "color TEXT NOT NULL, " "creation_time BIGINT NOT NULL, " "parent_thread_id TEXT, " "containing_thread_id TEXT, " "community TEXT, " "members TEXT NOT NULL, " "roles TEXT NOT NULL, " "current_user TEXT NOT NULL, " "source_message_id TEXT, " "replies_count INTEGER NOT NULL);"; return create_table(db, query, "threads"); } bool update_threadID_for_pending_threads_in_drafts(sqlite3 *db) { char *error; sqlite3_exec( db, "UPDATE drafts SET key = " "REPLACE(REPLACE(REPLACE(REPLACE(key, 'type4/', '')," "'type5/', ''),'type6/', ''),'type7/', '')" "WHERE key LIKE 'pending/%'", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error update pending threadIDs on drafts table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool enable_write_ahead_logging_mode(sqlite3 *db) { char *error; sqlite3_exec(db, "PRAGMA journal_mode=wal;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error enabling write-ahead logging mode: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_metadata_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS metadata ( " "name TEXT UNIQUE PRIMARY KEY NOT NULL, " "data TEXT);"; return create_table(db, query, "metadata"); } bool add_not_null_constraint_to_drafts(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS temporary_drafts (" "key TEXT UNIQUE PRIMARY KEY NOT NULL, " "text TEXT NOT NULL);" "INSERT INTO temporary_drafts SELECT * FROM drafts " "WHERE key IS NOT NULL AND text IS NOT NULL;" "DROP TABLE drafts;" "ALTER TABLE temporary_drafts RENAME TO drafts;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding NOT NULL constraint to drafts table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_not_null_constraint_to_metadata(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS temporary_metadata (" "name TEXT UNIQUE PRIMARY KEY NOT NULL, " "data TEXT NOT NULL);" "INSERT INTO temporary_metadata SELECT * FROM metadata " "WHERE data IS NOT NULL;" "DROP TABLE metadata;" "ALTER TABLE temporary_metadata RENAME TO metadata;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding NOT NULL constraint to metadata table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_avatar_column_to_threads_table(sqlite3 *db) { char *error; sqlite3_exec( db, "ALTER TABLE threads ADD COLUMN avatar TEXT;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding avatar column to threads table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_pinned_count_column_to_threads(sqlite3 *db) { sqlite3_stmt *pinned_column_stmt; sqlite3_prepare_v2( db, "SELECT name AS col_name FROM pragma_table_xinfo ('threads') WHERE " "col_name='pinned_count';", -1, &pinned_column_stmt, nullptr); sqlite3_step(pinned_column_stmt); auto num_bytes = sqlite3_column_bytes(pinned_column_stmt, 0); sqlite3_finalize(pinned_column_stmt); if (num_bytes) { return true; } char *error; sqlite3_exec( db, "ALTER TABLE threads ADD COLUMN pinned_count INTEGER NOT NULL DEFAULT 0;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding pinned_count column to threads table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_message_store_threads_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL," " last_navigated_to BIGINT NOT NULL," " last_pruned BIGINT NOT NULL" ");"; return create_table(db, query, "message_store_threads"); } bool create_reports_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS reports (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " report TEXT NOT NULL" ");"; return create_table(db, query, "reports"); } bool create_persist_storage_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS persist_storage (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " item TEXT NOT NULL" ");"; return create_table(db, query, "persist_storage"); } bool recreate_message_store_threads_table(sqlite3 *db) { char *errMsg = 0; // 1. Create table without `last_navigated_to` or `last_pruned`. std::string create_new_table_query = "CREATE TABLE IF NOT EXISTS temp_message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL" ");"; if (sqlite3_exec(db, create_new_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error creating temp_message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 2. Dump data from existing `message_store_threads` table into temp table. std::string copy_data_query = "INSERT INTO temp_message_store_threads (id, start_reached)" "SELECT id, start_reached FROM message_store_threads;"; if (sqlite3_exec(db, copy_data_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error dumping data from existing message_store_threads to " "temp_message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 3. Drop the existing `message_store_threads` table. std::string drop_old_table_query = "DROP TABLE message_store_threads;"; if (sqlite3_exec(db, drop_old_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error dropping message_store_threads table: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 4. Rename the temp table back to `message_store_threads`. std::string rename_table_query = "ALTER TABLE temp_message_store_threads RENAME TO message_store_threads;"; if (sqlite3_exec(db, rename_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error renaming temp_message_store_threads to message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } return true; } bool create_users_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " user_info TEXT NOT NULL" ");"; return create_table(db, query, "users"); } bool create_keyservers_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS keyservers (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");"; return create_table(db, query, "keyservers"); } bool enable_rollback_journal_mode(sqlite3 *db) { char *error; sqlite3_exec(db, "PRAGMA journal_mode=DELETE;", nullptr, nullptr, &error); if (!error) { return true; } std::stringstream error_message; error_message << "Error disabling write-ahead logging mode: " << error; Logger::log(error_message.str()); sqlite3_free(error); return false; } bool create_communities_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS communities (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " community_info TEXT NOT NULL" ");"; return create_table(db, query, "communities"); } bool create_messages_to_device_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS messages_to_device (" " message_id TEXT NOT NULL," " device_id TEXT NOT NULL," " user_id TEXT NOT NULL," " timestamp BIGINT NOT NULL," " plaintext TEXT NOT NULL," " ciphertext TEXT NOT NULL," " PRIMARY KEY (message_id, device_id)" ");" "CREATE INDEX IF NOT EXISTS messages_to_device_idx_id_timestamp" " ON messages_to_device (device_id, timestamp);"; return create_table(db, query, "messages_to_device"); } bool create_integrity_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS integrity_store (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " thread_hash TEXT NOT NULL" ");"; return create_table(db, query, "integrity_store"); } bool create_synced_metadata_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS synced_metadata (" " name TEXT UNIQUE PRIMARY KEY NOT NULL," " data TEXT NOT NULL" ");"; return create_table(db, query, "synced_metadata"); } bool create_keyservers_synced(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS keyservers_synced (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");"; bool success = create_table(db, query, "keyservers_synced"); if (!success) { return false; } std::string copyData = "INSERT INTO keyservers_synced (id, keyserver_info)" "SELECT id, keyserver_info " "FROM keyservers;"; char *error; sqlite3_exec(db, copyData.c_str(), nullptr, nullptr, &error); if (error) { return false; } return true; } bool create_aux_user_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS aux_users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " aux_user_info TEXT NOT NULL" ");"; return create_table(db, query, "aux_users"); } bool add_version_column_to_olm_persist_sessions_table(sqlite3 *db) { char *error; sqlite3_exec( db, "ALTER TABLE olm_persist_sessions " " RENAME COLUMN `target_user_id` TO `target_device_id`; " "ALTER TABLE olm_persist_sessions " " ADD COLUMN version INTEGER NOT NULL DEFAULT 1;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error updating olm_persist_sessions table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_schema(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS drafts (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " text TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS messages (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " local_id TEXT," " thread TEXT NOT NULL," " user TEXT NOT NULL," " type INTEGER NOT NULL," " future_type INTEGER," " content TEXT," " time INTEGER NOT NULL" ");" "CREATE TABLE IF NOT EXISTS olm_persist_account (" " id INTEGER UNIQUE PRIMARY KEY NOT NULL," " account_data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS olm_persist_sessions (" " target_device_id TEXT UNIQUE PRIMARY KEY NOT NULL," " session_data TEXT NOT NULL," " version INTEGER NOT NULL DEFAULT 1" ");" "CREATE TABLE IF NOT EXISTS media (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " container TEXT NOT NULL," " thread TEXT NOT NULL," " uri TEXT NOT NULL," " type TEXT NOT NULL," " extras TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " type INTEGER NOT NULL," " name TEXT," " description TEXT," " color TEXT NOT NULL," " creation_time BIGINT NOT NULL," " parent_thread_id TEXT," " containing_thread_id TEXT," " community TEXT," " members TEXT NOT NULL," " roles TEXT NOT NULL," " current_user TEXT NOT NULL," " source_message_id TEXT," " replies_count INTEGER NOT NULL," " avatar TEXT," " pinned_count INTEGER NOT NULL DEFAULT 0" ");" "CREATE TABLE IF NOT EXISTS metadata (" " name TEXT UNIQUE PRIMARY KEY NOT NULL," " data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL" ");" "CREATE TABLE IF NOT EXISTS reports (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " report TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS persist_storage (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " item TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " user_info TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS keyservers (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS keyservers_synced (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS communities (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " community_info TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS messages_to_device (" " message_id TEXT NOT NULL," " device_id TEXT NOT NULL," " user_id TEXT NOT NULL," " timestamp BIGINT NOT NULL," " plaintext TEXT NOT NULL," " ciphertext TEXT NOT NULL," " PRIMARY KEY (message_id, device_id)" ");" "CREATE TABLE IF NOT EXISTS integrity_store (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " thread_hash TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS synced_metadata (" " name TEXT UNIQUE PRIMARY KEY NOT NULL," " data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS aux_users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " aux_user_info TEXT NOT NULL" ");" "CREATE INDEX IF NOT EXISTS media_idx_container" " ON media (container);" "CREATE INDEX IF NOT EXISTS messages_idx_thread_time" " ON messages (thread, time);" "CREATE INDEX IF NOT EXISTS messages_to_device_idx_id_timestamp" " ON messages_to_device (device_id, timestamp);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating tables: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } void set_encryption_key( sqlite3 *db, const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) { std::string set_encryption_key_query = "PRAGMA key = \"x'" + encryptionKey + "'\";"; char *error_set_key; sqlite3_exec( db, set_encryption_key_query.c_str(), nullptr, nullptr, &error_set_key); if (error_set_key) { std::ostringstream error_message; error_message << "Failed to set encryption key: " << error_set_key; throw std::system_error( ECANCELED, std::generic_category(), error_message.str()); } } int get_database_version(sqlite3 *db) { sqlite3_stmt *user_version_stmt; sqlite3_prepare_v2( db, "PRAGMA user_version;", -1, &user_version_stmt, nullptr); sqlite3_step(user_version_stmt); int current_user_version = sqlite3_column_int(user_version_stmt, 0); sqlite3_finalize(user_version_stmt); return current_user_version; } bool set_database_version(sqlite3 *db, int db_version) { std::stringstream update_version; update_version << "PRAGMA user_version=" << db_version << ";"; auto update_version_str = update_version.str(); char *error; sqlite3_exec(db, update_version_str.c_str(), nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream errorStream; errorStream << "Error setting database version to " << db_version << ": " << error; Logger::log(errorStream.str()); sqlite3_free(error); return false; } // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web void default_on_db_open_callback(sqlite3 *db) { #ifndef EMSCRIPTEN set_encryption_key(db); #endif } // This is a temporary solution. In future we want to keep // a separate table for blob hashes. Tracked on Linear: // https://linear.app/comm/issue/ENG-6261/introduce-blob-hash-table std::string blob_hash_from_blob_service_uri(const std::string &media_uri) { static const std::string blob_service_prefix = "comm-blob-service://"; return media_uri.substr(blob_service_prefix.size()); } bool file_exists(const std::string &file_path) { std::ifstream file(file_path.c_str()); return file.good(); } void attempt_delete_file( const std::string &file_path, const char *error_message) { if (std::remove(file_path.c_str())) { throw std::system_error(errno, std::generic_category(), error_message); } } void attempt_rename_file( const std::string &old_path, const std::string &new_path, const char *error_message) { if (std::rename(old_path.c_str(), new_path.c_str())) { throw std::system_error(errno, std::generic_category(), error_message); } } bool is_database_queryable( sqlite3 *db, bool use_encryption_key, const std::string &path = SQLiteQueryExecutor::sqliteFilePath, const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) { char *err_msg; sqlite3_open(path.c_str(), &db); // According to SQLCipher documentation running some SELECT is the only way to // check for key validity if (use_encryption_key) { set_encryption_key(db, encryptionKey); } sqlite3_exec( db, "SELECT COUNT(*) FROM sqlite_master;", nullptr, nullptr, &err_msg); sqlite3_close(db); return !err_msg; } void validate_encryption() { std::string temp_encrypted_db_path = SQLiteQueryExecutor::sqliteFilePath + "_temp_encrypted"; bool temp_encrypted_exists = file_exists(temp_encrypted_db_path); bool default_location_exists = file_exists(SQLiteQueryExecutor::sqliteFilePath); if (temp_encrypted_exists && default_location_exists) { Logger::log( "Previous encryption attempt failed. Repeating encryption process from " "the beginning."); attempt_delete_file( temp_encrypted_db_path, "Failed to delete corrupted encrypted database."); } else if (temp_encrypted_exists && !default_location_exists) { Logger::log( "Moving temporary encrypted database to default location failed in " "previous encryption attempt. Repeating rename step."); attempt_rename_file( temp_encrypted_db_path, SQLiteQueryExecutor::sqliteFilePath, "Failed to move encrypted database to default location."); return; } else if (!default_location_exists) { Logger::log( "Database not present yet. It will be created encrypted under default " "path."); return; } sqlite3 *db; if (is_database_queryable(db, true)) { Logger::log( "Database exists under default path and it is correctly encrypted."); return; } if (!is_database_queryable(db, false)) { Logger::log( "Database exists but it is encrypted with key that was lost. " "Attempting database deletion. New encrypted one will be created."); attempt_delete_file( SQLiteQueryExecutor::sqliteFilePath.c_str(), "Failed to delete database encrypted with lost key."); return; } else { Logger::log( "Database exists but it is not encrypted. Attempting encryption " "process."); } sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db); std::string createEncryptedCopySQL = "ATTACH DATABASE '" + temp_encrypted_db_path + "' AS encrypted_comm " "KEY \"x'" + SQLiteQueryExecutor::encryptionKey + "'\";" "SELECT sqlcipher_export('encrypted_comm');" "DETACH DATABASE encrypted_comm;"; char *encryption_error; sqlite3_exec( db, createEncryptedCopySQL.c_str(), nullptr, nullptr, &encryption_error); if (encryption_error) { throw std::system_error( ECANCELED, std::generic_category(), "Failed to create encrypted copy of the original database."); } sqlite3_close(db); attempt_delete_file( SQLiteQueryExecutor::sqliteFilePath, "Failed to delete unencrypted database."); attempt_rename_file( temp_encrypted_db_path, SQLiteQueryExecutor::sqliteFilePath, "Failed to move encrypted database to default location."); Logger::log("Encryption completed successfully."); } typedef bool ShouldBeInTransaction; typedef std::function MigrateFunction; typedef std::pair SQLiteMigration; std::vector> migrations{ {{1, {create_drafts_table, true}}, {2, {rename_threadID_to_key, true}}, {4, {create_persist_account_table, true}}, {5, {create_persist_sessions_table, true}}, {15, {create_media_table, true}}, {16, {drop_messages_table, true}}, {17, {recreate_messages_table, true}}, {18, {create_messages_idx_thread_time, true}}, {19, {create_media_idx_container, true}}, {20, {create_threads_table, true}}, {21, {update_threadID_for_pending_threads_in_drafts, true}}, {22, {enable_write_ahead_logging_mode, false}}, {23, {create_metadata_table, true}}, {24, {add_not_null_constraint_to_drafts, true}}, {25, {add_not_null_constraint_to_metadata, true}}, {26, {add_avatar_column_to_threads_table, true}}, {27, {add_pinned_count_column_to_threads, true}}, {28, {create_message_store_threads_table, true}}, {29, {create_reports_table, true}}, {30, {create_persist_storage_table, true}}, {31, {recreate_message_store_threads_table, true}}, {32, {create_users_table, true}}, {33, {create_keyservers_table, true}}, {34, {enable_rollback_journal_mode, false}}, {35, {create_communities_table, true}}, {36, {create_messages_to_device_table, true}}, {37, {create_integrity_table, true}}, {38, {[](sqlite3 *) { return true; }, false}}, {39, {create_synced_metadata_table, true}}, {40, {create_keyservers_synced, true}}, {41, {create_aux_user_table, true}}, {42, {add_version_column_to_olm_persist_sessions_table, true}}}}; enum class MigrationResult { SUCCESS, FAILURE, NOT_APPLIED }; MigrationResult applyMigrationWithTransaction( sqlite3 *db, const MigrateFunction &migrate, int index) { sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto db_version = get_database_version(db); if (index <= db_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::NOT_APPLIED; } auto rc = migrate(db); if (!rc) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } auto database_version_set = set_database_version(db, index); if (!database_version_set) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return MigrationResult::SUCCESS; } MigrationResult applyMigrationWithoutTransaction( sqlite3 *db, const MigrateFunction &migrate, int index) { auto db_version = get_database_version(db); if (index <= db_version) { return MigrationResult::NOT_APPLIED; } auto rc = migrate(db); if (!rc) { return MigrationResult::FAILURE; } sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto inner_db_version = get_database_version(db); if (index <= inner_db_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::NOT_APPLIED; } auto database_version_set = set_database_version(db, index); if (!database_version_set) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return MigrationResult::SUCCESS; } bool set_up_database(sqlite3 *db) { sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto db_version = get_database_version(db); auto latest_version = migrations.back().first; if (db_version == latest_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return true; } if (db_version != 0 || !create_schema(db) || !set_database_version(db, latest_version)) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return false; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return true; } void SQLiteQueryExecutor::migrate() { // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web #ifndef EMSCRIPTEN validate_encryption(); #endif sqlite3 *db; sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db); default_on_db_open_callback(db); std::stringstream db_path; db_path << "db path: " << SQLiteQueryExecutor::sqliteFilePath.c_str() << std::endl; Logger::log(db_path.str()); auto db_version = get_database_version(db); std::stringstream version_msg; version_msg << "db version: " << db_version << std::endl; Logger::log(version_msg.str()); if (db_version == 0) { auto db_created = set_up_database(db); if (!db_created) { sqlite3_close(db); Logger::log("Database structure creation error."); throw std::runtime_error("Database structure creation error"); } Logger::log("Database structure created."); sqlite3_close(db); return; } for (const auto &[idx, migration] : migrations) { const auto &[applyMigration, shouldBeInTransaction] = migration; MigrationResult migrationResult; if (shouldBeInTransaction) { migrationResult = applyMigrationWithTransaction(db, applyMigration, idx); } else { migrationResult = applyMigrationWithoutTransaction(db, applyMigration, idx); } if (migrationResult == MigrationResult::NOT_APPLIED) { continue; } std::stringstream migration_msg; if (migrationResult == MigrationResult::FAILURE) { migration_msg << "migration " << idx << " failed." << std::endl; Logger::log(migration_msg.str()); sqlite3_close(db); throw std::runtime_error(migration_msg.str()); } if (migrationResult == MigrationResult::SUCCESS) { migration_msg << "migration " << idx << " succeeded." << std::endl; Logger::log(migration_msg.str()); } } sqlite3_close(db); } SQLiteQueryExecutor::SQLiteQueryExecutor() { SQLiteQueryExecutor::migrate(); #ifndef EMSCRIPTEN SQLiteQueryExecutor::initializeTablesForLogMonitoring(); std::string currentBackupID = this->getMetadata("backupID"); if (!StaffUtils::isStaffRelease() || !currentBackupID.size()) { return; } SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true); #endif } SQLiteQueryExecutor::SQLiteQueryExecutor(std::string sqliteFilePath) { SQLiteQueryExecutor::sqliteFilePath = sqliteFilePath; SQLiteQueryExecutor::migrate(); } sqlite3 *SQLiteQueryExecutor::getConnection() { if (SQLiteQueryExecutor::connectionManager.getConnection()) { return SQLiteQueryExecutor::connectionManager.getConnection(); } SQLiteQueryExecutor::connectionManager.initializeConnection( SQLiteQueryExecutor::sqliteFilePath, default_on_db_open_callback); return SQLiteQueryExecutor::connectionManager.getConnection(); } void SQLiteQueryExecutor::closeConnection() { SQLiteQueryExecutor::connectionManager.closeConnection(); } SQLiteQueryExecutor::~SQLiteQueryExecutor() { SQLiteQueryExecutor::closeConnection(); } std::string SQLiteQueryExecutor::getDraft(std::string key) const { static std::string getDraftByPrimaryKeySQL = "SELECT * " "FROM drafts " "WHERE key = ?;"; std::unique_ptr draft = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getDraftByPrimaryKeySQL, key); return (draft == nullptr) ? "" : draft->text; } std::unique_ptr SQLiteQueryExecutor::getThread(std::string threadID) const { static std::string getThreadByPrimaryKeySQL = "SELECT * " "FROM threads " "WHERE id = ?;"; return getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getThreadByPrimaryKeySQL, threadID); } void SQLiteQueryExecutor::updateDraft(std::string key, std::string text) const { static std::string replaceDraftSQL = "REPLACE INTO drafts (key, text) " "VALUES (?, ?);"; Draft draft = {key, text}; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceDraftSQL, draft); } bool SQLiteQueryExecutor::moveDraft(std::string oldKey, std::string newKey) const { std::string draftText = this->getDraft(oldKey); if (!draftText.size()) { return false; } static std::string rekeyDraftSQL = "UPDATE OR REPLACE drafts " "SET key = ? " "WHERE key = ?;"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyDraftSQL, oldKey, newKey); return true; } std::vector SQLiteQueryExecutor::getAllDrafts() const { static std::string getAllDraftsSQL = "SELECT * " "FROM drafts;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllDraftsSQL); } void SQLiteQueryExecutor::removeAllDrafts() const { static std::string removeAllDraftsSQL = "DELETE FROM drafts;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllDraftsSQL); } void SQLiteQueryExecutor::removeDrafts( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeDraftsByKeysSQLStream; removeDraftsByKeysSQLStream << "DELETE FROM drafts " "WHERE key IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeDraftsByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::removeAllMessages() const { static std::string removeAllMessagesSQL = "DELETE FROM messages;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMessagesSQL); } std::vector>> SQLiteQueryExecutor::getAllMessages() const { static std::string getAllMessagesSQL = "SELECT * " "FROM messages " "LEFT JOIN media " " ON messages.id = media.container " "ORDER BY messages.id;"; SQLiteStatementWrapper preparedSQL( SQLiteQueryExecutor::getConnection(), getAllMessagesSQL, "Failed to retrieve all messages."); std::string prevMsgIdx{}; std::vector>> allMessages; for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW; stepResult = sqlite3_step(preparedSQL)) { Message message = Message::fromSQLResult(preparedSQL, 0); if (message.id == prevMsgIdx) { allMessages.back().second.push_back(Media::fromSQLResult(preparedSQL, 8)); } else { prevMsgIdx = message.id; std::vector mediaForMsg; if (sqlite3_column_type(preparedSQL, 8) != SQLITE_NULL) { mediaForMsg.push_back(Media::fromSQLResult(preparedSQL, 8)); } allMessages.push_back(std::make_pair(std::move(message), mediaForMsg)); } } return allMessages; } void SQLiteQueryExecutor::removeMessages( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeMessagesByKeysSQLStream; removeMessagesByKeysSQLStream << "DELETE FROM messages " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessagesByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::removeMessagesForThreads( const std::vector &threadIDs) const { if (!threadIDs.size()) { return; } std::stringstream removeMessagesByKeysSQLStream; removeMessagesByKeysSQLStream << "DELETE FROM messages " "WHERE thread IN " << getSQLStatementArray(threadIDs.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessagesByKeysSQLStream.str(), threadIDs); } void SQLiteQueryExecutor::replaceMessage(const Message &message) const { static std::string replaceMessageSQL = "REPLACE INTO messages " "(id, local_id, thread, user, type, future_type, content, time) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMessageSQL, message); } void SQLiteQueryExecutor::rekeyMessage(std::string from, std::string to) const { static std::string rekeyMessageSQL = "UPDATE OR REPLACE messages " "SET id = ? " "WHERE id = ?"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyMessageSQL, from, to); } void SQLiteQueryExecutor::removeAllMedia() const { static std::string removeAllMediaSQL = "DELETE FROM media;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMediaSQL); } void SQLiteQueryExecutor::removeMediaForMessages( const std::vector &msg_ids) const { if (!msg_ids.size()) { return; } std::stringstream removeMediaByKeysSQLStream; removeMediaByKeysSQLStream << "DELETE FROM media " "WHERE container IN " << getSQLStatementArray(msg_ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeysSQLStream.str(), msg_ids); } void SQLiteQueryExecutor::removeMediaForMessage(std::string msg_id) const { static std::string removeMediaByKeySQL = "DELETE FROM media " "WHERE container IN (?);"; std::vector keys = {msg_id}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeySQL, keys); } void SQLiteQueryExecutor::removeMediaForThreads( const std::vector &thread_ids) const { if (!thread_ids.size()) { return; } std::stringstream removeMediaByKeysSQLStream; removeMediaByKeysSQLStream << "DELETE FROM media " "WHERE thread IN " << getSQLStatementArray(thread_ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeysSQLStream.str(), thread_ids); } void SQLiteQueryExecutor::replaceMedia(const Media &media) const { static std::string replaceMediaSQL = "REPLACE INTO media " "(id, container, thread, uri, type, extras) " "VALUES (?, ?, ?, ?, ?, ?)"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMediaSQL, media); } void SQLiteQueryExecutor::rekeyMediaContainers(std::string from, std::string to) const { static std::string rekeyMediaContainersSQL = "UPDATE media SET container = ? WHERE container = ?;"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyMediaContainersSQL, from, to); } void SQLiteQueryExecutor::replaceMessageStoreThreads( const std::vector &threads) const { static std::string replaceMessageStoreThreadSQL = "REPLACE INTO message_store_threads " "(id, start_reached) " "VALUES (?, ?);"; for (auto &thread : threads) { replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMessageStoreThreadSQL, thread); } } void SQLiteQueryExecutor::removeAllMessageStoreThreads() const { static std::string removeAllMessageStoreThreadsSQL = "DELETE FROM message_store_threads;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllMessageStoreThreadsSQL); } void SQLiteQueryExecutor::removeMessageStoreThreads( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeMessageStoreThreadsByKeysSQLStream; removeMessageStoreThreadsByKeysSQLStream << "DELETE FROM message_store_threads " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessageStoreThreadsByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllMessageStoreThreads() const { static std::string getAllMessageStoreThreadsSQL = "SELECT * " "FROM message_store_threads;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllMessageStoreThreadsSQL); } std::vector SQLiteQueryExecutor::getAllThreads() const { static std::string getAllThreadsSQL = "SELECT * " "FROM threads;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllThreadsSQL); }; void SQLiteQueryExecutor::removeThreads(std::vector ids) const { if (!ids.size()) { return; } std::stringstream removeThreadsByKeysSQLStream; removeThreadsByKeysSQLStream << "DELETE FROM threads " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeThreadsByKeysSQLStream.str(), ids); }; void SQLiteQueryExecutor::replaceThread(const Thread &thread) const { static std::string replaceThreadSQL = "REPLACE INTO threads (" " id, type, name, description, color, creation_time, parent_thread_id," " containing_thread_id, community, members, roles, current_user," " source_message_id, replies_count, avatar, pinned_count) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceThreadSQL, thread); }; void SQLiteQueryExecutor::removeAllThreads() const { static std::string removeAllThreadsSQL = "DELETE FROM threads;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllThreadsSQL); }; void SQLiteQueryExecutor::replaceReport(const Report &report) const { static std::string replaceReportSQL = "REPLACE INTO reports (id, report) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceReportSQL, report); } void SQLiteQueryExecutor::removeAllReports() const { static std::string removeAllReportsSQL = "DELETE FROM reports;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllReportsSQL); } void SQLiteQueryExecutor::removeReports( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeReportsByKeysSQLStream; removeReportsByKeysSQLStream << "DELETE FROM reports " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeReportsByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllReports() const { static std::string getAllReportsSQL = "SELECT * " "FROM reports;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllReportsSQL); } void SQLiteQueryExecutor::setPersistStorageItem( std::string key, std::string item) const { static std::string replacePersistStorageItemSQL = "REPLACE INTO persist_storage (key, item) " "VALUES (?, ?);"; PersistItem entry{ key, item, }; replaceEntity( SQLiteQueryExecutor::getConnection(), replacePersistStorageItemSQL, entry); } void SQLiteQueryExecutor::removePersistStorageItem(std::string key) const { static std::string removePersistStorageItemByKeySQL = "DELETE FROM persist_storage " "WHERE key IN (?);"; std::vector keys = {key}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removePersistStorageItemByKeySQL, keys); } std::string SQLiteQueryExecutor::getPersistStorageItem(std::string key) const { static std::string getPersistStorageItemByPrimaryKeySQL = "SELECT * " "FROM persist_storage " "WHERE key = ?;"; std::unique_ptr entry = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getPersistStorageItemByPrimaryKeySQL, key); return (entry == nullptr) ? "" : entry->item; } void SQLiteQueryExecutor::replaceUser(const UserInfo &user_info) const { static std::string replaceUserSQL = "REPLACE INTO users (id, user_info) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceUserSQL, user_info); } void SQLiteQueryExecutor::removeAllUsers() const { static std::string removeAllUsersSQL = "DELETE FROM users;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllUsersSQL); } void SQLiteQueryExecutor::removeUsers( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeUsersByKeysSQLStream; removeUsersByKeysSQLStream << "DELETE FROM users " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeUsersByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::replaceKeyserver( const KeyserverInfo &keyserver_info) const { static std::string replaceKeyserverSQL = "REPLACE INTO keyservers (id, keyserver_info) " "VALUES (:id, :keyserver_info);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceKeyserverSQL, keyserver_info); static std::string replaceKeyserverSyncedSQL = "REPLACE INTO keyservers_synced (id, keyserver_info) " "VALUES (:id, :synced_keyserver_info);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceKeyserverSyncedSQL, keyserver_info); } void SQLiteQueryExecutor::removeAllKeyservers() const { static std::string removeAllKeyserversSQL = "DELETE FROM keyservers;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllKeyserversSQL); static std::string removeAllKeyserversSyncedSQL = "DELETE FROM keyservers_synced;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllKeyserversSyncedSQL); } void SQLiteQueryExecutor::removeKeyservers( const std::vector &ids) const { if (!ids.size()) { return; } auto idArray = getSQLStatementArray(ids.size()); std::stringstream removeKeyserversByKeysSQLStream; removeKeyserversByKeysSQLStream << "DELETE FROM keyservers " "WHERE id IN " << idArray << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeKeyserversByKeysSQLStream.str(), ids); std::stringstream removeKeyserversSyncedByKeysSQLStream; removeKeyserversSyncedByKeysSQLStream << "DELETE FROM keyservers_synced " "WHERE id IN " << idArray << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeKeyserversSyncedByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllKeyservers() const { static std::string getAllKeyserversSQL = "SELECT " " synced.id, " " COALESCE(keyservers.keyserver_info, ''), " " synced.keyserver_info " "FROM keyservers_synced synced " "LEFT JOIN keyservers " " ON synced.id = keyservers.id;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllKeyserversSQL); } std::vector SQLiteQueryExecutor::getAllUsers() const { static std::string getAllUsersSQL = "SELECT * " "FROM users;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllUsersSQL); } void SQLiteQueryExecutor::replaceCommunity( const CommunityInfo &community_info) const { static std::string replaceCommunitySQL = "REPLACE INTO communities (id, community_info) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceCommunitySQL, community_info); } void SQLiteQueryExecutor::removeAllCommunities() const { static std::string removeAllCommunitiesSQL = "DELETE FROM communities;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllCommunitiesSQL); } void SQLiteQueryExecutor::removeCommunities( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeCommunitiesByKeysSQLStream; removeCommunitiesByKeysSQLStream << "DELETE FROM communities " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeCommunitiesByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllCommunities() const { static std::string getAllCommunitiesSQL = "SELECT * " "FROM communities;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllCommunitiesSQL); } void SQLiteQueryExecutor::replaceIntegrityThreadHashes( const std::vector &threadHashes) const { static std::string replaceIntegrityThreadHashSQL = "REPLACE INTO integrity_store (id, thread_hash) " "VALUES (?, ?);"; for (const IntegrityThreadHash &integrityThreadHash : threadHashes) { replaceEntity( SQLiteQueryExecutor::getConnection(), replaceIntegrityThreadHashSQL, integrityThreadHash); } } void SQLiteQueryExecutor::removeAllIntegrityThreadHashes() const { static std::string removeAllIntegrityThreadHashesSQL = "DELETE FROM integrity_store;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllIntegrityThreadHashesSQL); } void SQLiteQueryExecutor::removeIntegrityThreadHashes( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeIntegrityThreadHashesByKeysSQLStream; removeIntegrityThreadHashesByKeysSQLStream << "DELETE FROM integrity_store " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeIntegrityThreadHashesByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllIntegrityThreadHashes() const { static std::string getAllIntegrityThreadHashesSQL = "SELECT * " "FROM integrity_store;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllIntegrityThreadHashesSQL); } void SQLiteQueryExecutor::replaceSyncedMetadataEntry( const SyncedMetadataEntry &synced_metadata_entry) const { static std::string replaceSyncedMetadataEntrySQL = "REPLACE INTO synced_metadata (name, data) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceSyncedMetadataEntrySQL, synced_metadata_entry); } void SQLiteQueryExecutor::removeAllSyncedMetadata() const { static std::string removeAllSyncedMetadataSQL = "DELETE FROM synced_metadata;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllSyncedMetadataSQL); } void SQLiteQueryExecutor::removeSyncedMetadata( const std::vector &names) const { if (!names.size()) { return; } std::stringstream removeSyncedMetadataByNamesSQLStream; removeSyncedMetadataByNamesSQLStream << "DELETE FROM synced_metadata " "WHERE name IN " << getSQLStatementArray(names.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeSyncedMetadataByNamesSQLStream.str(), names); } std::vector SQLiteQueryExecutor::getAllSyncedMetadata() const { static std::string getAllSyncedMetadataSQL = "SELECT * " "FROM synced_metadata;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllSyncedMetadataSQL); } void SQLiteQueryExecutor::replaceAuxUserInfo( const AuxUserInfo &aux_user_info) const { static std::string replaceAuxUserInfoSQL = "REPLACE INTO aux_users (id, aux_user_info) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceAuxUserInfoSQL, aux_user_info); } void SQLiteQueryExecutor::removeAllAuxUserInfos() const { static std::string removeAllAuxUserInfosSQL = "DELETE FROM aux_users;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllAuxUserInfosSQL); } void SQLiteQueryExecutor::removeAuxUserInfos( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeAuxUserInfosByKeysSQLStream; removeAuxUserInfosByKeysSQLStream << "DELETE FROM aux_users " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeAuxUserInfosByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllAuxUserInfos() const { static std::string getAllAuxUserInfosSQL = "SELECT * " "FROM aux_users;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllAuxUserInfosSQL); } void SQLiteQueryExecutor::beginTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "BEGIN TRANSACTION;"); } void SQLiteQueryExecutor::commitTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "COMMIT;"); } void SQLiteQueryExecutor::rollbackTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "ROLLBACK;"); } int SQLiteQueryExecutor::getContentAccountID() const { return CONTENT_ACCOUNT_ID; } int SQLiteQueryExecutor::getNotifsAccountID() const { return NOTIFS_ACCOUNT_ID; } std::vector SQLiteQueryExecutor::getOlmPersistSessionsData() const { static std::string getAllOlmPersistSessionsSQL = "SELECT * " "FROM olm_persist_sessions;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllOlmPersistSessionsSQL); } std::optional SQLiteQueryExecutor::getOlmPersistAccountData(int accountID) const { static std::string getOlmPersistAccountSQL = "SELECT * " "FROM olm_persist_account " "WHERE id = ?;"; std::unique_ptr result = getEntityByIntegerPrimaryKey( SQLiteQueryExecutor::getConnection(), getOlmPersistAccountSQL, accountID); if (result == nullptr) { return std::nullopt; } return result->account_data; } void SQLiteQueryExecutor::storeOlmPersistAccount( int accountID, const std::string &accountData) const { static std::string replaceOlmPersistAccountSQL = "REPLACE INTO olm_persist_account (id, account_data) " "VALUES (?, ?);"; OlmPersistAccount persistAccount = {accountID, accountData}; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceOlmPersistAccountSQL, persistAccount); } void SQLiteQueryExecutor::storeOlmPersistSession( const OlmPersistSession &session) const { static std::string replaceOlmPersistSessionSQL = "REPLACE INTO olm_persist_sessions " "(target_device_id, session_data, version) " "VALUES (?, ?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceOlmPersistSessionSQL, session); } void SQLiteQueryExecutor::storeOlmPersistData( int accountID, crypto::Persist persist) const { if (accountID != CONTENT_ACCOUNT_ID && persist.sessions.size() > 0) { throw std::runtime_error( "Attempt to store notifications sessions in SQLite. Notifications " "sessions must be stored in storage shared with NSE."); } std::string accountData = std::string(persist.account.begin(), persist.account.end()); this->storeOlmPersistAccount(accountID, accountData); for (auto it = persist.sessions.begin(); it != persist.sessions.end(); it++) { OlmPersistSession persistSession = { - it->first, std::string(it->second.begin(), it->second.end())}; + it->first, + std::string(it->second.buffer.begin(), it->second.buffer.end()), + it->second.version}; + this->storeOlmPersistSession(persistSession); } } void SQLiteQueryExecutor::setNotifyToken(std::string token) const { this->setMetadata("notify_token", token); } void SQLiteQueryExecutor::clearNotifyToken() const { this->clearMetadata("notify_token"); } void SQLiteQueryExecutor::setCurrentUserID(std::string userID) const { this->setMetadata("current_user_id", userID); } std::string SQLiteQueryExecutor::getCurrentUserID() const { return this->getMetadata("current_user_id"); } void SQLiteQueryExecutor::setMetadata(std::string entry_name, std::string data) const { std::string replaceMetadataSQL = "REPLACE INTO metadata (name, data) " "VALUES (?, ?);"; Metadata entry{ entry_name, data, }; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMetadataSQL, entry); } void SQLiteQueryExecutor::clearMetadata(std::string entry_name) const { static std::string removeMetadataByKeySQL = "DELETE FROM metadata " "WHERE name IN (?);"; std::vector keys = {entry_name}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMetadataByKeySQL, keys); } std::string SQLiteQueryExecutor::getMetadata(std::string entry_name) const { std::string getMetadataByPrimaryKeySQL = "SELECT * " "FROM metadata " "WHERE name = ?;"; std::unique_ptr entry = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getMetadataByPrimaryKeySQL, entry_name); return (entry == nullptr) ? "" : entry->data; } void SQLiteQueryExecutor::addMessagesToDevice( const std::vector &messages) const { static std::string addMessageToDevice = "REPLACE INTO messages_to_device (" " message_id, device_id, user_id, timestamp, plaintext, ciphertext) " "VALUES (?, ?, ?, ?, ?, ?);"; for (const ClientMessageToDevice &clientMessage : messages) { MessageToDevice message = clientMessage.toMessageToDevice(); replaceEntity( SQLiteQueryExecutor::getConnection(), addMessageToDevice, message); } } std::vector SQLiteQueryExecutor::getAllMessagesToDevice(const std::string &deviceID) const { std::string query = "SELECT * FROM messages_to_device " "WHERE device_id = ? " "ORDER BY timestamp;"; SQLiteStatementWrapper preparedSQL( SQLiteQueryExecutor::getConnection(), query, "Failed to get all messages to device"); sqlite3_bind_text(preparedSQL, 1, deviceID.c_str(), -1, SQLITE_TRANSIENT); std::vector messages; for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW; stepResult = sqlite3_step(preparedSQL)) { messages.emplace_back( ClientMessageToDevice(MessageToDevice::fromSQLResult(preparedSQL, 0))); } return messages; } void SQLiteQueryExecutor::removeMessagesToDeviceOlderThan( const ClientMessageToDevice &lastConfirmedMessageClient) const { static std::string query = "DELETE FROM messages_to_device " "WHERE timestamp <= ? AND device_id IN (?);"; MessageToDevice lastConfirmedMessage = lastConfirmedMessageClient.toMessageToDevice(); comm::SQLiteStatementWrapper preparedSQL( SQLiteQueryExecutor::getConnection(), query, "Failed to remove messages to device"); sqlite3_bind_int64(preparedSQL, 1, lastConfirmedMessage.timestamp); sqlite3_bind_text( preparedSQL, 2, lastConfirmedMessage.device_id.c_str(), -1, SQLITE_TRANSIENT); int result = sqlite3_step(preparedSQL); if (result != SQLITE_DONE) { throw std::runtime_error( "Failed to execute removeMessagesToDeviceOlderThan statement: " + std::string(sqlite3_errmsg(SQLiteQueryExecutor::getConnection()))); } } void SQLiteQueryExecutor::removeAllMessagesForDevice( const std::string &deviceID) const { static std::string removeMessagesSQL = "DELETE FROM messages_to_device " "WHERE device_id IN (?);"; std::vector keys = {deviceID}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessagesSQL, keys); } #ifdef EMSCRIPTEN std::vector SQLiteQueryExecutor::getAllThreadsWeb() const { auto threads = this->getAllThreads(); std::vector webThreads; webThreads.reserve(threads.size()); for (const auto &thread : threads) { webThreads.emplace_back(thread); } return webThreads; }; void SQLiteQueryExecutor::replaceThreadWeb(const WebThread &thread) const { this->replaceThread(thread.toThread()); }; std::vector SQLiteQueryExecutor::getAllMessagesWeb() const { auto allMessages = this->getAllMessages(); std::vector allMessageWithMedias; for (auto &messageWitMedia : allMessages) { allMessageWithMedias.push_back( {std::move(messageWitMedia.first), messageWitMedia.second}); } return allMessageWithMedias; } void SQLiteQueryExecutor::replaceMessageWeb(const WebMessage &message) const { this->replaceMessage(message.toMessage()); }; NullableString SQLiteQueryExecutor::getOlmPersistAccountDataWeb(int accountID) const { std::optional accountData = this->getOlmPersistAccountData(accountID); if (!accountData.has_value()) { return NullableString(); } return std::make_unique(accountData.value()); } #else void SQLiteQueryExecutor::clearSensitiveData() { SQLiteQueryExecutor::closeConnection(); if (file_exists(SQLiteQueryExecutor::sqliteFilePath) && std::remove(SQLiteQueryExecutor::sqliteFilePath.c_str())) { std::ostringstream errorStream; errorStream << "Failed to delete database file. Details: " << strerror(errno); throw std::system_error(errno, std::generic_category(), errorStream.str()); } SQLiteQueryExecutor::generateFreshEncryptionKey(); SQLiteQueryExecutor::migrate(); } void SQLiteQueryExecutor::initialize(std::string &databasePath) { std::call_once(SQLiteQueryExecutor::initialized, [&databasePath]() { SQLiteQueryExecutor::sqliteFilePath = databasePath; folly::Optional maybeEncryptionKey = CommSecureStore::get(SQLiteQueryExecutor::secureStoreEncryptionKeyID); folly::Optional maybeBackupLogsEncryptionKey = CommSecureStore::get( SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID); if (file_exists(databasePath) && maybeEncryptionKey && maybeBackupLogsEncryptionKey) { SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value(); SQLiteQueryExecutor::backupLogsEncryptionKey = maybeBackupLogsEncryptionKey.value(); return; } else if (file_exists(databasePath) && maybeEncryptionKey) { SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value(); SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey(); return; } SQLiteQueryExecutor::generateFreshEncryptionKey(); }); } void SQLiteQueryExecutor::initializeTablesForLogMonitoring() { sqlite3 *db; sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db); default_on_db_open_callback(db); std::vector tablesToMonitor; { SQLiteStatementWrapper preparedSQL( db, "SELECT name FROM sqlite_master WHERE type='table';", "Failed to get all database tables"); for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW; stepResult = sqlite3_step(preparedSQL)) { std::string table_name = reinterpret_cast(sqlite3_column_text(preparedSQL, 0)); if (SQLiteQueryExecutor::backedUpTablesBlocklist.find(table_name) == SQLiteQueryExecutor::backedUpTablesBlocklist.end()) { tablesToMonitor.emplace_back(table_name); } } // Runs preparedSQL destructor which finalizes the sqlite statement } sqlite3_close(db); SQLiteQueryExecutor::connectionManager.tablesToMonitor = tablesToMonitor; } void SQLiteQueryExecutor::createMainCompaction(std::string backupID) const { std::string finalBackupPath = PlatformSpecificTools::getBackupFilePath(backupID, false); std::string finalAttachmentsPath = PlatformSpecificTools::getBackupFilePath(backupID, true); std::string tempBackupPath = finalBackupPath + "_tmp"; std::string tempAttachmentsPath = finalAttachmentsPath + "_tmp"; if (file_exists(tempBackupPath)) { Logger::log( "Attempting to delete temporary backup file from previous backup " "attempt."); attempt_delete_file( tempBackupPath, "Failed to delete temporary backup file from previous backup attempt."); } if (file_exists(tempAttachmentsPath)) { Logger::log( "Attempting to delete temporary attachments file from previous backup " "attempt."); attempt_delete_file( tempAttachmentsPath, "Failed to delete temporary attachments file from previous backup " "attempt."); } sqlite3 *backupDB; sqlite3_open(tempBackupPath.c_str(), &backupDB); set_encryption_key(backupDB); sqlite3_backup *backupObj = sqlite3_backup_init( backupDB, "main", SQLiteQueryExecutor::getConnection(), "main"); if (!backupObj) { std::stringstream error_message; error_message << "Failed to init backup for main compaction. Details: " << sqlite3_errmsg(backupDB) << std::endl; sqlite3_close(backupDB); throw std::runtime_error(error_message.str()); } int backupResult = sqlite3_backup_step(backupObj, -1); sqlite3_backup_finish(backupObj); if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) { sqlite3_close(backupDB); throw std::runtime_error( "Programmer error. Database in transaction during backup attempt."); } else if (backupResult != SQLITE_DONE) { sqlite3_close(backupDB); std::stringstream error_message; error_message << "Failed to create database backup. Details: " << sqlite3_errstr(backupResult); throw std::runtime_error(error_message.str()); } if (!SQLiteQueryExecutor::backedUpTablesBlocklist.empty()) { std::string removeDeviceSpecificDataSQL = ""; for (const auto &table_name : SQLiteQueryExecutor::backedUpTablesBlocklist) { removeDeviceSpecificDataSQL.append("DELETE FROM " + table_name + ";\n"); } executeQuery(backupDB, removeDeviceSpecificDataSQL); } executeQuery(backupDB, "VACUUM;"); sqlite3_close(backupDB); attempt_rename_file( tempBackupPath, finalBackupPath, "Failed to rename complete temporary backup file to final backup file."); std::ofstream tempAttachmentsFile(tempAttachmentsPath); if (!tempAttachmentsFile.is_open()) { throw std::runtime_error( "Unable to create attachments file for backup id: " + backupID); } std::string getAllBlobServiceMediaSQL = "SELECT * FROM media WHERE uri LIKE 'comm-blob-service://%';"; std::vector blobServiceMedia = getAllEntities( SQLiteQueryExecutor::getConnection(), getAllBlobServiceMediaSQL); for (const auto &media : blobServiceMedia) { std::string blobServiceURI = media.uri; std::string blobHash = blob_hash_from_blob_service_uri(blobServiceURI); tempAttachmentsFile << blobHash << "\n"; } tempAttachmentsFile.close(); attempt_rename_file( tempAttachmentsPath, finalAttachmentsPath, "Failed to rename complete temporary attachments file to final " "attachments file."); this->setMetadata("backupID", backupID); this->clearMetadata("logID"); if (StaffUtils::isStaffRelease()) { SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true); } } void SQLiteQueryExecutor::generateFreshEncryptionKey() { std::string encryptionKey = comm::crypto::Tools::generateRandomHexString( SQLiteQueryExecutor::sqlcipherEncryptionKeySize); CommSecureStore::set( SQLiteQueryExecutor::secureStoreEncryptionKeyID, encryptionKey); SQLiteQueryExecutor::encryptionKey = encryptionKey; SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey(); } void SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey() { std::string backupLogsEncryptionKey = comm::crypto::Tools::generateRandomHexString( SQLiteQueryExecutor::backupLogsEncryptionKeySize); CommSecureStore::set( SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID, backupLogsEncryptionKey); SQLiteQueryExecutor::backupLogsEncryptionKey = backupLogsEncryptionKey; } void SQLiteQueryExecutor::captureBackupLogs() const { std::string backupID = this->getMetadata("backupID"); if (!backupID.size()) { return; } std::string logID = this->getMetadata("logID"); if (!logID.size()) { logID = "1"; } bool newLogCreated = SQLiteQueryExecutor::connectionManager.captureLogs( backupID, logID, SQLiteQueryExecutor::backupLogsEncryptionKey); if (!newLogCreated) { return; } this->setMetadata("logID", std::to_string(std::stoi(logID) + 1)); } #endif void SQLiteQueryExecutor::restoreFromMainCompaction( std::string mainCompactionPath, std::string mainCompactionEncryptionKey) const { if (!file_exists(mainCompactionPath)) { throw std::runtime_error("Restore attempt but backup file does not exist."); } sqlite3 *backupDB; if (!is_database_queryable( backupDB, true, mainCompactionPath, mainCompactionEncryptionKey)) { throw std::runtime_error("Backup file or encryption key corrupted."); } // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web #ifdef EMSCRIPTEN std::string plaintextBackupPath = mainCompactionPath + "_plaintext"; if (file_exists(plaintextBackupPath)) { attempt_delete_file( plaintextBackupPath, "Failed to delete plaintext backup file from previous backup attempt."); } std::string plaintextMigrationDBQuery = "PRAGMA key = \"x'" + mainCompactionEncryptionKey + "'\";" "ATTACH DATABASE '" + plaintextBackupPath + "' AS plaintext KEY '';" "SELECT sqlcipher_export('plaintext');" "DETACH DATABASE plaintext;"; sqlite3_open(mainCompactionPath.c_str(), &backupDB); char *plaintextMigrationErr; sqlite3_exec( backupDB, plaintextMigrationDBQuery.c_str(), nullptr, nullptr, &plaintextMigrationErr); sqlite3_close(backupDB); if (plaintextMigrationErr) { std::stringstream error_message; error_message << "Failed to migrate backup SQLCipher file to plaintext " "SQLite file. Details" << plaintextMigrationErr << std::endl; std::string error_message_str = error_message.str(); sqlite3_free(plaintextMigrationErr); throw std::runtime_error(error_message_str); } sqlite3_open(plaintextBackupPath.c_str(), &backupDB); #else sqlite3_open(mainCompactionPath.c_str(), &backupDB); set_encryption_key(backupDB, mainCompactionEncryptionKey); #endif sqlite3_backup *backupObj = sqlite3_backup_init( SQLiteQueryExecutor::getConnection(), "main", backupDB, "main"); if (!backupObj) { std::stringstream error_message; error_message << "Failed to init backup for main compaction. Details: " << sqlite3_errmsg(SQLiteQueryExecutor::getConnection()) << std::endl; sqlite3_close(backupDB); throw std::runtime_error(error_message.str()); } int backupResult = sqlite3_backup_step(backupObj, -1); sqlite3_backup_finish(backupObj); sqlite3_close(backupDB); if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) { throw std::runtime_error( "Programmer error. Database in transaction during restore attempt."); } else if (backupResult != SQLITE_DONE) { std::stringstream error_message; error_message << "Failed to restore database from backup. Details: " << sqlite3_errstr(backupResult); throw std::runtime_error(error_message.str()); } #ifdef EMSCRIPTEN attempt_delete_file( plaintextBackupPath, "Failed to delete plaintext compaction file after successful restore."); #endif attempt_delete_file( mainCompactionPath, "Failed to delete main compaction file after successful restore."); } void SQLiteQueryExecutor::restoreFromBackupLog( const std::vector &backupLog) const { SQLiteQueryExecutor::connectionManager.restoreFromBackupLog(backupLog); } } // namespace comm diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp index 03673f50f..39e158ea4 100644 --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -1,1831 +1,1833 @@ #include "CommCoreModule.h" #include "../Notifications/BackgroundDataStorage/NotificationsCryptoModule.h" #include "BaseDataStore.h" #include "CommServicesAuthMetadataEmitter.h" #include "DatabaseManager.h" #include "InternalModules/GlobalDBSingleton.h" #include "InternalModules/RustPromiseManager.h" #include "NativeModuleUtils.h" #include "TerminateApp.h" #include #include #include #include #include "JSIRust.h" #include "lib.rs.h" #include namespace comm { using namespace facebook::react; jsi::Value CommCoreModule::getDraft(jsi::Runtime &rt, jsi::String key) { std::string keyStr = key.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string draftStr; try { draftStr = DatabaseManager::getQueryExecutor().getDraft(keyStr); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } jsi::String draft = jsi::String::createFromUtf8(innerRt, draftStr); promise->resolve(std::move(draft)); }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::updateDraft( jsi::Runtime &rt, jsi::String key, jsi::String text) { std::string keyStr = key.utf8(rt); std::string textStr = text.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=]() { std::string error; try { DatabaseManager::getQueryExecutor().updateDraft(keyStr, textStr); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=]() { if (error.size()) { promise->reject(error); } else { promise->resolve(true); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::moveDraft( jsi::Runtime &rt, jsi::String oldKey, jsi::String newKey) { std::string oldKeyStr = oldKey.utf8(rt); std::string newKeyStr = newKey.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=]() { std::string error; bool result = false; try { result = DatabaseManager::getQueryExecutor().moveDraft( oldKeyStr, newKeyStr); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=]() { if (error.size()) { promise->reject(error); } else { promise->resolve(result); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::getClientDBStore(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::vector draftsVector; std::vector threadsVector; std::vector>> messagesVector; std::vector messageStoreThreadsVector; std::vector reportStoreVector; std::vector userStoreVector; std::vector keyserverStoreVector; std::vector communityStoreVector; std::vector integrityStoreVector; std::vector syncedMetadataStoreVector; std::vector auxUserStoreVector; try { draftsVector = DatabaseManager::getQueryExecutor().getAllDrafts(); messagesVector = DatabaseManager::getQueryExecutor().getAllMessages(); threadsVector = DatabaseManager::getQueryExecutor().getAllThreads(); messageStoreThreadsVector = DatabaseManager::getQueryExecutor().getAllMessageStoreThreads(); reportStoreVector = DatabaseManager::getQueryExecutor().getAllReports(); userStoreVector = DatabaseManager::getQueryExecutor().getAllUsers(); keyserverStoreVector = DatabaseManager::getQueryExecutor().getAllKeyservers(); communityStoreVector = DatabaseManager::getQueryExecutor().getAllCommunities(); integrityStoreVector = DatabaseManager::getQueryExecutor() .getAllIntegrityThreadHashes(); syncedMetadataStoreVector = DatabaseManager::getQueryExecutor().getAllSyncedMetadata(); auxUserStoreVector = DatabaseManager::getQueryExecutor().getAllAuxUserInfos(); } catch (std::system_error &e) { error = e.what(); } auto draftsVectorPtr = std::make_shared>(std::move(draftsVector)); auto messagesVectorPtr = std::make_shared< std::vector>>>( std::move(messagesVector)); auto threadsVectorPtr = std::make_shared>(std::move(threadsVector)); auto messageStoreThreadsVectorPtr = std::make_shared>( std::move(messageStoreThreadsVector)); auto reportStoreVectorPtr = std::make_shared>( std::move(reportStoreVector)); auto userStoreVectorPtr = std::make_shared>( std::move(userStoreVector)); auto keyserveStoreVectorPtr = std::make_shared>( std::move(keyserverStoreVector)); auto communityStoreVectorPtr = std::make_shared>( std::move(communityStoreVector)); auto integrityStoreVectorPtr = std::make_shared>( std::move(integrityStoreVector)); auto syncedMetadataStoreVectorPtr = std::make_shared>( std::move(syncedMetadataStoreVector)); auto auxUserStoreVectorPtr = std::make_shared>( std::move(auxUserStoreVector)); this->jsInvoker_->invokeAsync([&innerRt, draftsVectorPtr, messagesVectorPtr, threadsVectorPtr, messageStoreThreadsVectorPtr, reportStoreVectorPtr, userStoreVectorPtr, keyserveStoreVectorPtr, communityStoreVectorPtr, integrityStoreVectorPtr, syncedMetadataStoreVectorPtr, auxUserStoreVectorPtr, error, promise, draftStore = this->draftStore, threadStore = this->threadStore, messageStore = this->messageStore, reportStore = this->reportStore, userStore = this->userStore, keyserverStore = this->keyserverStore, communityStore = this->communityStore, integrityStore = this->integrityStore, syncedMetadataStore = this->syncedMetadataStore, auxUserStore = this->auxUserStore]() { if (error.size()) { promise->reject(error); return; } jsi::Array jsiDrafts = draftStore.parseDBDataStore(innerRt, draftsVectorPtr); jsi::Array jsiMessages = messageStore.parseDBDataStore(innerRt, messagesVectorPtr); jsi::Array jsiThreads = threadStore.parseDBDataStore(innerRt, threadsVectorPtr); jsi::Array jsiMessageStoreThreads = messageStore.parseDBMessageStoreThreads( innerRt, messageStoreThreadsVectorPtr); jsi::Array jsiReportStore = reportStore.parseDBDataStore(innerRt, reportStoreVectorPtr); jsi::Array jsiUserStore = userStore.parseDBDataStore(innerRt, userStoreVectorPtr); jsi::Array jsiKeyserverStore = keyserverStore.parseDBDataStore( innerRt, keyserveStoreVectorPtr); jsi::Array jsiCommunityStore = communityStore.parseDBDataStore( innerRt, communityStoreVectorPtr); jsi::Array jsiIntegrityStore = integrityStore.parseDBDataStore( innerRt, integrityStoreVectorPtr); jsi::Array jsiSyncedMetadataStore = syncedMetadataStore.parseDBDataStore( innerRt, syncedMetadataStoreVectorPtr); jsi::Array jsiAuxUserStore = auxUserStore.parseDBDataStore(innerRt, auxUserStoreVectorPtr); auto jsiClientDBStore = jsi::Object(innerRt); jsiClientDBStore.setProperty(innerRt, "messages", jsiMessages); jsiClientDBStore.setProperty(innerRt, "threads", jsiThreads); jsiClientDBStore.setProperty(innerRt, "drafts", jsiDrafts); jsiClientDBStore.setProperty( innerRt, "messageStoreThreads", jsiMessageStoreThreads); jsiClientDBStore.setProperty(innerRt, "reports", jsiReportStore); jsiClientDBStore.setProperty(innerRt, "users", jsiUserStore); jsiClientDBStore.setProperty( innerRt, "keyservers", jsiKeyserverStore); jsiClientDBStore.setProperty( innerRt, "communities", jsiCommunityStore); jsiClientDBStore.setProperty( innerRt, "integrityThreadHashes", jsiIntegrityStore); jsiClientDBStore.setProperty( innerRt, "syncedMetadata", jsiSyncedMetadataStore); jsiClientDBStore.setProperty( innerRt, "auxUserInfos", jsiAuxUserStore); promise->resolve(std::move(jsiClientDBStore)); }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::removeAllDrafts(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=]() { std::string error; try { DatabaseManager::getQueryExecutor().removeAllDrafts(); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=]() { if (error.size()) { promise->reject(error); return; } promise->resolve(jsi::Value::undefined()); }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Array CommCoreModule::getAllMessagesSync(jsi::Runtime &rt) { auto messagesVector = NativeModuleUtils::runSyncOrThrowJSError< std::vector>>>(rt, []() { return DatabaseManager::getQueryExecutor().getAllMessages(); }); auto messagesVectorPtr = std::make_shared>>>( std::move(messagesVector)); jsi::Array jsiMessages = this->messageStore.parseDBDataStore(rt, messagesVectorPtr); return jsiMessages; } jsi::Value CommCoreModule::processDraftStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->draftStore.processStoreOperations(rt, std::move(operations)); } jsi::Value CommCoreModule::processMessageStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->messageStore.processStoreOperations(rt, std::move(operations)); } void CommCoreModule::processMessageStoreOperationsSync( jsi::Runtime &rt, jsi::Array operations) { return this->messageStore.processStoreOperationsSync( rt, std::move(operations)); } jsi::Array CommCoreModule::getAllThreadsSync(jsi::Runtime &rt) { auto threadsVector = NativeModuleUtils::runSyncOrThrowJSError>(rt, []() { return DatabaseManager::getQueryExecutor().getAllThreads(); }); auto threadsVectorPtr = std::make_shared>(std::move(threadsVector)); jsi::Array jsiThreads = this->threadStore.parseDBDataStore(rt, threadsVectorPtr); return jsiThreads; } jsi::Value CommCoreModule::processThreadStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->threadStore.processStoreOperations(rt, std::move(operations)); } void CommCoreModule::processThreadStoreOperationsSync( jsi::Runtime &rt, jsi::Array operations) { this->threadStore.processStoreOperationsSync(rt, std::move(operations)); } jsi::Value CommCoreModule::processReportStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->reportStore.processStoreOperations(rt, std::move(operations)); } void CommCoreModule::processReportStoreOperationsSync( jsi::Runtime &rt, jsi::Array operations) { this->reportStore.processStoreOperationsSync(rt, std::move(operations)); } jsi::Value CommCoreModule::processUserStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->userStore.processStoreOperations(rt, std::move(operations)); } jsi::Value CommCoreModule::processKeyserverStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->keyserverStore.processStoreOperations(rt, std::move(operations)); } jsi::Value CommCoreModule::processCommunityStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->communityStore.processStoreOperations(rt, std::move(operations)); } jsi::Value CommCoreModule::processIntegrityStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->integrityStore.processStoreOperations(rt, std::move(operations)); } jsi::Value CommCoreModule::processSyncedMetadataStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->syncedMetadataStore.processStoreOperations( rt, std::move(operations)); } jsi::Value CommCoreModule::processAuxUserStoreOperations( jsi::Runtime &rt, jsi::Array operations) { return this->auxUserStore.processStoreOperations(rt, std::move(operations)); } void CommCoreModule::terminate(jsi::Runtime &rt) { TerminateApp::terminate(); } void CommCoreModule::persistCryptoModules( bool persistContentModule, bool persistNotifsModule) { folly::Optional storedSecretKey = CommSecureStore::get(this->secureStoreAccountDataKey); if (!storedSecretKey.hasValue()) { storedSecretKey = crypto::Tools::generateRandomString(64); CommSecureStore::set( this->secureStoreAccountDataKey, storedSecretKey.value()); } if (!persistContentModule && !persistNotifsModule) { return; } crypto::Persist newContentPersist; if (persistContentModule) { newContentPersist = this->contentCryptoModule->storeAsB64(storedSecretKey.value()); } crypto::Persist newNotifsPersist; if (persistNotifsModule) { newNotifsPersist = this->notifsCryptoModule->storeAsB64(storedSecretKey.value()); } std::promise persistencePromise; std::future persistenceFuture = persistencePromise.get_future(); GlobalDBSingleton::instance.scheduleOrRunCancellable( [=, &persistencePromise]() { try { if (persistContentModule) { DatabaseManager::getQueryExecutor().storeOlmPersistData( DatabaseManager::getQueryExecutor().getContentAccountID(), newContentPersist); } if (persistNotifsModule) { DatabaseManager::getQueryExecutor().storeOlmPersistData( DatabaseManager::getQueryExecutor().getNotifsAccountID(), newNotifsPersist); } persistencePromise.set_value(); } catch (std::system_error &e) { persistencePromise.set_exception(std::make_exception_ptr(e)); } }); persistenceFuture.get(); } jsi::Value CommCoreModule::initializeCryptoAccount(jsi::Runtime &rt) { folly::Optional storedSecretKey = CommSecureStore::get(this->secureStoreAccountDataKey); if (!storedSecretKey.hasValue()) { storedSecretKey = crypto::Tools::generateRandomString(64); CommSecureStore::set( this->secureStoreAccountDataKey, storedSecretKey.value()); } return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=]() { crypto::Persist contentPersist; crypto::Persist notifsPersist; std::string error; try { std::optional contentAccountData = DatabaseManager::getQueryExecutor().getOlmPersistAccountData( DatabaseManager::getQueryExecutor().getContentAccountID()); if (contentAccountData.has_value()) { contentPersist.account = crypto::OlmBuffer( contentAccountData->begin(), contentAccountData->end()); // handle sessions data std::vector sessionsData = DatabaseManager::getQueryExecutor() .getOlmPersistSessionsData(); for (OlmPersistSession &sessionsDataItem : sessionsData) { crypto::OlmBuffer sessionDataBuffer( sessionsDataItem.session_data.begin(), sessionsDataItem.session_data.end()); + crypto::SessionPersist sessionPersist{ + sessionDataBuffer, sessionsDataItem.version}; contentPersist.sessions.insert(std::make_pair( - sessionsDataItem.target_device_id, sessionDataBuffer)); + sessionsDataItem.target_device_id, sessionPersist)); } } std::optional notifsAccountData = DatabaseManager::getQueryExecutor().getOlmPersistAccountData( DatabaseManager::getQueryExecutor().getNotifsAccountID()); if (notifsAccountData.has_value()) { notifsPersist.account = crypto::OlmBuffer( notifsAccountData->begin(), notifsAccountData->end()); } } catch (std::system_error &e) { error = e.what(); } this->cryptoThread->scheduleTask([=]() { std::string error; this->contentCryptoModule.reset(new crypto::CryptoModule( this->publicCryptoAccountID, storedSecretKey.value(), contentPersist)); this->notifsCryptoModule.reset(new crypto::CryptoModule( this->notifsCryptoAccountID, storedSecretKey.value(), notifsPersist)); try { this->persistCryptoModules( contentPersist.isEmpty(), notifsPersist.isEmpty()); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=]() { if (error.size()) { promise->reject(error); return; } }); this->jsInvoker_->invokeAsync( [=]() { promise->resolve(jsi::Value::undefined()); }); }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::getUserPublicKey(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string primaryKeysResult; std::string notificationsKeysResult; if (this->contentCryptoModule == nullptr || this->notifsCryptoModule == nullptr) { error = "user has not been initialized"; } else { primaryKeysResult = this->contentCryptoModule->getIdentityKeys(); notificationsKeysResult = this->notifsCryptoModule->getIdentityKeys(); } std::string notificationsCurve25519Cpp, notificationsEd25519Cpp, blobPayloadCpp, signatureCpp, primaryCurve25519Cpp, primaryEd25519Cpp; if (!error.size()) { folly::dynamic parsedPrimary; try { parsedPrimary = folly::parseJson(primaryKeysResult); } catch (const folly::json::parse_error &e) { error = "parsing identity keys failed with: " + std::string(e.what()); } if (!error.size()) { primaryCurve25519Cpp = parsedPrimary["curve25519"].asString(); primaryEd25519Cpp = parsedPrimary["ed25519"].asString(); folly::dynamic parsedNotifications; try { parsedNotifications = folly::parseJson(notificationsKeysResult); } catch (const folly::json::parse_error &e) { error = "parsing notifications keys failed with: " + std::string(e.what()); } if (!error.size()) { notificationsCurve25519Cpp = parsedNotifications["curve25519"].asString(); notificationsEd25519Cpp = parsedNotifications["ed25519"].asString(); folly::dynamic blobPayloadJSON = folly::dynamic::object( "primaryIdentityPublicKeys", folly::dynamic::object("ed25519", primaryEd25519Cpp)( "curve25519", primaryCurve25519Cpp))( "notificationIdentityPublicKeys", folly::dynamic::object("ed25519", notificationsEd25519Cpp)( "curve25519", notificationsCurve25519Cpp)); blobPayloadCpp = folly::toJson(blobPayloadJSON); signatureCpp = this->contentCryptoModule->signMessage(blobPayloadCpp); } } } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto primaryCurve25519{ jsi::String::createFromUtf8(innerRt, primaryCurve25519Cpp)}; auto primaryEd25519{ jsi::String::createFromUtf8(innerRt, primaryEd25519Cpp)}; auto jsiPrimaryIdentityPublicKeys = jsi::Object(innerRt); jsiPrimaryIdentityPublicKeys.setProperty( innerRt, "ed25519", primaryEd25519); jsiPrimaryIdentityPublicKeys.setProperty( innerRt, "curve25519", primaryCurve25519); auto notificationsCurve25519{jsi::String::createFromUtf8( innerRt, notificationsCurve25519Cpp)}; auto notificationsEd25519{ jsi::String::createFromUtf8(innerRt, notificationsEd25519Cpp)}; auto jsiNotificationIdentityPublicKeys = jsi::Object(innerRt); jsiNotificationIdentityPublicKeys.setProperty( innerRt, "ed25519", notificationsEd25519); jsiNotificationIdentityPublicKeys.setProperty( innerRt, "curve25519", notificationsCurve25519); auto blobPayload{ jsi::String::createFromUtf8(innerRt, blobPayloadCpp)}; auto signature{jsi::String::createFromUtf8(innerRt, signatureCpp)}; auto jsiClientPublicKeys = jsi::Object(innerRt); jsiClientPublicKeys.setProperty( innerRt, "primaryIdentityPublicKeys", jsiPrimaryIdentityPublicKeys); jsiClientPublicKeys.setProperty( innerRt, "notificationIdentityPublicKeys", jsiNotificationIdentityPublicKeys); jsiClientPublicKeys.setProperty( innerRt, "blobPayload", blobPayload); jsiClientPublicKeys.setProperty(innerRt, "signature", signature); promise->resolve(std::move(jsiClientPublicKeys)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Object parseOLMOneTimeKeys(jsi::Runtime &rt, std::string oneTimeKeysBlob) { folly::dynamic parsedOneTimeKeys = folly::parseJson(oneTimeKeysBlob); auto jsiOneTimeKeysInner = jsi::Object(rt); for (auto &kvPair : parsedOneTimeKeys["curve25519"].items()) { jsiOneTimeKeysInner.setProperty( rt, kvPair.first.asString().c_str(), jsi::String::createFromUtf8(rt, kvPair.second.asString())); } auto jsiOneTimeKeys = jsi::Object(rt); jsiOneTimeKeys.setProperty(rt, "curve25519", jsiOneTimeKeysInner); return jsiOneTimeKeys; } std::string parseOLMPrekey(std::string prekeyBlob) { folly::dynamic parsedPrekey; try { parsedPrekey = folly::parseJson(prekeyBlob); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "parsing prekey failed with: " + std::string(e.what())); } folly::dynamic innerObject = parsedPrekey["curve25519"]; if (!innerObject.isObject()) { throw std::runtime_error("parsing prekey failed: inner object malformed"); } if (innerObject.values().begin() == innerObject.values().end()) { throw std::runtime_error("parsing prekey failed: prekey missing"); } return parsedPrekey["curve25519"].values().begin()->asString(); } jsi::Object parseOneTimeKeysResult( jsi::Runtime &rt, std::string contentOneTimeKeysBlob, std::string notifOneTimeKeysBlob) { auto contentOneTimeKeys = parseOLMOneTimeKeys(rt, contentOneTimeKeysBlob); auto notifOneTimeKeys = parseOLMOneTimeKeys(rt, notifOneTimeKeysBlob); auto jsiOneTimeKeysResult = jsi::Object(rt); jsiOneTimeKeysResult.setProperty( rt, "contentOneTimeKeys", contentOneTimeKeys); jsiOneTimeKeysResult.setProperty( rt, "notificationsOneTimeKeys", notifOneTimeKeys); return jsiOneTimeKeysResult; } jsi::Value CommCoreModule::getOneTimeKeys(jsi::Runtime &rt, double oneTimeKeysAmount) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string contentResult; std::string notifResult; if (this->contentCryptoModule == nullptr || this->notifsCryptoModule == nullptr) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } try { contentResult = this->contentCryptoModule->getOneTimeKeysForPublishing( oneTimeKeysAmount); notifResult = this->notifsCryptoModule->getOneTimeKeysForPublishing( oneTimeKeysAmount); this->persistCryptoModules(true, true); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve( parseOneTimeKeysResult(innerRt, contentResult, notifResult)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::validateAndUploadPrekeys( jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken) { auto authUserIDRust = jsiStringToRustString(authUserID, rt); auto authDeviceIDRust = jsiStringToRustString(authDeviceID, rt); auto authAccessTokenRust = jsiStringToRustString(authAccessToken, rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::optional maybeContentPrekeyToUpload; std::optional maybeNotifsPrekeyToUpload; if (this->contentCryptoModule == nullptr || this->notifsCryptoModule == nullptr) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } try { maybeContentPrekeyToUpload = this->contentCryptoModule->validatePrekey(); maybeNotifsPrekeyToUpload = this->notifsCryptoModule->validatePrekey(); this->persistCryptoModules(true, true); if (!maybeContentPrekeyToUpload.has_value()) { maybeContentPrekeyToUpload = this->contentCryptoModule->getUnpublishedPrekey(); } if (!maybeNotifsPrekeyToUpload.has_value()) { maybeNotifsPrekeyToUpload = this->notifsCryptoModule->getUnpublishedPrekey(); } } catch (const std::exception &e) { error = e.what(); } if (error.size()) { this->jsInvoker_->invokeAsync( [=, &innerRt]() { promise->reject(error); }); return; } if (!maybeContentPrekeyToUpload.has_value() && !maybeNotifsPrekeyToUpload.has_value()) { this->jsInvoker_->invokeAsync( [=]() { promise->resolve(jsi::Value::undefined()); }); return; } std::string contentPrekeyToUpload; if (maybeContentPrekeyToUpload.has_value()) { contentPrekeyToUpload = maybeContentPrekeyToUpload.value(); } else { contentPrekeyToUpload = this->contentCryptoModule->getPrekey(); } std::string notifsPrekeyToUpload; if (maybeNotifsPrekeyToUpload.has_value()) { notifsPrekeyToUpload = maybeNotifsPrekeyToUpload.value(); } else { notifsPrekeyToUpload = this->notifsCryptoModule->getPrekey(); } std::string prekeyUploadError; try { std::string contentPrekeySignature = this->contentCryptoModule->getPrekeySignature(); std::string notifsPrekeySignature = this->notifsCryptoModule->getPrekeySignature(); try { std::promise prekeyPromise; std::future prekeyFuture = prekeyPromise.get_future(); RustPromiseManager::CPPPromiseInfo promiseInfo = { std::move(prekeyPromise)}; auto currentID = RustPromiseManager::instance.addPromise( std::move(promiseInfo)); auto contentPrekeyToUploadRust = rust::String(parseOLMPrekey(contentPrekeyToUpload)); auto prekeySignatureRust = rust::string(contentPrekeySignature); auto notifsPrekeyToUploadRust = rust::String(parseOLMPrekey(notifsPrekeyToUpload)); auto notificationsPrekeySignatureRust = rust::string(notifsPrekeySignature); ::identityRefreshUserPrekeys( authUserIDRust, authDeviceIDRust, authAccessTokenRust, contentPrekeyToUploadRust, prekeySignatureRust, notifsPrekeyToUploadRust, notificationsPrekeySignatureRust, currentID); prekeyFuture.get(); } catch (const std::exception &e) { prekeyUploadError = e.what(); } if (!prekeyUploadError.size()) { this->contentCryptoModule->markPrekeyAsPublished(); this->notifsCryptoModule->markPrekeyAsPublished(); this->persistCryptoModules(true, true); } } catch (std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=]() { if (error.size()) { promise->reject(error); return; } if (prekeyUploadError.size()) { promise->reject(prekeyUploadError); return; } promise->resolve(jsi::Value::undefined()); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::validateAndGetPrekeys(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string contentPrekey, notifPrekey, contentPrekeySignature, notifPrekeySignature; std::optional contentPrekeyBlob; std::optional notifPrekeyBlob; if (this->contentCryptoModule == nullptr || this->notifsCryptoModule == nullptr) { this->jsInvoker_->invokeAsync([=, &innerRt]() { promise->reject("user has not been initialized"); }); return; } try { contentPrekeyBlob = this->contentCryptoModule->validatePrekey(); if (!contentPrekeyBlob) { contentPrekeyBlob = this->contentCryptoModule->getUnpublishedPrekey(); } if (!contentPrekeyBlob) { contentPrekeyBlob = this->contentCryptoModule->getPrekey(); } notifPrekeyBlob = this->notifsCryptoModule->validatePrekey(); if (!notifPrekeyBlob) { notifPrekeyBlob = this->notifsCryptoModule->getUnpublishedPrekey(); } if (!notifPrekeyBlob) { notifPrekeyBlob = this->notifsCryptoModule->getPrekey(); } this->persistCryptoModules(true, true); contentPrekeySignature = this->contentCryptoModule->getPrekeySignature(); notifPrekeySignature = this->notifsCryptoModule->getPrekeySignature(); contentPrekey = parseOLMPrekey(contentPrekeyBlob.value()); notifPrekey = parseOLMPrekey(notifPrekeyBlob.value()); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto contentPrekeyJSI = jsi::String::createFromUtf8(innerRt, contentPrekey); auto contentPrekeySignatureJSI = jsi::String::createFromUtf8(innerRt, contentPrekeySignature); auto notifPrekeyJSI = jsi::String::createFromUtf8(innerRt, notifPrekey); auto notifPrekeySignatureJSI = jsi::String::createFromUtf8(innerRt, notifPrekeySignature); auto signedPrekeysJSI = jsi::Object(innerRt); signedPrekeysJSI.setProperty( innerRt, "contentPrekey", contentPrekeyJSI); signedPrekeysJSI.setProperty( innerRt, "contentPrekeySignature", contentPrekeySignatureJSI); signedPrekeysJSI.setProperty( innerRt, "notifPrekey", notifPrekeyJSI); signedPrekeysJSI.setProperty( innerRt, "notifPrekeySignature", notifPrekeySignatureJSI); promise->resolve(std::move(signedPrekeysJSI)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::initializeNotificationsSession( jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, jsi::String oneTimeKey, jsi::String keyserverID) { auto identityKeysCpp{identityKeys.utf8(rt)}; auto prekeyCpp{prekey.utf8(rt)}; auto prekeySignatureCpp{prekeySignature.utf8(rt)}; auto oneTimeKeyCpp{oneTimeKey.utf8(rt)}; auto keyserverIDCpp{keyserverID.utf8(rt)}; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData result; try { this->notifsCryptoModule->initializeOutboundForSendingSession( keyserverIDCpp, std::vector( identityKeysCpp.begin(), identityKeysCpp.end()), std::vector(prekeyCpp.begin(), prekeyCpp.end()), std::vector( prekeySignatureCpp.begin(), prekeySignatureCpp.end()), std::vector( oneTimeKeyCpp.begin(), oneTimeKeyCpp.end())); result = this->notifsCryptoModule->encrypt( keyserverIDCpp, NotificationsCryptoModule::initialEncryptedMessageContent); std::shared_ptr keyserverNotificationsSession = this->notifsCryptoModule->getSessionByDeviceId(keyserverIDCpp); NotificationsCryptoModule::persistNotificationsSession( keyserverIDCpp, keyserverNotificationsSession); this->notifsCryptoModule->removeSessionByDeviceId(keyserverIDCpp); this->persistCryptoModules(false, true); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve(jsi::String::createFromUtf8( innerRt, std::string{result.message.begin(), result.message.end()})); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::isNotificationsSessionInitialized(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; bool result; try { result = NotificationsCryptoModule::isNotificationsSessionInitialized( "Comm"); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve(result); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::updateKeyserverDataInNotifStorage( jsi::Runtime &rt, jsi::Array keyserversData) { std::vector> keyserversDataCpp; for (auto idx = 0; idx < keyserversData.size(rt); idx++) { auto data = keyserversData.getValueAtIndex(rt, idx).asObject(rt); std::string keyserverID = data.getProperty(rt, "id").asString(rt).utf8(rt); std::string keyserverUnreadCountKey = "KEYSERVER." + keyserverID + ".UNREAD_COUNT"; int unreadCount = data.getProperty(rt, "unreadCount").asNumber(); keyserversDataCpp.push_back({keyserverUnreadCountKey, unreadCount}); } return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { std::string error; try { for (const auto &keyserverData : keyserversDataCpp) { CommMMKV::setInt(keyserverData.first, keyserverData.second); } } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve(jsi::Value::undefined()); }); }); } jsi::Value CommCoreModule::removeKeyserverDataFromNotifStorage( jsi::Runtime &rt, jsi::Array keyserverIDsToDelete) { std::vector keyserverIDsToDeleteCpp{}; for (auto idx = 0; idx < keyserverIDsToDelete.size(rt); idx++) { std::string keyserverID = keyserverIDsToDelete.getValueAtIndex(rt, idx).asString(rt).utf8(rt); std::string keyserverUnreadCountKey = "KEYSERVER." + keyserverID + ".UNREAD_COUNT"; keyserverIDsToDeleteCpp.push_back(keyserverUnreadCountKey); } return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { std::string error; try { CommMMKV::removeKeys(keyserverIDsToDeleteCpp); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve(jsi::Value::undefined()); }); }); } jsi::Value CommCoreModule::getKeyserverDataFromNotifStorage( jsi::Runtime &rt, jsi::Array keyserverIDs) { std::vector keyserverIDsCpp{}; for (auto idx = 0; idx < keyserverIDs.size(rt); idx++) { std::string keyserverID = keyserverIDs.getValueAtIndex(rt, idx).asString(rt).utf8(rt); keyserverIDsCpp.push_back(keyserverID); } return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { std::string error; std::vector> keyserversDataVector{}; try { for (const auto &keyserverID : keyserverIDsCpp) { std::string keyserverUnreadCountKey = "KEYSERVER." + keyserverID + ".UNREAD_COUNT"; std::optional unreadCount = CommMMKV::getInt(keyserverUnreadCountKey, -1); if (!unreadCount.has_value()) { continue; } keyserversDataVector.push_back({keyserverID, unreadCount.value()}); } } catch (const std::exception &e) { error = e.what(); } auto keyserversDataVectorPtr = std::make_shared>>( std::move(keyserversDataVector)); this->jsInvoker_->invokeAsync( [&innerRt, keyserversDataVectorPtr, error, promise]() { if (error.size()) { promise->reject(error); return; } size_t numKeyserversData = keyserversDataVectorPtr->size(); jsi::Array jsiKeyserversData = jsi::Array(innerRt, numKeyserversData); size_t writeIdx = 0; for (const auto &keyserverData : *keyserversDataVectorPtr) { jsi::Object jsiKeyserverData = jsi::Object(innerRt); jsiKeyserverData.setProperty( innerRt, "id", keyserverData.first); jsiKeyserverData.setProperty( innerRt, "unreadCount", keyserverData.second); jsiKeyserversData.setValueAtIndex( innerRt, writeIdx++, jsiKeyserverData); } promise->resolve(std::move(jsiKeyserversData)); }); }); } jsi::Value CommCoreModule::initializeContentOutboundSession( jsi::Runtime &rt, jsi::String identityKeys, jsi::String prekey, jsi::String prekeySignature, jsi::String oneTimeKey, jsi::String deviceID) { auto identityKeysCpp{identityKeys.utf8(rt)}; auto prekeyCpp{prekey.utf8(rt)}; auto prekeySignatureCpp{prekeySignature.utf8(rt)}; auto oneTimeKeyCpp{oneTimeKey.utf8(rt)}; auto deviceIDCpp{deviceID.utf8(rt)}; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData initialEncryptedData; try { this->contentCryptoModule->initializeOutboundForSendingSession( deviceIDCpp, std::vector( identityKeysCpp.begin(), identityKeysCpp.end()), std::vector(prekeyCpp.begin(), prekeyCpp.end()), std::vector( prekeySignatureCpp.begin(), prekeySignatureCpp.end()), std::vector( oneTimeKeyCpp.begin(), oneTimeKeyCpp.end())); const std::string initMessage = "{\"type\": \"init\"}"; initialEncryptedData = contentCryptoModule->encrypt(deviceIDCpp, initMessage); this->persistCryptoModules(true, false); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto initialEncryptedDataJSI = jsi::Object(innerRt); auto message = std::string{ initialEncryptedData.message.begin(), initialEncryptedData.message.end()}; auto messageJSI = jsi::String::createFromUtf8(innerRt, message); initialEncryptedDataJSI.setProperty(innerRt, "message", messageJSI); initialEncryptedDataJSI.setProperty( innerRt, "messageType", static_cast(initialEncryptedData.messageType)); promise->resolve(std::move(initialEncryptedDataJSI)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::initializeContentInboundSession( jsi::Runtime &rt, jsi::String identityKeys, jsi::Object encryptedDataJSI, jsi::String deviceID) { auto identityKeysCpp{identityKeys.utf8(rt)}; size_t messageType = std::lround(encryptedDataJSI.getProperty(rt, "messageType").asNumber()); std::string encryptedMessageCpp = encryptedDataJSI.getProperty(rt, "message").asString(rt).utf8(rt); auto deviceIDCpp{deviceID.utf8(rt)}; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string decryptedMessage; try { this->contentCryptoModule->initializeInboundForReceivingSession( deviceIDCpp, std::vector( encryptedMessageCpp.begin(), encryptedMessageCpp.end()), std::vector( identityKeysCpp.begin(), identityKeysCpp.end())); crypto::EncryptedData encryptedData{ std::vector( encryptedMessageCpp.begin(), encryptedMessageCpp.end()), messageType}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); this->persistCryptoModules(true, false); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve( jsi::String::createFromUtf8(innerRt, decryptedMessage)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::encrypt( jsi::Runtime &rt, jsi::String message, jsi::String deviceID) { auto messageCpp{message.utf8(rt)}; auto deviceIDCpp{deviceID.utf8(rt)}; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; crypto::EncryptedData encryptedMessage; try { encryptedMessage = contentCryptoModule->encrypt(deviceIDCpp, messageCpp); this->persistCryptoModules(true, false); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto encryptedDataJSI = jsi::Object(innerRt); auto message = std::string{ encryptedMessage.message.begin(), encryptedMessage.message.end()}; auto messageJSI = jsi::String::createFromUtf8(innerRt, message); encryptedDataJSI.setProperty(innerRt, "message", messageJSI); encryptedDataJSI.setProperty( innerRt, "messageType", static_cast(encryptedMessage.messageType)); promise->resolve(std::move(encryptedDataJSI)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::decrypt( jsi::Runtime &rt, jsi::Object encryptedDataJSI, jsi::String deviceID) { size_t messageType = std::lround(encryptedDataJSI.getProperty(rt, "messageType").asNumber()); std::string message = encryptedDataJSI.getProperty(rt, "message").asString(rt).utf8(rt); auto deviceIDCpp{deviceID.utf8(rt)}; return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string decryptedMessage; try { crypto::EncryptedData encryptedData{ std::vector(message.begin(), message.end()), messageType}; decryptedMessage = this->contentCryptoModule->decrypt(deviceIDCpp, encryptedData); this->persistCryptoModules(true, false); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } promise->resolve( jsi::String::createFromUtf8(innerRt, decryptedMessage)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::signMessage(jsi::Runtime &rt, jsi::String message) { std::string messageStr = message.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string signature; try { signature = this->contentCryptoModule->signMessage(messageStr); } catch (const std::exception &e) { error = "signing message failed with: " + std::string(e.what()); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto jsiSignature{jsi::String::createFromUtf8(innerRt, signature)}; promise->resolve(std::move(jsiSignature)); }); }; this->cryptoThread->scheduleTask(job); }); } CommCoreModule::CommCoreModule( std::shared_ptr jsInvoker) : facebook::react::CommCoreModuleSchemaCxxSpecJSI(jsInvoker), cryptoThread(std::make_unique("crypto")), draftStore(jsInvoker), threadStore(jsInvoker), messageStore(jsInvoker), reportStore(jsInvoker), userStore(jsInvoker), keyserverStore(jsInvoker), communityStore(jsInvoker), integrityStore(jsInvoker), syncedMetadataStore(jsInvoker), auxUserStore(jsInvoker) { GlobalDBSingleton::instance.enableMultithreading(); } double CommCoreModule::getCodeVersion(jsi::Runtime &rt) { return this->codeVersion; } jsi::Value CommCoreModule::setNotifyToken(jsi::Runtime &rt, jsi::String token) { auto notifyToken{token.utf8(rt)}; return createPromiseAsJSIValue( rt, [this, notifyToken](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, notifyToken, promise]() { std::string error; try { DatabaseManager::getQueryExecutor().setNotifyToken(notifyToken); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::clearNotifyToken(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, promise]() { std::string error; try { DatabaseManager::getQueryExecutor().clearNotifyToken(); } catch (std::system_error &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); }; jsi::Value CommCoreModule::setCurrentUserID(jsi::Runtime &rt, jsi::String userID) { auto currentUserID{userID.utf8(rt)}; return createPromiseAsJSIValue( rt, [this, currentUserID](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, promise, currentUserID]() { std::string error; try { DatabaseManager::getQueryExecutor().setCurrentUserID(currentUserID); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::getCurrentUserID(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, &innerRt, promise]() { std::string error; std::string result; try { result = DatabaseManager::getQueryExecutor().getCurrentUserID(); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([&innerRt, error, result, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::String::createFromUtf8(innerRt, result)); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::clearSensitiveData(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { GlobalDBSingleton::instance.setTasksCancelled(true); taskType job = [this, promise]() { std::string error; try { DatabaseManager::clearSensitiveData(); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); GlobalDBSingleton::instance.scheduleOrRun( []() { GlobalDBSingleton::instance.setTasksCancelled(false); }); }; GlobalDBSingleton::instance.scheduleOrRun(job); }); } bool CommCoreModule::checkIfDatabaseNeedsDeletion(jsi::Runtime &rt) { return DatabaseManager::checkIfDatabaseNeedsDeletion(); } void CommCoreModule::reportDBOperationsFailure(jsi::Runtime &rt) { DatabaseManager::reportDBOperationsFailure(); } jsi::Value CommCoreModule::computeBackupKey( jsi::Runtime &rt, jsi::String password, jsi::String backupID) { std::string passwordStr = password.utf8(rt); std::string backupIDStr = backupID.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::array<::std::uint8_t, 32> backupKey; try { backupKey = compute_backup_key(passwordStr, backupIDStr); } catch (const std::exception &e) { error = std::string{"Failed to compute backup key: "} + e.what(); } this->jsInvoker_->invokeAsync([=, &innerRt]() { if (error.size()) { promise->reject(error); return; } auto size = backupKey.size(); auto arrayBuffer = innerRt.global() .getPropertyAsFunction(innerRt, "ArrayBuffer") .callAsConstructor(innerRt, {static_cast(size)}) .asObject(innerRt) .getArrayBuffer(innerRt); auto bufferPtr = arrayBuffer.data(innerRt); memcpy(bufferPtr, backupKey.data(), size); promise->resolve(std::move(arrayBuffer)); }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::generateRandomString(jsi::Runtime &rt, double size) { return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { std::string error; std::string randomString; try { randomString = crypto::Tools::generateRandomString(static_cast(size)); } catch (const std::exception &e) { error = "Failed to generate random string for size " + std::to_string(size) + ": " + e.what(); } this->jsInvoker_->invokeAsync( [&innerRt, error, randomString, promise]() { if (error.size()) { promise->reject(error); } else { jsi::String jsiRandomString = jsi::String::createFromUtf8(innerRt, randomString); promise->resolve(std::move(jsiRandomString)); } }); }; this->cryptoThread->scheduleTask(job); }); } jsi::Value CommCoreModule::setCommServicesAuthMetadata( jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken) { auto userIDStr{userID.utf8(rt)}; auto deviceIDStr{deviceID.utf8(rt)}; auto accessTokenStr{accessToken.utf8(rt)}; return createPromiseAsJSIValue( rt, [this, userIDStr, deviceIDStr, accessTokenStr]( jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, promise, userIDStr, deviceIDStr, accessTokenStr]() { std::string error; try { CommSecureStore::set(CommSecureStore::userID, userIDStr); CommSecureStore::set(CommSecureStore::deviceID, deviceIDStr); CommSecureStore::set( CommSecureStore::commServicesAccessToken, accessTokenStr); CommServicesAuthMetadataEmitter::sendAuthMetadataToJS( accessTokenStr, userIDStr); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::getCommServicesAuthMetadata(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, &innerRt, promise]() { std::string error; std::string userID; std::string deviceID; std::string accessToken; try { folly::Optional userIDOpt = CommSecureStore::get(CommSecureStore::userID); if (userIDOpt.hasValue()) { userID = userIDOpt.value(); } folly::Optional deviceIDOpt = CommSecureStore::get(CommSecureStore::deviceID); if (deviceIDOpt.hasValue()) { deviceID = deviceIDOpt.value(); } folly::Optional accessTokenOpt = CommSecureStore::get(CommSecureStore::commServicesAccessToken); if (accessTokenOpt.hasValue()) { accessToken = accessTokenOpt.value(); } } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync( [&innerRt, error, userID, deviceID, accessToken, promise]() { if (error.size()) { promise->reject(error); } else { auto authMetadata = jsi::Object(innerRt); if (!userID.empty()) { authMetadata.setProperty( innerRt, "userID", jsi::String::createFromUtf8(innerRt, userID)); } if (!deviceID.empty()) { authMetadata.setProperty( innerRt, "deviceID", jsi::String::createFromUtf8(innerRt, deviceID)); } if (!accessToken.empty()) { authMetadata.setProperty( innerRt, "accessToken", jsi::String::createFromUtf8(innerRt, accessToken)); } promise->resolve(std::move(authMetadata)); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::setCommServicesAccessToken( jsi::Runtime &rt, jsi::String accessToken) { auto accessTokenStr{accessToken.utf8(rt)}; return createPromiseAsJSIValue( rt, [this, accessTokenStr]( jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, promise, accessTokenStr]() { std::string error; try { CommSecureStore::set( CommSecureStore::commServicesAccessToken, accessTokenStr); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } jsi::Value CommCoreModule::clearCommServicesAccessToken(jsi::Runtime &rt) { return createPromiseAsJSIValue( rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [this, promise]() { std::string error; try { CommSecureStore::set(CommSecureStore::commServicesAccessToken, ""); } catch (const std::exception &e) { error = e.what(); } this->jsInvoker_->invokeAsync([error, promise]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker_); }); } void CommCoreModule::startBackupHandler(jsi::Runtime &rt) { try { ::startBackupHandler(); } catch (const std::exception &e) { throw jsi::JSError(rt, e.what()); } } void CommCoreModule::stopBackupHandler(jsi::Runtime &rt) { try { ::stopBackupHandler(); } catch (const std::exception &e) { throw jsi::JSError(rt, e.what()); } } jsi::Value CommCoreModule::createNewBackup(jsi::Runtime &rt, jsi::String backupSecret) { std::string backupSecretStr = backupSecret.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { this->cryptoThread->scheduleTask([=, &innerRt]() { std::string error; std::string backupID; try { backupID = crypto::Tools::generateRandomString(32); } catch (const std::exception &e) { error = "Failed to generate backupID"; } std::string pickleKey; std::string pickledAccount; if (!error.size()) { try { pickleKey = crypto::Tools::generateRandomString(64); crypto::Persist persist = this->contentCryptoModule->storeAsB64(pickleKey); pickledAccount = std::string(persist.account.begin(), persist.account.end()); } catch (const std::exception &e) { error = "Failed to pickle crypto account"; } } if (!error.size()) { auto currentID = RustPromiseManager::instance.addPromise( {promise, this->jsInvoker_, innerRt}); ::createBackup( rust::string(backupID), rust::string(backupSecretStr), rust::string(pickleKey), rust::string(pickledAccount), currentID); } else { this->jsInvoker_->invokeAsync( [=, &innerRt]() { promise->reject(error); }); } }); }); } jsi::Value CommCoreModule::restoreBackup(jsi::Runtime &rt, jsi::String backupSecret) { std::string backupSecretStr = backupSecret.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { auto currentID = RustPromiseManager::instance.addPromise( {promise, this->jsInvoker_, innerRt}); ::restoreBackup(rust::string(backupSecretStr), currentID); }); } jsi::Value CommCoreModule::restoreBackupData( jsi::Runtime &rt, jsi::String backupID, jsi::String backupDataKey, jsi::String backupLogDataKey) { std::string backupIDStr = backupID.utf8(rt); std::string backupDataKeyStr = backupDataKey.utf8(rt); std::string backupLogDataKeyStr = backupLogDataKey.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { auto currentID = RustPromiseManager::instance.addPromise( {promise, this->jsInvoker_, innerRt}); ::restoreBackupData( rust::string(backupIDStr), rust::string(backupDataKeyStr), rust::string(backupLogDataKeyStr), currentID); }); } jsi::Value CommCoreModule::retrieveBackupKeys(jsi::Runtime &rt, jsi::String backupSecret) { std::string backupSecretStr = backupSecret.utf8(rt); return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { auto currentID = RustPromiseManager::instance.addPromise( {promise, this->jsInvoker_, innerRt}); ::retrieveBackupKeys(rust::string(backupSecretStr), currentID); }); } } // namespace comm diff --git a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp index 9ce9007a5..e334a9da0 100644 --- a/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp +++ b/native/cpp/CommonCpp/Notifications/BackgroundDataStorage/NotificationsCryptoModule.cpp @@ -1,403 +1,403 @@ #include "NotificationsCryptoModule.h" #include "../../CryptoTools/Persist.h" #include "../../CryptoTools/Tools.h" #include "../../Tools/CommMMKV.h" #include "../../Tools/CommSecureStore.h" #include "../../Tools/Logger.h" #include "../../Tools/PlatformSpecificTools.h" #include #include #include #include #include #include #include #include namespace comm { const std::string NotificationsCryptoModule::secureStoreNotificationsAccountDataKey = "notificationsCryptoAccountDataKey"; const std::string NotificationsCryptoModule::notificationsCryptoAccountID = "notificationsCryptoAccountDataID"; const std::string NotificationsCryptoModule::keyserverHostedNotificationsID = "keyserverHostedNotificationsID"; const std::string NotificationsCryptoModule::initialEncryptedMessageContent = "{\"type\": \"init\"}"; const int NotificationsCryptoModule::olmEncryptedTypeMessage = 1; // This constant is only used to migrate the existing notifications // session with production keyserver from flat file to MMKV. This // migration will fire when user updates the app. It will also fire // on dev env provided old keyserver set up is used. Developers willing // to use new keyserver set up must log out before installing updated // app version. Do not introduce new usages of this constant in the code!!! const std::string ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage = "256"; const int temporaryFilePathRandomSuffixLength = 32; std::unique_ptr NotificationsCryptoModule::deserializeCryptoModule( const std::string &path, const std::string &picklingKey) { std::ifstream pickledPersistStream(path, std::ifstream::in); if (!pickledPersistStream.good()) { throw std::runtime_error( "Attempt to deserialize non-existing notifications crypto account"); } std::stringstream pickledPersistStringStream; pickledPersistStringStream << pickledPersistStream.rdbuf(); pickledPersistStream.close(); std::string pickledPersist = pickledPersistStringStream.str(); folly::dynamic persistJSON; try { persistJSON = folly::parseJson(pickledPersist); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "Notifications crypto account JSON deserialization failed with " "reason: " + std::string(e.what())); } std::string accountString = persistJSON["account"].asString(); crypto::OlmBuffer account = std::vector(accountString.begin(), accountString.end()); - std::unordered_map sessions; + std::unordered_map sessions; if (persistJSON["sessions"].isNull()) { return std::make_unique( notificationsCryptoAccountID, picklingKey, crypto::Persist({account, sessions})); } for (auto &sessionKeyValuePair : persistJSON["sessions"].items()) { std::string targetUserID = sessionKeyValuePair.first.asString(); std::string sessionData = sessionKeyValuePair.second.asString(); - sessions[targetUserID] = - std::vector(sessionData.begin(), sessionData.end()); + sessions[targetUserID] = { + std::vector(sessionData.begin(), sessionData.end()), 1}; } return std::make_unique( notificationsCryptoAccountID, picklingKey, crypto::Persist({account, sessions})); } void NotificationsCryptoModule::serializeAndFlushCryptoModule( std::unique_ptr cryptoModule, const std::string &path, const std::string &picklingKey) { crypto::Persist persist = cryptoModule->storeAsB64(picklingKey); folly::dynamic sessions = folly::dynamic::object; for (auto &sessionKeyValuePair : persist.sessions) { std::string targetUserID = sessionKeyValuePair.first; - crypto::OlmBuffer sessionData = sessionKeyValuePair.second; + crypto::OlmBuffer sessionData = sessionKeyValuePair.second.buffer; sessions[targetUserID] = std::string(sessionData.begin(), sessionData.end()); } std::string account = std::string(persist.account.begin(), persist.account.end()); folly::dynamic persistJSON = folly::dynamic::object("account", account)("sessions", sessions); std::string pickledPersist = folly::toJson(persistJSON); std::string temporaryFilePathRandomSuffix = crypto::Tools::generateRandomHexString( temporaryFilePathRandomSuffixLength); std::string temporaryPath = path + temporaryFilePathRandomSuffix; mode_t readWritePermissionsMode = 0666; int temporaryFD = open(temporaryPath.c_str(), O_CREAT | O_WRONLY, readWritePermissionsMode); if (temporaryFD == -1) { throw std::runtime_error( "Failed to create temporary file. Unable to atomically update " "notifications crypto account. Details: " + std::string(strerror(errno))); } ssize_t bytesWritten = write(temporaryFD, pickledPersist.c_str(), pickledPersist.length()); if (bytesWritten == -1 || bytesWritten != pickledPersist.length()) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to write all data to temporary file. Unable to atomically " "update notifications crypto account. Details: " + std::string(strerror(errno))); } if (fsync(temporaryFD) == -1) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to synchronize temporary file data with hardware storage. " "Unable to atomically update notifications crypto account. Details: " + std::string(strerror(errno))); }; close(temporaryFD); if (rename(temporaryPath.c_str(), path.c_str()) == -1) { remove(temporaryPath.c_str()); throw std::runtime_error( "Failed to replace temporary file content with notifications crypto " "account. Unable to atomically update notifications crypto account. " "Details: " + std::string(strerror(errno))); } remove(temporaryPath.c_str()); } std::string NotificationsCryptoModule::getKeyserverNotificationsSessionKey( const std::string &keyserverID) { return "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; } std::string NotificationsCryptoModule::serializeNotificationsSession( std::shared_ptr session, std::string picklingKey) { crypto::OlmBuffer pickledSessionBytes = session->storeAsB64(picklingKey); std::string pickledSession = std::string{pickledSessionBytes.begin(), pickledSessionBytes.end()}; folly::dynamic serializedSessionJson = folly::dynamic::object( "session", pickledSession)("picklingKey", picklingKey); return folly::toJson(serializedSessionJson); } std::pair, std::string> NotificationsCryptoModule::deserializeNotificationsSession( const std::string &serializedSession) { folly::dynamic serializedSessionJson; try { serializedSessionJson = folly::parseJson(serializedSession); } catch (const folly::json::parse_error &e) { throw std::runtime_error( "Notifications session deserialization failed with reason: " + std::string(e.what())); } std::string pickledSession = serializedSessionJson["session"].asString(); crypto::OlmBuffer pickledSessionBytes = crypto::OlmBuffer{pickledSession.begin(), pickledSession.end()}; std::string picklingKey = serializedSessionJson["picklingKey"].asString(); std::unique_ptr session = crypto::Session::restoreFromB64(picklingKey, pickledSessionBytes); return {std::move(session), picklingKey}; } void NotificationsCryptoModule::clearSensitiveData() { std::string notificationsCryptoAccountPath = PlatformSpecificTools::getNotificationsCryptoAccountPath(); if (remove(notificationsCryptoAccountPath.c_str()) == -1 && errno != ENOENT) { throw std::runtime_error( "Unable to remove notifications crypto account. Security requirements " "might be violated."); } } void NotificationsCryptoModule::persistNotificationsSessionInternal( const std::string &keyserverID, const std::string &picklingKey, std::shared_ptr session) { std::string serializedSession = NotificationsCryptoModule::serializeNotificationsSession( session, picklingKey); std::string keyserverNotificationsSessionKey = NotificationsCryptoModule::getKeyserverNotificationsSessionKey( keyserverID); bool sessionStored = CommMMKV::setString(keyserverNotificationsSessionKey, serializedSession); if (!sessionStored) { throw std::runtime_error( "Failed to persist to MMKV notifications session for keyserver: " + keyserverID); } } std::optional, std::string>> NotificationsCryptoModule::fetchNotificationsSession( const std::string &keyserverID) { std::string keyserverNotificationsSessionKey = NotificationsCryptoModule::getKeyserverNotificationsSessionKey( keyserverID); std::optional serializedSession; try { serializedSession = CommMMKV::getString(keyserverNotificationsSessionKey); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { serializedSession = std::nullopt; } if (!serializedSession.has_value() && keyserverID != ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage) { throw std::runtime_error( "Missing notifications session for keyserver: " + keyserverID); } else if (!serializedSession.has_value()) { return std::nullopt; } return NotificationsCryptoModule::deserializeNotificationsSession( serializedSession.value()); } void NotificationsCryptoModule::persistNotificationsSession( const std::string &keyserverID, std::shared_ptr keyserverNotificationsSession) { std::string picklingKey = crypto::Tools::generateRandomString(64); NotificationsCryptoModule::persistNotificationsSessionInternal( keyserverID, picklingKey, keyserverNotificationsSession); } bool NotificationsCryptoModule::isNotificationsSessionInitialized( const std::string &keyserverID) { std::string keyserverNotificationsSessionKey = "KEYSERVER." + keyserverID + ".NOTIFS_SESSION"; return CommMMKV::getString(keyserverNotificationsSessionKey).has_value(); } NotificationsCryptoModule::BaseStatefulDecryptResult::BaseStatefulDecryptResult( std::string picklingKey, std::string decryptedData) : picklingKey(picklingKey), decryptedData(decryptedData) { } std::string NotificationsCryptoModule::BaseStatefulDecryptResult::getDecryptedData() { return this->decryptedData; } NotificationsCryptoModule::StatefulDecryptResult::StatefulDecryptResult( std::unique_ptr session, std::string keyserverID, std::string picklingKey, std::string decryptedData) : NotificationsCryptoModule::BaseStatefulDecryptResult:: BaseStatefulDecryptResult(picklingKey, decryptedData), sessionState(std::move(session)), keyserverID(keyserverID) { } void NotificationsCryptoModule::StatefulDecryptResult::flushState() { NotificationsCryptoModule::persistNotificationsSessionInternal( this->keyserverID, this->picklingKey, std::move(this->sessionState)); } NotificationsCryptoModule::LegacyStatefulDecryptResult:: LegacyStatefulDecryptResult( std::unique_ptr cryptoModule, std::string path, std::string picklingKey, std::string decryptedData) : NotificationsCryptoModule::BaseStatefulDecryptResult:: BaseStatefulDecryptResult(picklingKey, decryptedData), path(path), cryptoModule(std::move(cryptoModule)) { } void NotificationsCryptoModule::LegacyStatefulDecryptResult::flushState() { std::shared_ptr legacyNotificationsSession = this->cryptoModule->getSessionByDeviceId(keyserverHostedNotificationsID); NotificationsCryptoModule::serializeAndFlushCryptoModule( std::move(this->cryptoModule), this->path, this->picklingKey); try { NotificationsCryptoModule::persistNotificationsSession( ashoatKeyserverIDUsedOnlyForMigrationFromLegacyNotifStorage, legacyNotificationsSession); } catch (const CommMMKV::InitFromNSEForbiddenError &e) { return; } } std::unique_ptr NotificationsCryptoModule::prepareLegacyDecryptedState( const std::string &data, const size_t messageType) { folly::Optional picklingKey = comm::CommSecureStore::get( NotificationsCryptoModule::secureStoreNotificationsAccountDataKey); if (!picklingKey.hasValue()) { throw std::runtime_error( "Legacy notifications session pickling key missing."); } std::string legacyNotificationsAccountPath = comm::PlatformSpecificTools::getNotificationsCryptoAccountPath(); crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; auto cryptoModule = NotificationsCryptoModule::deserializeCryptoModule( legacyNotificationsAccountPath, picklingKey.value()); std::string decryptedData = cryptoModule->decrypt( NotificationsCryptoModule::keyserverHostedNotificationsID, encryptedData); LegacyStatefulDecryptResult statefulDecryptResult( std::move(cryptoModule), legacyNotificationsAccountPath, picklingKey.value(), decryptedData); return std::make_unique( std::move(statefulDecryptResult)); } std::string NotificationsCryptoModule::decrypt( const std::string &keyserverID, const std::string &data, const size_t messageType) { auto sessionWithPicklingKey = NotificationsCryptoModule::fetchNotificationsSession(keyserverID); if (!sessionWithPicklingKey.has_value()) { auto statefulDecryptResult = NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); statefulDecryptResult->flushState(); return statefulDecryptResult->getDecryptedData(); } std::unique_ptr session = std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; std::string decryptedData = session->decrypt(encryptedData); NotificationsCryptoModule::persistNotificationsSessionInternal( keyserverID, picklingKey, std::move(session)); return decryptedData; } std::unique_ptr NotificationsCryptoModule::statefulDecrypt( const std::string &keyserverID, const std::string &data, const size_t messageType) { auto sessionWithPicklingKey = NotificationsCryptoModule::fetchNotificationsSession(keyserverID); if (!sessionWithPicklingKey.has_value()) { return NotificationsCryptoModule::prepareLegacyDecryptedState( data, messageType); } std::unique_ptr session = std::move(sessionWithPicklingKey.value().first); std::string picklingKey = sessionWithPicklingKey.value().second; crypto::EncryptedData encryptedData{ std::vector(data.begin(), data.end()), messageType}; std::string decryptedData = session->decrypt(encryptedData); StatefulDecryptResult statefulDecryptResult( std::move(session), keyserverID, picklingKey, decryptedData); return std::make_unique( std::move(statefulDecryptResult)); } void NotificationsCryptoModule::flushState( std::unique_ptr baseStatefulDecryptResult) { baseStatefulDecryptResult->flushState(); } } // namespace comm diff --git a/web/shared-worker/_generated/comm_query_executor.wasm b/web/shared-worker/_generated/comm_query_executor.wasm index d652def06..f1d9a1e53 100755 Binary files a/web/shared-worker/_generated/comm_query_executor.wasm and b/web/shared-worker/_generated/comm_query_executor.wasm differ