diff --git a/keyserver/src/push/rescind.js b/keyserver/src/push/rescind.js --- a/keyserver/src/push/rescind.js +++ b/keyserver/src/push/rescind.js @@ -309,7 +309,7 @@ platformDetails: PlatformDetails, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - threadID = validateOutput(platformDetails, tID, threadID); + threadID = await validateOutput(platformDetails, tID, threadID); const { codeVersion } = platformDetails; const notification = new apn.Notification(); @@ -356,7 +356,7 @@ platformDetails: PlatformDetails, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - threadID = validateOutput(platformDetails, tID, threadID); + threadID = await validateOutput(platformDetails, tID, threadID); const { codeVersion } = platformDetails; const notification = { diff --git a/keyserver/src/push/send.js b/keyserver/src/push/send.js --- a/keyserver/src/push/send.js +++ b/keyserver/src/push/send.js @@ -907,7 +907,7 @@ inputData: APNsNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, apnsNotifInputDataValidator, inputData, @@ -1087,7 +1087,7 @@ inputData: AndroidNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, androidNotifInputDataValidator, inputData, @@ -1241,7 +1241,7 @@ inputData: WebNotifInputData, devices: $ReadOnlyArray, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, webNotifInputDataValidator, inputData, @@ -1283,7 +1283,7 @@ devices: $ReadOnlyArray, inputData: WNSNotifInputData, ): Promise<$ReadOnlyArray> { - const convertedData = validateOutput( + const convertedData = await validateOutput( inputData.platformDetails, wnsNotifInputDataValidator, inputData, diff --git a/keyserver/src/responders/handlers.js b/keyserver/src/responders/handlers.js --- a/keyserver/src/responders/handlers.js +++ b/keyserver/src/responders/handlers.js @@ -44,7 +44,11 @@ responder: async (viewer, input) => { const request = await validateInput(viewer, inputValidator, input); const result = await responder(viewer, request); - return validateOutput(viewer.platformDetails, outputValidator, result); + return await validateOutput( + viewer.platformDetails, + outputValidator, + result, + ); }, requiredPolicies, }; diff --git a/keyserver/src/responders/redux-state-responders.js b/keyserver/src/responders/redux-state-responders.js --- a/keyserver/src/responders/redux-state-responders.js +++ b/keyserver/src/responders/redux-state-responders.js @@ -44,7 +44,7 @@ import { ServerError } from 'lib/utils/errors.js'; import { promiseAll } from 'lib/utils/promises.js'; import { urlInfoValidator } from 'lib/utils/url-utils.js'; -import { tShape, ashoatKeyserverID, tID } from 'lib/utils/validation-utils.js'; +import { tShape, tID } from 'lib/utils/validation-utils.js'; import type { InitialReduxStateRequest, ExcludedData, @@ -64,6 +64,7 @@ import { getWebPushConfig } from '../push/providers.js'; import { setNewSession } from '../session/cookies.js'; import { Viewer } from '../session/viewer.js'; +import { thisKeyserverID } from '../user/identity.js'; const excludedDataValidator: TInterface = tShape({ threadStore: t.maybe(t.Bool), @@ -190,24 +191,26 @@ { threadInfos }, { rawMessageInfos, truncationStatuses }, hasNotAcknowledgedPolicies, + keyserverID, ] = await Promise.all([ threadInfoPromise, messageInfoPromise, hasNotAcknowledgedPoliciesPromise, + thisKeyserverID(), ]); if (hasNotAcknowledgedPolicies) { return { messages: {}, threads: {}, local: {}, - currentAsOf: { [ashoatKeyserverID]: 0 }, + currentAsOf: { [keyserverID]: 0 }, }; } const { messageStore: freshStore } = freshMessageStore( rawMessageInfos, truncationStatuses, { - [ashoatKeyserverID]: mostRecentMessageTimestamp( + [keyserverID]: mostRecentMessageTimestamp( rawMessageInfos, serverUpdatesCurrentAsOf, ), diff --git a/keyserver/src/shared/state-sync/current-user-state-sync-spec.js b/keyserver/src/shared/state-sync/current-user-state-sync-spec.js --- a/keyserver/src/shared/state-sync/current-user-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/current-user-state-sync-spec.js @@ -22,7 +22,7 @@ }, async fetchServerInfosHash(viewer: Viewer) { const info = await fetch(viewer); - return getHash(info); + return await getHash(info); }, getServerInfosHash: getHash, getServerInfoHash: getHash, @@ -33,6 +33,11 @@ return fetchCurrentUserInfo(viewer); } -function getHash(currentUserInfo: CurrentUserInfo) { - return hash(validateOutput(null, currentUserInfoValidator, currentUserInfo)); +async function getHash(currentUserInfo: CurrentUserInfo) { + const output = await validateOutput( + null, + currentUserInfoValidator, + currentUserInfo, + ); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/entries-state-sync-spec.js b/keyserver/src/shared/state-sync/entries-state-sync-spec.js --- a/keyserver/src/shared/state-sync/entries-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/entries-state-sync-spec.js @@ -35,7 +35,7 @@ }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const info = await fetch(viewer, ids); - return getServerInfosHash(info); + return await getServerInfosHash(info); }, getServerInfosHash, getServerInfoHash, @@ -51,10 +51,12 @@ return serverEntryInfosObject(entriesResult.rawEntryInfos); } -function getServerInfosHash(infos: RawEntryInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: RawEntryInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: RawEntryInfo) { - return hash(validateOutput(null, rawEntryInfoValidator, info)); +async function getServerInfoHash(info: RawEntryInfo) { + const output = await validateOutput(null, rawEntryInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/state-sync-spec.js b/keyserver/src/shared/state-sync/state-sync-spec.js --- a/keyserver/src/shared/state-sync/state-sync-spec.js +++ b/keyserver/src/shared/state-sync/state-sync-spec.js @@ -20,7 +20,7 @@ viewer: Viewer, ids?: $ReadOnlySet, ) => Promise, - +getServerInfosHash: (infos: Infos) => number, - +getServerInfoHash: (info: Info) => number, + +getServerInfosHash: (infos: Infos) => Promise, + +getServerInfoHash: (info: Info) => Promise, ...StateSyncSpec, }; diff --git a/keyserver/src/shared/state-sync/threads-state-sync-spec.js b/keyserver/src/shared/state-sync/threads-state-sync-spec.js --- a/keyserver/src/shared/state-sync/threads-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/threads-state-sync-spec.js @@ -28,7 +28,7 @@ }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const infos = await fetch(viewer, ids); - return getServerInfosHash(infos); + return await getServerInfosHash(infos); }, getServerInfosHash, getServerInfoHash, @@ -41,10 +41,12 @@ return result.threadInfos; } -function getServerInfosHash(infos: MixedRawThreadInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: MixedRawThreadInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: LegacyRawThreadInfo | RawThreadInfo) { - return hash(validateOutput(null, mixedRawThreadInfoValidator, info)); +async function getServerInfoHash(info: LegacyRawThreadInfo | RawThreadInfo) { + const output = await validateOutput(null, mixedRawThreadInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/shared/state-sync/users-state-sync-spec.js b/keyserver/src/shared/state-sync/users-state-sync-spec.js --- a/keyserver/src/shared/state-sync/users-state-sync-spec.js +++ b/keyserver/src/shared/state-sync/users-state-sync-spec.js @@ -24,7 +24,7 @@ }, async fetchServerInfosHash(viewer: Viewer, ids?: $ReadOnlySet) { const infos = await fetch(viewer, ids); - return getServerInfosHash(infos); + return await getServerInfosHash(infos); }, getServerInfosHash, getServerInfoHash, @@ -39,10 +39,12 @@ return fetchKnownUserInfos(viewer); } -function getServerInfosHash(infos: UserInfos) { - return combineUnorderedHashes(values(infos).map(getServerInfoHash)); +async function getServerInfosHash(infos: UserInfos) { + const results = await Promise.all(values(infos).map(getServerInfoHash)); + return combineUnorderedHashes(results); } -function getServerInfoHash(info: UserInfo) { - return hash(validateOutput(null, userInfoValidator, info)); +async function getServerInfoHash(info: UserInfo) { + const output = await validateOutput(null, userInfoValidator, info); + return hash(output); } diff --git a/keyserver/src/socket/session-utils.js b/keyserver/src/socket/session-utils.js --- a/keyserver/src/socket/session-utils.js +++ b/keyserver/src/socket/session-utils.js @@ -478,7 +478,9 @@ ) { // We have a type error here because Flow has no way to determine that // spec and infos are matched up - hashValue = spec.getServerInfoHash((iterableInfos[infoID]: any)); + hashValue = await spec.getServerInfoHash( + (iterableInfos[infoID]: any), + ); } else { hashValue = hash(iterableInfos[infoID]); } diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -379,7 +379,7 @@ } const { viewer } = this; - const validatedMessage = validateOutput( + const validatedMessage = await validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, message, @@ -411,7 +411,7 @@ payload: compressionResult.result, }; - const validatedCompressedMessage = validateOutput( + const validatedCompressedMessage = await validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, compressedMessage, diff --git a/keyserver/src/uploads/uploads.js b/keyserver/src/uploads/uploads.js --- a/keyserver/src/uploads/uploads.js +++ b/keyserver/src/uploads/uploads.js @@ -102,7 +102,7 @@ throw new ServerError('invalid_parameters'); } const results = await createUploads(viewer, uploadInfos); - return validateOutput( + return await validateOutput( viewer.platformDetails, MultimediaUploadResultValidator, { diff --git a/keyserver/src/user/checks.js b/keyserver/src/user/checks.js --- a/keyserver/src/user/checks.js +++ b/keyserver/src/user/checks.js @@ -1,7 +1,11 @@ // @flow import { getCommConfig } from 'lib/utils/comm-config.js'; -export type UserCredentials = { +username: string, +password: string }; +export type UserCredentials = { + +username: string, + +password: string, + +id?: string, +}; async function ensureUserCredentials() { const userCredentials = await getCommConfig({ diff --git a/keyserver/src/user/identity.js b/keyserver/src/user/identity.js --- a/keyserver/src/user/identity.js +++ b/keyserver/src/user/identity.js @@ -2,6 +2,10 @@ import type { QueryResults } from 'mysql'; +import { getCommConfig } from 'lib/utils/comm-config.js'; +import { ashoatKeyserverID } from 'lib/utils/validation-utils.js'; + +import type { UserCredentials } from './checks.js'; import { SQL, dbQuery } from '../database/database.js'; const userIDMetadataKey = 'user_id'; @@ -34,6 +38,14 @@ return { userId: userID, accessToken }; } +async function thisKeyserverID(): Promise { + const userCredentials = await getCommConfig({ + folder: 'secrets', + name: 'user_credentials', + }); + return userCredentials?.id ?? ashoatKeyserverID; +} + function saveIdentityInfo(userInfo: IdentityInfo): Promise { const updateQuery = SQL` REPLACE INTO metadata (name, data) @@ -44,4 +56,4 @@ return dbQuery(updateQuery); } -export { fetchIdentityInfo, saveIdentityInfo }; +export { fetchIdentityInfo, thisKeyserverID, saveIdentityInfo }; diff --git a/keyserver/src/user/login.js b/keyserver/src/user/login.js --- a/keyserver/src/user/login.js +++ b/keyserver/src/user/login.js @@ -7,6 +7,7 @@ import { ServerError } from 'lib/utils/errors.js'; import { retrieveAccountKeysSet } from 'lib/utils/olm-utils.js'; +import type { UserCredentials } from './checks.js'; import { saveIdentityInfo, fetchIdentityInfo, @@ -15,8 +16,6 @@ import { getMessageForException } from '../responders/utils.js'; import { fetchCallUpdateOlmAccount } from '../updaters/olm-account-updater.js'; -type UserCredentials = { +username: string, +password: string }; - // After register or login is successful function markKeysAsPublished(account: OlmAccount) { account.mark_prekey_as_published(); @@ -109,6 +108,9 @@ fetchCallUpdateOlmAccount('content', markKeysAsPublished), fetchCallUpdateOlmAccount('notifications', markKeysAsPublished), ]); + if (userInfo.id && userInfo.id !== identity_info.userId) { + throw new Error('User id inconsistent with environment config'); + } return identity_info; } catch (e) { console.warn('Failed to login user: ' + getMessageForException(e)); @@ -128,6 +130,9 @@ fetchCallUpdateOlmAccount('content', markKeysAsPublished), fetchCallUpdateOlmAccount('notifications', markKeysAsPublished), ]); + if (userInfo.id && userInfo.id !== identity_info.userId) { + throw new Error('User id inconsistent with environment config'); + } return identity_info; } catch (err) { console.warn('Failed to register user: ' + getMessageForException(err)); diff --git a/keyserver/src/utils/validation-utils.js b/keyserver/src/utils/validation-utils.js --- a/keyserver/src/utils/validation-utils.js +++ b/keyserver/src/utils/validation-utils.js @@ -20,12 +20,12 @@ tPlatform, tPlatformDetails, assertWithValidator, - ashoatKeyserverID, } from 'lib/utils/validation-utils.js'; import { fetchNotAcknowledgedPolicies } from '../fetchers/policy-acknowledgment-fetchers.js'; import { verifyClientSupported } from '../session/version.js'; import type { Viewer } from '../session/viewer.js'; +import { thisKeyserverID } from '../user/identity.js'; async function validateInput( viewer: Viewer, @@ -37,6 +37,8 @@ } const convertedInput = checkInputValidator(inputValidator, input); + const keyserverID = await thisKeyserverID(); + if ( hasMinStateVersion(viewer.platformDetails, { native: 43, @@ -45,7 +47,7 @@ ) { try { return convertClientIDsToServerIDs( - ashoatKeyserverID, + keyserverID, inputValidator, convertedInput, ); @@ -57,11 +59,11 @@ return convertedInput; } -function validateOutput( +async function validateOutput( platformDetails: ?PlatformDetails, outputValidator: TType, data: T, -): T { +): Promise { if (!outputValidator.is(data)) { console.trace( 'Output validation failed, validator is:', @@ -70,17 +72,15 @@ return data; } + const keyserverID = await thisKeyserverID(); + if ( hasMinStateVersion(platformDetails, { native: 43, web: 3, }) ) { - return convertServerIDsToClientIDs( - ashoatKeyserverID, - outputValidator, - data, - ); + return convertServerIDsToClientIDs(keyserverID, outputValidator, data); } return data;