diff --git a/lib/components/secondary-device-qr-auth-context-provider.react.js b/lib/components/secondary-device-qr-auth-context-provider.react.js --- a/lib/components/secondary-device-qr-auth-context-provider.react.js +++ b/lib/components/secondary-device-qr-auth-context-provider.react.js @@ -11,6 +11,7 @@ import { isLoggedIn } from '../selectors/user-selectors.js'; import { IdentityClientContext } from '../shared/identity-client-context.js'; import { useTunnelbroker } from '../tunnelbroker/tunnelbroker-context.js'; +import { databaseIdentifier } from '../types/database-identifier-types.js'; import { platformToIdentityDeviceType } from '../types/identity-service-types.js'; import type { IdentityAuthResult } from '../types/identity-service-types.js'; import { @@ -191,6 +192,7 @@ await sqliteAPI.migrateBackupSchema(); await sqliteAPI.copyContentFromBackupDatabase(); const clientDBStore = await sqliteAPI.getClientDBStore( + databaseIdentifier.MAIN, identityAuthResult.userID, ); dispatch({ diff --git a/lib/types/sqlite-types.js b/lib/types/sqlite-types.js --- a/lib/types/sqlite-types.js +++ b/lib/types/sqlite-types.js @@ -70,7 +70,10 @@ +fetchDMOperationsByType: ( type: string, ) => Promise>, - +getClientDBStore: (currentUserID: ?string) => Promise, + +getClientDBStore: ( + dbID: DatabaseIdentifier, + currentUserID: ?string, + ) => Promise, // write operations +removeInboundP2PMessages: (ids: $ReadOnlyArray) => Promise, diff --git a/native/account/restore.js b/native/account/restore.js --- a/native/account/restore.js +++ b/native/account/restore.js @@ -15,6 +15,7 @@ useWalletLogIn, } from 'lib/hooks/login-hooks.js'; import { IdentityClientContext } from 'lib/shared/identity-client-context.js'; +import { databaseIdentifier } from 'lib/types/database-identifier-types.js'; import { type IdentityAuthResult, type SignedDeviceList, @@ -198,6 +199,7 @@ await sqliteAPI.restoreUserData(backupData, identityAuthResult); const clientDBStore = await sqliteAPI.getClientDBStore( + databaseIdentifier.MAIN, identityAuthResult.userID, ); dispatch({ diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.h @@ -64,7 +64,8 @@ virtual jsi::Value updateDraft(jsi::Runtime &rt, jsi::String key, jsi::String text) override; - virtual jsi::Value getClientDBStore(jsi::Runtime &rt) override; + virtual jsi::Value + getClientDBStore(jsi::Runtime &rt, std::optional dbID) override; virtual jsi::Array getInitialMessagesSync(jsi::Runtime &rt) override; virtual void processReportStoreOperationsSync( jsi::Runtime &rt, diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -52,7 +52,14 @@ }); } -jsi::Value CommCoreModule::getClientDBStore(jsi::Runtime &rt) { +jsi::Value CommCoreModule::getClientDBStore( + jsi::Runtime &rt, + std::optional dbID) { + DatabaseIdentifier identifier = DatabaseIdentifier::MAIN; + if (dbID.has_value()) { + identifier = stringToDatabaseIdentifier(dbID->utf8(rt)); + } + return createPromiseAsJSIValue( rt, [=](jsi::Runtime &innerRt, std::shared_ptr promise) { taskType job = [=, &innerRt]() { @@ -73,34 +80,40 @@ std::vector messageStoreLocalMessageInfosVector; std::vector dmOperationsVector; try { - draftsVector = DatabaseManager::getQueryExecutor().getAllDrafts(); - messagesVector = - DatabaseManager::getQueryExecutor().getInitialMessages(); - threadsVector = DatabaseManager::getQueryExecutor().getAllThreads(); + draftsVector = + DatabaseManager::getQueryExecutor(identifier).getAllDrafts(); + messagesVector = DatabaseManager::getQueryExecutor(identifier) + .getInitialMessages(); + threadsVector = + DatabaseManager::getQueryExecutor(identifier).getAllThreads(); messageStoreThreadsVector = - DatabaseManager::getQueryExecutor().getAllMessageStoreThreads(); + DatabaseManager::getQueryExecutor(identifier) + .getAllMessageStoreThreads(); reportStoreVector = - DatabaseManager::getQueryExecutor().getAllReports(); - userStoreVector = DatabaseManager::getQueryExecutor().getAllUsers(); - keyserverStoreVector = - DatabaseManager::getQueryExecutor().getAllKeyservers(); - communityStoreVector = - DatabaseManager::getQueryExecutor().getAllCommunities(); - integrityStoreVector = DatabaseManager::getQueryExecutor() + DatabaseManager::getQueryExecutor(identifier).getAllReports(); + userStoreVector = + DatabaseManager::getQueryExecutor(identifier).getAllUsers(); + keyserverStoreVector = DatabaseManager::getQueryExecutor(identifier) + .getAllKeyservers(); + communityStoreVector = DatabaseManager::getQueryExecutor(identifier) + .getAllCommunities(); + integrityStoreVector = DatabaseManager::getQueryExecutor(identifier) .getAllIntegrityThreadHashes(); syncedMetadataStoreVector = - DatabaseManager::getQueryExecutor().getAllSyncedMetadata(); - auxUserStoreVector = - DatabaseManager::getQueryExecutor().getAllAuxUserInfos(); - threadActivityStoreVector = DatabaseManager::getQueryExecutor() - .getAllThreadActivityEntries(); + DatabaseManager::getQueryExecutor(identifier) + .getAllSyncedMetadata(); + auxUserStoreVector = DatabaseManager::getQueryExecutor(identifier) + .getAllAuxUserInfos(); + threadActivityStoreVector = + DatabaseManager::getQueryExecutor(identifier) + .getAllThreadActivityEntries(); entryStoreVector = - DatabaseManager::getQueryExecutor().getAllEntries(); + DatabaseManager::getQueryExecutor(identifier).getAllEntries(); messageStoreLocalMessageInfosVector = - DatabaseManager::getQueryExecutor() + DatabaseManager::getQueryExecutor(identifier) .getAllMessageStoreLocalMessageInfos(); dmOperationsVector = - DatabaseManager::getQueryExecutor().getDMOperations(); + DatabaseManager::getQueryExecutor(identifier).getDMOperations(); } catch (std::system_error &e) { error = e.what(); } diff --git a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/commJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/commJSI-generated.cpp @@ -16,7 +16,7 @@ return static_cast(&turboModule)->updateDraft(rt, args[0].asString(rt), args[1].asString(rt)); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getClientDBStore(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->getClientDBStore(rt); + return static_cast(&turboModule)->getClientDBStore(rt, args[0].isNull() || args[0].isUndefined() ? std::nullopt : std::make_optional(args[0].asString(rt))); } static jsi::Value __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getInitialMessagesSync(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->getInitialMessagesSync(rt); @@ -239,7 +239,7 @@ CommCoreModuleSchemaCxxSpecJSI::CommCoreModuleSchemaCxxSpecJSI(std::shared_ptr jsInvoker) : TurboModule("CommTurboModule", jsInvoker) { methodMap_["updateDraft"] = MethodMetadata {2, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_updateDraft}; - methodMap_["getClientDBStore"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getClientDBStore}; + methodMap_["getClientDBStore"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getClientDBStore}; methodMap_["getInitialMessagesSync"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getInitialMessagesSync}; methodMap_["processMessageStoreOperationsSync"] = MethodMetadata {1, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_processMessageStoreOperationsSync}; methodMap_["getAllThreadsSync"] = MethodMetadata {0, __hostFunction_CommCoreModuleSchemaCxxSpecJSI_getAllThreadsSync}; diff --git a/native/cpp/CommonCpp/_generated/commJSI.h b/native/cpp/CommonCpp/_generated/commJSI.h --- a/native/cpp/CommonCpp/_generated/commJSI.h +++ b/native/cpp/CommonCpp/_generated/commJSI.h @@ -21,7 +21,7 @@ public: virtual jsi::Value updateDraft(jsi::Runtime &rt, jsi::String key, jsi::String text) = 0; - virtual jsi::Value getClientDBStore(jsi::Runtime &rt) = 0; + virtual jsi::Value getClientDBStore(jsi::Runtime &rt, std::optional dbID) = 0; virtual jsi::Array getInitialMessagesSync(jsi::Runtime &rt) = 0; virtual void processMessageStoreOperationsSync(jsi::Runtime &rt, jsi::Array operations) = 0; virtual jsi::Array getAllThreadsSync(jsi::Runtime &rt) = 0; @@ -121,13 +121,13 @@ return bridging::callFromJs( rt, &T::updateDraft, jsInvoker_, instance_, std::move(key), std::move(text)); } - jsi::Value getClientDBStore(jsi::Runtime &rt) override { + jsi::Value getClientDBStore(jsi::Runtime &rt, std::optional dbID) override { static_assert( - bridging::getParameterCount(&T::getClientDBStore) == 1, - "Expected getClientDBStore(...) to have 1 parameters"); + bridging::getParameterCount(&T::getClientDBStore) == 2, + "Expected getClientDBStore(...) to have 2 parameters"); return bridging::callFromJs( - rt, &T::getClientDBStore, jsInvoker_, instance_); + rt, &T::getClientDBStore, jsInvoker_, instance_, std::move(dbID)); } jsi::Array getInitialMessagesSync(jsi::Runtime &rt) override { static_assert( diff --git a/native/data/sqlite-data-handler.js b/native/data/sqlite-data-handler.js --- a/native/data/sqlite-data-handler.js +++ b/native/data/sqlite-data-handler.js @@ -15,6 +15,7 @@ recoveryFromDataHandlerActionSources, type RecoveryFromDataHandlerActionSource, } from 'lib/types/account-types.js'; +import { databaseIdentifier } from 'lib/types/database-identifier-types.js'; import { getConfig } from 'lib/utils/config.js'; import { getMessageForException } from 'lib/utils/errors.js'; import { useDispatchActionPromise } from 'lib/utils/redux-promise-utils.js'; @@ -229,6 +230,7 @@ try { const { sqliteAPI } = getConfig(); const clientDBStore = await sqliteAPI.getClientDBStore( + databaseIdentifier.MAIN, currentLoggedInUserID, ); dispatch({ diff --git a/native/database/store.js b/native/database/store.js --- a/native/database/store.js +++ b/native/database/store.js @@ -10,12 +10,16 @@ import { threadActivityStoreOpsHandlers } from 'lib/ops/thread-activity-store-ops.js'; import { threadStoreOpsHandlers } from 'lib/ops/thread-store-ops.js'; import { userStoreOpsHandlers } from 'lib/ops/user-store-ops.js'; +import type { DatabaseIdentifier } from 'lib/types/database-identifier-types'; import type { ClientStore } from 'lib/types/store-ops-types.js'; import { translateClientDBLocalMessageInfos } from 'lib/utils/message-ops-utils.js'; import { commCoreModule } from '../native-modules.js'; -async function getClientDBStore(currentUserID: ?string): Promise { +async function getClientDBStore( + dbID: DatabaseIdentifier, + currentUserID: ?string, +): Promise { const { threads, messages, @@ -31,7 +35,7 @@ threadActivityEntries, entries, messageStoreLocalMessageInfos, - } = await commCoreModule.getClientDBStore(); + } = await commCoreModule.getClientDBStore(dbID); const threadInfosFromDB = threadStoreOpsHandlers.translateClientDBData(threads); const reportsFromDB = reportStoreOpsHandlers.translateClientDBData(reports); diff --git a/native/schema/CommCoreModuleSchema.js b/native/schema/CommCoreModuleSchema.js --- a/native/schema/CommCoreModuleSchema.js +++ b/native/schema/CommCoreModuleSchema.js @@ -37,7 +37,8 @@ interface Spec extends TurboModule { +updateDraft: (key: string, text: string) => Promise; - +getClientDBStore: () => Promise; + // This type should be DatabaseIdentifier + +getClientDBStore: (dbID?: ?string) => Promise; +getInitialMessagesSync: () => $ReadOnlyArray; +processMessageStoreOperationsSync: ( operations: $ReadOnlyArray, @@ -253,7 +254,7 @@ +getSIWEBackupSecrets: () => Promise; +processDBStoreOperations: ( operations: ClientDBStoreOperations, - //This type should be DatabaseIdentifier + // This type should be DatabaseIdentifier dbID?: ?string, ) => Promise; +getQRAuthBackupData: () => Promise; diff --git a/web/database/store.js b/web/database/store.js --- a/web/database/store.js +++ b/web/database/store.js @@ -10,6 +10,7 @@ import { threadActivityStoreOpsHandlers } from 'lib/ops/thread-activity-store-ops.js'; import { threadStoreOpsHandlers } from 'lib/ops/thread-store-ops.js'; import { userStoreOpsHandlers } from 'lib/ops/user-store-ops.js'; +import type { DatabaseIdentifier } from 'lib/types/database-identifier-types.js'; import type { ClientStore } from 'lib/types/store-ops-types.js'; import { translateClientDBLocalMessageInfos } from 'lib/utils/message-ops-utils.js'; @@ -17,7 +18,10 @@ import { getCommSharedWorker } from '../shared-worker/shared-worker-provider.js'; import { workerRequestMessageTypes } from '../types/worker-types.js'; -async function getClientDBStore(currentUserID: ?string): Promise { +async function getClientDBStore( + dbID: DatabaseIdentifier, + currentUserID: ?string, +): Promise { const sharedWorker = await getCommSharedWorker(); let result: ClientStore = { currentUserID, @@ -38,6 +42,7 @@ }; const data = await sharedWorker.schedule({ type: workerRequestMessageTypes.GET_CLIENT_STORE, + dbID, }); if (data?.store?.drafts) { result = { diff --git a/web/redux/initial-state-gate.js b/web/redux/initial-state-gate.js --- a/web/redux/initial-state-gate.js +++ b/web/redux/initial-state-gate.js @@ -77,7 +77,10 @@ void (async () => { try { const { sqliteAPI } = getConfig(); - const clientDBStore = await sqliteAPI.getClientDBStore(null); + const clientDBStore = await sqliteAPI.getClientDBStore( + databaseIdentifier.MAIN, + null, + ); dispatch({ type: setClientDBStoreActionType, payload: clientDBStore, diff --git a/web/shared-worker/worker/shared-worker.js b/web/shared-worker/worker/shared-worker.js --- a/web/shared-worker/worker/shared-worker.js +++ b/web/shared-worker/worker/shared-worker.js @@ -224,6 +224,21 @@ // read-only operations if (message.type === workerRequestMessageTypes.GET_CLIENT_STORE) { + if (message.dbID && message.dbID === databaseIdentifier.RESTORED) { + const backupQueryExecutor = getSQLiteQueryExecutor( + databaseIdentifier.RESTORED, + ); + if (!backupQueryExecutor) { + throw new Error( + `Backup not initialized, unable to process request type: ${message.type}`, + ); + } + return { + type: workerResponseMessageTypes.CLIENT_STORE, + store: getClientStoreFromQueryExecutor(backupQueryExecutor), + }; + } + return { type: workerResponseMessageTypes.CLIENT_STORE, store: getClientStoreFromQueryExecutor(sqliteQueryExecutor), diff --git a/web/types/worker-types.js b/web/types/worker-types.js --- a/web/types/worker-types.js +++ b/web/types/worker-types.js @@ -108,6 +108,7 @@ export type GetClientStoreRequestMessage = { +type: 4, + +dbID?: DatabaseIdentifier, }; export type SetCurrentUserIDRequestMessage = {