diff --git a/keyserver/src/database/database.js b/keyserver/src/database/database.js --- a/keyserver/src/database/database.js +++ b/keyserver/src/database/database.js @@ -8,7 +8,12 @@ import { connectionLimit, queryWarnTime } from './consts.js'; import { getDBConfig } from './db-config.js'; import DatabaseMonitor from './monitor.js'; -import type { Pool, SQLOrString, SQLStatementType } from './types.js'; +import type { + Connection, + Pool, + SQLOrString, + SQLStatementType, +} from './types.js'; import { getScriptContext } from '../scripts/script-context.js'; const SQLStatement: SQLStatementType = SQL.SQLStatement; @@ -184,7 +189,7 @@ return mysql.format(statement.sql, statement.values); } -async function getMultipleStatementsConnection() { +async function getMultipleStatementsConnection(): Promise { const { dbType, ...dbConfig } = await getDBConfig(); const options: ConnectionOptions = { ...dbConfig, @@ -203,4 +208,5 @@ setConnectionContext, dbQuery, rawSQL, + getMultipleStatementsConnection, }; diff --git a/keyserver/src/database/types.js b/keyserver/src/database/types.js --- a/keyserver/src/database/types.js +++ b/keyserver/src/database/types.js @@ -25,3 +25,8 @@ }; export type SQLOrString = SQLStatementType | string; + +export type Connection = { + +query: (input: SQLOrString) => Promise, + +end: () => void, +}; diff --git a/keyserver/src/updaters/olm-account-updater.js b/keyserver/src/updaters/olm-account-updater.js new file mode 100644 --- /dev/null +++ b/keyserver/src/updaters/olm-account-updater.js @@ -0,0 +1,95 @@ +// @flow + +import type { Account as OlmAccount } from '@commapp/olm'; + +import { ServerError } from 'lib/utils/errors.js'; + +import { + SQL, + dbQuery, + getMultipleStatementsConnection, +} from '../database/database.js'; +import { unpickleOlmAccount } from '../utils/olm-utils.js'; + +async function fetchCallUpdateOlmAccount( + olmAccountType: 'content' | 'notifications', + callback: (account: OlmAccount) => Promise, +): Promise { + const isContent = olmAccountType === 'content'; + const connection = await getMultipleStatementsConnection(); + await connection.query( + SQL` + START TRANSACTION + `, + ); + const [olmAccountResult] = await connection.query( + SQL` + SELECT pickling_key, pickled_olm_account + FROM olm_accounts + WHERE is_content = ${isContent} + FOR UPDATE + `, + ); + if (olmAccountResult.length === 0) { + await connection.query( + SQL` + ROLLBACK + `, + ); + throw new ServerError('missing_olm_account'); + } + + const picklingKey = olmAccountResult[0].pickling_key; + const pickledAccount = olmAccountResult[0].pickled_olm_account; + + const account = await unpickleOlmAccount({ + picklingKey, + pickledAccount, + }); + const result = await callback(account); + const updatedPickledAccount = account.pickle(picklingKey); + + await connection.query( + SQL` + UPDATE olm_accounts + SET pickled_olm_account = ${updatedPickledAccount} + WHERE is_content = ${isContent} + `, + ); + await connection.query( + SQL` + COMMIT + `, + ); + return result; +} + +async function fetchOlmAccount( + olmAccountType: 'content' | 'notifications', +): Promise<{ + account: OlmAccount, + picklingKey: string, +}> { + const isContent = olmAccountType === 'content'; + const [olmAccountResult] = await dbQuery( + SQL` + SELECT pickling_key, pickled_olm_account + FROM olm_accounts + WHERE is_content = ${isContent} + `, + ); + if (olmAccountResult.length === 0) { + throw new ServerError('missing_olm_account'); + } + const picklingKey = olmAccountResult[0].pickling_key; + const pickledAccount = olmAccountResult[0].pickled_olm_account; + + const account = await unpickleOlmAccount({ + picklingKey, + pickledAccount, + }); + + return { account, picklingKey }; +} + +export { fetchCallUpdateOlmAccount, fetchOlmAccount };