diff --git a/native/backup/restore-siwe-backup.react.js b/native/backup/restore-siwe-backup.react.js
index 623daba2c..f4e10d58f 100644
--- a/native/backup/restore-siwe-backup.react.js
+++ b/native/backup/restore-siwe-backup.react.js
@@ -1,78 +1,83 @@
// @flow
import * as React from 'react';
import { Alert } from 'react-native';
import { SafeAreaView } from 'react-native-safe-area-context';
import { type SIWEResult } from 'lib/types/siwe-types.js';
import { getMessageForException } from 'lib/utils/errors.js';
import { SignSIWEBackupMessageForRestore } from '../account/registration/siwe-backup-message-creation.react.js';
import { commCoreModule } from '../native-modules.js';
import { type RootNavigationProp } from '../navigation/root-navigator.react.js';
import { type NavigationRoute } from '../navigation/route-names.js';
+import { persistConfig } from '../redux/persist.js';
import { useStyles } from '../themes/colors.js';
export type RestoreSIWEBackupParams = {
+backupID: string,
+siweNonce: string,
+siweStatement: string,
+siweIssuedAt: string,
};
type Props = {
+navigation: RootNavigationProp<'RestoreSIWEBackup'>,
+route: NavigationRoute<'RestoreSIWEBackup'>,
};
function RestoreSIWEBackup(props: Props): React.Node {
const styles = useStyles(unboundStyles);
const { goBack } = props.navigation;
const { route } = props;
const {
params: { backupID, siweStatement, siweIssuedAt, siweNonce },
} = route;
const onSuccessfulWalletSignature = React.useCallback(
(result: SIWEResult) => {
void (async () => {
const { signature } = result;
let message = 'success';
try {
- await commCoreModule.restoreSIWEBackup(signature, backupID);
+ await commCoreModule.restoreSIWEBackup(
+ signature,
+ backupID,
+ persistConfig.version.toString(),
+ );
} catch (e) {
message = `Backup restore error: ${String(
getMessageForException(e),
)}`;
console.error(message);
}
Alert.alert('Restore protocol result', message);
goBack();
})();
},
[goBack, backupID],
);
return (
);
}
const safeAreaEdges = ['top'];
const unboundStyles = {
container: {
flex: 1,
backgroundColor: 'panelBackground',
justifyContent: 'space-between',
},
};
export default RestoreSIWEBackup;
diff --git a/native/backup/use-client-backup.js b/native/backup/use-client-backup.js
index c1939a6ef..e56462482 100644
--- a/native/backup/use-client-backup.js
+++ b/native/backup/use-client-backup.js
@@ -1,146 +1,150 @@
// @flow
import * as React from 'react';
import { isLoggedIn } from 'lib/selectors/user-selectors.js';
import { accountHasPassword } from 'lib/shared/account-utils.js';
import type { SIWEBackupSecrets } from 'lib/types/siwe-types.js';
import { getContentSigningKey } from 'lib/utils/crypto-utils.js';
import { fetchNativeKeychainCredentials } from '../account/native-credentials.js';
import { commCoreModule } from '../native-modules.js';
+import { persistConfig } from '../redux/persist.js';
import { useSelector } from '../redux/redux-utils.js';
type SIWEBackupData = {
+backupID: string,
+siweBackupMsg: string,
+siweBackupMsgNonce: string,
+siweBackupMsgStatement: string,
+siweBackupMsgIssuedAt: string,
};
type ClientBackup = {
+uploadBackupProtocol: () => Promise,
+restorePasswordUserBackupProtocol: () => Promise,
+retrieveLatestSIWEBackupData: () => Promise,
};
async function getBackupSecret(): Promise {
const nativeCredentials = await fetchNativeKeychainCredentials();
if (!nativeCredentials) {
throw new Error('Native credentials are missing');
}
return nativeCredentials.password;
}
async function getSIWEBackupSecrets(): Promise {
const siweBackupSecrets = await commCoreModule.getSIWEBackupSecrets();
if (!siweBackupSecrets) {
throw new Error('SIWE backup message and its signature are missing');
}
return siweBackupSecrets;
}
function useClientBackup(): ClientBackup {
const accessToken = useSelector(state => state.commServicesAccessToken);
const currentUserID = useSelector(
state => state.currentUserInfo && state.currentUserInfo.id,
);
const currentUserInfo = useSelector(state => state.currentUserInfo);
const loggedIn = useSelector(isLoggedIn);
const setMockCommServicesAuthMetadata = React.useCallback(async () => {
if (!currentUserID) {
return;
}
const ed25519 = await getContentSigningKey();
await commCoreModule.setCommServicesAuthMetadata(
currentUserID,
ed25519,
accessToken ? accessToken : '',
);
}, [accessToken, currentUserID]);
const uploadBackupProtocol = React.useCallback(async () => {
if (!loggedIn || !currentUserID) {
throw new Error('Attempt to upload backup for not logged in user.');
}
console.info('Start uploading backup...');
await setMockCommServicesAuthMetadata();
if (accountHasPassword(currentUserInfo)) {
const backupSecret = await getBackupSecret();
await commCoreModule.createNewBackup(backupSecret);
} else {
const { message, signature } = await getSIWEBackupSecrets();
await commCoreModule.createNewSIWEBackup(signature, message);
}
console.info('Backup uploaded.');
}, [
currentUserID,
loggedIn,
setMockCommServicesAuthMetadata,
currentUserInfo,
]);
const restorePasswordUserBackupProtocol = React.useCallback(async () => {
if (!loggedIn || !currentUserID) {
throw new Error('Attempt to restore backup for not logged in user.');
}
if (!accountHasPassword(currentUserInfo)) {
throw new Error(
'Attempt to restore from password for non-password user.',
);
}
console.info('Start restoring backup...');
await setMockCommServicesAuthMetadata();
const backupSecret = await getBackupSecret();
- await commCoreModule.restoreBackup(backupSecret);
+ await commCoreModule.restoreBackup(
+ backupSecret,
+ persistConfig.version.toString(),
+ );
console.info('Backup restored.');
return;
}, [
currentUserID,
loggedIn,
setMockCommServicesAuthMetadata,
currentUserInfo,
]);
const retrieveLatestSIWEBackupData = React.useCallback(async () => {
if (!loggedIn || !currentUserID) {
throw new Error('Attempt to restore backup for not logged in user.');
}
if (accountHasPassword(currentUserInfo)) {
throw new Error(
'Attempt to retrieve siwe backup data for password user.',
);
}
await setMockCommServicesAuthMetadata();
const serializedBackupData =
await commCoreModule.retrieveLatestSIWEBackupData();
const siweBackupData: SIWEBackupData = JSON.parse(serializedBackupData);
return siweBackupData;
}, [
currentUserID,
currentUserInfo,
loggedIn,
setMockCommServicesAuthMetadata,
]);
return {
uploadBackupProtocol,
restorePasswordUserBackupProtocol,
retrieveLatestSIWEBackupData,
};
}
export { getBackupSecret, useClientBackup };
diff --git a/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h b/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h
index 06ed84997..bc8d60c1d 100644
--- a/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h
+++ b/native/cpp/CommonCpp/DatabaseManagers/DatabaseQueryExecutor.h
@@ -1,175 +1,176 @@
#pragma once
#include "../CryptoTools/Persist.h"
#include "entities/AuxUserInfo.h"
#include "entities/CommunityInfo.h"
#include "entities/Draft.h"
#include "entities/EntryInfo.h"
#include "entities/InboundP2PMessage.h"
#include "entities/IntegrityThreadHash.h"
#include "entities/KeyserverInfo.h"
#include "entities/Message.h"
#include "entities/MessageStoreThread.h"
#include "entities/OlmPersistAccount.h"
#include "entities/OlmPersistSession.h"
#include "entities/OutboundP2PMessage.h"
#include "entities/PersistItem.h"
#include "entities/Report.h"
#include "entities/SyncedMetadataEntry.h"
#include "entities/Thread.h"
#include "entities/ThreadActivityEntry.h"
#include "entities/UserInfo.h"
#include
namespace comm {
/**
* if any initialization/cleaning up steps are required for specific
* database managers they should appear in constructors/destructors
* following the RAII pattern
*/
class DatabaseQueryExecutor {
public:
virtual std::string getDraft(std::string key) const = 0;
virtual std::unique_ptr getThread(std::string threadID) const = 0;
virtual void updateDraft(std::string key, std::string text) const = 0;
virtual bool moveDraft(std::string oldKey, std::string newKey) const = 0;
virtual std::vector getAllDrafts() const = 0;
virtual void removeAllDrafts() const = 0;
virtual void removeDrafts(const std::vector &ids) const = 0;
virtual void removeAllMessages() const = 0;
virtual std::vector>>
getAllMessages() const = 0;
virtual void removeMessages(const std::vector &ids) const = 0;
virtual void
removeMessagesForThreads(const std::vector &threadIDs) const = 0;
virtual void replaceMessage(const Message &message) const = 0;
virtual void rekeyMessage(std::string from, std::string to) const = 0;
virtual void removeAllMedia() const = 0;
virtual void replaceMessageStoreThreads(
const std::vector &threads) const = 0;
virtual void
removeMessageStoreThreads(const std::vector &ids) const = 0;
virtual void removeAllMessageStoreThreads() const = 0;
virtual std::vector getAllMessageStoreThreads() const = 0;
virtual void
removeMediaForMessages(const std::vector &msg_ids) const = 0;
virtual void removeMediaForMessage(std::string msg_id) const = 0;
virtual void
removeMediaForThreads(const std::vector &thread_ids) const = 0;
virtual void replaceMedia(const Media &media) const = 0;
virtual void rekeyMediaContainers(std::string from, std::string to) const = 0;
virtual std::vector getAllThreads() const = 0;
virtual void removeThreads(std::vector ids) const = 0;
virtual void replaceThread(const Thread &thread) const = 0;
virtual void removeAllThreads() const = 0;
virtual void replaceReport(const Report &report) const = 0;
virtual void removeReports(const std::vector &ids) const = 0;
virtual void removeAllReports() const = 0;
virtual std::vector getAllReports() const = 0;
virtual void
setPersistStorageItem(std::string key, std::string item) const = 0;
virtual void removePersistStorageItem(std::string key) const = 0;
virtual std::string getPersistStorageItem(std::string key) const = 0;
virtual void replaceUser(const UserInfo &user_info) const = 0;
virtual void removeUsers(const std::vector &ids) const = 0;
virtual void removeAllUsers() const = 0;
virtual std::vector getAllUsers() const = 0;
virtual void replaceKeyserver(const KeyserverInfo &keyserver_info) const = 0;
virtual void removeKeyservers(const std::vector &ids) const = 0;
virtual void removeAllKeyservers() const = 0;
virtual std::vector getAllKeyservers() const = 0;
virtual void replaceCommunity(const CommunityInfo &community_info) const = 0;
virtual void removeCommunities(const std::vector &ids) const = 0;
virtual void removeAllCommunities() const = 0;
virtual std::vector getAllCommunities() const = 0;
virtual void replaceIntegrityThreadHashes(
const std::vector &thread_hashes) const = 0;
virtual void
removeIntegrityThreadHashes(const std::vector &ids) const = 0;
virtual void removeAllIntegrityThreadHashes() const = 0;
virtual std::vector
getAllIntegrityThreadHashes() const = 0;
virtual void replaceSyncedMetadataEntry(
const SyncedMetadataEntry &synced_metadata_entry) const = 0;
virtual void
removeSyncedMetadata(const std::vector &names) const = 0;
virtual void removeAllSyncedMetadata() const = 0;
virtual std::vector getAllSyncedMetadata() const = 0;
virtual void replaceAuxUserInfo(const AuxUserInfo &aux_user_info) const = 0;
virtual void
removeAuxUserInfos(const std::vector &ids) const = 0;
virtual void removeAllAuxUserInfos() const = 0;
virtual std::vector getAllAuxUserInfos() const = 0;
virtual void replaceThreadActivityEntry(
const ThreadActivityEntry &thread_activity_entry) const = 0;
virtual void
removeThreadActivityEntries(const std::vector &ids) const = 0;
virtual void removeAllThreadActivityEntries() const = 0;
virtual std::vector
getAllThreadActivityEntries() const = 0;
virtual void replaceEntry(const EntryInfo &entry_info) const = 0;
virtual void removeEntries(const std::vector &ids) const = 0;
virtual void removeAllEntries() const = 0;
virtual std::vector getAllEntries() const = 0;
virtual void beginTransaction() const = 0;
virtual void commitTransaction() const = 0;
virtual void rollbackTransaction() const = 0;
virtual int getContentAccountID() const = 0;
virtual int getNotifsAccountID() const = 0;
virtual std::vector getOlmPersistSessionsData() const = 0;
virtual std::optional
getOlmPersistAccountData(int accountID) const = 0;
virtual void
storeOlmPersistSession(const OlmPersistSession &session) const = 0;
virtual void storeOlmPersistAccount(
int accountID,
const std::string &accountData) const = 0;
virtual void
storeOlmPersistData(int accountID, crypto::Persist persist) const = 0;
virtual void setNotifyToken(std::string token) const = 0;
virtual void clearNotifyToken() const = 0;
virtual void stampSQLiteDBUserID(std::string userID) const = 0;
virtual std::string getSQLiteStampedUserID() const = 0;
virtual void setMetadata(std::string entry_name, std::string data) const = 0;
virtual void clearMetadata(std::string entry_name) const = 0;
virtual std::string getMetadata(std::string entry_name) const = 0;
virtual void restoreFromMainCompaction(
std::string mainCompactionPath,
- std::string mainCompactionEncryptionKey) const = 0;
+ std::string mainCompactionEncryptionKey,
+ std::string maxVersion) const = 0;
virtual void
restoreFromBackupLog(const std::vector &backupLog) const = 0;
virtual void addOutboundP2PMessages(
const std::vector &messages) const = 0;
virtual std::vector getAllOutboundP2PMessages() const = 0;
virtual void removeOutboundP2PMessagesOlderThan(
std::string lastConfirmedMessageID,
std::string deviceID) const = 0;
virtual void
removeAllOutboundP2PMessages(const std::string &deviceID) const = 0;
virtual void setCiphertextForOutboundP2PMessage(
std::string messageID,
std::string deviceID,
std::string ciphertext) const = 0;
virtual void markOutboundP2PMessageAsSent(
std::string messageID,
std::string deviceID) const = 0;
virtual void addInboundP2PMessage(InboundP2PMessage message) const = 0;
virtual std::vector getAllInboundP2PMessage() const = 0;
virtual void
removeInboundP2PMessages(const std::vector &ids) const = 0;
#ifdef EMSCRIPTEN
virtual std::vector getAllThreadsWeb() const = 0;
virtual void replaceThreadWeb(const WebThread &thread) const = 0;
virtual std::vector getAllMessagesWeb() const = 0;
virtual void replaceMessageWeb(const WebMessage &message) const = 0;
virtual NullableString getOlmPersistAccountDataWeb(int accountID) const = 0;
#else
virtual void createMainCompaction(std::string backupID) const = 0;
virtual void captureBackupLogs() const = 0;
#endif
};
} // namespace comm
diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp
index c57a4b6a4..c8fbf1e89 100644
--- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp
+++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.cpp
@@ -1,2652 +1,2678 @@
#include "SQLiteQueryExecutor.h"
#include "Logger.h"
#include "entities/CommunityInfo.h"
#include "entities/EntityQueryHelpers.h"
#include "entities/EntryInfo.h"
#include "entities/IntegrityThreadHash.h"
#include "entities/KeyserverInfo.h"
#include "entities/Metadata.h"
#include "entities/SyncedMetadataEntry.h"
#include "entities/UserInfo.h"
#include
#include
#include
#ifndef EMSCRIPTEN
#include "../CryptoTools/CryptoModule.h"
#include "CommSecureStore.h"
#include "PlatformSpecificTools.h"
#include "StaffUtils.h"
#endif
const int CONTENT_ACCOUNT_ID = 1;
const int NOTIFS_ACCOUNT_ID = 2;
namespace comm {
std::string SQLiteQueryExecutor::sqliteFilePath;
std::string SQLiteQueryExecutor::encryptionKey;
std::once_flag SQLiteQueryExecutor::initialized;
int SQLiteQueryExecutor::sqlcipherEncryptionKeySize = 64;
// Should match constant defined in `native_rust_library/src/constants.rs`
std::string SQLiteQueryExecutor::secureStoreEncryptionKeyID =
"comm.encryptionKey";
int SQLiteQueryExecutor::backupLogsEncryptionKeySize = 32;
std::string SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID =
"comm.backupLogsEncryptionKey";
std::string SQLiteQueryExecutor::backupLogsEncryptionKey;
#ifndef EMSCRIPTEN
NativeSQLiteConnectionManager SQLiteQueryExecutor::connectionManager;
std::unordered_set SQLiteQueryExecutor::backedUpTablesBlocklist = {
"olm_persist_account",
"olm_persist_sessions",
"metadata",
"outbound_p2p_messages",
"received_messages_to_device",
"integrity_store",
"persist_storage",
"keyservers",
};
#else
SQLiteConnectionManager SQLiteQueryExecutor::connectionManager;
#endif
bool create_table(sqlite3 *db, std::string query, std::string tableName) {
char *error;
sqlite3_exec(db, query.c_str(), nullptr, nullptr, &error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error creating '" << tableName << "' table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_drafts_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS drafts (threadID TEXT UNIQUE PRIMARY KEY, "
"text TEXT);";
return create_table(db, query, "drafts");
}
bool rename_threadID_to_key(sqlite3 *db) {
sqlite3_stmt *key_column_stmt;
sqlite3_prepare_v2(
db,
"SELECT name AS col_name FROM pragma_table_xinfo ('drafts') WHERE "
"col_name='key';",
-1,
&key_column_stmt,
nullptr);
sqlite3_step(key_column_stmt);
auto num_bytes = sqlite3_column_bytes(key_column_stmt, 0);
sqlite3_finalize(key_column_stmt);
if (num_bytes) {
return true;
}
char *error;
sqlite3_exec(
db,
"ALTER TABLE drafts RENAME COLUMN `threadID` TO `key`;",
nullptr,
nullptr,
&error);
if (error) {
std::ostringstream stringStream;
stringStream << "Error occurred renaming threadID column in drafts table "
<< "to key: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
return true;
}
bool create_persist_account_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS olm_persist_account("
"id INTEGER UNIQUE PRIMARY KEY NOT NULL, "
"account_data TEXT NOT NULL);";
return create_table(db, query, "olm_persist_account");
}
bool create_persist_sessions_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS olm_persist_sessions("
"target_user_id TEXT UNIQUE PRIMARY KEY NOT NULL, "
"session_data TEXT NOT NULL);";
return create_table(db, query, "olm_persist_sessions");
}
bool drop_messages_table(sqlite3 *db) {
char *error;
sqlite3_exec(db, "DROP TABLE IF EXISTS messages;", nullptr, nullptr, &error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error dropping 'messages' table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool recreate_messages_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS messages ( "
"id TEXT UNIQUE PRIMARY KEY NOT NULL, "
"local_id TEXT, "
"thread TEXT NOT NULL, "
"user TEXT NOT NULL, "
"type INTEGER NOT NULL, "
"future_type INTEGER, "
"content TEXT, "
"time INTEGER NOT NULL);";
return create_table(db, query, "messages");
}
bool create_messages_idx_thread_time(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"CREATE INDEX IF NOT EXISTS messages_idx_thread_time "
"ON messages (thread, time);",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error creating (thread, time) index on messages table: "
<< error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_media_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS media ( "
"id TEXT UNIQUE PRIMARY KEY NOT NULL, "
"container TEXT NOT NULL, "
"thread TEXT NOT NULL, "
"uri TEXT NOT NULL, "
"type TEXT NOT NULL, "
"extras TEXT NOT NULL);";
return create_table(db, query, "media");
}
bool create_media_idx_container(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"CREATE INDEX IF NOT EXISTS media_idx_container "
"ON media (container);",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error creating (container) index on media table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_threads_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS threads ( "
"id TEXT UNIQUE PRIMARY KEY NOT NULL, "
"type INTEGER NOT NULL, "
"name TEXT, "
"description TEXT, "
"color TEXT NOT NULL, "
"creation_time BIGINT NOT NULL, "
"parent_thread_id TEXT, "
"containing_thread_id TEXT, "
"community TEXT, "
"members TEXT NOT NULL, "
"roles TEXT NOT NULL, "
"current_user TEXT NOT NULL, "
"source_message_id TEXT, "
"replies_count INTEGER NOT NULL);";
return create_table(db, query, "threads");
}
bool update_threadID_for_pending_threads_in_drafts(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"UPDATE drafts SET key = "
"REPLACE(REPLACE(REPLACE(REPLACE(key, 'type4/', ''),"
"'type5/', ''),'type6/', ''),'type7/', '')"
"WHERE key LIKE 'pending/%'",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error update pending threadIDs on drafts table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool enable_write_ahead_logging_mode(sqlite3 *db) {
char *error;
sqlite3_exec(db, "PRAGMA journal_mode=wal;", nullptr, nullptr, &error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error enabling write-ahead logging mode: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_metadata_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS metadata ( "
"name TEXT UNIQUE PRIMARY KEY NOT NULL, "
"data TEXT);";
return create_table(db, query, "metadata");
}
bool add_not_null_constraint_to_drafts(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"CREATE TABLE IF NOT EXISTS temporary_drafts ("
"key TEXT UNIQUE PRIMARY KEY NOT NULL, "
"text TEXT NOT NULL);"
"INSERT INTO temporary_drafts SELECT * FROM drafts "
"WHERE key IS NOT NULL AND text IS NOT NULL;"
"DROP TABLE drafts;"
"ALTER TABLE temporary_drafts RENAME TO drafts;",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error adding NOT NULL constraint to drafts table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool add_not_null_constraint_to_metadata(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"CREATE TABLE IF NOT EXISTS temporary_metadata ("
"name TEXT UNIQUE PRIMARY KEY NOT NULL, "
"data TEXT NOT NULL);"
"INSERT INTO temporary_metadata SELECT * FROM metadata "
"WHERE data IS NOT NULL;"
"DROP TABLE metadata;"
"ALTER TABLE temporary_metadata RENAME TO metadata;",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error adding NOT NULL constraint to metadata table: "
<< error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool add_avatar_column_to_threads_table(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"ALTER TABLE threads ADD COLUMN avatar TEXT;",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error adding avatar column to threads table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool add_pinned_count_column_to_threads(sqlite3 *db) {
sqlite3_stmt *pinned_column_stmt;
sqlite3_prepare_v2(
db,
"SELECT name AS col_name FROM pragma_table_xinfo ('threads') WHERE "
"col_name='pinned_count';",
-1,
&pinned_column_stmt,
nullptr);
sqlite3_step(pinned_column_stmt);
auto num_bytes = sqlite3_column_bytes(pinned_column_stmt, 0);
sqlite3_finalize(pinned_column_stmt);
if (num_bytes) {
return true;
}
char *error;
sqlite3_exec(
db,
"ALTER TABLE threads ADD COLUMN pinned_count INTEGER NOT NULL DEFAULT 0;",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error adding pinned_count column to threads table: "
<< error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_message_store_threads_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS message_store_threads ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" start_reached INTEGER NOT NULL,"
" last_navigated_to BIGINT NOT NULL,"
" last_pruned BIGINT NOT NULL"
");";
return create_table(db, query, "message_store_threads");
}
bool create_reports_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS reports ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" report TEXT NOT NULL"
");";
return create_table(db, query, "reports");
}
bool create_persist_storage_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS persist_storage ("
" key TEXT UNIQUE PRIMARY KEY NOT NULL,"
" item TEXT NOT NULL"
");";
return create_table(db, query, "persist_storage");
}
bool recreate_message_store_threads_table(sqlite3 *db) {
char *errMsg = 0;
// 1. Create table without `last_navigated_to` or `last_pruned`.
std::string create_new_table_query =
"CREATE TABLE IF NOT EXISTS temp_message_store_threads ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" start_reached INTEGER NOT NULL"
");";
if (sqlite3_exec(db, create_new_table_query.c_str(), NULL, NULL, &errMsg) !=
SQLITE_OK) {
Logger::log(
"Error creating temp_message_store_threads: " + std::string{errMsg});
sqlite3_free(errMsg);
return false;
}
// 2. Dump data from existing `message_store_threads` table into temp table.
std::string copy_data_query =
"INSERT INTO temp_message_store_threads (id, start_reached)"
"SELECT id, start_reached FROM message_store_threads;";
if (sqlite3_exec(db, copy_data_query.c_str(), NULL, NULL, &errMsg) !=
SQLITE_OK) {
Logger::log(
"Error dumping data from existing message_store_threads to "
"temp_message_store_threads: " +
std::string{errMsg});
sqlite3_free(errMsg);
return false;
}
// 3. Drop the existing `message_store_threads` table.
std::string drop_old_table_query = "DROP TABLE message_store_threads;";
if (sqlite3_exec(db, drop_old_table_query.c_str(), NULL, NULL, &errMsg) !=
SQLITE_OK) {
Logger::log(
"Error dropping message_store_threads table: " + std::string{errMsg});
sqlite3_free(errMsg);
return false;
}
// 4. Rename the temp table back to `message_store_threads`.
std::string rename_table_query =
"ALTER TABLE temp_message_store_threads RENAME TO message_store_threads;";
if (sqlite3_exec(db, rename_table_query.c_str(), NULL, NULL, &errMsg) !=
SQLITE_OK) {
Logger::log(
"Error renaming temp_message_store_threads to message_store_threads: " +
std::string{errMsg});
sqlite3_free(errMsg);
return false;
}
return true;
}
bool create_users_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS users ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" user_info TEXT NOT NULL"
");";
return create_table(db, query, "users");
}
bool create_keyservers_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS keyservers ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" keyserver_info TEXT NOT NULL"
");";
return create_table(db, query, "keyservers");
}
bool enable_rollback_journal_mode(sqlite3 *db) {
char *error;
sqlite3_exec(db, "PRAGMA journal_mode=DELETE;", nullptr, nullptr, &error);
if (!error) {
return true;
}
std::stringstream error_message;
error_message << "Error disabling write-ahead logging mode: " << error;
Logger::log(error_message.str());
sqlite3_free(error);
return false;
}
bool create_communities_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS communities ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" community_info TEXT NOT NULL"
");";
return create_table(db, query, "communities");
}
bool create_messages_to_device_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS messages_to_device ("
" message_id TEXT NOT NULL,"
" device_id TEXT NOT NULL,"
" user_id TEXT NOT NULL,"
" timestamp BIGINT NOT NULL,"
" plaintext TEXT NOT NULL,"
" ciphertext TEXT NOT NULL,"
" PRIMARY KEY (message_id, device_id)"
");"
"CREATE INDEX IF NOT EXISTS messages_to_device_idx_id_timestamp"
" ON messages_to_device (device_id, timestamp);";
return create_table(db, query, "messages_to_device");
}
bool create_integrity_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS integrity_store ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" thread_hash TEXT NOT NULL"
");";
return create_table(db, query, "integrity_store");
}
bool create_synced_metadata_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS synced_metadata ("
" name TEXT UNIQUE PRIMARY KEY NOT NULL,"
" data TEXT NOT NULL"
");";
return create_table(db, query, "synced_metadata");
}
bool create_keyservers_synced(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS keyservers_synced ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" keyserver_info TEXT NOT NULL"
");";
bool success = create_table(db, query, "keyservers_synced");
if (!success) {
return false;
}
std::string copyData =
"INSERT INTO keyservers_synced (id, keyserver_info)"
"SELECT id, keyserver_info "
"FROM keyservers;";
char *error;
sqlite3_exec(db, copyData.c_str(), nullptr, nullptr, &error);
if (error) {
return false;
}
return true;
}
bool create_aux_user_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS aux_users ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" aux_user_info TEXT NOT NULL"
");";
return create_table(db, query, "aux_users");
}
bool add_version_column_to_olm_persist_sessions_table(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"ALTER TABLE olm_persist_sessions "
" RENAME COLUMN `target_user_id` TO `target_device_id`; "
"ALTER TABLE olm_persist_sessions "
" ADD COLUMN version INTEGER NOT NULL DEFAULT 1;",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error updating olm_persist_sessions table: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
bool create_thread_activity_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS thread_activity ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" thread_activity_store_entry TEXT NOT NULL"
");";
return create_table(db, query, "thread_activity");
}
bool create_received_messages_to_device(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS received_messages_to_device ("
" id INTEGER PRIMARY KEY,"
" message_id TEXT NOT NULL,"
" sender_device_id TEXT NOT NULL,"
" plaintext TEXT NOT NULL,"
" status TEXT NOT NULL"
");";
return create_table(db, query, "received_messages_to_device");
}
bool recreate_outbound_p2p_messages_table(sqlite3 *db) {
std::string query =
"DROP TABLE IF EXISTS messages_to_device;"
"CREATE TABLE IF NOT EXISTS outbound_p2p_messages ("
" message_id TEXT NOT NULL,"
" device_id TEXT NOT NULL,"
" user_id TEXT NOT NULL,"
" timestamp BIGINT NOT NULL,"
" plaintext TEXT NOT NULL,"
" ciphertext TEXT NOT NULL,"
" status TEXT NOT NULL,"
" PRIMARY KEY (message_id, device_id)"
");"
"CREATE INDEX IF NOT EXISTS outbound_p2p_messages_idx_id_timestamp"
" ON outbound_p2p_messages (device_id, timestamp);";
return create_table(db, query, "outbound_p2p_messages");
}
bool create_entries_table(sqlite3 *db) {
std::string query =
"CREATE TABLE IF NOT EXISTS entries ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" entry TEXT NOT NULL"
");";
return create_table(db, query, "entries");
}
bool create_schema(sqlite3 *db) {
char *error;
sqlite3_exec(
db,
"CREATE TABLE IF NOT EXISTS drafts ("
" key TEXT UNIQUE PRIMARY KEY NOT NULL,"
" text TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS messages ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" local_id TEXT,"
" thread TEXT NOT NULL,"
" user TEXT NOT NULL,"
" type INTEGER NOT NULL,"
" future_type INTEGER,"
" content TEXT,"
" time INTEGER NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS olm_persist_account ("
" id INTEGER UNIQUE PRIMARY KEY NOT NULL,"
" account_data TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS olm_persist_sessions ("
" target_device_id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" session_data TEXT NOT NULL,"
" version INTEGER NOT NULL DEFAULT 1"
");"
"CREATE TABLE IF NOT EXISTS media ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" container TEXT NOT NULL,"
" thread TEXT NOT NULL,"
" uri TEXT NOT NULL,"
" type TEXT NOT NULL,"
" extras TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS threads ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" type INTEGER NOT NULL,"
" name TEXT,"
" description TEXT,"
" color TEXT NOT NULL,"
" creation_time BIGINT NOT NULL,"
" parent_thread_id TEXT,"
" containing_thread_id TEXT,"
" community TEXT,"
" members TEXT NOT NULL,"
" roles TEXT NOT NULL,"
" current_user TEXT NOT NULL,"
" source_message_id TEXT,"
" replies_count INTEGER NOT NULL,"
" avatar TEXT,"
" pinned_count INTEGER NOT NULL DEFAULT 0"
");"
"CREATE TABLE IF NOT EXISTS metadata ("
" name TEXT UNIQUE PRIMARY KEY NOT NULL,"
" data TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS message_store_threads ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" start_reached INTEGER NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS reports ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" report TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS persist_storage ("
" key TEXT UNIQUE PRIMARY KEY NOT NULL,"
" item TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS users ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" user_info TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS keyservers ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" keyserver_info TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS keyservers_synced ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" keyserver_info TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS communities ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" community_info TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS outbound_p2p_messages ("
" message_id TEXT NOT NULL,"
" device_id TEXT NOT NULL,"
" user_id TEXT NOT NULL,"
" timestamp BIGINT NOT NULL,"
" plaintext TEXT NOT NULL,"
" ciphertext TEXT NOT NULL,"
" status TEXT NOT NULL,"
" PRIMARY KEY (message_id, device_id)"
");"
"CREATE TABLE IF NOT EXISTS integrity_store ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" thread_hash TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS synced_metadata ("
" name TEXT UNIQUE PRIMARY KEY NOT NULL,"
" data TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS aux_users ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" aux_user_info TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS thread_activity ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" thread_activity_store_entry TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS received_messages_to_device ("
" id INTEGER PRIMARY KEY,"
" message_id TEXT NOT NULL,"
" sender_device_id TEXT NOT NULL,"
" plaintext TEXT NOT NULL,"
" status TEXT NOT NULL"
");"
"CREATE TABLE IF NOT EXISTS entries ("
" id TEXT UNIQUE PRIMARY KEY NOT NULL,"
" entry TEXT NOT NULL"
");"
"CREATE INDEX IF NOT EXISTS media_idx_container"
" ON media (container);"
"CREATE INDEX IF NOT EXISTS messages_idx_thread_time"
" ON messages (thread, time);"
"CREATE INDEX IF NOT EXISTS outbound_p2p_messages_idx_id_timestamp"
" ON outbound_p2p_messages (device_id, timestamp);",
nullptr,
nullptr,
&error);
if (!error) {
return true;
}
std::ostringstream stringStream;
stringStream << "Error creating tables: " << error;
Logger::log(stringStream.str());
sqlite3_free(error);
return false;
}
void set_encryption_key(
sqlite3 *db,
const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) {
std::string set_encryption_key_query =
"PRAGMA key = \"x'" + encryptionKey + "'\";";
char *error_set_key;
sqlite3_exec(
db, set_encryption_key_query.c_str(), nullptr, nullptr, &error_set_key);
if (error_set_key) {
std::ostringstream error_message;
error_message << "Failed to set encryption key: " << error_set_key;
throw std::system_error(
ECANCELED, std::generic_category(), error_message.str());
}
}
int get_database_version(sqlite3 *db) {
sqlite3_stmt *user_version_stmt;
sqlite3_prepare_v2(
db, "PRAGMA user_version;", -1, &user_version_stmt, nullptr);
sqlite3_step(user_version_stmt);
int current_user_version = sqlite3_column_int(user_version_stmt, 0);
sqlite3_finalize(user_version_stmt);
return current_user_version;
}
bool set_database_version(sqlite3 *db, int db_version) {
std::stringstream update_version;
update_version << "PRAGMA user_version=" << db_version << ";";
auto update_version_str = update_version.str();
char *error;
sqlite3_exec(db, update_version_str.c_str(), nullptr, nullptr, &error);
if (!error) {
return true;
}
std::ostringstream errorStream;
errorStream << "Error setting database version to " << db_version << ": "
<< error;
Logger::log(errorStream.str());
sqlite3_free(error);
return false;
}
// We don't want to run `PRAGMA key = ...;`
// on main web database. The context is here:
// https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web
void default_on_db_open_callback(sqlite3 *db) {
#ifndef EMSCRIPTEN
set_encryption_key(db);
#endif
}
// This is a temporary solution. In future we want to keep
// a separate table for blob hashes. Tracked on Linear:
// https://linear.app/comm/issue/ENG-6261/introduce-blob-hash-table
std::string blob_hash_from_blob_service_uri(const std::string &media_uri) {
static const std::string blob_service_prefix = "comm-blob-service://";
return media_uri.substr(blob_service_prefix.size());
}
bool file_exists(const std::string &file_path) {
std::ifstream file(file_path.c_str());
return file.good();
}
void attempt_delete_file(
const std::string &file_path,
const char *error_message) {
if (std::remove(file_path.c_str())) {
throw std::system_error(errno, std::generic_category(), error_message);
}
}
void attempt_rename_file(
const std::string &old_path,
const std::string &new_path,
const char *error_message) {
if (std::rename(old_path.c_str(), new_path.c_str())) {
throw std::system_error(errno, std::generic_category(), error_message);
}
}
bool is_database_queryable(
sqlite3 *db,
bool use_encryption_key,
const std::string &path = SQLiteQueryExecutor::sqliteFilePath,
const std::string &encryptionKey = SQLiteQueryExecutor::encryptionKey) {
char *err_msg;
sqlite3_open(path.c_str(), &db);
// According to SQLCipher documentation running some SELECT is the only way to
// check for key validity
if (use_encryption_key) {
set_encryption_key(db, encryptionKey);
}
sqlite3_exec(
db, "SELECT COUNT(*) FROM sqlite_master;", nullptr, nullptr, &err_msg);
sqlite3_close(db);
return !err_msg;
}
void validate_encryption() {
std::string temp_encrypted_db_path =
SQLiteQueryExecutor::sqliteFilePath + "_temp_encrypted";
bool temp_encrypted_exists = file_exists(temp_encrypted_db_path);
bool default_location_exists =
file_exists(SQLiteQueryExecutor::sqliteFilePath);
if (temp_encrypted_exists && default_location_exists) {
Logger::log(
"Previous encryption attempt failed. Repeating encryption process from "
"the beginning.");
attempt_delete_file(
temp_encrypted_db_path,
"Failed to delete corrupted encrypted database.");
} else if (temp_encrypted_exists && !default_location_exists) {
Logger::log(
"Moving temporary encrypted database to default location failed in "
"previous encryption attempt. Repeating rename step.");
attempt_rename_file(
temp_encrypted_db_path,
SQLiteQueryExecutor::sqliteFilePath,
"Failed to move encrypted database to default location.");
return;
} else if (!default_location_exists) {
Logger::log(
"Database not present yet. It will be created encrypted under default "
"path.");
return;
}
sqlite3 *db;
if (is_database_queryable(db, true)) {
Logger::log(
"Database exists under default path and it is correctly encrypted.");
return;
}
if (!is_database_queryable(db, false)) {
Logger::log(
"Database exists but it is encrypted with key that was lost. "
"Attempting database deletion. New encrypted one will be created.");
attempt_delete_file(
SQLiteQueryExecutor::sqliteFilePath.c_str(),
"Failed to delete database encrypted with lost key.");
return;
} else {
Logger::log(
"Database exists but it is not encrypted. Attempting encryption "
"process.");
}
sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db);
std::string createEncryptedCopySQL = "ATTACH DATABASE '" +
temp_encrypted_db_path +
"' AS encrypted_comm "
"KEY \"x'" +
SQLiteQueryExecutor::encryptionKey +
"'\";"
"SELECT sqlcipher_export('encrypted_comm');"
"DETACH DATABASE encrypted_comm;";
char *encryption_error;
sqlite3_exec(
db, createEncryptedCopySQL.c_str(), nullptr, nullptr, &encryption_error);
if (encryption_error) {
throw std::system_error(
ECANCELED,
std::generic_category(),
"Failed to create encrypted copy of the original database.");
}
sqlite3_close(db);
attempt_delete_file(
SQLiteQueryExecutor::sqliteFilePath,
"Failed to delete unencrypted database.");
attempt_rename_file(
temp_encrypted_db_path,
SQLiteQueryExecutor::sqliteFilePath,
"Failed to move encrypted database to default location.");
Logger::log("Encryption completed successfully.");
}
typedef bool ShouldBeInTransaction;
typedef std::function MigrateFunction;
typedef std::pair SQLiteMigration;
std::vector> migrations{
{{1, {create_drafts_table, true}},
{2, {rename_threadID_to_key, true}},
{4, {create_persist_account_table, true}},
{5, {create_persist_sessions_table, true}},
{15, {create_media_table, true}},
{16, {drop_messages_table, true}},
{17, {recreate_messages_table, true}},
{18, {create_messages_idx_thread_time, true}},
{19, {create_media_idx_container, true}},
{20, {create_threads_table, true}},
{21, {update_threadID_for_pending_threads_in_drafts, true}},
{22, {enable_write_ahead_logging_mode, false}},
{23, {create_metadata_table, true}},
{24, {add_not_null_constraint_to_drafts, true}},
{25, {add_not_null_constraint_to_metadata, true}},
{26, {add_avatar_column_to_threads_table, true}},
{27, {add_pinned_count_column_to_threads, true}},
{28, {create_message_store_threads_table, true}},
{29, {create_reports_table, true}},
{30, {create_persist_storage_table, true}},
{31, {recreate_message_store_threads_table, true}},
{32, {create_users_table, true}},
{33, {create_keyservers_table, true}},
{34, {enable_rollback_journal_mode, false}},
{35, {create_communities_table, true}},
{36, {create_messages_to_device_table, true}},
{37, {create_integrity_table, true}},
{38, {[](sqlite3 *) { return true; }, false}},
{39, {create_synced_metadata_table, true}},
{40, {create_keyservers_synced, true}},
{41, {create_aux_user_table, true}},
{42, {add_version_column_to_olm_persist_sessions_table, true}},
{43, {create_thread_activity_table, true}},
{44, {create_received_messages_to_device, true}},
{45, {recreate_outbound_p2p_messages_table, true}},
{46, {create_entries_table, true}}}};
enum class MigrationResult { SUCCESS, FAILURE, NOT_APPLIED };
MigrationResult applyMigrationWithTransaction(
sqlite3 *db,
const MigrateFunction &migrate,
int index) {
sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
auto db_version = get_database_version(db);
if (index <= db_version) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return MigrationResult::NOT_APPLIED;
}
auto rc = migrate(db);
if (!rc) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return MigrationResult::FAILURE;
}
auto database_version_set = set_database_version(db, index);
if (!database_version_set) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return MigrationResult::FAILURE;
}
sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
return MigrationResult::SUCCESS;
}
MigrationResult applyMigrationWithoutTransaction(
sqlite3 *db,
const MigrateFunction &migrate,
int index) {
auto db_version = get_database_version(db);
if (index <= db_version) {
return MigrationResult::NOT_APPLIED;
}
auto rc = migrate(db);
if (!rc) {
return MigrationResult::FAILURE;
}
sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
auto inner_db_version = get_database_version(db);
if (index <= inner_db_version) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return MigrationResult::NOT_APPLIED;
}
auto database_version_set = set_database_version(db, index);
if (!database_version_set) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return MigrationResult::FAILURE;
}
sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
return MigrationResult::SUCCESS;
}
bool set_up_database(sqlite3 *db) {
sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
auto db_version = get_database_version(db);
auto latest_version = migrations.back().first;
if (db_version == latest_version) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return true;
}
if (db_version != 0 || !create_schema(db) ||
!set_database_version(db, latest_version)) {
sqlite3_exec(db, "ROLLBACK;", nullptr, nullptr, nullptr);
return false;
}
sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
return true;
}
void SQLiteQueryExecutor::migrate() {
// We don't want to run `PRAGMA key = ...;`
// on main web database. The context is here:
// https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web
#ifndef EMSCRIPTEN
validate_encryption();
#endif
sqlite3 *db;
sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db);
default_on_db_open_callback(db);
std::stringstream db_path;
db_path << "db path: " << SQLiteQueryExecutor::sqliteFilePath.c_str()
<< std::endl;
Logger::log(db_path.str());
auto db_version = get_database_version(db);
std::stringstream version_msg;
version_msg << "db version: " << db_version << std::endl;
Logger::log(version_msg.str());
if (db_version == 0) {
auto db_created = set_up_database(db);
if (!db_created) {
sqlite3_close(db);
Logger::log("Database structure creation error.");
throw std::runtime_error("Database structure creation error");
}
Logger::log("Database structure created.");
sqlite3_close(db);
return;
}
for (const auto &[idx, migration] : migrations) {
const auto &[applyMigration, shouldBeInTransaction] = migration;
MigrationResult migrationResult;
if (shouldBeInTransaction) {
migrationResult = applyMigrationWithTransaction(db, applyMigration, idx);
} else {
migrationResult =
applyMigrationWithoutTransaction(db, applyMigration, idx);
}
if (migrationResult == MigrationResult::NOT_APPLIED) {
continue;
}
std::stringstream migration_msg;
if (migrationResult == MigrationResult::FAILURE) {
migration_msg << "migration " << idx << " failed." << std::endl;
Logger::log(migration_msg.str());
sqlite3_close(db);
throw std::runtime_error(migration_msg.str());
}
if (migrationResult == MigrationResult::SUCCESS) {
migration_msg << "migration " << idx << " succeeded." << std::endl;
Logger::log(migration_msg.str());
}
}
sqlite3_close(db);
}
SQLiteQueryExecutor::SQLiteQueryExecutor() {
SQLiteQueryExecutor::migrate();
#ifndef EMSCRIPTEN
SQLiteQueryExecutor::initializeTablesForLogMonitoring();
std::string currentBackupID = this->getMetadata("backupID");
if (!StaffUtils::isStaffRelease() || !currentBackupID.size()) {
return;
}
SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true);
#endif
}
SQLiteQueryExecutor::SQLiteQueryExecutor(std::string sqliteFilePath) {
SQLiteQueryExecutor::sqliteFilePath = sqliteFilePath;
SQLiteQueryExecutor::migrate();
}
sqlite3 *SQLiteQueryExecutor::getConnection() {
if (SQLiteQueryExecutor::connectionManager.getConnection()) {
return SQLiteQueryExecutor::connectionManager.getConnection();
}
SQLiteQueryExecutor::connectionManager.initializeConnection(
SQLiteQueryExecutor::sqliteFilePath, default_on_db_open_callback);
return SQLiteQueryExecutor::connectionManager.getConnection();
}
void SQLiteQueryExecutor::closeConnection() {
SQLiteQueryExecutor::connectionManager.closeConnection();
}
SQLiteQueryExecutor::~SQLiteQueryExecutor() {
SQLiteQueryExecutor::closeConnection();
}
std::string SQLiteQueryExecutor::getDraft(std::string key) const {
static std::string getDraftByPrimaryKeySQL =
"SELECT * "
"FROM drafts "
"WHERE key = ?;";
std::unique_ptr draft = getEntityByPrimaryKey(
SQLiteQueryExecutor::getConnection(), getDraftByPrimaryKeySQL, key);
return (draft == nullptr) ? "" : draft->text;
}
std::unique_ptr
SQLiteQueryExecutor::getThread(std::string threadID) const {
static std::string getThreadByPrimaryKeySQL =
"SELECT * "
"FROM threads "
"WHERE id = ?;";
return getEntityByPrimaryKey(
SQLiteQueryExecutor::getConnection(), getThreadByPrimaryKeySQL, threadID);
}
void SQLiteQueryExecutor::updateDraft(std::string key, std::string text) const {
static std::string replaceDraftSQL =
"REPLACE INTO drafts (key, text) "
"VALUES (?, ?);";
Draft draft = {key, text};
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceDraftSQL, draft);
}
bool SQLiteQueryExecutor::moveDraft(std::string oldKey, std::string newKey)
const {
std::string draftText = this->getDraft(oldKey);
if (!draftText.size()) {
return false;
}
static std::string rekeyDraftSQL =
"UPDATE OR REPLACE drafts "
"SET key = ? "
"WHERE key = ?;";
rekeyAllEntities(
SQLiteQueryExecutor::getConnection(), rekeyDraftSQL, oldKey, newKey);
return true;
}
std::vector SQLiteQueryExecutor::getAllDrafts() const {
static std::string getAllDraftsSQL =
"SELECT * "
"FROM drafts;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllDraftsSQL);
}
void SQLiteQueryExecutor::removeAllDrafts() const {
static std::string removeAllDraftsSQL = "DELETE FROM drafts;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllDraftsSQL);
}
void SQLiteQueryExecutor::removeDrafts(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeDraftsByKeysSQLStream;
removeDraftsByKeysSQLStream << "DELETE FROM drafts "
"WHERE key IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeDraftsByKeysSQLStream.str(),
ids);
}
void SQLiteQueryExecutor::removeAllMessages() const {
static std::string removeAllMessagesSQL = "DELETE FROM messages;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMessagesSQL);
}
std::vector>>
SQLiteQueryExecutor::getAllMessages() const {
static std::string getAllMessagesSQL =
"SELECT * "
"FROM messages "
"LEFT JOIN media "
" ON messages.id = media.container "
"ORDER BY messages.id;";
SQLiteStatementWrapper preparedSQL(
SQLiteQueryExecutor::getConnection(),
getAllMessagesSQL,
"Failed to retrieve all messages.");
std::string prevMsgIdx{};
std::vector>> allMessages;
for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW;
stepResult = sqlite3_step(preparedSQL)) {
Message message = Message::fromSQLResult(preparedSQL, 0);
if (message.id == prevMsgIdx) {
allMessages.back().second.push_back(Media::fromSQLResult(preparedSQL, 8));
} else {
prevMsgIdx = message.id;
std::vector mediaForMsg;
if (sqlite3_column_type(preparedSQL, 8) != SQLITE_NULL) {
mediaForMsg.push_back(Media::fromSQLResult(preparedSQL, 8));
}
allMessages.push_back(std::make_pair(std::move(message), mediaForMsg));
}
}
return allMessages;
}
void SQLiteQueryExecutor::removeMessages(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeMessagesByKeysSQLStream;
removeMessagesByKeysSQLStream << "DELETE FROM messages "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeMessagesByKeysSQLStream.str(),
ids);
}
void SQLiteQueryExecutor::removeMessagesForThreads(
const std::vector &threadIDs) const {
if (!threadIDs.size()) {
return;
}
std::stringstream removeMessagesByKeysSQLStream;
removeMessagesByKeysSQLStream << "DELETE FROM messages "
"WHERE thread IN "
<< getSQLStatementArray(threadIDs.size())
<< ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeMessagesByKeysSQLStream.str(),
threadIDs);
}
void SQLiteQueryExecutor::replaceMessage(const Message &message) const {
static std::string replaceMessageSQL =
"REPLACE INTO messages "
"(id, local_id, thread, user, type, future_type, content, time) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceMessageSQL, message);
}
void SQLiteQueryExecutor::rekeyMessage(std::string from, std::string to) const {
static std::string rekeyMessageSQL =
"UPDATE OR REPLACE messages "
"SET id = ? "
"WHERE id = ?";
rekeyAllEntities(
SQLiteQueryExecutor::getConnection(), rekeyMessageSQL, from, to);
}
void SQLiteQueryExecutor::removeAllMedia() const {
static std::string removeAllMediaSQL = "DELETE FROM media;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllMediaSQL);
}
void SQLiteQueryExecutor::removeMediaForMessages(
const std::vector &msg_ids) const {
if (!msg_ids.size()) {
return;
}
std::stringstream removeMediaByKeysSQLStream;
removeMediaByKeysSQLStream << "DELETE FROM media "
"WHERE container IN "
<< getSQLStatementArray(msg_ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeMediaByKeysSQLStream.str(),
msg_ids);
}
void SQLiteQueryExecutor::removeMediaForMessage(std::string msg_id) const {
static std::string removeMediaByKeySQL =
"DELETE FROM media "
"WHERE container IN (?);";
std::vector keys = {msg_id};
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(), removeMediaByKeySQL, keys);
}
void SQLiteQueryExecutor::removeMediaForThreads(
const std::vector &thread_ids) const {
if (!thread_ids.size()) {
return;
}
std::stringstream removeMediaByKeysSQLStream;
removeMediaByKeysSQLStream << "DELETE FROM media "
"WHERE thread IN "
<< getSQLStatementArray(thread_ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeMediaByKeysSQLStream.str(),
thread_ids);
}
void SQLiteQueryExecutor::replaceMedia(const Media &media) const {
static std::string replaceMediaSQL =
"REPLACE INTO media "
"(id, container, thread, uri, type, extras) "
"VALUES (?, ?, ?, ?, ?, ?)";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceMediaSQL, media);
}
void SQLiteQueryExecutor::rekeyMediaContainers(std::string from, std::string to)
const {
static std::string rekeyMediaContainersSQL =
"UPDATE media SET container = ? WHERE container = ?;";
rekeyAllEntities(
SQLiteQueryExecutor::getConnection(), rekeyMediaContainersSQL, from, to);
}
void SQLiteQueryExecutor::replaceMessageStoreThreads(
const std::vector &threads) const {
static std::string replaceMessageStoreThreadSQL =
"REPLACE INTO message_store_threads "
"(id, start_reached) "
"VALUES (?, ?);";
for (auto &thread : threads) {
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceMessageStoreThreadSQL,
thread);
}
}
void SQLiteQueryExecutor::removeAllMessageStoreThreads() const {
static std::string removeAllMessageStoreThreadsSQL =
"DELETE FROM message_store_threads;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllMessageStoreThreadsSQL);
}
void SQLiteQueryExecutor::removeMessageStoreThreads(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeMessageStoreThreadsByKeysSQLStream;
removeMessageStoreThreadsByKeysSQLStream
<< "DELETE FROM message_store_threads "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeMessageStoreThreadsByKeysSQLStream.str(),
ids);
}
std::vector
SQLiteQueryExecutor::getAllMessageStoreThreads() const {
static std::string getAllMessageStoreThreadsSQL =
"SELECT * "
"FROM message_store_threads;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllMessageStoreThreadsSQL);
}
std::vector SQLiteQueryExecutor::getAllThreads() const {
static std::string getAllThreadsSQL =
"SELECT * "
"FROM threads;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllThreadsSQL);
};
void SQLiteQueryExecutor::removeThreads(std::vector ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeThreadsByKeysSQLStream;
removeThreadsByKeysSQLStream << "DELETE FROM threads "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeThreadsByKeysSQLStream.str(),
ids);
};
void SQLiteQueryExecutor::replaceThread(const Thread &thread) const {
static std::string replaceThreadSQL =
"REPLACE INTO threads ("
" id, type, name, description, color, creation_time, parent_thread_id,"
" containing_thread_id, community, members, roles, current_user,"
" source_message_id, replies_count, avatar, pinned_count) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceThreadSQL, thread);
};
void SQLiteQueryExecutor::removeAllThreads() const {
static std::string removeAllThreadsSQL = "DELETE FROM threads;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllThreadsSQL);
};
void SQLiteQueryExecutor::replaceReport(const Report &report) const {
static std::string replaceReportSQL =
"REPLACE INTO reports (id, report) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceReportSQL, report);
}
void SQLiteQueryExecutor::removeAllReports() const {
static std::string removeAllReportsSQL = "DELETE FROM reports;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllReportsSQL);
}
void SQLiteQueryExecutor::removeReports(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeReportsByKeysSQLStream;
removeReportsByKeysSQLStream << "DELETE FROM reports "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeReportsByKeysSQLStream.str(),
ids);
}
std::vector SQLiteQueryExecutor::getAllReports() const {
static std::string getAllReportsSQL =
"SELECT * "
"FROM reports;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllReportsSQL);
}
void SQLiteQueryExecutor::setPersistStorageItem(
std::string key,
std::string item) const {
static std::string replacePersistStorageItemSQL =
"REPLACE INTO persist_storage (key, item) "
"VALUES (?, ?);";
PersistItem entry{
key,
item,
};
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replacePersistStorageItemSQL,
entry);
}
void SQLiteQueryExecutor::removePersistStorageItem(std::string key) const {
static std::string removePersistStorageItemByKeySQL =
"DELETE FROM persist_storage "
"WHERE key IN (?);";
std::vector keys = {key};
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removePersistStorageItemByKeySQL,
keys);
}
std::string SQLiteQueryExecutor::getPersistStorageItem(std::string key) const {
static std::string getPersistStorageItemByPrimaryKeySQL =
"SELECT * "
"FROM persist_storage "
"WHERE key = ?;";
std::unique_ptr entry = getEntityByPrimaryKey(
SQLiteQueryExecutor::getConnection(),
getPersistStorageItemByPrimaryKeySQL,
key);
return (entry == nullptr) ? "" : entry->item;
}
void SQLiteQueryExecutor::replaceUser(const UserInfo &user_info) const {
static std::string replaceUserSQL =
"REPLACE INTO users (id, user_info) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceUserSQL, user_info);
}
void SQLiteQueryExecutor::removeAllUsers() const {
static std::string removeAllUsersSQL = "DELETE FROM users;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllUsersSQL);
}
void SQLiteQueryExecutor::removeUsers(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeUsersByKeysSQLStream;
removeUsersByKeysSQLStream << "DELETE FROM users "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeUsersByKeysSQLStream.str(),
ids);
}
void SQLiteQueryExecutor::replaceKeyserver(
const KeyserverInfo &keyserver_info) const {
static std::string replaceKeyserverSQL =
"REPLACE INTO keyservers (id, keyserver_info) "
"VALUES (:id, :keyserver_info);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceKeyserverSQL,
keyserver_info);
static std::string replaceKeyserverSyncedSQL =
"REPLACE INTO keyservers_synced (id, keyserver_info) "
"VALUES (:id, :synced_keyserver_info);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceKeyserverSyncedSQL,
keyserver_info);
}
void SQLiteQueryExecutor::removeAllKeyservers() const {
static std::string removeAllKeyserversSQL = "DELETE FROM keyservers;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllKeyserversSQL);
static std::string removeAllKeyserversSyncedSQL =
"DELETE FROM keyservers_synced;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllKeyserversSyncedSQL);
}
void SQLiteQueryExecutor::removeKeyservers(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
auto idArray = getSQLStatementArray(ids.size());
std::stringstream removeKeyserversByKeysSQLStream;
removeKeyserversByKeysSQLStream << "DELETE FROM keyservers "
"WHERE id IN "
<< idArray << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeKeyserversByKeysSQLStream.str(),
ids);
std::stringstream removeKeyserversSyncedByKeysSQLStream;
removeKeyserversSyncedByKeysSQLStream << "DELETE FROM keyservers_synced "
"WHERE id IN "
<< idArray << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeKeyserversSyncedByKeysSQLStream.str(),
ids);
}
std::vector SQLiteQueryExecutor::getAllKeyservers() const {
static std::string getAllKeyserversSQL =
"SELECT "
" synced.id, "
" COALESCE(keyservers.keyserver_info, ''), "
" synced.keyserver_info "
"FROM keyservers_synced synced "
"LEFT JOIN keyservers "
" ON synced.id = keyservers.id;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllKeyserversSQL);
}
std::vector SQLiteQueryExecutor::getAllUsers() const {
static std::string getAllUsersSQL =
"SELECT * "
"FROM users;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllUsersSQL);
}
void SQLiteQueryExecutor::replaceCommunity(
const CommunityInfo &community_info) const {
static std::string replaceCommunitySQL =
"REPLACE INTO communities (id, community_info) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceCommunitySQL,
community_info);
}
void SQLiteQueryExecutor::removeAllCommunities() const {
static std::string removeAllCommunitiesSQL = "DELETE FROM communities;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllCommunitiesSQL);
}
void SQLiteQueryExecutor::removeCommunities(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeCommunitiesByKeysSQLStream;
removeCommunitiesByKeysSQLStream << "DELETE FROM communities "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeCommunitiesByKeysSQLStream.str(),
ids);
}
std::vector SQLiteQueryExecutor::getAllCommunities() const {
static std::string getAllCommunitiesSQL =
"SELECT * "
"FROM communities;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllCommunitiesSQL);
}
void SQLiteQueryExecutor::replaceIntegrityThreadHashes(
const std::vector &threadHashes) const {
static std::string replaceIntegrityThreadHashSQL =
"REPLACE INTO integrity_store (id, thread_hash) "
"VALUES (?, ?);";
for (const IntegrityThreadHash &integrityThreadHash : threadHashes) {
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceIntegrityThreadHashSQL,
integrityThreadHash);
}
}
void SQLiteQueryExecutor::removeAllIntegrityThreadHashes() const {
static std::string removeAllIntegrityThreadHashesSQL =
"DELETE FROM integrity_store;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllIntegrityThreadHashesSQL);
}
void SQLiteQueryExecutor::removeIntegrityThreadHashes(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeIntegrityThreadHashesByKeysSQLStream;
removeIntegrityThreadHashesByKeysSQLStream << "DELETE FROM integrity_store "
"WHERE id IN "
<< getSQLStatementArray(ids.size())
<< ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeIntegrityThreadHashesByKeysSQLStream.str(),
ids);
}
std::vector
SQLiteQueryExecutor::getAllIntegrityThreadHashes() const {
static std::string getAllIntegrityThreadHashesSQL =
"SELECT * "
"FROM integrity_store;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllIntegrityThreadHashesSQL);
}
void SQLiteQueryExecutor::replaceSyncedMetadataEntry(
const SyncedMetadataEntry &synced_metadata_entry) const {
static std::string replaceSyncedMetadataEntrySQL =
"REPLACE INTO synced_metadata (name, data) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceSyncedMetadataEntrySQL,
synced_metadata_entry);
}
void SQLiteQueryExecutor::removeAllSyncedMetadata() const {
static std::string removeAllSyncedMetadataSQL =
"DELETE FROM synced_metadata;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllSyncedMetadataSQL);
}
void SQLiteQueryExecutor::removeSyncedMetadata(
const std::vector &names) const {
if (!names.size()) {
return;
}
std::stringstream removeSyncedMetadataByNamesSQLStream;
removeSyncedMetadataByNamesSQLStream << "DELETE FROM synced_metadata "
"WHERE name IN "
<< getSQLStatementArray(names.size())
<< ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeSyncedMetadataByNamesSQLStream.str(),
names);
}
std::vector
SQLiteQueryExecutor::getAllSyncedMetadata() const {
static std::string getAllSyncedMetadataSQL =
"SELECT * "
"FROM synced_metadata;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllSyncedMetadataSQL);
}
+std::optional
+SQLiteQueryExecutor::getSyncedDatabaseVersion(sqlite3 *db) const {
+ static std::string getDBVersionSyncedMetadataSQL =
+ "SELECT * "
+ "FROM synced_metadata "
+ "WHERE name=\"db_version\";";
+ std::vector entries =
+ getAllEntities(db, getDBVersionSyncedMetadataSQL);
+ for (auto &entry : entries) {
+ return std::stoi(entry.data);
+ }
+ return std::nullopt;
+}
+
void SQLiteQueryExecutor::replaceAuxUserInfo(
const AuxUserInfo &aux_user_info) const {
static std::string replaceAuxUserInfoSQL =
"REPLACE INTO aux_users (id, aux_user_info) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceAuxUserInfoSQL,
aux_user_info);
}
void SQLiteQueryExecutor::removeAllAuxUserInfos() const {
static std::string removeAllAuxUserInfosSQL = "DELETE FROM aux_users;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllAuxUserInfosSQL);
}
void SQLiteQueryExecutor::removeAuxUserInfos(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeAuxUserInfosByKeysSQLStream;
removeAuxUserInfosByKeysSQLStream << "DELETE FROM aux_users "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeAuxUserInfosByKeysSQLStream.str(),
ids);
}
std::vector SQLiteQueryExecutor::getAllAuxUserInfos() const {
static std::string getAllAuxUserInfosSQL =
"SELECT * "
"FROM aux_users;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllAuxUserInfosSQL);
}
void SQLiteQueryExecutor::replaceThreadActivityEntry(
const ThreadActivityEntry &thread_activity_entry) const {
static std::string replaceThreadActivityEntrySQL =
"REPLACE INTO thread_activity (id, thread_activity_store_entry) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceThreadActivityEntrySQL,
thread_activity_entry);
}
void SQLiteQueryExecutor::removeAllThreadActivityEntries() const {
static std::string removeAllThreadActivityEntriesSQL =
"DELETE FROM thread_activity;";
removeAllEntities(
SQLiteQueryExecutor::getConnection(), removeAllThreadActivityEntriesSQL);
}
void SQLiteQueryExecutor::removeThreadActivityEntries(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeThreadActivityEntriesByKeysSQLStream;
removeThreadActivityEntriesByKeysSQLStream << "DELETE FROM thread_activity "
"WHERE id IN "
<< getSQLStatementArray(ids.size())
<< ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeThreadActivityEntriesByKeysSQLStream.str(),
ids);
}
std::vector
SQLiteQueryExecutor::getAllThreadActivityEntries() const {
static std::string getAllThreadActivityEntriesSQL =
"SELECT * "
"FROM thread_activity;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllThreadActivityEntriesSQL);
}
void SQLiteQueryExecutor::replaceEntry(const EntryInfo &entry_info) const {
static std::string replaceEntrySQL =
"REPLACE INTO entries (id, entry) "
"VALUES (?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceEntrySQL, entry_info);
}
void SQLiteQueryExecutor::removeAllEntries() const {
static std::string removeAllEntriesSQL = "DELETE FROM entries;";
removeAllEntities(SQLiteQueryExecutor::getConnection(), removeAllEntriesSQL);
}
void SQLiteQueryExecutor::removeEntries(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeEntriesByKeysSQLStream;
removeEntriesByKeysSQLStream << "DELETE FROM entries "
"WHERE id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(),
removeEntriesByKeysSQLStream.str(),
ids);
}
std::vector SQLiteQueryExecutor::getAllEntries() const {
static std::string getAllEntriesSQL =
"SELECT * "
"FROM entries;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllEntriesSQL);
}
void SQLiteQueryExecutor::beginTransaction() const {
executeQuery(SQLiteQueryExecutor::getConnection(), "BEGIN TRANSACTION;");
}
void SQLiteQueryExecutor::commitTransaction() const {
executeQuery(SQLiteQueryExecutor::getConnection(), "COMMIT;");
}
void SQLiteQueryExecutor::rollbackTransaction() const {
executeQuery(SQLiteQueryExecutor::getConnection(), "ROLLBACK;");
}
int SQLiteQueryExecutor::getContentAccountID() const {
return CONTENT_ACCOUNT_ID;
}
int SQLiteQueryExecutor::getNotifsAccountID() const {
return NOTIFS_ACCOUNT_ID;
}
std::vector
SQLiteQueryExecutor::getOlmPersistSessionsData() const {
static std::string getAllOlmPersistSessionsSQL =
"SELECT * "
"FROM olm_persist_sessions;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllOlmPersistSessionsSQL);
}
std::optional
SQLiteQueryExecutor::getOlmPersistAccountData(int accountID) const {
static std::string getOlmPersistAccountSQL =
"SELECT * "
"FROM olm_persist_account "
"WHERE id = ?;";
std::unique_ptr result =
getEntityByIntegerPrimaryKey(
SQLiteQueryExecutor::getConnection(),
getOlmPersistAccountSQL,
accountID);
if (result == nullptr) {
return std::nullopt;
}
return result->account_data;
}
void SQLiteQueryExecutor::storeOlmPersistAccount(
int accountID,
const std::string &accountData) const {
static std::string replaceOlmPersistAccountSQL =
"REPLACE INTO olm_persist_account (id, account_data) "
"VALUES (?, ?);";
OlmPersistAccount persistAccount = {accountID, accountData};
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceOlmPersistAccountSQL,
persistAccount);
}
void SQLiteQueryExecutor::storeOlmPersistSession(
const OlmPersistSession &session) const {
static std::string replaceOlmPersistSessionSQL =
"REPLACE INTO olm_persist_sessions "
"(target_device_id, session_data, version) "
"VALUES (?, ?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(),
replaceOlmPersistSessionSQL,
session);
}
void SQLiteQueryExecutor::storeOlmPersistData(
int accountID,
crypto::Persist persist) const {
if (accountID != CONTENT_ACCOUNT_ID && persist.sessions.size() > 0) {
throw std::runtime_error(
"Attempt to store notifications sessions in SQLite. Notifications "
"sessions must be stored in storage shared with NSE.");
}
std::string accountData =
std::string(persist.account.begin(), persist.account.end());
this->storeOlmPersistAccount(accountID, accountData);
for (auto it = persist.sessions.begin(); it != persist.sessions.end(); it++) {
OlmPersistSession persistSession = {
it->first,
std::string(it->second.buffer.begin(), it->second.buffer.end()),
it->second.version};
this->storeOlmPersistSession(persistSession);
}
}
void SQLiteQueryExecutor::setNotifyToken(std::string token) const {
this->setMetadata("notify_token", token);
}
void SQLiteQueryExecutor::clearNotifyToken() const {
this->clearMetadata("notify_token");
}
void SQLiteQueryExecutor::stampSQLiteDBUserID(std::string userID) const {
this->setMetadata("current_user_id", userID);
}
std::string SQLiteQueryExecutor::getSQLiteStampedUserID() const {
return this->getMetadata("current_user_id");
}
void SQLiteQueryExecutor::setMetadata(std::string entry_name, std::string data)
const {
std::string replaceMetadataSQL =
"REPLACE INTO metadata (name, data) "
"VALUES (?, ?);";
Metadata entry{
entry_name,
data,
};
replaceEntity(
SQLiteQueryExecutor::getConnection(), replaceMetadataSQL, entry);
}
void SQLiteQueryExecutor::clearMetadata(std::string entry_name) const {
static std::string removeMetadataByKeySQL =
"DELETE FROM metadata "
"WHERE name IN (?);";
std::vector keys = {entry_name};
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(), removeMetadataByKeySQL, keys);
}
std::string SQLiteQueryExecutor::getMetadata(std::string entry_name) const {
std::string getMetadataByPrimaryKeySQL =
"SELECT * "
"FROM metadata "
"WHERE name = ?;";
std::unique_ptr entry = getEntityByPrimaryKey(
SQLiteQueryExecutor::getConnection(),
getMetadataByPrimaryKeySQL,
entry_name);
return (entry == nullptr) ? "" : entry->data;
}
void SQLiteQueryExecutor::addOutboundP2PMessages(
const std::vector &messages) const {
static std::string addMessage =
"REPLACE INTO outbound_p2p_messages ("
" message_id, device_id, user_id, timestamp,"
" plaintext, ciphertext, status) "
"VALUES (?, ?, ?, ?, ?, ?, ?);";
for (const OutboundP2PMessage &clientMessage : messages) {
SQLiteOutboundP2PMessage message =
clientMessage.toSQLiteOutboundP2PMessage();
replaceEntity(
SQLiteQueryExecutor::getConnection(), addMessage, message);
}
}
std::vector
SQLiteQueryExecutor::getAllOutboundP2PMessages() const {
std::string query =
"SELECT * FROM outbound_p2p_messages "
"ORDER BY timestamp;";
SQLiteStatementWrapper preparedSQL(
SQLiteQueryExecutor::getConnection(),
query,
"Failed to get all messages to device");
std::vector messages;
for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW;
stepResult = sqlite3_step(preparedSQL)) {
messages.emplace_back(OutboundP2PMessage(
SQLiteOutboundP2PMessage::fromSQLResult(preparedSQL, 0)));
}
return messages;
}
void SQLiteQueryExecutor::removeOutboundP2PMessagesOlderThan(
std::string lastConfirmedMessageID,
std::string deviceID) const {
std::string query =
"DELETE FROM outbound_p2p_messages "
"WHERE timestamp <= ("
" SELECT timestamp "
" FROM outbound_p2p_messages"
" WHERE message_id = ?"
") "
"AND device_id IN (?);";
comm::SQLiteStatementWrapper preparedSQL(
SQLiteQueryExecutor::getConnection(),
query,
"Failed to remove messages to device");
bindStringToSQL(lastConfirmedMessageID.c_str(), preparedSQL, 1);
bindStringToSQL(deviceID.c_str(), preparedSQL, 2);
sqlite3_step(preparedSQL);
}
void SQLiteQueryExecutor::removeAllOutboundP2PMessages(
const std::string &deviceID) const {
static std::string removeMessagesSQL =
"DELETE FROM outbound_p2p_messages "
"WHERE device_id IN (?);";
std::vector keys = {deviceID};
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(), removeMessagesSQL, keys);
}
void SQLiteQueryExecutor::setCiphertextForOutboundP2PMessage(
std::string messageID,
std::string deviceID,
std::string ciphertext) const {
static std::string query =
"UPDATE outbound_p2p_messages "
"SET ciphertext = ?, status = 'encrypted' "
"WHERE message_id = ? AND device_id = ?;";
comm::SQLiteStatementWrapper preparedSQL(
SQLiteQueryExecutor::getConnection(),
query,
"Failed to set ciphertext for OutboundP2PMessage");
bindStringToSQL(ciphertext.c_str(), preparedSQL, 1);
bindStringToSQL(messageID.c_str(), preparedSQL, 2);
bindStringToSQL(deviceID.c_str(), preparedSQL, 3);
sqlite3_step(preparedSQL);
}
void SQLiteQueryExecutor::markOutboundP2PMessageAsSent(
std::string messageID,
std::string deviceID) const {
static std::string query =
"UPDATE outbound_p2p_messages "
"SET status = 'sent' "
"WHERE message_id = ? AND device_id = ?;";
comm::SQLiteStatementWrapper preparedSQL(
SQLiteQueryExecutor::getConnection(),
query,
"Failed to mark OutboundP2PMessage as sent");
bindStringToSQL(messageID.c_str(), preparedSQL, 1);
bindStringToSQL(deviceID.c_str(), preparedSQL, 2);
sqlite3_step(preparedSQL);
}
void SQLiteQueryExecutor::addInboundP2PMessage(
InboundP2PMessage message) const {
static std::string addMessage =
"REPLACE INTO received_messages_to_device ("
" message_id, sender_device_id, plaintext, status)"
"VALUES (?, ?, ?, ?);";
replaceEntity(
SQLiteQueryExecutor::getConnection(), addMessage, message);
}
std::vector
SQLiteQueryExecutor::getAllInboundP2PMessage() const {
static std::string query =
"SELECT message_id, sender_device_id, plaintext, status "
"FROM received_messages_to_device;";
return getAllEntities(
SQLiteQueryExecutor::getConnection(), query);
}
void SQLiteQueryExecutor::removeInboundP2PMessages(
const std::vector &ids) const {
if (!ids.size()) {
return;
}
std::stringstream removeMessagesSQLStream;
removeMessagesSQLStream << "DELETE FROM received_messages_to_device "
"WHERE message_id IN "
<< getSQLStatementArray(ids.size()) << ";";
removeEntitiesByKeys(
SQLiteQueryExecutor::getConnection(), removeMessagesSQLStream.str(), ids);
}
#ifdef EMSCRIPTEN
std::vector SQLiteQueryExecutor::getAllThreadsWeb() const {
auto threads = this->getAllThreads();
std::vector webThreads;
webThreads.reserve(threads.size());
for (const auto &thread : threads) {
webThreads.emplace_back(thread);
}
return webThreads;
};
void SQLiteQueryExecutor::replaceThreadWeb(const WebThread &thread) const {
this->replaceThread(thread.toThread());
};
std::vector SQLiteQueryExecutor::getAllMessagesWeb() const {
auto allMessages = this->getAllMessages();
std::vector allMessageWithMedias;
for (auto &messageWitMedia : allMessages) {
allMessageWithMedias.push_back(
{std::move(messageWitMedia.first), messageWitMedia.second});
}
return allMessageWithMedias;
}
void SQLiteQueryExecutor::replaceMessageWeb(const WebMessage &message) const {
this->replaceMessage(message.toMessage());
};
NullableString
SQLiteQueryExecutor::getOlmPersistAccountDataWeb(int accountID) const {
std::optional accountData =
this->getOlmPersistAccountData(accountID);
if (!accountData.has_value()) {
return NullableString();
}
return std::make_unique(accountData.value());
}
#else
void SQLiteQueryExecutor::clearSensitiveData() {
SQLiteQueryExecutor::closeConnection();
if (file_exists(SQLiteQueryExecutor::sqliteFilePath) &&
std::remove(SQLiteQueryExecutor::sqliteFilePath.c_str())) {
std::ostringstream errorStream;
errorStream << "Failed to delete database file. Details: "
<< strerror(errno);
throw std::system_error(errno, std::generic_category(), errorStream.str());
}
SQLiteQueryExecutor::generateFreshEncryptionKey();
SQLiteQueryExecutor::migrate();
}
void SQLiteQueryExecutor::initialize(std::string &databasePath) {
std::call_once(SQLiteQueryExecutor::initialized, [&databasePath]() {
SQLiteQueryExecutor::sqliteFilePath = databasePath;
folly::Optional maybeEncryptionKey =
CommSecureStore::get(SQLiteQueryExecutor::secureStoreEncryptionKeyID);
folly::Optional maybeBackupLogsEncryptionKey =
CommSecureStore::get(
SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID);
if (file_exists(databasePath) && maybeEncryptionKey &&
maybeBackupLogsEncryptionKey) {
SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value();
SQLiteQueryExecutor::backupLogsEncryptionKey =
maybeBackupLogsEncryptionKey.value();
return;
} else if (file_exists(databasePath) && maybeEncryptionKey) {
SQLiteQueryExecutor::encryptionKey = maybeEncryptionKey.value();
SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey();
return;
}
SQLiteQueryExecutor::generateFreshEncryptionKey();
});
}
void SQLiteQueryExecutor::initializeTablesForLogMonitoring() {
sqlite3 *db;
sqlite3_open(SQLiteQueryExecutor::sqliteFilePath.c_str(), &db);
default_on_db_open_callback(db);
std::vector tablesToMonitor;
{
SQLiteStatementWrapper preparedSQL(
db,
"SELECT name FROM sqlite_master WHERE type='table';",
"Failed to get all database tables");
for (int stepResult = sqlite3_step(preparedSQL); stepResult == SQLITE_ROW;
stepResult = sqlite3_step(preparedSQL)) {
std::string table_name =
reinterpret_cast(sqlite3_column_text(preparedSQL, 0));
if (SQLiteQueryExecutor::backedUpTablesBlocklist.find(table_name) ==
SQLiteQueryExecutor::backedUpTablesBlocklist.end()) {
tablesToMonitor.emplace_back(table_name);
}
}
// Runs preparedSQL destructor which finalizes the sqlite statement
}
sqlite3_close(db);
SQLiteQueryExecutor::connectionManager.tablesToMonitor = tablesToMonitor;
}
void SQLiteQueryExecutor::createMainCompaction(std::string backupID) const {
std::string finalBackupPath =
PlatformSpecificTools::getBackupFilePath(backupID, false);
std::string finalAttachmentsPath =
PlatformSpecificTools::getBackupFilePath(backupID, true);
std::string tempBackupPath = finalBackupPath + "_tmp";
std::string tempAttachmentsPath = finalAttachmentsPath + "_tmp";
if (file_exists(tempBackupPath)) {
Logger::log(
"Attempting to delete temporary backup file from previous backup "
"attempt.");
attempt_delete_file(
tempBackupPath,
"Failed to delete temporary backup file from previous backup attempt.");
}
if (file_exists(tempAttachmentsPath)) {
Logger::log(
"Attempting to delete temporary attachments file from previous backup "
"attempt.");
attempt_delete_file(
tempAttachmentsPath,
"Failed to delete temporary attachments file from previous backup "
"attempt.");
}
sqlite3 *backupDB;
sqlite3_open(tempBackupPath.c_str(), &backupDB);
set_encryption_key(backupDB);
sqlite3_backup *backupObj = sqlite3_backup_init(
backupDB, "main", SQLiteQueryExecutor::getConnection(), "main");
if (!backupObj) {
std::stringstream error_message;
error_message << "Failed to init backup for main compaction. Details: "
<< sqlite3_errmsg(backupDB) << std::endl;
sqlite3_close(backupDB);
throw std::runtime_error(error_message.str());
}
int backupResult = sqlite3_backup_step(backupObj, -1);
sqlite3_backup_finish(backupObj);
if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) {
sqlite3_close(backupDB);
throw std::runtime_error(
"Programmer error. Database in transaction during backup attempt.");
} else if (backupResult != SQLITE_DONE) {
sqlite3_close(backupDB);
std::stringstream error_message;
error_message << "Failed to create database backup. Details: "
<< sqlite3_errstr(backupResult);
throw std::runtime_error(error_message.str());
}
if (!SQLiteQueryExecutor::backedUpTablesBlocklist.empty()) {
std::string removeDeviceSpecificDataSQL = "";
for (const auto &table_name :
SQLiteQueryExecutor::backedUpTablesBlocklist) {
removeDeviceSpecificDataSQL.append("DELETE FROM " + table_name + ";\n");
}
executeQuery(backupDB, removeDeviceSpecificDataSQL);
}
executeQuery(backupDB, "VACUUM;");
sqlite3_close(backupDB);
attempt_rename_file(
tempBackupPath,
finalBackupPath,
"Failed to rename complete temporary backup file to final backup file.");
std::ofstream tempAttachmentsFile(tempAttachmentsPath);
if (!tempAttachmentsFile.is_open()) {
throw std::runtime_error(
"Unable to create attachments file for backup id: " + backupID);
}
std::string getAllBlobServiceMediaSQL =
"SELECT * FROM media WHERE uri LIKE 'comm-blob-service://%';";
std::vector blobServiceMedia = getAllEntities(
SQLiteQueryExecutor::getConnection(), getAllBlobServiceMediaSQL);
for (const auto &media : blobServiceMedia) {
std::string blobServiceURI = media.uri;
std::string blobHash = blob_hash_from_blob_service_uri(blobServiceURI);
tempAttachmentsFile << blobHash << "\n";
}
tempAttachmentsFile.close();
attempt_rename_file(
tempAttachmentsPath,
finalAttachmentsPath,
"Failed to rename complete temporary attachments file to final "
"attachments file.");
this->setMetadata("backupID", backupID);
this->clearMetadata("logID");
if (StaffUtils::isStaffRelease()) {
SQLiteQueryExecutor::connectionManager.setLogsMonitoring(true);
}
}
void SQLiteQueryExecutor::generateFreshEncryptionKey() {
std::string encryptionKey = comm::crypto::Tools::generateRandomHexString(
SQLiteQueryExecutor::sqlcipherEncryptionKeySize);
CommSecureStore::set(
SQLiteQueryExecutor::secureStoreEncryptionKeyID, encryptionKey);
SQLiteQueryExecutor::encryptionKey = encryptionKey;
SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey();
}
void SQLiteQueryExecutor::generateFreshBackupLogsEncryptionKey() {
std::string backupLogsEncryptionKey =
comm::crypto::Tools::generateRandomHexString(
SQLiteQueryExecutor::backupLogsEncryptionKeySize);
CommSecureStore::set(
SQLiteQueryExecutor::secureStoreBackupLogsEncryptionKeyID,
backupLogsEncryptionKey);
SQLiteQueryExecutor::backupLogsEncryptionKey = backupLogsEncryptionKey;
}
void SQLiteQueryExecutor::captureBackupLogs() const {
std::string backupID = this->getMetadata("backupID");
if (!backupID.size()) {
return;
}
std::string logID = this->getMetadata("logID");
if (!logID.size()) {
logID = "1";
}
bool newLogCreated = SQLiteQueryExecutor::connectionManager.captureLogs(
backupID, logID, SQLiteQueryExecutor::backupLogsEncryptionKey);
if (!newLogCreated) {
return;
}
this->setMetadata("logID", std::to_string(std::stoi(logID) + 1));
}
#endif
void SQLiteQueryExecutor::restoreFromMainCompaction(
std::string mainCompactionPath,
- std::string mainCompactionEncryptionKey) const {
+ std::string mainCompactionEncryptionKey,
+ std::string maxVersion) const {
if (!file_exists(mainCompactionPath)) {
throw std::runtime_error("Restore attempt but backup file does not exist.");
}
sqlite3 *backupDB;
if (!is_database_queryable(
backupDB, true, mainCompactionPath, mainCompactionEncryptionKey)) {
throw std::runtime_error("Backup file or encryption key corrupted.");
}
// We don't want to run `PRAGMA key = ...;`
// on main web database. The context is here:
// https://linear.app/comm/issue/ENG-6398/issues-with-sqlcipher-on-web
#ifdef EMSCRIPTEN
std::string plaintextBackupPath = mainCompactionPath + "_plaintext";
if (file_exists(plaintextBackupPath)) {
attempt_delete_file(
plaintextBackupPath,
"Failed to delete plaintext backup file from previous backup attempt.");
}
std::string plaintextMigrationDBQuery = "PRAGMA key = \"x'" +
mainCompactionEncryptionKey +
"'\";"
"ATTACH DATABASE '" +
plaintextBackupPath +
"' AS plaintext KEY '';"
"SELECT sqlcipher_export('plaintext');"
"DETACH DATABASE plaintext;";
sqlite3_open(mainCompactionPath.c_str(), &backupDB);
char *plaintextMigrationErr;
sqlite3_exec(
backupDB,
plaintextMigrationDBQuery.c_str(),
nullptr,
nullptr,
&plaintextMigrationErr);
sqlite3_close(backupDB);
if (plaintextMigrationErr) {
std::stringstream error_message;
error_message << "Failed to migrate backup SQLCipher file to plaintext "
"SQLite file. Details"
<< plaintextMigrationErr << std::endl;
std::string error_message_str = error_message.str();
sqlite3_free(plaintextMigrationErr);
throw std::runtime_error(error_message_str);
}
sqlite3_open(plaintextBackupPath.c_str(), &backupDB);
#else
sqlite3_open(mainCompactionPath.c_str(), &backupDB);
set_encryption_key(backupDB, mainCompactionEncryptionKey);
#endif
+ int version = this->getSyncedDatabaseVersion(backupDB).value_or(-1);
+ if (version > std::stoi(maxVersion)) {
+ std::stringstream error_message;
+ error_message << "Failed to restore a backup because it was created "
+ << "with version " << version
+ << " that is newer than the max supported version "
+ << maxVersion << std::endl;
+ sqlite3_close(backupDB);
+ throw std::runtime_error(error_message.str());
+ }
+
sqlite3_backup *backupObj = sqlite3_backup_init(
SQLiteQueryExecutor::getConnection(), "main", backupDB, "main");
if (!backupObj) {
std::stringstream error_message;
error_message << "Failed to init backup for main compaction. Details: "
<< sqlite3_errmsg(SQLiteQueryExecutor::getConnection())
<< std::endl;
sqlite3_close(backupDB);
throw std::runtime_error(error_message.str());
}
int backupResult = sqlite3_backup_step(backupObj, -1);
sqlite3_backup_finish(backupObj);
sqlite3_close(backupDB);
if (backupResult == SQLITE_BUSY || backupResult == SQLITE_LOCKED) {
throw std::runtime_error(
"Programmer error. Database in transaction during restore attempt.");
} else if (backupResult != SQLITE_DONE) {
std::stringstream error_message;
error_message << "Failed to restore database from backup. Details: "
<< sqlite3_errstr(backupResult);
throw std::runtime_error(error_message.str());
}
#ifdef EMSCRIPTEN
attempt_delete_file(
plaintextBackupPath,
"Failed to delete plaintext compaction file after successful restore.");
#endif
attempt_delete_file(
mainCompactionPath,
"Failed to delete main compaction file after successful restore.");
}
void SQLiteQueryExecutor::restoreFromBackupLog(
const std::vector &backupLog) const {
SQLiteQueryExecutor::connectionManager.restoreFromBackupLog(backupLog);
}
} // namespace comm
diff --git a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h
index 5483c2cfc..87e659b20 100644
--- a/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h
+++ b/native/cpp/CommonCpp/DatabaseManagers/SQLiteQueryExecutor.h
@@ -1,184 +1,187 @@
#pragma once
#include "../CryptoTools/Persist.h"
#include "DatabaseQueryExecutor.h"
#include "NativeSQLiteConnectionManager.h"
#include "entities/AuxUserInfo.h"
#include "entities/CommunityInfo.h"
#include "entities/Draft.h"
#include "entities/IntegrityThreadHash.h"
#include "entities/KeyserverInfo.h"
#include "entities/ThreadActivityEntry.h"
#include "entities/UserInfo.h"
#include
#include
#include
namespace comm {
class SQLiteQueryExecutor : public DatabaseQueryExecutor {
static void migrate();
static sqlite3 *getConnection();
static void closeConnection();
static std::once_flag initialized;
static int sqlcipherEncryptionKeySize;
static std::string secureStoreEncryptionKeyID;
static int backupLogsEncryptionKeySize;
static std::string secureStoreBackupLogsEncryptionKeyID;
static std::string backupLogsEncryptionKey;
#ifndef EMSCRIPTEN
static NativeSQLiteConnectionManager connectionManager;
static std::unordered_set backedUpTablesBlocklist;
static void generateFreshEncryptionKey();
static void generateFreshBackupLogsEncryptionKey();
static void initializeTablesForLogMonitoring();
#else
static SQLiteConnectionManager connectionManager;
#endif
+ std::optional getSyncedDatabaseVersion(sqlite3 *db) const;
+
public:
static std::string sqliteFilePath;
static std::string encryptionKey;
SQLiteQueryExecutor();
~SQLiteQueryExecutor();
SQLiteQueryExecutor(std::string sqliteFilePath);
std::unique_ptr getThread(std::string threadID) const override;
std::string getDraft(std::string key) const override;
void updateDraft(std::string key, std::string text) const override;
bool moveDraft(std::string oldKey, std::string newKey) const override;
std::vector getAllDrafts() const override;
void removeAllDrafts() const override;
void removeDrafts(const std::vector &ids) const override;
void removeAllMessages() const override;
std::vector>>
getAllMessages() const override;
void removeMessages(const std::vector &ids) const override;
void removeMessagesForThreads(
const std::vector &threadIDs) const override;
void replaceMessage(const Message &message) const override;
void rekeyMessage(std::string from, std::string to) const override;
void replaceMessageStoreThreads(
const std::vector &threads) const override;
void
removeMessageStoreThreads(const std::vector &ids) const override;
void removeAllMessageStoreThreads() const override;
std::vector getAllMessageStoreThreads() const override;
void removeAllMedia() const override;
void removeMediaForMessages(
const std::vector &msg_ids) const override;
void removeMediaForMessage(std::string msg_id) const override;
void removeMediaForThreads(
const std::vector &thread_ids) const override;
void replaceMedia(const Media &media) const override;
void rekeyMediaContainers(std::string from, std::string to) const override;
std::vector getAllThreads() const override;
void removeThreads(std::vector ids) const override;
void replaceThread(const Thread &thread) const override;
void removeAllThreads() const override;
void replaceReport(const Report &report) const override;
void removeReports(const std::vector &ids) const override;
void removeAllReports() const override;
std::vector getAllReports() const override;
void setPersistStorageItem(std::string key, std::string item) const override;
void removePersistStorageItem(std::string key) const override;
std::string getPersistStorageItem(std::string key) const override;
void replaceUser(const UserInfo &user_info) const override;
void removeUsers(const std::vector &ids) const override;
void removeAllUsers() const override;
std::vector getAllUsers() const override;
void replaceKeyserver(const KeyserverInfo &keyserver_info) const override;
void removeKeyservers(const std::vector &ids) const override;
void removeAllKeyservers() const override;
std::vector getAllKeyservers() const override;
void replaceCommunity(const CommunityInfo &community_info) const override;
void removeCommunities(const std::vector &ids) const override;
void removeAllCommunities() const override;
std::vector getAllCommunities() const override;
void replaceIntegrityThreadHashes(
const std::vector &thread_hashes) const override;
void removeIntegrityThreadHashes(
const std::vector &ids) const override;
void removeAllIntegrityThreadHashes() const override;
std::vector getAllIntegrityThreadHashes() const override;
void replaceSyncedMetadataEntry(
const SyncedMetadataEntry &synced_metadata_entry) const override;
void
removeSyncedMetadata(const std::vector &names) const override;
void removeAllSyncedMetadata() const override;
std::vector getAllSyncedMetadata() const override;
void replaceAuxUserInfo(const AuxUserInfo &aux_user_info) const override;
void removeAuxUserInfos(const std::vector &ids) const override;
void removeAllAuxUserInfos() const override;
virtual std::vector getAllAuxUserInfos() const override;
void replaceThreadActivityEntry(
const ThreadActivityEntry &thread_activity_entry) const override;
void removeThreadActivityEntries(
const std::vector &ids) const override;
void removeAllThreadActivityEntries() const override;
std::vector getAllThreadActivityEntries() const override;
void replaceEntry(const EntryInfo &entry_info) const override;
void removeEntries(const std::vector &ids) const override;
void removeAllEntries() const override;
std::vector getAllEntries() const override;
void beginTransaction() const override;
void commitTransaction() const override;
void rollbackTransaction() const override;
int getContentAccountID() const override;
int getNotifsAccountID() const override;
std::vector getOlmPersistSessionsData() const override;
std::optional
getOlmPersistAccountData(int accountID) const override;
void storeOlmPersistSession(const OlmPersistSession &session) const override;
void storeOlmPersistAccount(int accountID, const std::string &accountData)
const override;
void
storeOlmPersistData(int accountID, crypto::Persist persist) const override;
void setNotifyToken(std::string token) const override;
void clearNotifyToken() const override;
void stampSQLiteDBUserID(std::string userID) const override;
std::string getSQLiteStampedUserID() const override;
void setMetadata(std::string entry_name, std::string data) const override;
void clearMetadata(std::string entry_name) const override;
std::string getMetadata(std::string entry_name) const override;
void restoreFromMainCompaction(
std::string mainCompactionPath,
- std::string mainCompactionEncryptionKey) const override;
+ std::string mainCompactionEncryptionKey,
+ std::string maxVersion) const override;
void restoreFromBackupLog(
const std::vector &backupLog) const override;
void addOutboundP2PMessages(
const std::vector &messages) const override;
std::vector getAllOutboundP2PMessages() const override;
void removeOutboundP2PMessagesOlderThan(
std::string lastConfirmedMessageID,
std::string deviceID) const override;
void removeAllOutboundP2PMessages(const std::string &deviceID) const override;
void setCiphertextForOutboundP2PMessage(
std::string messageID,
std::string deviceID,
std::string ciphertext) const override;
void markOutboundP2PMessageAsSent(std::string messageID, std::string deviceID)
const override;
void addInboundP2PMessage(InboundP2PMessage message) const override;
std::vector getAllInboundP2PMessage() const override;
void
removeInboundP2PMessages(const std::vector &ids) const override;
#ifdef EMSCRIPTEN
std::vector getAllThreadsWeb() const override;
void replaceThreadWeb(const WebThread &thread) const override;
std::vector getAllMessagesWeb() const override;
void replaceMessageWeb(const WebMessage &message) const override;
NullableString getOlmPersistAccountDataWeb(int accountID) const override;
#else
static void clearSensitiveData();
static void initialize(std::string &databasePath);
void createMainCompaction(std::string backupID) const override;
void captureBackupLogs() const override;
#endif
};
} // namespace comm
diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp
index e54019642..0ab9a7648 100644
--- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp
+++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp
@@ -1,2535 +1,2547 @@
#include "CommCoreModule.h"
#include "../Notifications/BackgroundDataStorage/NotificationsCryptoModule.h"
#include "BaseDataStore.h"
#include "CommServicesAuthMetadataEmitter.h"
#include "DatabaseManager.h"
#include "InternalModules/GlobalDBSingleton.h"
#include "InternalModules/RustPromiseManager.h"
#include "NativeModuleUtils.h"
#include "TerminateApp.h"
#include
#include
#include
#include
#include "JSIRust.h"
#include "lib.rs.h"
#include
#include
namespace comm {
using namespace facebook::react;
jsi::Value CommCoreModule::getDraft(jsi::Runtime &rt, jsi::String key) {
std::string keyStr = key.utf8(rt);
return createPromiseAsJSIValue(
rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) {
taskType job = [=, &innerRt]() {
std::string error;
std::string draftStr;
try {
draftStr = DatabaseManager::getQueryExecutor().getDraft(keyStr);
} catch (std::system_error &e) {
error = e.what();
}
this->jsInvoker_->invokeAsync([=, &innerRt]() {
if (error.size()) {
promise->reject(error);
return;
}
jsi::String draft = jsi::String::createFromUtf8(innerRt, draftStr);
promise->resolve(std::move(draft));
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
jsi::Value CommCoreModule::updateDraft(
jsi::Runtime &rt,
jsi::String key,
jsi::String text) {
std::string keyStr = key.utf8(rt);
std::string textStr = text.utf8(rt);
return createPromiseAsJSIValue(
rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) {
taskType job = [=]() {
std::string error;
try {
DatabaseManager::getQueryExecutor().updateDraft(keyStr, textStr);
} catch (std::system_error &e) {
error = e.what();
}
this->jsInvoker_->invokeAsync([=]() {
if (error.size()) {
promise->reject(error);
} else {
promise->resolve(true);
}
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
jsi::Value CommCoreModule::moveDraft(
jsi::Runtime &rt,
jsi::String oldKey,
jsi::String newKey) {
std::string oldKeyStr = oldKey.utf8(rt);
std::string newKeyStr = newKey.utf8(rt);
return createPromiseAsJSIValue(
rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) {
taskType job = [=]() {
std::string error;
bool result = false;
try {
result = DatabaseManager::getQueryExecutor().moveDraft(
oldKeyStr, newKeyStr);
} catch (std::system_error &e) {
error = e.what();
}
this->jsInvoker_->invokeAsync([=]() {
if (error.size()) {
promise->reject(error);
} else {
promise->resolve(result);
}
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
jsi::Value CommCoreModule::getClientDBStore(jsi::Runtime &rt) {
return createPromiseAsJSIValue(
rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) {
taskType job = [=, &innerRt]() {
std::string error;
std::vector draftsVector;
std::vector threadsVector;
std::vector>> messagesVector;
std::vector messageStoreThreadsVector;
std::vector reportStoreVector;
std::vector userStoreVector;
std::vector keyserverStoreVector;
std::vector communityStoreVector;
std::vector integrityStoreVector;
std::vector syncedMetadataStoreVector;
std::vector auxUserStoreVector;
std::vector threadActivityStoreVector;
std::vector entryStoreVector;
try {
draftsVector = DatabaseManager::getQueryExecutor().getAllDrafts();
messagesVector =
DatabaseManager::getQueryExecutor().getAllMessages();
threadsVector = DatabaseManager::getQueryExecutor().getAllThreads();
messageStoreThreadsVector =
DatabaseManager::getQueryExecutor().getAllMessageStoreThreads();
reportStoreVector =
DatabaseManager::getQueryExecutor().getAllReports();
userStoreVector = DatabaseManager::getQueryExecutor().getAllUsers();
keyserverStoreVector =
DatabaseManager::getQueryExecutor().getAllKeyservers();
communityStoreVector =
DatabaseManager::getQueryExecutor().getAllCommunities();
integrityStoreVector = DatabaseManager::getQueryExecutor()
.getAllIntegrityThreadHashes();
syncedMetadataStoreVector =
DatabaseManager::getQueryExecutor().getAllSyncedMetadata();
auxUserStoreVector =
DatabaseManager::getQueryExecutor().getAllAuxUserInfos();
threadActivityStoreVector = DatabaseManager::getQueryExecutor()
.getAllThreadActivityEntries();
entryStoreVector =
DatabaseManager::getQueryExecutor().getAllEntries();
} catch (std::system_error &e) {
error = e.what();
}
auto draftsVectorPtr =
std::make_shared>(std::move(draftsVector));
auto messagesVectorPtr = std::make_shared<
std::vector>>>(
std::move(messagesVector));
auto threadsVectorPtr =
std::make_shared>(std::move(threadsVector));
auto messageStoreThreadsVectorPtr =
std::make_shared>(
std::move(messageStoreThreadsVector));
auto reportStoreVectorPtr = std::make_shared>(
std::move(reportStoreVector));
auto userStoreVectorPtr = std::make_shared>(
std::move(userStoreVector));
auto keyserveStoreVectorPtr =
std::make_shared>(
std::move(keyserverStoreVector));
auto communityStoreVectorPtr =
std::make_shared>(
std::move(communityStoreVector));
auto integrityStoreVectorPtr =
std::make_shared>(
std::move(integrityStoreVector));
auto syncedMetadataStoreVectorPtr =
std::make_shared>(
std::move(syncedMetadataStoreVector));
auto auxUserStoreVectorPtr =
std::make_shared>(
std::move(auxUserStoreVector));
auto threadActivityStoreVectorPtr =
std::make_shared>(
std::move(threadActivityStoreVector));
auto entryStoreVectorPtr = std::make_shared>(
std::move(entryStoreVector));
this->jsInvoker_->invokeAsync([&innerRt,
draftsVectorPtr,
messagesVectorPtr,
threadsVectorPtr,
messageStoreThreadsVectorPtr,
reportStoreVectorPtr,
userStoreVectorPtr,
keyserveStoreVectorPtr,
communityStoreVectorPtr,
integrityStoreVectorPtr,
syncedMetadataStoreVectorPtr,
auxUserStoreVectorPtr,
threadActivityStoreVectorPtr,
entryStoreVectorPtr,
error,
promise,
draftStore = this->draftStore,
threadStore = this->threadStore,
messageStore = this->messageStore,
reportStore = this->reportStore,
userStore = this->userStore,
keyserverStore = this->keyserverStore,
communityStore = this->communityStore,
integrityStore = this->integrityStore,
syncedMetadataStore =
this->syncedMetadataStore,
auxUserStore = this->auxUserStore,
threadActivityStore =
this->threadActivityStore,
entryStore = this->entryStore]() {
if (error.size()) {
promise->reject(error);
return;
}
jsi::Array jsiDrafts =
draftStore.parseDBDataStore(innerRt, draftsVectorPtr);
jsi::Array jsiMessages =
messageStore.parseDBDataStore(innerRt, messagesVectorPtr);
jsi::Array jsiThreads =
threadStore.parseDBDataStore(innerRt, threadsVectorPtr);
jsi::Array jsiMessageStoreThreads =
messageStore.parseDBMessageStoreThreads(
innerRt, messageStoreThreadsVectorPtr);
jsi::Array jsiReportStore =
reportStore.parseDBDataStore(innerRt, reportStoreVectorPtr);
jsi::Array jsiUserStore =
userStore.parseDBDataStore(innerRt, userStoreVectorPtr);
jsi::Array jsiKeyserverStore = keyserverStore.parseDBDataStore(
innerRt, keyserveStoreVectorPtr);
jsi::Array jsiCommunityStore = communityStore.parseDBDataStore(
innerRt, communityStoreVectorPtr);
jsi::Array jsiIntegrityStore = integrityStore.parseDBDataStore(
innerRt, integrityStoreVectorPtr);
jsi::Array jsiSyncedMetadataStore =
syncedMetadataStore.parseDBDataStore(
innerRt, syncedMetadataStoreVectorPtr);
jsi::Array jsiAuxUserStore =
auxUserStore.parseDBDataStore(innerRt, auxUserStoreVectorPtr);
jsi::Array jsiThreadActivityStore =
threadActivityStore.parseDBDataStore(
innerRt, threadActivityStoreVectorPtr);
jsi::Array jsiEntryStore =
entryStore.parseDBDataStore(innerRt, entryStoreVectorPtr);
auto jsiClientDBStore = jsi::Object(innerRt);
jsiClientDBStore.setProperty(innerRt, "messages", jsiMessages);
jsiClientDBStore.setProperty(innerRt, "threads", jsiThreads);
jsiClientDBStore.setProperty(innerRt, "drafts", jsiDrafts);
jsiClientDBStore.setProperty(
innerRt, "messageStoreThreads", jsiMessageStoreThreads);
jsiClientDBStore.setProperty(innerRt, "reports", jsiReportStore);
jsiClientDBStore.setProperty(innerRt, "users", jsiUserStore);
jsiClientDBStore.setProperty(
innerRt, "keyservers", jsiKeyserverStore);
jsiClientDBStore.setProperty(
innerRt, "communities", jsiCommunityStore);
jsiClientDBStore.setProperty(
innerRt, "integrityThreadHashes", jsiIntegrityStore);
jsiClientDBStore.setProperty(
innerRt, "syncedMetadata", jsiSyncedMetadataStore);
jsiClientDBStore.setProperty(
innerRt, "auxUserInfos", jsiAuxUserStore);
jsiClientDBStore.setProperty(
innerRt, "threadActivityEntries", jsiThreadActivityStore);
jsiClientDBStore.setProperty(innerRt, "entries", jsiEntryStore);
promise->resolve(std::move(jsiClientDBStore));
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
jsi::Value CommCoreModule::removeAllDrafts(jsi::Runtime &rt) {
return createPromiseAsJSIValue(
rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) {
taskType job = [=]() {
std::string error;
try {
DatabaseManager::getQueryExecutor().removeAllDrafts();
} catch (std::system_error &e) {
error = e.what();
}
this->jsInvoker_->invokeAsync([=]() {
if (error.size()) {
promise->reject(error);
return;
}
promise->resolve(jsi::Value::undefined());
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
jsi::Array CommCoreModule::getAllMessagesSync(jsi::Runtime &rt) {
auto messagesVector = NativeModuleUtils::runSyncOrThrowJSError<
std::vector>>>(rt, []() {
return DatabaseManager::getQueryExecutor().getAllMessages();
});
auto messagesVectorPtr =
std::make_shared>>>(
std::move(messagesVector));
jsi::Array jsiMessages =
this->messageStore.parseDBDataStore(rt, messagesVectorPtr);
return jsiMessages;
}
void CommCoreModule::processMessageStoreOperationsSync(
jsi::Runtime &rt,
jsi::Array operations) {
return this->messageStore.processStoreOperationsSync(
rt, std::move(operations));
}
jsi::Array CommCoreModule::getAllThreadsSync(jsi::Runtime &rt) {
auto threadsVector =
NativeModuleUtils::runSyncOrThrowJSError>(rt, []() {
return DatabaseManager::getQueryExecutor().getAllThreads();
});
auto threadsVectorPtr =
std::make_shared>(std::move(threadsVector));
jsi::Array jsiThreads =
this->threadStore.parseDBDataStore(rt, threadsVectorPtr);
return jsiThreads;
}
void CommCoreModule::processThreadStoreOperationsSync(
jsi::Runtime &rt,
jsi::Array operations) {
this->threadStore.processStoreOperationsSync(rt, std::move(operations));
}
void CommCoreModule::processReportStoreOperationsSync(
jsi::Runtime &rt,
jsi::Array operations) {
this->reportStore.processStoreOperationsSync(rt, std::move(operations));
}
template
void CommCoreModule::appendDBStoreOps(
jsi::Runtime &rt,
jsi::Object &operations,
const char *key,
T &store,
std::shared_ptr>>
&destination) {
auto opsObject = operations.getProperty(rt, key);
if (opsObject.isObject()) {
auto ops = store.createOperations(rt, opsObject.asObject(rt).asArray(rt));
std::move(
std::make_move_iterator(ops.begin()),
std::make_move_iterator(ops.end()),
std::back_inserter(*destination));
}
}
jsi::Value CommCoreModule::processDBStoreOperations(
jsi::Runtime &rt,
jsi::Object operations) {
std::string createOperationsError;
auto storeOpsPtr =
std::make_shared>>();
try {
this->appendDBStoreOps(
rt, operations, "draftStoreOperations", this->draftStore, storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"threadStoreOperations",
this->threadStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"messageStoreOperations",
this->messageStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"reportStoreOperations",
this->reportStore,
storeOpsPtr);
this->appendDBStoreOps(
rt, operations, "userStoreOperations", this->userStore, storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"keyserverStoreOperations",
this->keyserverStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"communityStoreOperations",
this->communityStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"integrityStoreOperations",
this->integrityStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"syncedMetadataStoreOperations",
this->syncedMetadataStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"auxUserStoreOperations",
this->auxUserStore,
storeOpsPtr);
this->appendDBStoreOps(
rt,
operations,
"threadActivityStoreOperations",
this->threadActivityStore,
storeOpsPtr);
this->appendDBStoreOps(
rt, operations, "entryStoreOperations", this->entryStore, storeOpsPtr);
} catch (std::runtime_error &e) {
createOperationsError = e.what();
}
std::vector messages;
try {
auto messagesJSIObj = operations.getProperty(rt, "outboundP2PMessages");
if (messagesJSIObj.isObject()) {
auto messagesJSI = messagesJSIObj.asObject(rt).asArray(rt);
for (size_t idx = 0; idx < messagesJSI.size(rt); idx++) {
jsi::Object msgObj = messagesJSI.getValueAtIndex(rt, idx).asObject(rt);
std::string messageID =
msgObj.getProperty(rt, "messageID").asString(rt).utf8(rt);
std::string deviceID =
msgObj.getProperty(rt, "deviceID").asString(rt).utf8(rt);
std::string userID =
msgObj.getProperty(rt, "userID").asString(rt).utf8(rt);
std::string timestamp =
msgObj.getProperty(rt, "timestamp").asString(rt).utf8(rt);
std::string plaintext =
msgObj.getProperty(rt, "plaintext").asString(rt).utf8(rt);
std::string ciphertext =
msgObj.getProperty(rt, "ciphertext").asString(rt).utf8(rt);
std::string status =
msgObj.getProperty(rt, "status").asString(rt).utf8(rt);
OutboundP2PMessage outboundMessage{
messageID,
deviceID,
userID,
timestamp,
plaintext,
ciphertext,
status};
messages.push_back(outboundMessage);
}
}
} catch (std::runtime_error &e) {
createOperationsError = e.what();
}
return facebook::react::createPromiseAsJSIValue(
rt,
[=](jsi::Runtime &innerRt,
std::shared_ptr promise) {
taskType job = [=]() {
std::string error = createOperationsError;
if (!error.size()) {
try {
DatabaseManager::getQueryExecutor().beginTransaction();
for (const auto &operation : *storeOpsPtr) {
operation->execute();
}
if (messages.size() > 0) {
DatabaseManager::getQueryExecutor().addOutboundP2PMessages(
messages);
}
DatabaseManager::getQueryExecutor().captureBackupLogs();
DatabaseManager::getQueryExecutor().commitTransaction();
} catch (std::system_error &e) {
error = e.what();
DatabaseManager::getQueryExecutor().rollbackTransaction();
}
}
if (!error.size()) {
::triggerBackupFileUpload();
}
this->jsInvoker_->invokeAsync([=]() {
if (error.size()) {
promise->reject(error);
} else {
promise->resolve(jsi::Value::undefined());
}
});
};
GlobalDBSingleton::instance.scheduleOrRunCancellable(
job, promise, this->jsInvoker_);
});
}
void CommCoreModule::terminate(jsi::Runtime &rt) {
TerminateApp::terminate();
}
const std::string
getAccountDataKey(const std::string secureStoreAccountDataKey) {
folly::Optional storedSecretKey =
CommSecureStore::get(secureStoreAccountDataKey);
if (!storedSecretKey.hasValue()) {
storedSecretKey = crypto::Tools::generateRandomString(64);
CommSecureStore::set(secureStoreAccountDataKey, storedSecretKey.value());
}
return storedSecretKey.value();
}
void CommCoreModule::persistCryptoModules(
bool persistContentModule,
bool persistNotifsModule) {
std::string storedSecretKey = getAccountDataKey(secureStoreAccountDataKey);
if (!persistContentModule && !persistNotifsModule) {
return;
}
crypto::Persist newContentPersist;
if (persistContentModule) {
newContentPersist = this->contentCryptoModule->storeAsB64(storedSecretKey);
}
crypto::Persist newNotifsPersist;
if (persistNotifsModule) {
newNotifsPersist = this->notifsCryptoModule->storeAsB64(storedSecretKey);
}
std::promise persistencePromise;
std::future persistenceFuture = persistencePromise.get_future();
GlobalDBSingleton::instance.scheduleOrRunCancellable(
[=, &persistencePromise]() {
try {
DatabaseManager::getQueryExecutor().beginTransaction();
if (persistContentModule) {
DatabaseManager::getQueryExecutor().storeOlmPersistData(
DatabaseManager::getQueryExecutor().getContentAccountID(),
newContentPersist);
}
if (persistNotifsModule) {
DatabaseManager::getQueryExecutor().storeOlmPersistData(
DatabaseManager::getQueryExecutor().getNotifsAccountID(),
newNotifsPersist);
}
DatabaseManager::getQueryExecutor().commitTransaction();
persistencePromise.set_value();
} catch (std::system_error &e) {
DatabaseManager::getQueryExecutor().rollbackTransaction();
persistencePromise.set_exception(std::make_exception_ptr(e));
}
});
persistenceFuture.get();
}
jsi::Value CommCoreModule::initializeCryptoAccount(jsi::Runtime &rt) {
folly::Optional