diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp index d5a06b4a2..ea37a5542 100644 --- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp +++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp @@ -1,1923 +1,1923 @@ #include "SQLiteQueryExecutor.h" #include "Logger.h" #include "entities/EntityQueryHelpers.h" #include "entities/KeyserverInfo.h" #include "entities/Metadata.h" #include "entities/UserInfo.h" #include #include #include #ifndef EMSCRIPTEN #include "CommSecureStore.h" #include "PlatformSpecificTools.h" #include "StaffUtils.h" #endif #define ACCOUNT_ID 1 namespace comm { std::string SQLiteQueryExecutor::sqliteFilePath; std::string SQLiteQueryExecutor::encryptionKey; std::once_flag SQLiteQueryExecutor::initialized; int SQLiteQueryExecutor::sqlcipherEncryptionKeySize = 64; // Should match constant defined in `native_rust_library/src/constants.rs` std::string SQLiteQueryExecutor::secureStoreEncryptionKeyID = "comm.encryptionKey"; int SQLiteQueryExecutor::backupLogsEncryptionKeySize = 32; std::string SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID = "comm.backupLogsEncryptionKey"; std::string SQLiteQueryExecutor::backupLogsEncryptionKey; #ifndef EMSCRIPTEN NativeSQLiteConnectionManager SQLiteQueryExecutor::connectionManager; #else SQLiteConnectionManager SQLiteQueryExecutor::connectionManager; #endif bool create_table(sqlite3 *db, std::string query, std::string tableName) { char *error; sqlite3_exec(db, query.c_str(), nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating '" << tableName << "' table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_drafts_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS drafts (threadID TEXT UNIQUE PRIMARY KEY, " "text TEXT);"; return create_table(db, query, "drafts"); } bool rename_threadID_to_key(sqlite3 *db) { sqlite3_stmt *key_column_stmt; sqlite3_prepare_v2( db, "SELECT name AS col_name FROM pragma_table_xinfo ('drafts') WHERE " "col_name='key';", -1, &key_column_stmt, nullptr); sqlite3_step(key_column_stmt); auto num_bytes = sqlite3_column_bytes(key_column_stmt, 0); sqlite3_finalize(key_column_stmt); if (num_bytes) { return true; } char *error; sqlite3_exec( db, "ALTER TABLE drafts RENAME COLUMN `threadID` TO `key`;", nullptr, nullptr, &error); if (error) { std::ostringstream stringStream; stringStream << "Error occurred renaming threadID column in drafts table " << "to key: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } return true; } bool create_persist_account_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS olm_persist_account(" "id INTEGER UNIQUE PRIMARY KEY NOT NULL, " "account_data TEXT NOT NULL);"; return create_table(db, query, "olm_persist_account"); } bool create_persist_sessions_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS olm_persist_sessions(" "target_user_id TEXT UNIQUE PRIMARY KEY NOT NULL, " "session_data TEXT NOT NULL);"; return create_table(db, query, "olm_persist_sessions"); } bool drop_messages_table(sqlite3 *db) { char *error; sqlite3_exec(db, "DROP TABLE IF EXISTS messages;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error dropping 'messages' table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool recreate_messages_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS messages ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "local_id TEXT, " "thread TEXT NOT NULL, " "user TEXT NOT NULL, " "type INTEGER NOT NULL, " "future_type INTEGER, " "content TEXT, " "time INTEGER NOT NULL);"; return create_table(db, query, "messages"); } bool create_messages_idx_thread_time(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE INDEX IF NOT EXISTS messages_idx_thread_time " "ON messages (thread, time);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating (thread, time) index on messages table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_media_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS media ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "container TEXT NOT NULL, " "thread TEXT NOT NULL, " "uri TEXT NOT NULL, " "type TEXT NOT NULL, " "extras TEXT NOT NULL);"; return create_table(db, query, "media"); } bool create_media_idx_container(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE INDEX IF NOT EXISTS media_idx_container " "ON media (container);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating (container) index on media table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_threads_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS threads ( " "id TEXT UNIQUE PRIMARY KEY NOT NULL, " "type INTEGER NOT NULL, " "name TEXT, " "description TEXT, " "color TEXT NOT NULL, " "creation_time BIGINT NOT NULL, " "parent_thread_id TEXT, " "containing_thread_id TEXT, " "community TEXT, " "members TEXT NOT NULL, " "roles TEXT NOT NULL, " "current_user TEXT NOT NULL, " "source_message_id TEXT, " "replies_count INTEGER NOT NULL);"; return create_table(db, query, "threads"); } bool update_threadID_for_pending_threads_in_drafts(sqlite3 *db) { char *error; sqlite3_exec( db, "UPDATE drafts SET key = " "REPLACE(REPLACE(REPLACE(REPLACE(key, 'type4/', '')," "'type5/', ''),'type6/', ''),'type7/', '')" "WHERE key LIKE 'pending/%'", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error update pending threadIDs on drafts table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool enable_write_ahead_logging_mode(sqlite3 *db) { char *error; sqlite3_exec(db, "PRAGMA journal_mode=wal;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error enabling write-ahead logging mode: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_metadata_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS metadata ( " "name TEXT UNIQUE PRIMARY KEY NOT NULL, " "data TEXT);"; return create_table(db, query, "metadata"); } bool add_not_null_constraint_to_drafts(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS temporary_drafts (" "key TEXT UNIQUE PRIMARY KEY NOT NULL, " "text TEXT NOT NULL);" "INSERT INTO temporary_drafts SELECT * FROM drafts " "WHERE key IS NOT NULL AND text IS NOT NULL;" "DROP TABLE drafts;" "ALTER TABLE temporary_drafts RENAME TO drafts;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding NOT NULL constraint to drafts table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_not_null_constraint_to_metadata(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS temporary_metadata (" "name TEXT UNIQUE PRIMARY KEY NOT NULL, " "data TEXT NOT NULL);" "INSERT INTO temporary_metadata SELECT * FROM metadata " "WHERE data IS NOT NULL;" "DROP TABLE metadata;" "ALTER TABLE temporary_metadata RENAME TO metadata;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding NOT NULL constraint to metadata table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_avatar_column_to_threads_table(sqlite3 *db) { char *error; sqlite3_exec( db, "ALTER TABLE threads ADD COLUMN avatar TEXT;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding avatar column to threads table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool add_pinned_count_column_to_threads(sqlite3 *db) { sqlite3_stmt *pinned_column_stmt; sqlite3_prepare_v2( db, "SELECT name AS col_name FROM pragma_table_xinfo ('threads') WHERE " "col_name='pinned_count';", -1, &pinned_column_stmt, nullptr); sqlite3_step(pinned_column_stmt); auto num_bytes = sqlite3_column_bytes(pinned_column_stmt, 0); sqlite3_finalize(pinned_column_stmt); if (num_bytes) { return true; } char *error; sqlite3_exec( db, "ALTER TABLE threads ADD COLUMN pinned_count INTEGER NOT NULL DEFAULT 0;", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error adding pinned_count column to threads table: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } bool create_message_store_threads_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL," " last_navigated_to BIGINT NOT NULL," " last_pruned BIGINT NOT NULL" ");"; return create_table(db, query, "message_store_threads"); } bool create_reports_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS reports (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " report TEXT NOT NULL" ");"; return create_table(db, query, "reports"); } bool create_persist_storage_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS persist_storage (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " item TEXT NOT NULL" ");"; return create_table(db, query, "persist_storage"); } bool recreate_message_store_threads_table(sqlite3 *db) { char *errMsg = 0; // 1. Create table without `last_navigated_to` or `last_pruned`. std::string create_new_table_query = "CREATE TABLE IF NOT EXISTS temp_message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL" ");"; if (sqlite3_exec(db, create_new_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error creating temp_message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 2. Dump data from existing `message_store_threads` table into temp table. std::string copy_data_query = "INSERT INTO temp_message_store_threads (id, start_reached)" "SELECT id, start_reached FROM message_store_threads;"; if (sqlite3_exec(db, copy_data_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error dumping data from existing message_store_threads to " "temp_message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 3. Drop the existing `message_store_threads` table. std::string drop_old_table_query = "DROP TABLE message_store_threads;"; if (sqlite3_exec(db, drop_old_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error dropping message_store_threads table: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } // 4. Rename the temp table back to `message_store_threads`. std::string rename_table_query = "ALTER TABLE temp_message_store_threads RENAME TO message_store_threads;"; if (sqlite3_exec(db, rename_table_query.c_str(), NULL, NULL, &errMsg) != SQLITE_OK) { Logger::log( "Error renaming temp_message_store_threads to message_store_threads: " + std::string{errMsg}); sqlite3_free(errMsg); return false; } return true; } bool create_users_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " user_info TEXT NOT NULL" ");"; return create_table(db, query, "users"); } bool create_keyservers_table(sqlite3 *db) { std::string query = "CREATE TABLE IF NOT EXISTS keyservers (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");"; return create_table(db, query, "keyservers"); } bool enable_rollback_journal_mode(sqlite3 *db) { char *error; sqlite3_exec(db, "PRAGMA journal_mode=DELETE;", nullptr, nullptr, &error); if (!error) { return true; } std::stringstream error_message; error_message << "Error disabling write-ahead logging mode: " << error; Logger::log(error_message.str()); sqlite3_free(error); return false; } bool create_schema(sqlite3 *db) { char *error; sqlite3_exec( db, "CREATE TABLE IF NOT EXISTS drafts (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " text TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS messages (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " local_id TEXT," " thread TEXT NOT NULL," " user TEXT NOT NULL," " type INTEGER NOT NULL," " future_type INTEGER," " content TEXT," " time INTEGER NOT NULL" ");" "CREATE TABLE IF NOT EXISTS olm_persist_account (" " id INTEGER UNIQUE PRIMARY KEY NOT NULL," " account_data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS olm_persist_sessions (" " target_user_id TEXT UNIQUE PRIMARY KEY NOT NULL," " session_data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS media (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " container TEXT NOT NULL," " thread TEXT NOT NULL," " uri TEXT NOT NULL," " type TEXT NOT NULL," " extras TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " type INTEGER NOT NULL," " name TEXT," " description TEXT," " color TEXT NOT NULL," " creation_time BIGINT NOT NULL," " parent_thread_id TEXT," " containing_thread_id TEXT," " community TEXT," " members TEXT NOT NULL," " roles TEXT NOT NULL," " current_user TEXT NOT NULL," " source_message_id TEXT," " replies_count INTEGER NOT NULL," " avatar TEXT," " pinned_count INTEGER NOT NULL DEFAULT 0" ");" "CREATE TABLE IF NOT EXISTS metadata (" " name TEXT UNIQUE PRIMARY KEY NOT NULL," " data TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS message_store_threads (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " start_reached INTEGER NOT NULL" ");" "CREATE TABLE IF NOT EXISTS reports (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " report TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS persist_storage (" " key TEXT UNIQUE PRIMARY KEY NOT NULL," " item TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS users (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " user_info TEXT NOT NULL" ");" "CREATE TABLE IF NOT EXISTS keyservers (" " id TEXT UNIQUE PRIMARY KEY NOT NULL," " keyserver_info TEXT NOT NULL" ");" "CREATE INDEX IF NOT EXISTS media_idx_container" " ON media (container);" "CREATE INDEX IF NOT EXISTS messages_idx_thread_time" " ON messages (thread, time);", nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream stringStream; stringStream << "Error creating tables: " << error; Logger::log(stringStream.str()); sqlite3_free(error); return false; } void set_encryption_key( sqlite3 *db, const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) { std::string set_encryption_key_query = "PRAGMA key = \"x'" + encryptionKey + "'\";"; char *error_set_key; sqlite3_exec( db, set_encryption_key_query.c_str(), nullptr, nullptr, &error_set_key); if (error_set_key) { std::ostringstream error_message; error_message << "Failed to set encryption key: " << error_set_key; throw std::system_error( ECANCELED, std::generic_category(), error_message.str()); } } int get_database_version(sqlite3 *db) { sqlite3_stmt *user_version_stmt; sqlite3_prepare_v2( db, "PRAGMA user_version;", -1, &user_version_stmt, nullptr); sqlite3_step(user_version_stmt); int current_user_version = sqlite3_column_int(user_version_stmt, 0); sqlite3_finalize(user_version_stmt); return current_user_version; } bool set_database_version(sqlite3 *db, int db_version) { std::stringstream update_version; update_version << "PRAGMA user_version=" << db_version << ";"; auto update_version_str = update_version.str(); char *error; sqlite3_exec(db, update_version_str.c_str(), nullptr, nullptr, &error); if (!error) { return true; } std::ostringstream errorStream; errorStream << "Error setting database version to " << db_version << ": " << error; Logger::log(errorStream.str()); sqlite3_free(error); return false; } // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web void default_on_db_open_callback(sqlite3 *db) { #ifndef EMSCRIPTEN set_encryption_key(db); #endif } // This is a temporary solution. In future we want to keep // a separate table for blob hashes. Tracked on Linear: // https://linear.app/comm/issue/ENG-6261/introduce-blob-hash-table std::string blob_hash_from_blob_service_uri(const std::string &media_uri) { static const std::string blob_service_prefix = "comm-blob-service://"; return media_uri.substr(blob_service_prefix.size()); } bool file_exists(const std::string &file_path) { std::ifstream file(file_path.c_str()); return file.good(); } void attempt_delete_file( const std::string &file_path, const char *error_message) { if (std::remove(file_path.c_str())) { throw std::system_error(errno, std::generic_category(), error_message); } } void attempt_rename_file( const std::string &old_path, const std::string &new_path, const char *error_message) { if (std::rename(old_path.c_str(), new_path.c_str())) { throw std::system_error(errno, std::generic_category(), error_message); } } bool is_database_queryable( sqlite3 *db, bool use_encryption_key, const std::string &path = SQLiteQueryExecutor::sqliteFilePath, const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) { char *err_msg; sqlite3_open(path.c_str(), &db); // According to SQLCipher documentation running some SELECT is the only way to // check for key validity if (use_encryption_key) { set_encryption_key(db, encryptionKey); } sqlite3_exec( db, "SELECT COUNT(*) FROM sqlite_master;", nullptr, nullptr, &err_msg); sqlite3_close(db); return !err_msg; } void validate_encryption() { std::string temp_encrypted_db_path = SQLiteQueryExecutor::sqliteFilePath + "_temp_encrypted"; bool temp_encrypted_exists = file_exists(temp_encrypted_db_path); bool default_location_exists = file_exists(SQLiteQueryExecutor::sqliteFilePath); if (temp_encrypted_exists && default_location_exists) { Logger::log( "Previous encryption attempt failed. Repeating encryption process from " "the beginning."); attempt_delete_file( temp_encrypted_db_path, "Failed to delete corrupted encrypted database."); } else if (temp_encrypted_exists && !default_location_exists) { Logger::log( "Moving temporary encrypted database to default location failed in " "previous encryption attempt. Repeating rename step."); attempt_rename_file( temp_encrypted_db_path, SQLiteQueryExecutor::sqliteFilePath, "Failed to move encrypted database to default location."); return; } else if (!default_location_exists) { Logger::log( "Database not present yet. It will be created encrypted under default " "path."); return; } sqlite3 *db; if (is_database_queryable(db, true)) { Logger::log( "Database exists under default path and it is correctly encrypted."); return; } if (!is_database_queryable(db, false)) { Logger::log( "Database exists but it is encrypted with key that was lost. " "Attempting database deletion. New encrypted one will be created."); attempt_delete_file( SQLiteQueryExecutor::sqliteFilePath.c_str(), "Failed to delete database encrypted with lost key."); return; } else { Logger::log( "Database exists but it is not encrypted. Attempting encryption " "process."); } sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db); std::string createEncryptedCopySQL = "ATTACH DATABASE '" + temp_encrypted_db_path + "' AS encrypted_comm " "KEY \"x'" + SQLiteQueryExecutor::encryptionKey + "'\";" "SELECT sqlcipher_export('encrypted_comm');" "DETACH DATABASE encrypted_comm;"; char *encryption_error; sqlite3_exec( db, createEncryptedCopySQL.c_str(), nullptr, nullptr, &encryption_error); if (encryption_error) { throw std::system_error( ECANCELED, std::generic_category(), "Failed to create encrypted copy of the original database."); } sqlite3_close(db); attempt_delete_file( SQLiteQueryExecutor::sqliteFilePath, "Failed to delete unencrypted database."); attempt_rename_file( temp_encrypted_db_path, SQLiteQueryExecutor::sqliteFilePath, "Failed to move encrypted database to default location."); Logger::log("Encryption completed successfully."); } typedef bool ShouldBeInTransaction; typedef std::function MigrateFunction; typedef std::pair SQLiteMigration; std::vector> migrations{ {{1, {create_drafts_table, true}}, {2, {rename_threadID_to_key, true}}, {4, {create_persist_account_table, true}}, {5, {create_persist_sessions_table, true}}, {15, {create_media_table, true}}, {16, {drop_messages_table, true}}, {17, {recreate_messages_table, true}}, {18, {create_messages_idx_thread_time, true}}, {19, {create_media_idx_container, true}}, {20, {create_threads_table, true}}, {21, {update_threadID_for_pending_threads_in_drafts, true}}, {22, {enable_write_ahead_logging_mode, false}}, {23, {create_metadata_table, true}}, {24, {add_not_null_constraint_to_drafts, true}}, {25, {add_not_null_constraint_to_metadata, true}}, {26, {add_avatar_column_to_threads_table, true}}, {27, {add_pinned_count_column_to_threads, true}}, {28, {create_message_store_threads_table, true}}, {29, {create_reports_table, true}}, {30, {create_persist_storage_table, true}}, {31, {recreate_message_store_threads_table, true}}, {32, {create_users_table, true}}, {33, {create_keyservers_table, true}}, {34, {enable_rollback_journal_mode, false}}}}; enum class MigrationResult { SUCCESS, FAILURE, NOT_APPLIED }; MigrationResult applyMigrationWithTransaction( sqlite3 *db, const MigrateFunction &migrate, int index) { sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto db_version = get_database_version(db); if (index <= db_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::NOT_APPLIED; } auto rc = migrate(db); if (!rc) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } auto database_version_set = set_database_version(db, index); if (!database_version_set) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return MigrationResult::SUCCESS; } MigrationResult applyMigrationWithoutTransaction( sqlite3 *db, const MigrateFunction &migrate, int index) { auto db_version = get_database_version(db); if (index <= db_version) { return MigrationResult::NOT_APPLIED; } auto rc = migrate(db); if (!rc) { return MigrationResult::FAILURE; } sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto inner_db_version = get_database_version(db); if (index <= inner_db_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::NOT_APPLIED; } auto database_version_set = set_database_version(db, index); if (!database_version_set) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return MigrationResult::FAILURE; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return MigrationResult::SUCCESS; } bool set_up_database(sqlite3 *db) { sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); auto db_version = get_database_version(db); auto latest_version = migrations.back().first; if (db_version == latest_version) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return true; } if (db_version != 0 || !create_schema(db) || !set_database_version(db, latest_version)) { sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr); return false; } sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); return true; } void SQLiteQueryExecutor::migrate() { // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web #ifndef EMSCRIPTEN validate_encryption(); #endif sqlite3 *db; sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db); default_on_db_open_callback(db); std::stringstream db_path; db_path << "db path: " << SQLiteQueryExecutor::sqliteFilePath.c_str() << std::endl; Logger::log(db_path.str()); auto db_version = get_database_version(db); std::stringstream version_msg; version_msg << "db version: " << db_version << std::endl; Logger::log(version_msg.str()); if (db_version == 0) { auto db_created = set_up_database(db); if (!db_created) { sqlite3_close(db); Logger::log("Database structure creation error."); throw std::runtime_error("Database structure creation error"); } Logger::log("Database structure created."); sqlite3_close(db); return; } for (const auto &[idx, migration] : migrations) { const auto &[applyMigration, shouldBeInTransaction] = migration; MigrationResult migrationResult; if (shouldBeInTransaction) { migrationResult = applyMigrationWithTransaction(db, applyMigration, idx); } else { migrationResult = applyMigrationWithoutTransaction(db, applyMigration, idx); } if (migrationResult == MigrationResult::NOT_APPLIED) { continue; } std::stringstream migration_msg; if (migrationResult == MigrationResult::FAILURE) { migration_msg << "migration " << idx << " failed." << std::endl; Logger::log(migration_msg.str()); sqlite3_close(db); throw std::runtime_error(migration_msg.str()); } if (migrationResult == MigrationResult::SUCCESS) { migration_msg << "migration " << idx << " succeeded." << std::endl; Logger::log(migration_msg.str()); } } sqlite3_close(db); } SQLiteQueryExecutor::SQLiteQueryExecutor() { SQLiteQueryExecutor::migrate(); #ifndef EMSCRIPTEN std::string currentBackupID = this->getMetadata("backupID"); if (!StaffUtils::isStaffRelease() || !currentBackupID.size()) { return; } SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true); #endif } SQLiteQueryExecutor::SQLiteQueryExecutor(std::string sqliteFilePath) { SQLiteQueryExecutor::sqliteFilePath = sqliteFilePath; SQLiteQueryExecutor::migrate(); } sqlite3 *SQLiteQueryExecutor::getConnection() { if (SQLiteQueryExecutor::connectionManager.getConnection()) { return SQLiteQueryExecutor::connectionManager.getConnection(); } SQLiteQueryExecutor::connectionManager.initializeConnection( SQLiteQueryExecutor::sqliteFilePath, default_on_db_open_callback); return SQLiteQueryExecutor::connectionManager.getConnection(); } void SQLiteQueryExecutor::closeConnection() { SQLiteQueryExecutor::connectionManager.closeConnection(); } SQLiteQueryExecutor::~SQLiteQueryExecutor() { SQLiteQueryExecutor::closeConnection(); } std::string SQLiteQueryExecutor::getDraft(std::string key) const { static std::string getDraftByPrimaryKeySQL = "SELECT * " "FROM drafts " "WHERE key = ?;"; std::unique_ptr draft = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getDraftByPrimaryKeySQL, key); return (draft == nullptr) ? "" : draft->text; } std::unique_ptr SQLiteQueryExecutor::getThread(std::string threadID) const { static std::string getThreadByPrimaryKeySQL = "SELECT * " "FROM threads " "WHERE id = ?;"; return getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getThreadByPrimaryKeySQL, threadID); } void SQLiteQueryExecutor::updateDraft(std::string key, std::string text) const { static std::string replaceDraftSQL = "REPLACE INTO drafts (key, text) " "VALUES (?, ?);"; Draft draft = {key, text}; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceDraftSQL, draft); } bool SQLiteQueryExecutor::moveDraft(std::string oldKey, std::string newKey) const { std::string draftText = this->getDraft(oldKey); if (!draftText.size()) { return false; } static std::string rekeyDraftSQL = "UPDATE OR REPLACE drafts " "SET key = ? " "WHERE key = ?;"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyDraftSQL, oldKey, newKey); return true; } std::vector SQLiteQueryExecutor::getAllDrafts() const { static std::string getAllDraftsSQL = "SELECT * " "FROM drafts;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllDraftsSQL); } void SQLiteQueryExecutor::removeAllDrafts() const { static std::string removeAllDraftsSQL = "DELETE FROM drafts;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllDraftsSQL); } void SQLiteQueryExecutor::removeDrafts( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeDraftsByKeysSQLStream; removeDraftsByKeysSQLStream << "DELETE FROM drafts " "WHERE key IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeDraftsByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::removeAllMessages() const { static std::string removeAllMessagesSQL = "DELETE FROM messages;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMessagesSQL); } std::vector>> SQLiteQueryExecutor::getAllMessages() const { static std::string getAllMessagesSQL = "SELECT * " "FROM messages " "LEFT JOIN media " " ON messages.id = media.container " "ORDER BY messages.id;"; SQLiteStatementWrapper preparedSQL( SQLiteQueryExecutor::getConnection(), getAllMessagesSQL, "Failed to retrieve all messages."); std::string prevMsgIdx{}; std::vector>> allMessages; for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW; stepResult = sqlite3_step(preparedSQL)) { Message message = Message::fromSQLResult(preparedSQL, 0); if (message.id == prevMsgIdx) { allMessages.back().second.push_back(Media::fromSQLResult(preparedSQL, 8)); } else { prevMsgIdx = message.id; std::vector mediaForMsg; if (sqlite3_column_type(preparedSQL, 8) != SQLITE_NULL) { mediaForMsg.push_back(Media::fromSQLResult(preparedSQL, 8)); } allMessages.push_back(std::make_pair(std::move(message), mediaForMsg)); } } return allMessages; } void SQLiteQueryExecutor::removeMessages( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeMessagesByKeysSQLStream; removeMessagesByKeysSQLStream << "DELETE FROM messages " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessagesByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::removeMessagesForThreads( const std::vector &threadIDs) const { if (!threadIDs.size()) { return; } std::stringstream removeMessagesByKeysSQLStream; removeMessagesByKeysSQLStream << "DELETE FROM messages " "WHERE thread IN " << getSQLStatementArray(threadIDs.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessagesByKeysSQLStream.str(), threadIDs); } void SQLiteQueryExecutor::replaceMessage(const Message &message) const { static std::string replaceMessageSQL = "REPLACE INTO messages " "(id, local_id, thread, user, type, future_type, content, time) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMessageSQL, message); } void SQLiteQueryExecutor::rekeyMessage(std::string from, std::string to) const { static std::string rekeyMessageSQL = "UPDATE OR REPLACE messages " "SET id = ? " "WHERE id = ?"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyMessageSQL, from, to); } void SQLiteQueryExecutor::removeAllMedia() const { static std::string removeAllMediaSQL = "DELETE FROM media;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMediaSQL); } void SQLiteQueryExecutor::removeMediaForMessages( const std::vector &msg_ids) const { if (!msg_ids.size()) { return; } std::stringstream removeMediaByKeysSQLStream; removeMediaByKeysSQLStream << "DELETE FROM media " "WHERE container IN " << getSQLStatementArray(msg_ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeysSQLStream.str(), msg_ids); } void SQLiteQueryExecutor::removeMediaForMessage(std::string msg_id) const { static std::string removeMediaByKeySQL = "DELETE FROM media " "WHERE container IN (?);"; std::vector keys = {msg_id}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeySQL, keys); } void SQLiteQueryExecutor::removeMediaForThreads( const std::vector &thread_ids) const { if (!thread_ids.size()) { return; } std::stringstream removeMediaByKeysSQLStream; removeMediaByKeysSQLStream << "DELETE FROM media " "WHERE thread IN " << getSQLStatementArray(thread_ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMediaByKeysSQLStream.str(), thread_ids); } void SQLiteQueryExecutor::replaceMedia(const Media &media) const { static std::string replaceMediaSQL = "REPLACE INTO media " "(id, container, thread, uri, type, extras) " "VALUES (?, ?, ?, ?, ?, ?)"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMediaSQL, media); } void SQLiteQueryExecutor::rekeyMediaContainers(std::string from, std::string to) const { static std::string rekeyMediaContainersSQL = "UPDATE media SET container = ? WHERE container = ?;"; rekeyAllEntities( SQLiteQueryExecutor::getConnection(), rekeyMediaContainersSQL, from, to); } void SQLiteQueryExecutor::replaceMessageStoreThreads( const std::vector &threads) const { static std::string replaceMessageStoreThreadSQL = "REPLACE INTO message_store_threads " "(id, start_reached) " "VALUES (?, ?);"; for (auto &thread : threads) { replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMessageStoreThreadSQL, thread); } } void SQLiteQueryExecutor::removeAllMessageStoreThreads() const { static std::string removeAllMessageStoreThreadsSQL = "DELETE FROM message_store_threads;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllMessageStoreThreadsSQL); } void SQLiteQueryExecutor::removeMessageStoreThreads( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeMessageStoreThreadsByKeysSQLStream; removeMessageStoreThreadsByKeysSQLStream << "DELETE FROM message_store_threads " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMessageStoreThreadsByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllMessageStoreThreads() const { static std::string getAllMessageStoreThreadsSQL = "SELECT * " "FROM message_store_threads;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllMessageStoreThreadsSQL); } std::vector SQLiteQueryExecutor::getAllThreads() const { static std::string getAllThreadsSQL = "SELECT * " "FROM threads;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllThreadsSQL); }; void SQLiteQueryExecutor::removeThreads(std::vector ids) const { if (!ids.size()) { return; } std::stringstream removeThreadsByKeysSQLStream; removeThreadsByKeysSQLStream << "DELETE FROM threads " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeThreadsByKeysSQLStream.str(), ids); }; void SQLiteQueryExecutor::replaceThread(const Thread &thread) const { static std::string replaceThreadSQL = "REPLACE INTO threads (" " id, type, name, description, color, creation_time, parent_thread_id," " containing_thread_id, community, members, roles, current_user," " source_message_id, replies_count, avatar, pinned_count) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceThreadSQL, thread); }; void SQLiteQueryExecutor::removeAllThreads() const { static std::string removeAllThreadsSQL = "DELETE FROM threads;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllThreadsSQL); }; void SQLiteQueryExecutor::replaceReport(const Report &report) const { static std::string replaceReportSQL = "REPLACE INTO reports (id, report) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceReportSQL, report); } void SQLiteQueryExecutor::removeAllReports() const { static std::string removeAllReportsSQL = "DELETE FROM reports;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllReportsSQL); } void SQLiteQueryExecutor::removeReports( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeReportsByKeysSQLStream; removeReportsByKeysSQLStream << "DELETE FROM reports " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeReportsByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllReports() const { static std::string getAllReportsSQL = "SELECT * " "FROM reports;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllReportsSQL); } void SQLiteQueryExecutor::setPersistStorageItem( std::string key, std::string item) const { static std::string replacePersistStorageItemSQL = "REPLACE INTO persist_storage (key, item) " "VALUES (?, ?);"; PersistItem entry{ key, item, }; replaceEntity( SQLiteQueryExecutor::getConnection(), replacePersistStorageItemSQL, entry); } void SQLiteQueryExecutor::removePersistStorageItem(std::string key) const { static std::string removePersistStorageItemByKeySQL = "DELETE FROM persist_storage " "WHERE key IN (?);"; std::vector keys = {key}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removePersistStorageItemByKeySQL, keys); } std::string SQLiteQueryExecutor::getPersistStorageItem(std::string key) const { static std::string getPersistStorageItemByPrimaryKeySQL = "SELECT * " "FROM persist_storage " "WHERE key = ?;"; std::unique_ptr entry = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getPersistStorageItemByPrimaryKeySQL, key); return (entry == nullptr) ? "" : entry->item; } void SQLiteQueryExecutor::replaceUser(const UserInfo &user_info) const { static std::string replaceUserSQL = "REPLACE INTO users (id, user_info) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceUserSQL, user_info); } void SQLiteQueryExecutor::removeAllUsers() const { static std::string removeAllUsersSQL = "DELETE FROM users;"; removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllUsersSQL); } void SQLiteQueryExecutor::removeUsers( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeUsersByKeysSQLStream; removeUsersByKeysSQLStream << "DELETE FROM users " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeUsersByKeysSQLStream.str(), ids); } void SQLiteQueryExecutor::replaceKeyserver( const KeyserverInfo &keyserver_info) const { static std::string replaceKeyserverSQL = "REPLACE INTO keyservers (id, keyserver_info) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceKeyserverSQL, keyserver_info); } void SQLiteQueryExecutor::removeAllKeyservers() const { static std::string removeAllKeyserversSQL = "DELETE FROM keyservers;"; removeAllEntities( SQLiteQueryExecutor::getConnection(), removeAllKeyserversSQL); } void SQLiteQueryExecutor::removeKeyservers( const std::vector &ids) const { if (!ids.size()) { return; } std::stringstream removeKeyserversByKeysSQLStream; removeKeyserversByKeysSQLStream << "DELETE FROM keyservers " "WHERE id IN " << getSQLStatementArray(ids.size()) << ";"; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeKeyserversByKeysSQLStream.str(), ids); } std::vector SQLiteQueryExecutor::getAllKeyservers() const { static std::string getAllKeyserversSQL = "SELECT * " "FROM keyservers;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllKeyserversSQL); } std::vector SQLiteQueryExecutor::getAllUsers() const { static std::string getAllUsersSQL = "SELECT * " "FROM users;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllUsersSQL); } void SQLiteQueryExecutor::beginTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "BEGIN TRANSACTION;"); } void SQLiteQueryExecutor::commitTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "COMMIT;"); } void SQLiteQueryExecutor::rollbackTransaction() const { executeQuery(SQLiteQueryExecutor::getConnection(), "ROLLBACK;"); } std::vector SQLiteQueryExecutor::getOlmPersistSessionsData() const { static std::string getAllOlmPersistSessionsSQL = "SELECT * " "FROM olm_persist_sessions;"; return getAllEntities( SQLiteQueryExecutor::getConnection(), getAllOlmPersistSessionsSQL); } std::optional SQLiteQueryExecutor::getOlmPersistAccountData() const { static std::string getAllOlmPersistAccountSQL = "SELECT * " "FROM olm_persist_account;"; std::vector result = getAllEntities( SQLiteQueryExecutor::getConnection(), getAllOlmPersistAccountSQL); if (result.size() > 1) { throw std::system_error( ECANCELED, std::generic_category(), "Multiple records found for the olm_persist_account table"); } return (result.size() == 0) ? std::nullopt : std::optional(result[0].account_data); } void SQLiteQueryExecutor::storeOlmPersistAccount( const std::string &accountData) const { static std::string replaceOlmPersistAccountSQL = "REPLACE INTO olm_persist_account (id, account_data) " "VALUES (?, ?);"; OlmPersistAccount persistAccount = {ACCOUNT_ID, accountData}; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceOlmPersistAccountSQL, persistAccount); } void SQLiteQueryExecutor::storeOlmPersistSession( const OlmPersistSession &session) const { static std::string replaceOlmPersistSessionSQL = "REPLACE INTO olm_persist_sessions (target_user_id, session_data) " "VALUES (?, ?);"; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceOlmPersistSessionSQL, session); } void SQLiteQueryExecutor::storeOlmPersistData(crypto::Persist persist) const { std::string accountData = std::string(persist.account.begin(), persist.account.end()); this->storeOlmPersistAccount(accountData); for (auto it = persist.sessions.begin(); it != persist.sessions.end(); it++) { OlmPersistSession persistSession = { it->first, std::string(it->second.begin(), it->second.end())}; this->storeOlmPersistSession(persistSession); } } void SQLiteQueryExecutor::setNotifyToken(std::string token) const { this->setMetadata("notify_token", token); } void SQLiteQueryExecutor::clearNotifyToken() const { this->clearMetadata("notify_token"); } void SQLiteQueryExecutor::setCurrentUserID(std::string userID) const { this->setMetadata("current_user_id", userID); } std::string SQLiteQueryExecutor::getCurrentUserID() const { return this->getMetadata("current_user_id"); } void SQLiteQueryExecutor::setMetadata(std::string entry_name, std::string data) const { std::string replaceMetadataSQL = "REPLACE INTO metadata (name, data) " "VALUES (?, ?);"; Metadata entry{ entry_name, data, }; replaceEntity( SQLiteQueryExecutor::getConnection(), replaceMetadataSQL, entry); } void SQLiteQueryExecutor::clearMetadata(std::string entry_name) const { static std::string removeMetadataByKeySQL = "DELETE FROM metadata " "WHERE name IN (?);"; std::vector keys = {entry_name}; removeEntitiesByKeys( SQLiteQueryExecutor::getConnection(), removeMetadataByKeySQL, keys); } std::string SQLiteQueryExecutor::getMetadata(std::string entry_name) const { std::string getMetadataByPrimaryKeySQL = "SELECT * " "FROM metadata " "WHERE name = ?;"; std::unique_ptr entry = getEntityByPrimaryKey( SQLiteQueryExecutor::getConnection(), getMetadataByPrimaryKeySQL, entry_name); return (entry == nullptr) ? "" : entry->data; } #ifdef EMSCRIPTEN std::vector SQLiteQueryExecutor::getAllThreadsWeb() const { auto threads = this->getAllThreads(); std::vector webThreads; webThreads.reserve(threads.size()); for (const auto &thread : threads) { webThreads.emplace_back(thread); } return webThreads; }; void SQLiteQueryExecutor::replaceThreadWeb(const WebThread &thread) const { this->replaceThread(thread.toThread()); }; std::vector SQLiteQueryExecutor::getAllMessagesWeb() const { auto allMessages = this->getAllMessages(); std::vector allMessageWithMedias; for (auto &messageWitMedia : allMessages) { allMessageWithMedias.push_back( {std::move(messageWitMedia.first), messageWitMedia.second}); } return allMessageWithMedias; } void SQLiteQueryExecutor::replaceMessageWeb(const WebMessage &message) const { this->replaceMessage(message.toMessage()); }; NullableString SQLiteQueryExecutor::getOlmPersistAccountDataWeb() const { std::optional accountData = this->getOlmPersistAccountData(); if (!accountData.has_value()) { return NullableString(); } return std::make_unique(accountData.value()); } #else void SQLiteQueryExecutor::clearSensitiveData() { SQLiteQueryExecutor::closeConnection(); if (file_exists(SQLiteQueryExecutor::sqliteFilePath) && std::remove(SQLiteQueryExecutor::sqliteFilePath.c_str())) { std::ostringstream errorStream; errorStream << "Failed to delete database file. Details: " << strerror(errno); throw std::system_error(errno, std::generic_category(), errorStream.str()); } SQLiteQueryExecutor::generateFreshEncryptionKey(); SQLiteQueryExecutor::migrate(); } void SQLiteQueryExecutor::initialize(std::string &databasePath) { std::call_once(SQLiteQueryExecutor::initialized, [&databasePath]() { SQLiteQueryExecutor::sqliteFilePath = databasePath; folly::Optional maybeEncryptionKey = CommSecureStore::get(SQLiteQueryExecutor::secureStoreEncryptionKeyID); folly::Optional maybeBackupLogsEncryptionKey = CommSecureStore::get( SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID); if (file_exists(databasePath) && maybeEncryptionKey && maybeBackupLogsEncryptionKey) { SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value(); SQLiteQueryExecutor::backupLogsEncryptionKey = maybeBackupLogsEncryptionKey.value(); return; } else if (file_exists(databasePath) && maybeEncryptionKey) { SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value(); SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey(); return; } SQLiteQueryExecutor::generateFreshEncryptionKey(); }); } void SQLiteQueryExecutor::createMainCompaction(std::string backupID) const { std::string finalBackupPath = PlatformSpecificTools::getBackupFilePath(backupID, false); std::string finalAttachmentsPath = PlatformSpecificTools::getBackupFilePath(backupID, true); std::string tempBackupPath = finalBackupPath + "_tmp"; std::string tempAttachmentsPath = finalAttachmentsPath + "_tmp"; if (file_exists(tempBackupPath)) { Logger::log( "Attempting to delete temporary backup file from previous backup " "attempt."); attempt_delete_file( tempBackupPath, "Failed to delete temporary backup file from previous backup attempt."); } if (file_exists(tempAttachmentsPath)) { Logger::log( "Attempting to delete temporary attachments file from previous backup " "attempt."); attempt_delete_file( tempAttachmentsPath, "Failed to delete temporary attachments file from previous backup " "attempt."); } sqlite3 *backupDB; sqlite3_open(tempBackupPath.c_str(), &backupDB); set_encryption_key(backupDB); sqlite3_backup *backupObj = sqlite3_backup_init( backupDB, "main", SQLiteQueryExecutor::getConnection(), "main"); if (!backupObj) { std::stringstream error_message; error_message << "Failed to init backup for main compaction. Details: " << sqlite3_errmsg(backupDB) << std::endl; sqlite3_close(backupDB); throw std::runtime_error(error_message.str()); } int backupResult = sqlite3_backup_step(backupObj, -1); sqlite3_backup_finish(backupObj); if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) { sqlite3_close(backupDB); throw std::runtime_error( "Programmer error. Database in transaction during backup attempt."); } else if (backupResult != SQLITE_DONE) { sqlite3_close(backupDB); std::stringstream error_message; error_message << "Failed to create database backup. Details: " << sqlite3_errstr(backupResult); throw std::runtime_error(error_message.str()); } std::string removeDeviceSpecificDataSQL = "DELETE FROM olm_persist_account;" "DELETE FROM olm_persist_sessions;" "DELETE FROM metadata;"; executeQuery(backupDB, removeDeviceSpecificDataSQL); executeQuery(backupDB, "VACUUM;"); sqlite3_close(backupDB); attempt_rename_file( tempBackupPath, finalBackupPath, "Failed to rename complete temporary backup file to final backup file."); std::ofstream tempAttachmentsFile(tempAttachmentsPath); if (!tempAttachmentsFile.is_open()) { throw std::runtime_error( "Unable to create attachments file for backup id: " + backupID); } std::string getAllBlobServiceMediaSQL = "SELECT * FROM media WHERE uri LIKE 'comm-blob-service://%';"; std::vector blobServiceMedia = getAllEntities( SQLiteQueryExecutor::getConnection(), getAllBlobServiceMediaSQL); for (const auto &media : blobServiceMedia) { std::string blobServiceURI = media.uri; std::string blobHash = blob_hash_from_blob_service_uri(blobServiceURI); tempAttachmentsFile << blobHash << "\n"; } tempAttachmentsFile.close(); attempt_rename_file( tempAttachmentsPath, finalAttachmentsPath, "Failed to rename complete temporary attachments file to final " "attachments file."); this->setMetadata("backupID", backupID); this->clearMetadata("logID"); if (StaffUtils::isStaffRelease()) { SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true); } } void SQLiteQueryExecutor::generateFreshEncryptionKey() { std::string encryptionKey = comm::crypto::Tools::generateRandomHexString( SQLiteQueryExecutor::sqlcipherEncryptionKeySize); CommSecureStore::set( SQLiteQueryExecutor::secureStoreEncryptionKeyID, encryptionKey); SQLiteQueryExecutor::encryptionKey = encryptionKey; SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey(); } void SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey() { std::string backupLogsEncryptionKey = comm::crypto::Tools::generateRandomHexString( SQLiteQueryExecutor::backupLogsEncryptionKeySize); CommSecureStore::set( SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID, backupLogsEncryptionKey); SQLiteQueryExecutor::backupLogsEncryptionKey = backupLogsEncryptionKey; } void SQLiteQueryExecutor::captureBackupLogs() const { std::string backupID = this->getMetadata("backupID"); if (!backupID.size()) { return; } std::string logID = this->getMetadata("logID"); if (!logID.size()) { - logID = "0"; + logID = "1"; } bool newLogCreated = SQLiteQueryExecutor::connectionManager.captureLogs( backupID, logID, SQLiteQueryExecutor::backupLogsEncryptionKey); if (!newLogCreated) { return; } this->setMetadata("logID", std::to_string(std::stoi(logID) + 1)); } #endif void SQLiteQueryExecutor::restoreFromMainCompaction( std::string mainCompactionPath, std::string mainCompactionEncryptionKey) const { if (!file_exists(mainCompactionPath)) { throw std::runtime_error("Restore attempt but backup file does not exist."); } sqlite3 *backupDB; if (!is_database_queryable( backupDB, true, mainCompactionPath, mainCompactionEncryptionKey)) { throw std::runtime_error("Backup file or encryption key corrupted."); } // We don't want to run `PRAGMA key = ...;` // on main web database. The context is here: // https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web #ifdef EMSCRIPTEN std::string plaintextBackupPath = mainCompactionPath + "_plaintext"; if (file_exists(plaintextBackupPath)) { attempt_delete_file( plaintextBackupPath, "Failed to delete plaintext backup file from previous backup attempt."); } std::string plaintextMigrationDBQuery = "PRAGMA key = \"x'" + mainCompactionEncryptionKey + "'\";" "ATTACH DATABASE '" + plaintextBackupPath + "' AS plaintext KEY '';" "SELECT sqlcipher_export('plaintext');" "DETACH DATABASE plaintext;"; sqlite3_open(mainCompactionPath.c_str(), &backupDB); char *plaintextMigrationErr; sqlite3_exec( backupDB, plaintextMigrationDBQuery.c_str(), nullptr, nullptr, &plaintextMigrationErr); sqlite3_close(backupDB); if (plaintextMigrationErr) { std::stringstream error_message; error_message << "Failed to migrate backup SQLCipher file to plaintext " "SQLite file. Details" << plaintextMigrationErr << std::endl; std::string error_message_str = error_message.str(); sqlite3_free(plaintextMigrationErr); throw std::runtime_error(error_message_str); } sqlite3_open(plaintextBackupPath.c_str(), &backupDB); #else sqlite3_open(mainCompactionPath.c_str(), &backupDB); set_encryption_key(backupDB, mainCompactionEncryptionKey); #endif sqlite3_backup *backupObj = sqlite3_backup_init( SQLiteQueryExecutor::getConnection(), "main", backupDB, "main"); if (!backupObj) { std::stringstream error_message; error_message << "Failed to init backup for main compaction. Details: " << sqlite3_errmsg(SQLiteQueryExecutor::getConnection()) << std::endl; sqlite3_close(backupDB); throw std::runtime_error(error_message.str()); } int backupResult = sqlite3_backup_step(backupObj, -1); sqlite3_backup_finish(backupObj); sqlite3_close(backupDB); if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) { throw std::runtime_error( "Programmer error. Database in transaction during restore attempt."); } else if (backupResult != SQLITE_DONE) { std::stringstream error_message; error_message << "Failed to restore database from backup. Details: " << sqlite3_errstr(backupResult); throw std::runtime_error(error_message.str()); } #ifdef EMSCRIPTEN attempt_delete_file( plaintextBackupPath, "Failed to delete plaintext compaction file after successful restore."); #endif attempt_delete_file( mainCompactionPath, "Failed to delete main compaction file after successful restore."); } void SQLiteQueryExecutor::restoreFromBackupLog( const std::vector &backupLog) const { SQLiteQueryExecutor::connectionManager.restoreFromBackupLog(backupLog); } } // namespace comm diff --git a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.cpp b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.cpp index 1f3148efc..d187044ba 100644 --- a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.cpp +++ b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.cpp @@ -1,59 +1,60 @@ #include "BackupOperationsExecutor.h" #include "DatabaseManager.h" #include "GlobalDBSingleton.h" #include "Logger.h" #include "RustPromiseManager.h" #include "WorkerThread.h" #include "lib.rs.h" namespace comm { void BackupOperationsExecutor::createMainCompaction( std::string backupID, size_t futureID) { taskType job = [backupID, futureID]() { try { DatabaseManager::getQueryExecutor().createMainCompaction(backupID); ::resolveUnitFuture(futureID); } catch (const std::exception &e) { ::rejectFuture(futureID, rust::String(e.what())); Logger::log( "Main compaction creation failed. Details: " + std::string(e.what())); } }; GlobalDBSingleton::instance.scheduleOrRunCancellable(job); } void BackupOperationsExecutor::restoreFromMainCompaction( std::string mainCompactionPath, std::string mainCompactionEncryptionKey, size_t futureID) { taskType job = [mainCompactionPath, mainCompactionEncryptionKey, futureID]() { try { DatabaseManager::getQueryExecutor().restoreFromMainCompaction( mainCompactionPath, mainCompactionEncryptionKey); ::resolveUnitFuture(futureID); } catch (const std::exception &e) { std::string errorDetails = std::string(e.what()); Logger::log( "Restore from main compaction failed. Details: " + errorDetails); ::rejectFuture(futureID, errorDetails); } }; GlobalDBSingleton::instance.scheduleOrRunCancellable(job); } void BackupOperationsExecutor::restoreFromBackupLog( - const std::vector &backupLog) { - taskType job = [backupLog]() { + const std::vector &backupLog, + size_t futureID) { + taskType job = [backupLog, futureID]() { try { DatabaseManager::getQueryExecutor().restoreFromBackupLog(backupLog); + ::resolveUnitFuture(futureID); } catch (const std::exception &e) { - // TODO: Inform Rust networking about failure - // of restoration from backup log. - Logger::log( - "Restore from backup log failed. Details: " + std::string(e.what())); + std::string errorDetails = std::string(e.what()); + Logger::log("Restore from backup log failed. Details: " + errorDetails); + ::rejectFuture(futureID, errorDetails); } }; GlobalDBSingleton::instance.scheduleOrRunCancellable(job); } } // namespace comm diff --git a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.h b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.h index 4d92b23f3..9e8587eb6 100644 --- a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.h +++ b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.h @@ -1,16 +1,18 @@ #pragma once #include #include namespace comm { class BackupOperationsExecutor { public: static void createMainCompaction(std::string backupID, size_t futureID); static void restoreFromMainCompaction( std::string mainCompactionPath, std::string mainCompactionEncryptionKey, size_t futureID); - static void restoreFromBackupLog(const std::vector &backupLog); + static void restoreFromBackupLog( + const std::vector &backupLog, + size_t futureID); }; } // namespace comm diff --git a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/DataStores/BaseDataStore.h b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/DataStores/BaseDataStore.h index 424195685..1f847b850 100644 --- a/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/DataStores/BaseDataStore.h +++ b/native/cpp/CommonCpp/NativeModules/PersistentStorageUtilities/DataStores/BaseDataStore.h @@ -1,107 +1,120 @@ #pragma once #include "DatabaseManager.h" #include "GlobalDBSingleton.h" #include "NativeModuleUtils.h" #include "WorkerThread.h" +#include "lib.rs.h" #include #include #include namespace comm { namespace jsi = facebook::jsi; using OperationType = const std::string; template class BaseDataStore { private: std::shared_ptr jsInvoker; public: BaseDataStore(std::shared_ptr _jsInvoker) : jsInvoker(_jsInvoker) { } virtual ~BaseDataStore(){}; virtual std::vector> createOperations(jsi::Runtime &rt, const jsi::Array &operations) const = 0; virtual jsi::Array parseDBDataStore( jsi::Runtime &rt, std::shared_ptr> dataVectorPtr) const = 0; jsi::Value processStoreOperations(jsi::Runtime &rt, jsi::Array &&operations) { std::string createOperationsError; std::shared_ptr>> storeOpsPtr; try { auto storeOps = createOperations(rt, operations); storeOpsPtr = std::make_shared>>( std::move(storeOps)); } catch (std::runtime_error &e) { createOperationsError = e.what(); } return facebook::react::createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=]() { std::string error = createOperationsError; if (!error.size()) { try { DatabaseManager::getQueryExecutor().beginTransaction(); for (const auto &operation : *storeOpsPtr) { operation->execute(); } DatabaseManager::getQueryExecutor().captureBackupLogs(); DatabaseManager::getQueryExecutor().commitTransaction(); } catch (std::system_error &e) { error = e.what(); DatabaseManager::getQueryExecutor().rollbackTransaction(); } } + if (!error.size()) { + ::triggerBackupFileUpload(); + } + this->jsInvoker->invokeAsync([=]() { if (error.size()) { promise->reject(error); } else { promise->resolve(jsi::Value::undefined()); } }); }; GlobalDBSingleton::instance.scheduleOrRunCancellable( job, promise, this->jsInvoker); }); } void processStoreOperationsSync(jsi::Runtime &rt, jsi::Array &&operations) { std::vector> storeOps; try { storeOps = createOperations(rt, operations); } catch (const std::exception &e) { throw jsi::JSError(rt, e.what()); } NativeModuleUtils::runSyncOrThrowJSError(rt, [&storeOps]() { + std::string error; + try { DatabaseManager::getQueryExecutor().beginTransaction(); for (const auto &operation : storeOps) { operation->execute(); } DatabaseManager::getQueryExecutor().captureBackupLogs(); DatabaseManager::getQueryExecutor().commitTransaction(); } catch (const std::exception &e) { + error = e.what(); DatabaseManager::getQueryExecutor().rollbackTransaction(); - throw e; + } + + if (error.size()) { + throw std::runtime_error(error); + } else if (!error.size()) { + ::triggerBackupFileUpload(); } }); } }; } // namespace comm diff --git a/native/native_rust_library/RustBackupExecutor.cpp b/native/native_rust_library/RustBackupExecutor.cpp index 77d7f0e4c..da6cee062 100644 --- a/native/native_rust_library/RustBackupExecutor.cpp +++ b/native/native_rust_library/RustBackupExecutor.cpp @@ -1,48 +1,49 @@ #include "RustBackupExecutor.h" #include "../cpp/CommonCpp/NativeModules/PersistentStorageUtilities/BackupOperationsUtilities/BackupOperationsExecutor.h" #include "../cpp/CommonCpp/Tools/PlatformSpecificTools.h" #include namespace comm { rust::String getBackupDirectoryPath() { return rust::String(PlatformSpecificTools::getBackupDirectoryPath()); } rust::String getBackupFilePath(rust::Str backupID, bool isAttachments) { return rust::String(PlatformSpecificTools::getBackupFilePath( std::string(backupID), isAttachments)); } rust::String getBackupLogFilePath(rust::Str backupID, rust::Str logID, bool isAttachments) { return rust::String(PlatformSpecificTools::getBackupLogFilePath( std::string(backupID), std::string(logID), isAttachments)); } rust::String getBackupUserKeysFilePath(rust::Str backupID) { return rust::String( PlatformSpecificTools::getBackupUserKeysFilePath(std::string(backupID))); } void createMainCompaction(rust::Str backupID, size_t futureID) { BackupOperationsExecutor::createMainCompaction( std::string(backupID), futureID); } void restoreFromMainCompaction( rust::Str mainCompactionPath, rust::Str mainCompactionEncryptionKey, size_t futureID) { BackupOperationsExecutor::restoreFromMainCompaction( std::string(mainCompactionPath), std::string(mainCompactionEncryptionKey), futureID); } -void restoreFromBackupLog(rust::Vec backupLog) { +void restoreFromBackupLog(rust::Vec backupLog, size_t futureID) { BackupOperationsExecutor::restoreFromBackupLog( - std::move(std::vector(backupLog.begin(), backupLog.end()))); + std::move(std::vector(backupLog.begin(), backupLog.end())), + futureID); } } // namespace comm diff --git a/native/native_rust_library/RustBackupExecutor.h b/native/native_rust_library/RustBackupExecutor.h index 0c17dccf1..1e1b64f7d 100644 --- a/native/native_rust_library/RustBackupExecutor.h +++ b/native/native_rust_library/RustBackupExecutor.h @@ -1,19 +1,19 @@ #pragma once #include "cxx.h" namespace comm { rust::String getBackupDirectoryPath(); rust::String getBackupFilePath(rust::Str backupID, bool isAttachments); rust::String getBackupLogFilePath(rust::Str backupID, rust::Str logID, bool isAttachments); rust::String getBackupUserKeysFilePath(rust::Str backupID); void createMainCompaction(rust::Str backupID, size_t futureID); void restoreFromMainCompaction( rust::Str mainCompactionPath, rust::Str mainCompactionEncryptionKey, size_t futureID); -void restoreFromBackupLog(rust::Vec backupLog); +void restoreFromBackupLog(rust::Vec backupLog, size_t futureID); } // namespace comm diff --git a/native/native_rust_library/src/backup.rs b/native/native_rust_library/src/backup.rs index 258bfda69..3ca1704d3 100644 --- a/native/native_rust_library/src/backup.rs +++ b/native/native_rust_library/src/backup.rs @@ -1,223 +1,264 @@ mod compaction_upload_promises; mod file_info; mod upload_handler; use crate::argon2_tools::{compute_backup_key, compute_backup_key_str}; use crate::constants::{aes, secure_store}; use crate::ffi::{ create_main_compaction, get_backup_directory_path, - get_backup_user_keys_file_path, restore_from_main_compaction, - secure_store_get, void_callback, + get_backup_user_keys_file_path, restore_from_backup_log, + restore_from_main_compaction, secure_store_get, void_callback, }; use crate::future_manager; use crate::BACKUP_SOCKET_ADDR; use crate::RUNTIME; use backup_client::{ BackupClient, BackupDescriptor, LatestBackupIDResponse, RequestedData, - UserIdentity, + TryStreamExt, UserIdentity, }; use serde::{Deserialize, Serialize}; use std::error::Error; use std::path::PathBuf; pub mod ffi { use super::*; pub use upload_handler::ffi::*; pub fn create_backup( backup_id: String, backup_secret: String, pickle_key: String, pickled_account: String, promise_id: u32, ) { compaction_upload_promises::insert(backup_id.clone(), promise_id); RUNTIME.spawn(async move { let result = create_userkeys_compaction( backup_id.clone(), backup_secret, pickle_key, pickled_account, ) .await .map_err(|err| err.to_string()); if let Err(err) = result { compaction_upload_promises::resolve(&backup_id, Err(err)); return; } let (future_id, future) = future_manager::new_future::<()>().await; create_main_compaction(&backup_id, future_id); if let Err(err) = future.await { compaction_upload_promises::resolve(&backup_id, Err(err)); tokio::spawn(upload_handler::compaction::cleanup_files(backup_id)); return; } trigger_backup_file_upload(); // The promise will be resolved when the backup is uploaded }); } pub fn restore_backup(backup_secret: String, promise_id: u32) { RUNTIME.spawn(async move { let result = download_backup(backup_secret) .await .map_err(|err| err.to_string()); let result = match result { Ok(result) => result, Err(error) => { void_callback(error, promise_id); return; } }; let (future_id, future) = future_manager::new_future::<()>().await; restore_from_main_compaction( &result.backup_restoration_path.to_string_lossy(), &result.backup_data_key, future_id, ); if let Err(error) = future.await { void_callback(error, promise_id); return; } + if let Err(error) = + download_and_apply_logs(&result.backup_id, result.backup_log_data_key) + .await + { + void_callback(error.to_string(), promise_id); + return; + } + void_callback(String::new(), promise_id); }); } } pub async fn create_userkeys_compaction( backup_id: String, backup_secret: String, pickle_key: String, pickled_account: String, ) -> Result<(), Box> { let mut backup_key = compute_backup_key(backup_secret.as_bytes(), backup_id.as_bytes())?; let backup_data_key = secure_store_get(secure_store::SECURE_STORE_ENCRYPTION_KEY_ID)?; + let backup_log_data_key = + secure_store_get(secure_store::SECURE_STORE_BACKUP_LOGS_ENCRYPTION_KEY_ID)?; + let user_keys = UserKeys { backup_data_key, + backup_log_data_key, pickle_key, pickled_account, }; let encrypted_user_keys = user_keys.encrypt(&mut backup_key)?; let user_keys_file = get_backup_user_keys_file_path(&backup_id)?; tokio::fs::write(user_keys_file, encrypted_user_keys).await?; Ok(()) } async fn download_backup( backup_secret: String, ) -> Result> { let backup_client = BackupClient::new(BACKUP_SOCKET_ADDR)?; let user_identity = get_user_identity_from_secure_store()?; let latest_backup_descriptor = BackupDescriptor::Latest { username: user_identity.user_id.clone(), }; let backup_id_response = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::BackupID) .await?; let LatestBackupIDResponse { backup_id } = serde_json::from_slice(&backup_id_response)?; let mut backup_key = compute_backup_key_str(&backup_secret, &backup_id)?; let mut encrypted_user_keys = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::UserKeys) .await?; let user_keys = UserKeys::from_encrypted(&mut encrypted_user_keys, &mut backup_key)?; let backup_data_descriptor = BackupDescriptor::BackupID { backup_id: backup_id.clone(), user_identity: user_identity.clone(), }; let encrypted_user_data = backup_client .download_backup_data(&backup_data_descriptor, RequestedData::UserData) .await?; let backup_restoration_path = PathBuf::from(get_backup_directory_path()?).join("restore_compaction"); tokio::fs::write(&backup_restoration_path, encrypted_user_data).await?; Ok(CompactionDownloadResult { backup_id, backup_restoration_path, backup_data_key: user_keys.backup_data_key, + backup_log_data_key: user_keys.backup_log_data_key, }) } +async fn download_and_apply_logs( + backup_id: &str, + backup_log_data_key: String, +) -> Result<(), Box> { + let mut backup_log_data_key = backup_log_data_key.into_bytes(); + + let backup_client = BackupClient::new(BACKUP_SOCKET_ADDR)?; + let user_identity = get_user_identity_from_secure_store()?; + + let stream = backup_client.download_logs(&user_identity, backup_id).await; + let mut stream = Box::pin(stream); + + while let Some(mut log) = stream.try_next().await? { + let data = decrypt( + backup_log_data_key.as_mut_slice(), + log.content.as_mut_slice(), + )?; + + let (future_id, future) = future_manager::new_future::<()>().await; + restore_from_backup_log(data, future_id); + future.await?; + } + + Ok(()) +} + fn get_user_identity_from_secure_store() -> Result { Ok(UserIdentity { user_id: secure_store_get(secure_store::USER_ID)?, access_token: secure_store_get(secure_store::COMM_SERVICES_ACCESS_TOKEN)?, device_id: secure_store_get(secure_store::DEVICE_ID)?, }) } struct CompactionDownloadResult { backup_id: String, backup_restoration_path: PathBuf, backup_data_key: String, + backup_log_data_key: String, } #[derive(Debug, Serialize, Deserialize)] struct UserKeys { backup_data_key: String, + backup_log_data_key: String, pickle_key: String, pickled_account: String, } impl UserKeys { fn encrypt(&self, backup_key: &mut [u8]) -> Result, Box> { let mut json = serde_json::to_vec(self)?; encrypt(backup_key, &mut json) } fn from_encrypted( data: &mut [u8], backup_key: &mut [u8], ) -> Result> { let decrypted = decrypt(backup_key, data)?; Ok(serde_json::from_slice(&decrypted)?) } } fn encrypt(key: &mut [u8], data: &mut [u8]) -> Result, Box> { let encrypted_len = data.len() + aes::IV_LENGTH + aes::TAG_LENGTH; let mut encrypted = vec![0; encrypted_len]; crate::ffi::encrypt(key, data, &mut encrypted)?; Ok(encrypted) } fn decrypt(key: &mut [u8], data: &mut [u8]) -> Result, Box> { let decrypted_len = data.len() - aes::IV_LENGTH - aes::TAG_LENGTH; let mut decrypted = vec![0; decrypted_len]; crate::ffi::decrypt(key, data, &mut decrypted)?; Ok(decrypted) } diff --git a/native/native_rust_library/src/constants.rs b/native/native_rust_library/src/constants.rs index 98e42f781..0bf2abac8 100644 --- a/native/native_rust_library/src/constants.rs +++ b/native/native_rust_library/src/constants.rs @@ -1,20 +1,22 @@ use std::time::Duration; #[allow(unused)] pub mod aes { pub const KEY_SIZE: usize = 32; // bytes pub const IV_LENGTH: usize = 12; // bytes - unique Initialization Vector (nonce) pub const TAG_LENGTH: usize = 16; // bytes - GCM auth tag } pub mod secure_store { /// Should match constants defined in `CommSecureStore.h` pub const COMM_SERVICES_ACCESS_TOKEN: &str = "accessToken"; pub const USER_ID: &str = "userID"; pub const DEVICE_ID: &str = "deviceID"; /// Should match constant defined in `SQLiteQueryExecutor.h` pub const SECURE_STORE_ENCRYPTION_KEY_ID: &str = "comm.encryptionKey"; + pub const SECURE_STORE_BACKUP_LOGS_ENCRYPTION_KEY_ID: &str = + "comm.backupLogsEncryptionKey"; } pub const BACKUP_SERVICE_CONNECTION_RETRY_DELAY: Duration = Duration::from_secs(5); diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs index 8aa336cde..5f4094871 100644 --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -1,1505 +1,1504 @@ use backup::ffi::*; use comm_opaque2::client::{Login, Registration}; use comm_opaque2::grpc::opaque_error_to_grpc_status as handle_error; use exact_user_search::find_user_id_for_wallet_address; use ffi::{bool_callback, string_callback, void_callback}; use future_manager::ffi::*; use grpc_clients::identity::protos::auth::{ GetDeviceListRequest, UpdateDeviceListRequest, }; use grpc_clients::identity::protos::authenticated::{ InboundKeyInfo, InboundKeysForUserRequest, KeyserverKeysResponse, OutboundKeyInfo, OutboundKeysForUserRequest, RefreshUserPrekeysRequest, UpdateUserPasswordFinishRequest, UpdateUserPasswordStartRequest, UploadOneTimeKeysRequest, }; use grpc_clients::identity::protos::unauth::{ AuthResponse, DeviceKeyUpload, DeviceType, Empty, IdentityKeyInfo, OpaqueLoginFinishRequest, OpaqueLoginStartRequest, Prekey, RegistrationFinishRequest, RegistrationStartRequest, SecondaryDeviceKeysUploadRequest, WalletAuthRequest, }; use grpc_clients::identity::{ get_auth_client, get_unauthenticated_client, REQUEST_METADATA_COOKIE_KEY, RESPONSE_METADATA_COOKIE_KEY, }; use lazy_static::lazy_static; use serde::Serialize; use std::sync::Arc; use tokio::runtime::{Builder, Runtime}; use tonic::{Request, Status}; use tracing::instrument; use wallet_registration::register_wallet_user; mod argon2_tools; mod backup; mod constants; mod exact_user_search; mod future_manager; mod wallet_registration; use argon2_tools::compute_backup_key_str; mod generated { // We get the CODE_VERSION from this generated file include!(concat!(env!("OUT_DIR"), "/version.rs")); // We get the IDENTITY_SOCKET_ADDR from this generated file include!(concat!(env!("OUT_DIR"), "/socket_config.rs")); } pub use generated::CODE_VERSION; pub use generated::{BACKUP_SOCKET_ADDR, IDENTITY_SOCKET_ADDR}; #[cfg(not(target_os = "android"))] pub const DEVICE_TYPE: DeviceType = DeviceType::Ios; #[cfg(target_os = "android")] pub const DEVICE_TYPE: DeviceType = DeviceType::Android; lazy_static! { static ref RUNTIME: Arc = Arc::new(Builder::new_multi_thread().enable_all().build().unwrap()); } #[cxx::bridge] mod ffi { extern "Rust" { #[cxx_name = "identityRegisterPasswordUser"] fn register_password_user( username: String, password: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityLogInPasswordUser"] fn log_in_password_user( username: String, password: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityRegisterWalletUser"] fn register_wallet_user( siwe_message: String, siwe_signature: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityLogInWalletUser"] fn log_in_wallet_user( siwe_message: String, siwe_signature: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityUpdateUserPassword"] fn update_user_password( user_id: String, device_id: String, access_token: String, password: String, promise_id: u32, ); #[cxx_name = "identityDeleteUser"] fn delete_user( user_id: String, device_id: String, access_token: String, promise_id: u32, ); #[cxx_name = "identityGetOutboundKeysForUser"] fn get_outbound_keys_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, promise_id: u32, ); #[cxx_name = "identityGetInboundKeysForUser"] fn get_inbound_keys_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, promise_id: u32, ); #[cxx_name = "identityRefreshUserPrekeys"] fn refresh_user_prekeys( auth_user_id: String, auth_device_id: String, auth_access_token: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, promise_id: u32, ); #[cxx_name = "identityGenerateNonce"] fn generate_nonce(promise_id: u32); #[cxx_name = "identityVersionSupported"] fn version_supported(promise_id: u32); #[cxx_name = "identityUploadOneTimeKeys"] fn upload_one_time_keys( auth_user_id: String, auth_device_id: String, auth_access_token: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityGetKeyserverKeys"] fn get_keyserver_keys( user_id: String, device_id: String, access_token: String, keyserver_id: String, promise_id: u32, ); #[cxx_name = "identityGetDeviceListForUser"] fn get_device_list_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, since_timestamp: i64, promise_id: u32, ); #[cxx_name = "identityUpdateDeviceList"] fn update_device_list( auth_user_id: String, auth_device_id: String, auth_access_token: String, update_payload: String, promise_id: u32, ); #[cxx_name = "identityUploadSecondaryDeviceKeysAndLogIn"] fn upload_secondary_device_keys_and_log_in( user_id: String, challenge_response: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ); #[cxx_name = "identityFindUserIDForWalletAddress"] fn find_user_id_for_wallet_address(wallet_address: String, promise_id: u32); // Argon2 #[cxx_name = "compute_backup_key"] fn compute_backup_key_str( password: &str, backup_id: &str, ) -> Result<[u8; 32]>; } unsafe extern "C++" { include!("RustCallback.h"); #[namespace = "comm"] #[cxx_name = "stringCallback"] fn string_callback(error: String, promise_id: u32, ret: String); #[namespace = "comm"] #[cxx_name = "voidCallback"] fn void_callback(error: String, promise_id: u32); #[namespace = "comm"] #[cxx_name = "boolCallback"] fn bool_callback(error: String, promise_id: u32, ret: bool); } // AES cryptography #[namespace = "comm"] unsafe extern "C++" { include!("RustAESCrypto.h"); #[allow(unused)] #[cxx_name = "aesGenerateKey"] fn generate_key(buffer: &mut [u8]) -> Result<()>; /// The first two argument aren't mutated but creation of Java ByteBuffer /// requires the underlying bytes to be mutable. #[allow(unused)] #[cxx_name = "aesEncrypt"] fn encrypt( key: &mut [u8], plaintext: &mut [u8], sealed_data: &mut [u8], ) -> Result<()>; /// The first two argument aren't mutated but creation of Java ByteBuffer /// requires the underlying bytes to be mutable. #[allow(unused)] #[cxx_name = "aesDecrypt"] fn decrypt( key: &mut [u8], sealed_data: &mut [u8], plaintext: &mut [u8], ) -> Result<()>; } // Comm Services Auth Metadata Emission #[namespace = "comm"] unsafe extern "C++" { include!("RustCSAMetadataEmitter.h"); #[allow(unused)] #[cxx_name = "sendAuthMetadataToJS"] fn send_auth_metadata_to_js( access_token: String, user_id: String, ) -> Result<()>; } // Backup extern "Rust" { #[cxx_name = "startBackupHandler"] fn start_backup_handler() -> Result<()>; #[cxx_name = "stopBackupHandler"] fn stop_backup_handler() -> Result<()>; #[cxx_name = "triggerBackupFileUpload"] fn trigger_backup_file_upload(); #[cxx_name = "createBackup"] fn create_backup( backup_id: String, backup_secret: String, pickle_key: String, pickled_account: String, promise_id: u32, ); #[cxx_name = "restoreBackup"] fn restore_backup(backup_secret: String, promise_id: u32); } // Secure store #[namespace = "comm"] unsafe extern "C++" { include!("RustSecureStore.h"); #[allow(unused)] #[cxx_name = "secureStoreSet"] fn secure_store_set(key: &str, value: String) -> Result<()>; #[cxx_name = "secureStoreGet"] fn secure_store_get(key: &str) -> Result; } // C++ Backup creation #[namespace = "comm"] unsafe extern "C++" { include!("RustBackupExecutor.h"); #[cxx_name = "getBackupDirectoryPath"] fn get_backup_directory_path() -> Result; #[cxx_name = "getBackupFilePath"] fn get_backup_file_path( backup_id: &str, is_attachments: bool, ) -> Result; #[cxx_name = "getBackupLogFilePath"] fn get_backup_log_file_path( backup_id: &str, log_id: &str, is_attachments: bool, ) -> Result; #[cxx_name = "getBackupUserKeysFilePath"] fn get_backup_user_keys_file_path(backup_id: &str) -> Result; #[cxx_name = "createMainCompaction"] fn create_main_compaction(backup_id: &str, future_id: usize); #[cxx_name = "restoreFromMainCompaction"] fn restore_from_main_compaction( main_compaction_path: &str, main_compaction_encryption_key: &str, future_id: usize, ); - #[allow(unused)] #[cxx_name = "restoreFromBackupLog"] - fn restore_from_backup_log(backup_log: Vec) -> Result<()>; + fn restore_from_backup_log(backup_log: Vec, future_id: usize); } // Future handling from C++ extern "Rust" { #[cxx_name = "resolveUnitFuture"] fn resolve_unit_future(future_id: usize); #[cxx_name = "rejectFuture"] fn reject_future(future_id: usize, error: String); } } fn handle_string_result_as_callback( result: Result, promise_id: u32, ) where E: std::fmt::Display, { match result { Err(e) => string_callback(e.to_string(), promise_id, "".to_string()), Ok(r) => string_callback("".to_string(), promise_id, r), } } fn handle_void_result_as_callback(result: Result<(), E>, promise_id: u32) where E: std::fmt::Display, { match result { Err(e) => void_callback(e.to_string(), promise_id), Ok(_) => void_callback("".to_string(), promise_id), } } fn handle_bool_result_as_callback(result: Result, promise_id: u32) where E: std::fmt::Display, { match result { Err(e) => bool_callback(e.to_string(), promise_id, false), Ok(r) => bool_callback("".to_string(), promise_id, r), } } fn generate_nonce(promise_id: u32) { RUNTIME.spawn(async move { let result = fetch_nonce().await; handle_string_result_as_callback(result, promise_id); }); } async fn fetch_nonce() -> Result { let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let nonce = identity_client .generate_nonce(Empty {}) .await? .into_inner() .nonce; Ok(nonce) } fn version_supported(promise_id: u32) { RUNTIME.spawn(async move { let result = version_supported_helper().await; handle_bool_result_as_callback(result, promise_id); }); } async fn version_supported_helper() -> Result { let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client.ping(Empty {}).await; match response { Ok(_) => Ok(true), Err(e) => { if grpc_clients::error::is_version_unsupported(&e) { Ok(false) } else { Err(e.into()) } } } } fn get_keyserver_keys( user_id: String, device_id: String, access_token: String, keyserver_id: String, promise_id: u32, ) { RUNTIME.spawn(async move { let get_keyserver_keys_request = OutboundKeysForUserRequest { user_id: keyserver_id, }; let auth_info = AuthInfo { access_token, user_id, device_id, }; let result = get_keyserver_keys_helper(get_keyserver_keys_request, auth_info).await; handle_string_result_as_callback(result, promise_id); }); } async fn get_keyserver_keys_helper( get_keyserver_keys_request: OutboundKeysForUserRequest, auth_info: AuthInfo, ) -> Result { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .get_keyserver_keys(get_keyserver_keys_request) .await? .into_inner(); let keyserver_keys = OutboundKeyInfoResponse::try_from(response)?; Ok(serde_json::to_string(&keyserver_keys)?) } struct AuthInfo { user_id: String, device_id: String, access_token: String, } #[instrument] fn register_password_user( username: String, password: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { let password_user_info = PasswordUserInfo { username, password, key_payload, key_payload_signature, content_prekey, content_prekey_signature, notif_prekey, notif_prekey_signature, content_one_time_keys, notif_one_time_keys, }; let result = register_password_user_helper(password_user_info).await; handle_string_result_as_callback(result, promise_id); }); } struct PasswordUserInfo { username: String, password: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, } #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct UserIDAndDeviceAccessToken { #[serde(rename = "userID")] user_id: String, access_token: String, } impl From for UserIDAndDeviceAccessToken { fn from(value: AuthResponse) -> Self { let AuthResponse { user_id, access_token, } = value; Self { user_id, access_token, } } } async fn register_password_user_helper( password_user_info: PasswordUserInfo, ) -> Result { let mut client_registration = Registration::new(); let opaque_registration_request = client_registration .start(&password_user_info.password) .map_err(handle_error)?; let registration_start_request = RegistrationStartRequest { opaque_registration_request, username: password_user_info.username, device_key_upload: Some(DeviceKeyUpload { device_key_info: Some(IdentityKeyInfo { payload: password_user_info.key_payload, payload_signature: password_user_info.key_payload_signature, social_proof: None, }), content_upload: Some(Prekey { prekey: password_user_info.content_prekey, prekey_signature: password_user_info.content_prekey_signature, }), notif_upload: Some(Prekey { prekey: password_user_info.notif_prekey, prekey_signature: password_user_info.notif_prekey_signature, }), one_time_content_prekeys: password_user_info.content_one_time_keys, one_time_notif_prekeys: password_user_info.notif_one_time_keys, device_type: DEVICE_TYPE.into(), }), }; let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .register_password_user_start(registration_start_request) .await?; // We need to get the load balancer cookie from from the response and send it // in the subsequent request to ensure it is routed to the same identity // service instance as the first request let cookie = response .metadata() .get(RESPONSE_METADATA_COOKIE_KEY) .cloned(); let registration_start_response = response.into_inner(); let opaque_registration_upload = client_registration .finish( &password_user_info.password, ®istration_start_response.opaque_registration_response, ) .map_err(handle_error)?; let registration_finish_request = RegistrationFinishRequest { session_id: registration_start_response.session_id, opaque_registration_upload, }; let mut finish_request = Request::new(registration_finish_request); // Cookie won't be available in local dev environments if let Some(cookie_metadata) = cookie { finish_request .metadata_mut() .insert(REQUEST_METADATA_COOKIE_KEY, cookie_metadata); } let registration_finish_response = identity_client .register_password_user_finish(finish_request) .await? .into_inner(); let user_id_and_access_token = UserIDAndDeviceAccessToken::from(registration_finish_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } #[instrument] fn log_in_password_user( username: String, password: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { let password_user_info = PasswordUserInfo { username, password, key_payload, key_payload_signature, content_prekey, content_prekey_signature, notif_prekey, notif_prekey_signature, content_one_time_keys, notif_one_time_keys, }; let result = log_in_password_user_helper(password_user_info).await; handle_string_result_as_callback(result, promise_id); }); } async fn log_in_password_user_helper( password_user_info: PasswordUserInfo, ) -> Result { let mut client_login = Login::new(); let opaque_login_request = client_login .start(&password_user_info.password) .map_err(handle_error)?; let login_start_request = OpaqueLoginStartRequest { opaque_login_request, username: password_user_info.username, device_key_upload: Some(DeviceKeyUpload { device_key_info: Some(IdentityKeyInfo { payload: password_user_info.key_payload, payload_signature: password_user_info.key_payload_signature, social_proof: None, }), content_upload: Some(Prekey { prekey: password_user_info.content_prekey, prekey_signature: password_user_info.content_prekey_signature, }), notif_upload: Some(Prekey { prekey: password_user_info.notif_prekey, prekey_signature: password_user_info.notif_prekey_signature, }), one_time_content_prekeys: password_user_info.content_one_time_keys, one_time_notif_prekeys: password_user_info.notif_one_time_keys, device_type: DEVICE_TYPE.into(), }), }; let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .log_in_password_user_start(login_start_request) .await?; // We need to get the load balancer cookie from from the response and send it // in the subsequent request to ensure it is routed to the same identity // service instance as the first request let cookie = response .metadata() .get(RESPONSE_METADATA_COOKIE_KEY) .cloned(); let login_start_response = response.into_inner(); let opaque_login_upload = client_login .finish(&login_start_response.opaque_login_response) .map_err(handle_error)?; let login_finish_request = OpaqueLoginFinishRequest { session_id: login_start_response.session_id, opaque_login_upload, }; let mut finish_request = Request::new(login_finish_request); // Cookie won't be available in local dev environments if let Some(cookie_metadata) = cookie { finish_request .metadata_mut() .insert(REQUEST_METADATA_COOKIE_KEY, cookie_metadata); } let login_finish_response = identity_client .log_in_password_user_finish(finish_request) .await? .into_inner(); let user_id_and_access_token = UserIDAndDeviceAccessToken::from(login_finish_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } struct WalletUserInfo { siwe_message: String, siwe_signature: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, } #[instrument] fn log_in_wallet_user( siwe_message: String, siwe_signature: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { let wallet_user_info = WalletUserInfo { siwe_message, siwe_signature, key_payload, key_payload_signature, content_prekey, content_prekey_signature, notif_prekey, notif_prekey_signature, content_one_time_keys, notif_one_time_keys, }; let result = log_in_wallet_user_helper(wallet_user_info).await; handle_string_result_as_callback(result, promise_id); }); } async fn log_in_wallet_user_helper( wallet_user_info: WalletUserInfo, ) -> Result { let login_request = WalletAuthRequest { siwe_message: wallet_user_info.siwe_message, siwe_signature: wallet_user_info.siwe_signature, device_key_upload: Some(DeviceKeyUpload { device_key_info: Some(IdentityKeyInfo { payload: wallet_user_info.key_payload, payload_signature: wallet_user_info.key_payload_signature, social_proof: None, // The SIWE message and signature are the social proof }), content_upload: Some(Prekey { prekey: wallet_user_info.content_prekey, prekey_signature: wallet_user_info.content_prekey_signature, }), notif_upload: Some(Prekey { prekey: wallet_user_info.notif_prekey, prekey_signature: wallet_user_info.notif_prekey_signature, }), one_time_content_prekeys: wallet_user_info.content_one_time_keys, one_time_notif_prekeys: wallet_user_info.notif_one_time_keys, device_type: DEVICE_TYPE.into(), }), }; let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let login_response = identity_client .log_in_wallet_user(login_request) .await? .into_inner(); let user_id_and_access_token = UserIDAndDeviceAccessToken::from(login_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } struct UpdatePasswordInfo { user_id: String, device_id: String, access_token: String, password: String, } fn update_user_password( user_id: String, device_id: String, access_token: String, password: String, promise_id: u32, ) { RUNTIME.spawn(async move { let update_password_info = UpdatePasswordInfo { access_token, user_id, device_id, password, }; let result = update_user_password_helper(update_password_info).await; handle_void_result_as_callback(result, promise_id); }); } async fn update_user_password_helper( update_password_info: UpdatePasswordInfo, ) -> Result<(), Error> { let mut client_registration = Registration::new(); let opaque_registration_request = client_registration .start(&update_password_info.password) .map_err(handle_error)?; let update_password_start_request = UpdateUserPasswordStartRequest { opaque_registration_request, }; let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, update_password_info.user_id, update_password_info.device_id, update_password_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .update_user_password_start(update_password_start_request) .await?; // We need to get the load balancer cookie from from the response and send it // in the subsequent request to ensure it is routed to the same identity // service instance as the first request let cookie = response .metadata() .get(RESPONSE_METADATA_COOKIE_KEY) .cloned(); let update_password_start_response = response.into_inner(); let opaque_registration_upload = client_registration .finish( &update_password_info.password, &update_password_start_response.opaque_registration_response, ) .map_err(handle_error)?; let update_password_finish_request = UpdateUserPasswordFinishRequest { session_id: update_password_start_response.session_id, opaque_registration_upload, }; let mut finish_request = Request::new(update_password_finish_request); // Cookie won't be available in local dev environments if let Some(cookie_metadata) = cookie { finish_request .metadata_mut() .insert(REQUEST_METADATA_COOKIE_KEY, cookie_metadata); } identity_client .update_user_password_finish(finish_request) .await?; Ok(()) } fn delete_user( user_id: String, device_id: String, access_token: String, promise_id: u32, ) { RUNTIME.spawn(async move { let auth_info = AuthInfo { access_token, user_id, device_id, }; let result = delete_user_helper(auth_info).await; handle_void_result_as_callback(result, promise_id); }); } async fn delete_user_helper(auth_info: AuthInfo) -> Result<(), Error> { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; identity_client.delete_user(Empty {}).await?; Ok(()) } struct GetOutboundKeysRequestInfo { user_id: String, } struct GetInboundKeysRequestInfo { user_id: String, } // This struct should not be altered without also updating // OutboundKeyInfoResponse in lib/types/identity-service-types.js #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct OutboundKeyInfoResponse { pub payload: String, pub payload_signature: String, pub social_proof: Option, pub content_prekey: String, pub content_prekey_signature: String, pub notif_prekey: String, pub notif_prekey_signature: String, pub one_time_content_prekey: Option, pub one_time_notif_prekey: Option, } // This struct should not be altered without also updating // InboundKeyInfoResponse in lib/types/identity-service-types.js #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct InboundKeyInfoResponse { pub payload: String, pub payload_signature: String, pub social_proof: Option, pub content_prekey: String, pub content_prekey_signature: String, pub notif_prekey: String, pub notif_prekey_signature: String, } impl TryFrom for OutboundKeyInfoResponse { type Error = Error; fn try_from(key_info: OutboundKeyInfo) -> Result { let identity_info = key_info.identity_info.ok_or(Error::MissingResponseData)?; let IdentityKeyInfo { payload, payload_signature, social_proof, } = identity_info; let content_prekey = key_info.content_prekey.ok_or(Error::MissingResponseData)?; let Prekey { prekey: content_prekey_value, prekey_signature: content_prekey_signature, } = content_prekey; let notif_prekey = key_info.notif_prekey.ok_or(Error::MissingResponseData)?; let Prekey { prekey: notif_prekey_value, prekey_signature: notif_prekey_signature, } = notif_prekey; let one_time_content_prekey = key_info.one_time_content_prekey; let one_time_notif_prekey = key_info.one_time_notif_prekey; Ok(Self { payload, payload_signature, social_proof, content_prekey: content_prekey_value, content_prekey_signature, notif_prekey: notif_prekey_value, notif_prekey_signature, one_time_content_prekey, one_time_notif_prekey, }) } } impl TryFrom for OutboundKeyInfoResponse { type Error = Error; fn try_from(response: KeyserverKeysResponse) -> Result { let key_info = response.keyserver_info.ok_or(Error::MissingResponseData)?; Self::try_from(key_info) } } fn get_outbound_keys_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, promise_id: u32, ) { RUNTIME.spawn(async move { let get_outbound_keys_request_info = GetOutboundKeysRequestInfo { user_id }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let result = get_outbound_keys_for_user_helper( get_outbound_keys_request_info, auth_info, ) .await; handle_string_result_as_callback(result, promise_id); }); } async fn get_outbound_keys_for_user_helper( get_outbound_keys_request_info: GetOutboundKeysRequestInfo, auth_info: AuthInfo, ) -> Result { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .get_outbound_keys_for_user(OutboundKeysForUserRequest { user_id: get_outbound_keys_request_info.user_id, }) .await? .into_inner(); let outbound_key_info: Vec = response .devices .into_values() .map(OutboundKeyInfoResponse::try_from) .collect::, _>>()?; Ok(serde_json::to_string(&outbound_key_info)?) } impl TryFrom for InboundKeyInfoResponse { type Error = Error; fn try_from(key_info: InboundKeyInfo) -> Result { let identity_info = key_info.identity_info.ok_or(Error::MissingResponseData)?; let IdentityKeyInfo { payload, payload_signature, social_proof, } = identity_info; let content_prekey = key_info.content_prekey.ok_or(Error::MissingResponseData)?; let Prekey { prekey: content_prekey_value, prekey_signature: content_prekey_signature, } = content_prekey; let notif_prekey = key_info.notif_prekey.ok_or(Error::MissingResponseData)?; let Prekey { prekey: notif_prekey_value, prekey_signature: notif_prekey_signature, } = notif_prekey; Ok(Self { payload, payload_signature, social_proof, content_prekey: content_prekey_value, content_prekey_signature, notif_prekey: notif_prekey_value, notif_prekey_signature, }) } } fn get_inbound_keys_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, promise_id: u32, ) { RUNTIME.spawn(async move { let get_inbound_keys_request_info = GetInboundKeysRequestInfo { user_id }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let result = get_inbound_keys_for_user_helper( get_inbound_keys_request_info, auth_info, ) .await; handle_string_result_as_callback(result, promise_id); }); } async fn get_inbound_keys_for_user_helper( get_inbound_keys_request_info: GetInboundKeysRequestInfo, auth_info: AuthInfo, ) -> Result { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .get_inbound_keys_for_user(InboundKeysForUserRequest { user_id: get_inbound_keys_request_info.user_id, }) .await? .into_inner(); let inbound_key_info: Vec = response .devices .into_values() .map(InboundKeyInfoResponse::try_from) .collect::, _>>()?; Ok(serde_json::to_string(&inbound_key_info)?) } fn refresh_user_prekeys( auth_user_id: String, auth_device_id: String, auth_access_token: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, promise_id: u32, ) { RUNTIME.spawn(async move { let refresh_request = RefreshUserPrekeysRequest { new_content_prekeys: Some(Prekey { prekey: content_prekey, prekey_signature: content_prekey_signature, }), new_notif_prekeys: Some(Prekey { prekey: notif_prekey, prekey_signature: notif_prekey_signature, }), }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let result = refresh_user_prekeys_helper(refresh_request, auth_info).await; handle_void_result_as_callback(result, promise_id); }); } async fn refresh_user_prekeys_helper( refresh_request: RefreshUserPrekeysRequest, auth_info: AuthInfo, ) -> Result<(), Error> { get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await? .refresh_user_prekeys(refresh_request) .await?; Ok(()) } #[instrument] fn upload_one_time_keys( auth_user_id: String, auth_device_id: String, auth_access_token: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { let upload_request = UploadOneTimeKeysRequest { content_one_time_prekeys: content_one_time_keys, notif_one_time_prekeys: notif_one_time_keys, }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let result = upload_one_time_keys_helper(auth_info, upload_request).await; handle_void_result_as_callback(result, promise_id); }); } async fn upload_one_time_keys_helper( auth_info: AuthInfo, upload_request: UploadOneTimeKeysRequest, ) -> Result<(), Error> { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; identity_client.upload_one_time_keys(upload_request).await?; Ok(()) } fn get_device_list_for_user( auth_user_id: String, auth_device_id: String, auth_access_token: String, user_id: String, since_timestamp: i64, promise_id: u32, ) { RUNTIME.spawn(async move { let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let since_timestamp = Option::from(since_timestamp).filter(|&t| t > 0); let result = get_device_list_for_user_helper(auth_info, user_id, since_timestamp) .await; handle_string_result_as_callback(result, promise_id); }); } async fn get_device_list_for_user_helper( auth_info: AuthInfo, user_id: String, since_timestamp: Option, ) -> Result { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let response = identity_client .get_device_list_for_user(GetDeviceListRequest { user_id, since_timestamp, }) .await? .into_inner(); let payload = serde_json::to_string(&response.device_list_updates)?; Ok(payload) } fn update_device_list( auth_user_id: String, auth_device_id: String, auth_access_token: String, update_payload: String, promise_id: u32, ) { RUNTIME.spawn(async move { let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, device_id: auth_device_id, }; let result = update_device_list_helper(auth_info, update_payload).await; handle_void_result_as_callback(result, promise_id); }); } async fn update_device_list_helper( auth_info: AuthInfo, update_payload: String, ) -> Result<(), Error> { let mut identity_client = get_auth_client( IDENTITY_SOCKET_ADDR, auth_info.user_id, auth_info.device_id, auth_info.access_token, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let update_request = UpdateDeviceListRequest { new_device_list: update_payload, }; identity_client .update_device_list_for_user(update_request) .await?; Ok(()) } fn upload_secondary_device_keys_and_log_in( user_id: String, challenge_response: String, key_payload: String, key_payload_signature: String, content_prekey: String, content_prekey_signature: String, notif_prekey: String, notif_prekey_signature: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { let device_key_upload = DeviceKeyUpload { device_key_info: Some(IdentityKeyInfo { payload: key_payload, payload_signature: key_payload_signature, social_proof: None, }), content_upload: Some(Prekey { prekey: content_prekey, prekey_signature: content_prekey_signature, }), notif_upload: Some(Prekey { prekey: notif_prekey, prekey_signature: notif_prekey_signature, }), one_time_content_prekeys: content_one_time_keys, one_time_notif_prekeys: notif_one_time_keys, device_type: DEVICE_TYPE.into(), }; let result = upload_secondary_device_keys_and_log_in_helper( user_id, challenge_response, device_key_upload, ) .await; handle_string_result_as_callback(result, promise_id); }); } async fn upload_secondary_device_keys_and_log_in_helper( user_id: String, challenge_response: String, device_key_upload: DeviceKeyUpload, ) -> Result { let mut identity_client = get_unauthenticated_client( IDENTITY_SOCKET_ADDR, CODE_VERSION, DEVICE_TYPE.as_str_name().to_lowercase(), ) .await?; let request = SecondaryDeviceKeysUploadRequest { user_id, challenge_response, device_key_upload: Some(device_key_upload), }; let response = identity_client .upload_keys_for_registered_device_and_log_in(request) .await? .into_inner(); let user_id_and_access_token = UserIDAndDeviceAccessToken::from(response); Ok(serde_json::to_string(&user_id_and_access_token)?) } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(fmt = "{}", "_0.message()")] TonicGRPC(Status), #[display(fmt = "{}", "_0")] SerdeJson(serde_json::Error), #[display(fmt = "Missing response data")] MissingResponseData, #[display(fmt = "{}", "_0")] GRPClient(grpc_clients::error::Error), } #[cfg(test)] mod tests { use super::{BACKUP_SOCKET_ADDR, CODE_VERSION, IDENTITY_SOCKET_ADDR}; #[test] fn test_code_version_exists() { assert!(CODE_VERSION > 0); } #[test] fn test_identity_socket_addr_exists() { assert!(IDENTITY_SOCKET_ADDR.len() > 0); assert!(BACKUP_SOCKET_ADDR.len() > 0); } }