diff --git a/keyserver/src/session/cookies.js b/keyserver/src/session/cookies.js index d18ed3618..ef66550fe 100644 --- a/keyserver/src/session/cookies.js +++ b/keyserver/src/session/cookies.js @@ -1,831 +1,789 @@ // @flow import crypto from 'crypto'; import type { $Response, $Request } from 'express'; import invariant from 'invariant'; import url from 'url'; import { isStaff } from 'lib/shared/staff-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { Shape } from 'lib/types/core.js'; import type { SignedIdentityKeysBlob } from 'lib/types/crypto-types.js'; import type { Platform, PlatformDetails } from 'lib/types/device-types.js'; import type { CalendarQuery } from 'lib/types/entry-types.js'; import { type ServerSessionChange, cookieLifetime, - cookieSources, - type CookieSource, cookieTypes, sessionIdentifierTypes, type SessionIdentifierType, } from 'lib/types/session-types.js'; import type { SIWESocialProof } from 'lib/types/siwe-types.js'; import type { InitialClientSocketMessage } from 'lib/types/socket-types.js'; import type { UserInfo } from 'lib/types/user-types.js'; import { isDev } from 'lib/utils/dev-utils.js'; import { values } from 'lib/utils/objects.js'; -import { promiseAll } from 'lib/utils/promises.js'; import { isBcryptHash, getCookieHash, verifyCookieHash, } from './cookie-hash.js'; import { Viewer } from './viewer.js'; import type { AnonymousViewerData, UserViewerData } from './viewer.js'; import createIDs from '../creators/id-creator.js'; import { createSession } from '../creators/session-creator.js'; import { dbQuery, SQL } from '../database/database.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import { handleAsyncPromise } from '../responders/handlers.js'; import { clearDeviceToken } from '../updaters/device-token-updaters.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { type AppURLFacts, getAppURLFactsFromRequestURL, } from '../utils/urls.js'; function cookieIsExpired(lastUsed: number) { return lastUsed + cookieLifetime <= Date.now(); } type SessionParameterInfo = { isSocket: boolean, sessionID: ?string, sessionIdentifierType: SessionIdentifierType, ipAddress: string, userAgent: ?string, }; type FetchViewerResult = - | { type: 'valid', viewer: Viewer } + | { +type: 'valid', +viewer: Viewer } | InvalidFetchViewerResult; type InvalidFetchViewerResult = | { - type: 'nonexistant', - cookieName: ?string, - cookieSource: ?CookieSource, - sessionParameterInfo: SessionParameterInfo, + +type: 'nonexistant', + +cookieName: ?string, + +sessionParameterInfo: SessionParameterInfo, } | { - type: 'invalidated', - cookieName: string, - cookieID: string, - cookieSource: CookieSource, - sessionParameterInfo: SessionParameterInfo, - platformDetails: ?PlatformDetails, - deviceToken: ?string, + +type: 'invalidated', + +cookieName: string, + +cookieID: string, + +sessionParameterInfo: SessionParameterInfo, + +platformDetails: ?PlatformDetails, + +deviceToken: ?string, }; async function fetchUserViewer( cookie: string, - cookieSource: CookieSource, sessionParameterInfo: SessionParameterInfo, ): Promise { const [cookieID, cookiePassword] = cookie.split(':'); if (!cookieID || !cookiePassword) { return { type: 'nonexistant', cookieName: cookieTypes.USER, - cookieSource, sessionParameterInfo, }; } const query = SQL` SELECT hash, user, last_used, platform, device_token, versions FROM cookies WHERE id = ${cookieID} AND user IS NOT NULL `; const [[result], allSessionInfo] = await Promise.all([ dbQuery(query), fetchSessionInfo(sessionParameterInfo, cookieID), ]); if (result.length === 0) { return { type: 'nonexistant', cookieName: cookieTypes.USER, - cookieSource, sessionParameterInfo, }; } let sessionID = null, sessionInfo = null; if (allSessionInfo) { ({ sessionID, ...sessionInfo } = allSessionInfo); } const cookieRow = result[0]; let platformDetails = null; if (cookieRow.versions) { const versions = JSON.parse(cookieRow.versions); platformDetails = { platform: cookieRow.platform, codeVersion: versions.codeVersion, stateVersion: versions.stateVersion, }; } else { platformDetails = { platform: cookieRow.platform }; } const deviceToken = cookieRow.device_token; const cookieHash = cookieRow.hash; if ( !verifyCookieHash(cookiePassword, cookieHash) || cookieIsExpired(cookieRow.last_used) ) { return { type: 'invalidated', cookieName: cookieTypes.USER, cookieID, - cookieSource, sessionParameterInfo, platformDetails, deviceToken, }; } const userID = cookieRow.user.toString(); const viewer = new Viewer({ isSocket: sessionParameterInfo.isSocket, loggedIn: true, id: userID, platformDetails, deviceToken, userID, - cookieSource, cookieID, cookiePassword, cookieHash, sessionIdentifierType: sessionParameterInfo.sessionIdentifierType, sessionID, sessionInfo, isScriptViewer: false, ipAddress: sessionParameterInfo.ipAddress, userAgent: sessionParameterInfo.userAgent, }); return { type: 'valid', viewer }; } async function fetchAnonymousViewer( cookie: string, - cookieSource: CookieSource, sessionParameterInfo: SessionParameterInfo, ): Promise { const [cookieID, cookiePassword] = cookie.split(':'); if (!cookieID || !cookiePassword) { return { type: 'nonexistant', cookieName: cookieTypes.ANONYMOUS, - cookieSource, sessionParameterInfo, }; } const query = SQL` SELECT last_used, hash, platform, device_token, versions FROM cookies WHERE id = ${cookieID} AND user IS NULL `; const [[result], allSessionInfo] = await Promise.all([ dbQuery(query), fetchSessionInfo(sessionParameterInfo, cookieID), ]); if (result.length === 0) { return { type: 'nonexistant', cookieName: cookieTypes.ANONYMOUS, - cookieSource, sessionParameterInfo, }; } let sessionID = null, sessionInfo = null; if (allSessionInfo) { ({ sessionID, ...sessionInfo } = allSessionInfo); } const cookieRow = result[0]; let platformDetails = null; if (cookieRow.platform && cookieRow.versions) { const versions = JSON.parse(cookieRow.versions); platformDetails = { platform: cookieRow.platform, codeVersion: versions.codeVersion, stateVersion: versions.stateVersion, }; } else if (cookieRow.platform) { platformDetails = { platform: cookieRow.platform }; } const deviceToken = cookieRow.device_token; const cookieHash = cookieRow.hash; if ( !verifyCookieHash(cookiePassword, cookieHash) || cookieIsExpired(cookieRow.last_used) ) { return { type: 'invalidated', cookieName: cookieTypes.ANONYMOUS, cookieID, - cookieSource, sessionParameterInfo, platformDetails, deviceToken, }; } const viewer = new Viewer({ isSocket: sessionParameterInfo.isSocket, loggedIn: false, id: cookieID, platformDetails, deviceToken, - cookieSource, cookieID, cookiePassword, cookieHash, sessionIdentifierType: sessionParameterInfo.sessionIdentifierType, sessionID, sessionInfo, isScriptViewer: false, ipAddress: sessionParameterInfo.ipAddress, userAgent: sessionParameterInfo.userAgent, }); return { type: 'valid', viewer }; } type SessionInfo = { +sessionID: ?string, +lastValidated: number, +lastUpdate: number, +calendarQuery: CalendarQuery, }; async function fetchSessionInfo( sessionParameterInfo: SessionParameterInfo, cookieID: string, ): Promise { const { sessionID } = sessionParameterInfo; const session = sessionID !== undefined ? sessionID : cookieID; if (!session) { return null; } const query = SQL` SELECT query, last_validated, last_update FROM sessions WHERE id = ${session} AND cookie = ${cookieID} `; const [result] = await dbQuery(query); if (result.length === 0) { return null; } return { sessionID, lastValidated: result[0].last_validated, lastUpdate: result[0].last_update, calendarQuery: JSON.parse(result[0].query), }; } async function fetchViewerFromRequestBody( body: mixed, sessionParameterInfo: SessionParameterInfo, ): Promise { if (!body || typeof body !== 'object') { return { type: 'nonexistant', cookieName: null, - cookieSource: null, sessionParameterInfo, }; } const cookiePair = body.cookie; if (cookiePair === null || cookiePair === '') { return { type: 'nonexistant', cookieName: null, - cookieSource: cookieSources.BODY, sessionParameterInfo, }; } if (!cookiePair || typeof cookiePair !== 'string') { return { type: 'nonexistant', cookieName: null, - cookieSource: null, sessionParameterInfo, }; } const [type, cookie] = cookiePair.split('='); if (type === cookieTypes.USER && cookie) { - return await fetchUserViewer( - cookie, - cookieSources.BODY, - sessionParameterInfo, - ); + return await fetchUserViewer(cookie, sessionParameterInfo); } else if (type === cookieTypes.ANONYMOUS && cookie) { - return await fetchAnonymousViewer( - cookie, - cookieSources.BODY, - sessionParameterInfo, - ); + return await fetchAnonymousViewer(cookie, sessionParameterInfo); } return { type: 'nonexistant', cookieName: null, - cookieSource: null, sessionParameterInfo, }; } function getRequestIPAddress(req: $Request) { const { proxy } = getAppURLFactsFromRequestURL(req.originalUrl); let ipAddress; if (proxy === 'none') { ipAddress = req.socket.remoteAddress; } else if (proxy === 'apache') { ipAddress = req.get('X-Forwarded-For'); } invariant(ipAddress, 'could not determine requesting IP address'); return ipAddress; } function getSessionParameterInfoFromRequestBody( req: $Request, ): SessionParameterInfo { const body = (req.body: any); let sessionID = body.sessionID !== undefined || req.method !== 'GET' ? body.sessionID : null; if (sessionID === '') { sessionID = null; } const sessionIdentifierType = req.method === 'GET' || sessionID !== undefined ? sessionIdentifierTypes.BODY_SESSION_ID : sessionIdentifierTypes.COOKIE_ID; return { isSocket: false, sessionID, sessionIdentifierType, ipAddress: getRequestIPAddress(req), userAgent: req.get('User-Agent'), }; } async function fetchViewerForJSONRequest(req: $Request): Promise { assertSecureRequest(req); const sessionParameterInfo = getSessionParameterInfoFromRequestBody(req); const result = await fetchViewerFromRequestBody( req.body, sessionParameterInfo, ); return await handleFetchViewerResult(result); } async function fetchViewerForSocket( req: $Request, clientMessage: InitialClientSocketMessage, -): Promise { +): Promise { assertSecureRequest(req); const { sessionIdentification } = clientMessage.payload; const { sessionID } = sessionIdentification; const sessionParameterInfo = { isSocket: true, sessionID, sessionIdentifierType: sessionID !== undefined ? sessionIdentifierTypes.BODY_SESSION_ID : sessionIdentifierTypes.COOKIE_ID, ipAddress: getRequestIPAddress(req), userAgent: req.get('User-Agent'), }; const result = await fetchViewerFromRequestBody( clientMessage.payload.sessionIdentification, sessionParameterInfo, ); if (result.type === 'valid') { return result.viewer; } - const promises = {}; - if (result.cookieSource === cookieSources.BODY) { - // We initialize a socket's Viewer after the WebSocket handshake, since to - // properly initialize the Viewer we need a bunch of data, but that data - // can't be sent until after the handshake. Consequently, by the time we - // know that a cookie may be invalid, we are no longer communicating via - // HTTP, and have no way to set a new cookie for HEADER (web) clients. - const platformDetails = - result.type === 'invalidated' ? result.platformDetails : null; - const deviceToken = - result.type === 'invalidated' ? result.deviceToken : null; - promises.anonymousViewerData = createNewAnonymousCookie({ - platformDetails, - deviceToken, - }); - } - if (result.type === 'invalidated') { - promises.deleteCookie = deleteCookie(result.cookieID); - } - const { anonymousViewerData } = await promiseAll(promises); - - if (!anonymousViewerData) { - return null; - } + const anonymousViewerDataPromise: Promise = + (async () => { + const platformDetails = + result.type === 'invalidated' ? result.platformDetails : null; + const deviceToken = + result.type === 'invalidated' ? result.deviceToken : null; + return await createNewAnonymousCookie({ + platformDetails, + deviceToken, + }); + })(); + const deleteCookiePromise = (async () => { + if (result.type === 'invalidated') { + await deleteCookie(result.cookieID); + } + })(); + const [anonymousViewerData] = await Promise.all([ + anonymousViewerDataPromise, + deleteCookiePromise, + ]); return createViewerForInvalidFetchViewerResult(result, anonymousViewerData); } async function handleFetchViewerResult( result: FetchViewerResult, inputPlatformDetails?: PlatformDetails, ) { if (result.type === 'valid') { return result.viewer; } let platformDetails = inputPlatformDetails; if (!platformDetails && result.type === 'invalidated') { platformDetails = result.platformDetails; } const deviceToken = result.type === 'invalidated' ? result.deviceToken : null; const [anonymousViewerData] = await Promise.all([ createNewAnonymousCookie({ platformDetails, deviceToken }), result.type === 'invalidated' ? deleteCookie(result.cookieID) : null, ]); return createViewerForInvalidFetchViewerResult(result, anonymousViewerData); } function createViewerForInvalidFetchViewerResult( result: InvalidFetchViewerResult, anonymousViewerData: AnonymousViewerData, ): Viewer { - // If a null cookie was specified in the request body, result.cookieSource - // will still be BODY here. The only way it would be null or undefined here - // is if there was no cookie specified in either the body or the header, in - // which case we default to returning the new cookie in the response header. - const cookieSource = - result.cookieSource !== null && result.cookieSource !== undefined - ? result.cookieSource - : cookieSources.HEADER; const viewer = new Viewer({ ...anonymousViewerData, - cookieSource, sessionIdentifierType: result.sessionParameterInfo.sessionIdentifierType, isSocket: result.sessionParameterInfo.isSocket, ipAddress: result.sessionParameterInfo.ipAddress, userAgent: result.sessionParameterInfo.userAgent, }); viewer.sessionChanged = true; // If cookieName is falsey, that tells us that there was no cookie specified // in the request, which means we can't be invalidating anything. if (result.cookieName) { viewer.cookieInvalidated = true; viewer.initialCookieName = result.cookieName; } return viewer; } function addSessionChangeInfoToResult( viewer: Viewer, res: $Response, result: Object, ) { let threadInfos = {}, userInfos = {}; if (result.cookieChange) { ({ threadInfos, userInfos } = result.cookieChange); } let sessionChange; if (viewer.cookieInvalidated) { sessionChange = ({ cookieInvalidated: true, threadInfos, userInfos: (values(userInfos).map(a => a): UserInfo[]), currentUserInfo: { anonymous: true, }, }: ServerSessionChange); } else { sessionChange = ({ cookieInvalidated: false, threadInfos, userInfos: (values(userInfos).map(a => a): UserInfo[]), }: ServerSessionChange); } - if (viewer.cookieSource === cookieSources.BODY) { - sessionChange.cookie = viewer.cookiePairString; - } + sessionChange.cookie = viewer.cookiePairString; if (viewer.sessionIdentifierType === sessionIdentifierTypes.BODY_SESSION_ID) { sessionChange.sessionID = viewer.sessionID ? viewer.sessionID : null; } result.cookieChange = sessionChange; } type AnonymousCookieCreationParams = Shape<{ +platformDetails: ?PlatformDetails, +deviceToken: ?string, }>; const defaultPlatformDetails = {}; // The result of this function should not be passed directly to the Viewer // constructor. Instead, it should be passed to viewer.setNewCookie. There are // several fields on AnonymousViewerData that are not set by this function: -// sessionIdentifierType, cookieSource, ipAddress, and userAgent. These -// parameters all depend on the initial request. If the result of this function -// is passed to the Viewer constructor directly, the resultant Viewer object -// will throw whenever anybody attempts to access the relevant properties. +// sessionIdentifierType, ipAddress, and userAgent. These parameters all depend +// on the initial request. If the result of this function is passed to the +// Viewer constructor directly, the resultant Viewer object will throw whenever +// anybody attempts to access the relevant properties. async function createNewAnonymousCookie( params: AnonymousCookieCreationParams, ): Promise { const { platformDetails, deviceToken } = params; const { platform, ...versions } = platformDetails || defaultPlatformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const time = Date.now(); const cookiePassword = crypto.randomBytes(32).toString('hex'); const cookieHash = getCookieHash(cookiePassword); const [[id]] = await Promise.all([ createIDs('cookies', 1), deviceToken ? clearDeviceToken(deviceToken) : undefined, ]); const cookieRow = [ id, cookieHash, null, platform, time, time, deviceToken, versionsString, ]; const query = SQL` INSERT INTO cookies(id, hash, user, platform, creation_time, last_used, device_token, versions) VALUES ${[cookieRow]} `; await dbQuery(query); return { loggedIn: false, id, platformDetails, deviceToken, cookieID: id, cookiePassword, cookieHash, sessionID: undefined, sessionInfo: null, cookieInsertedThisRequest: true, isScriptViewer: false, }; } type UserCookieCreationParams = { +platformDetails: PlatformDetails, +deviceToken?: ?string, +socialProof?: ?SIWESocialProof, +signedIdentityKeysBlob?: ?SignedIdentityKeysBlob, }; // The result of this function should never be passed directly to the Viewer // constructor. Instead, it should be passed to viewer.setNewCookie. There are // several fields on UserViewerData that are not set by this function: -// sessionID, sessionIdentifierType, cookieSource, and ipAddress. These -// parameters all depend on the initial request. If the result of this function -// is passed to the Viewer constructor directly, the resultant Viewer object -// will throw whenever anybody attempts to access the relevant properties. +// sessionID, sessionIdentifierType, and ipAddress. These parameters all depend +// on the initial request. If the result of this function is passed to the +// Viewer constructor directly, the resultant Viewer object will throw whenever +// anybody attempts to access the relevant properties. async function createNewUserCookie( userID: string, params: UserCookieCreationParams, ): Promise { const { platformDetails, deviceToken, socialProof, signedIdentityKeysBlob } = params; const { platform, ...versions } = platformDetails || defaultPlatformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const time = Date.now(); const cookiePassword = crypto.randomBytes(32).toString('hex'); const cookieHash = getCookieHash(cookiePassword); const [[cookieID]] = await Promise.all([ createIDs('cookies', 1), deviceToken ? clearDeviceToken(deviceToken) : undefined, ]); const cookieRow = [ cookieID, cookieHash, userID, platform, time, time, deviceToken, versionsString, JSON.stringify(socialProof), signedIdentityKeysBlob ? JSON.stringify(signedIdentityKeysBlob) : null, ]; const query = SQL` INSERT INTO cookies(id, hash, user, platform, creation_time, last_used, device_token, versions, social_proof, signed_identity_keys) VALUES ${[cookieRow]} `; await dbQuery(query); return { loggedIn: true, id: userID, platformDetails, deviceToken, userID, cookieID, sessionID: undefined, sessionInfo: null, cookiePassword, cookieHash, cookieInsertedThisRequest: true, isScriptViewer: false, }; } // This gets called after createNewUserCookie and from websiteResponder. If the // Viewer's sessionIdentifierType is COOKIE_ID then the cookieID is used as the // session identifier; otherwise, a new ID is created for the session. async function setNewSession( viewer: Viewer, calendarQuery: CalendarQuery, initialLastUpdate: number, ): Promise { if (viewer.sessionIdentifierType !== sessionIdentifierTypes.COOKIE_ID) { const [sessionID] = await createIDs('sessions', 1); viewer.setSessionID(sessionID); } await createSession(viewer, calendarQuery, initialLastUpdate); } async function updateCookie(viewer: Viewer) { const time = Date.now(); const { cookieID, cookieHash, cookiePassword } = viewer; const updateObj = {}; updateObj.last_used = time; if (isBcryptHash(cookieHash)) { updateObj.hash = getCookieHash(cookiePassword); } const query = SQL` UPDATE cookies SET ${updateObj} WHERE id = ${cookieID} `; await dbQuery(query); } function addCookieToJSONResponse( viewer: Viewer, res: $Response, result: Object, expectCookieInvalidation: boolean, ) { if (expectCookieInvalidation) { viewer.cookieInvalidated = false; } if (!viewer.getData().cookieInsertedThisRequest) { handleAsyncPromise(updateCookie(viewer)); } if (viewer.sessionChanged) { addSessionChangeInfoToResult(viewer, res, result); } } function addCookieToHomeResponse( req: $Request, res: $Response, appURLFacts: AppURLFacts, ) { const { user, anonymous } = req.cookies; if (user) { res.cookie(cookieTypes.USER, user, getCookieOptions(appURLFacts)); } if (anonymous) { res.cookie(cookieTypes.ANONYMOUS, anonymous, getCookieOptions(appURLFacts)); } } function getCookieOptions(appURLFacts: AppURLFacts) { const { baseDomain, basePath, https } = appURLFacts; const domainAsURL = new url.URL(baseDomain); return { domain: domainAsURL.hostname, path: basePath, httpOnly: false, secure: https, maxAge: cookieLifetime, sameSite: 'Strict', }; } async function setCookieSignedIdentityKeysBlob( cookieID: string, signedIdentityKeysBlob: SignedIdentityKeysBlob, ) { const signedIdentityKeysStr = JSON.stringify(signedIdentityKeysBlob); const query = SQL` UPDATE cookies SET signed_identity_keys = ${signedIdentityKeysStr} WHERE id = ${cookieID} `; await dbQuery(query); } // Returns `true` if row with `id = cookieID` exists AND // `signed_identity_keys` is `NULL`. Otherwise, returns `false`. async function isCookieMissingSignedIdentityKeysBlob( cookieID: string, ): Promise { const query = SQL` SELECT signed_identity_keys FROM cookies WHERE id = ${cookieID} `; const [queryResult] = await dbQuery(query); return ( queryResult.length === 1 && queryResult[0].signed_identity_keys === null ); } async function isCookieMissingOlmNotificationsSession( viewer: Viewer, ): Promise { const isStaffOrDev = isStaff(viewer.userID) || isDev; if ( !viewer.platformDetails || (viewer.platformDetails.platform !== 'ios' && viewer.platformDetails.platform !== 'android' && !(viewer.platformDetails.platform === 'web' && isStaffOrDev)) || !hasMinCodeVersion(viewer.platformDetails, { native: 222, web: 43, }) ) { return false; } const query = SQL` SELECT COUNT(*) AS count FROM olm_sessions WHERE cookie_id = ${viewer.cookieID} AND is_content = FALSE `; const [queryResult] = await dbQuery(query); return queryResult[0].count === 0; } async function setCookiePlatform( viewer: Viewer, platform: Platform, ): Promise { const newPlatformDetails = { ...viewer.platformDetails, platform }; viewer.setPlatformDetails(newPlatformDetails); const query = SQL` UPDATE cookies SET platform = ${platform} WHERE id = ${viewer.cookieID} `; await dbQuery(query); } async function setCookiePlatformDetails( viewer: Viewer, platformDetails: PlatformDetails, ): Promise { viewer.setPlatformDetails(platformDetails); const { platform, ...versions } = platformDetails; const versionsString = Object.keys(versions).length > 0 ? JSON.stringify(versions) : null; const query = SQL` UPDATE cookies SET platform = ${platform}, versions = ${versionsString} WHERE id = ${viewer.cookieID} `; await dbQuery(query); } export { fetchViewerForJSONRequest, fetchViewerForSocket, createNewAnonymousCookie, createNewUserCookie, setNewSession, updateCookie, addCookieToJSONResponse, addCookieToHomeResponse, setCookieSignedIdentityKeysBlob, isCookieMissingSignedIdentityKeysBlob, setCookiePlatform, setCookiePlatformDetails, isCookieMissingOlmNotificationsSession, }; diff --git a/keyserver/src/session/viewer.js b/keyserver/src/session/viewer.js index 96ab2e0f8..ba2414a9f 100644 --- a/keyserver/src/session/viewer.js +++ b/keyserver/src/session/viewer.js @@ -1,351 +1,331 @@ // @flow import geoip from 'geoip-lite'; import invariant from 'invariant'; import type { Platform, PlatformDetails } from 'lib/types/device-types.js'; import type { CalendarQuery } from 'lib/types/entry-types.js'; import { - type CookieSource, type SessionIdentifierType, cookieTypes, type CookieType, sessionIdentifierTypes, } from 'lib/types/session-types.js'; import { ServerError } from 'lib/utils/errors.js'; export type UserViewerData = { +loggedIn: true, +id: string, +platformDetails: ?PlatformDetails, +deviceToken: ?string, +userID: string, +cookieID: ?string, - +cookieSource?: CookieSource, +cookiePassword: ?string, +cookieHash: ?string, +cookieInsertedThisRequest?: boolean, +sessionIdentifierType?: SessionIdentifierType, +sessionID: ?string, +sessionInfo: ?SessionInfo, +isScriptViewer: boolean, +isSocket?: boolean, +ipAddress?: string, +userAgent?: ?string, }; export type AnonymousViewerData = { +loggedIn: false, +id: string, +platformDetails: ?PlatformDetails, +deviceToken: ?string, - +cookieSource?: CookieSource, +cookieID: string, +cookiePassword: ?string, +cookieHash: ?string, +cookieInsertedThisRequest?: boolean, +sessionIdentifierType?: SessionIdentifierType, +sessionID: ?string, +sessionInfo: ?SessionInfo, +isScriptViewer: boolean, +isSocket?: boolean, +ipAddress?: string, +userAgent?: ?string, }; type SessionInfo = { +lastValidated: number, +lastUpdate: number, +calendarQuery: CalendarQuery, }; export type ViewerData = UserViewerData | AnonymousViewerData; class Viewer { data: ViewerData; sessionChanged: boolean = false; cookieInvalidated: boolean = false; initialCookieName: string; cachedTimeZone: ?string; constructor(data: ViewerData) { this.data = data; this.initialCookieName = Viewer.cookieNameFromViewerData(data); } static cookieNameFromViewerData(data: ViewerData): CookieType { return data.loggedIn ? cookieTypes.USER : cookieTypes.ANONYMOUS; } getData(): ViewerData { return this.data; } setNewCookie(data: ViewerData) { - if (data.cookieSource === null || data.cookieSource === undefined) { - if (data.loggedIn) { - data = { ...data, cookieSource: this.cookieSource }; - } else { - // This is a separate condition because of Flow - data = { ...data, cookieSource: this.cookieSource }; - } - } if ( data.sessionIdentifierType === null || data.sessionIdentifierType === undefined ) { if (data.loggedIn) { data = { ...data, sessionIdentifierType: this.sessionIdentifierType }; } else { // This is a separate condition because of Flow data = { ...data, sessionIdentifierType: this.sessionIdentifierType }; } } if (data.isSocket === null || data.isSocket === undefined) { if (data.loggedIn) { data = { ...data, isSocket: this.isSocket }; } else { // This is a separate condition because of Flow data = { ...data, isSocket: this.isSocket }; } } if (data.ipAddress === null || data.ipAddress === undefined) { if (data.loggedIn) { data = { ...data, ipAddress: this.ipAddress }; } else { // This is a separate condition because of Flow data = { ...data, ipAddress: this.ipAddress }; } } else { this.cachedTimeZone = undefined; } if (data.userAgent === null || data.userAgent === undefined) { if (data.loggedIn) { data = { ...data, userAgent: this.userAgent }; } else { // This is a separate condition because of Flow data = { ...data, userAgent: this.userAgent }; } } this.data = data; this.sessionChanged = true; // If the request explicitly sets a new cookie, there's no point in telling // the client that their old cookie is invalid. Note that clients treat // cookieInvalidated as a forced log-out, which isn't necessary here. this.cookieInvalidated = false; } setSessionID(sessionID: string) { if (sessionID === this.sessionID) { return; } this.sessionChanged = true; if (this.data.loggedIn) { this.data = { ...this.data, sessionID }; } else { // This is a separate condition because of Flow this.data = { ...this.data, sessionID }; } } setSessionInfo(sessionInfo: SessionInfo) { if (this.data.loggedIn) { this.data = { ...this.data, sessionInfo }; } else { // This is a separate condition because of Flow this.data = { ...this.data, sessionInfo }; } } setDeviceToken(deviceToken: string) { if (this.data.loggedIn) { this.data = { ...this.data, deviceToken }; } else { // This is a separate condition because of Flow this.data = { ...this.data, deviceToken }; } } setPlatformDetails(platformDetails: PlatformDetails) { if (this.data.loggedIn) { this.data = { ...this.data, platformDetails }; } else { // This is a separate condition because of Flow this.data = { ...this.data, platformDetails }; } } get id(): string { return this.data.id; } get loggedIn(): boolean { return this.data.loggedIn; } - get cookieSource(): CookieSource { - const { cookieSource } = this.data; - invariant( - cookieSource !== null && cookieSource !== undefined, - 'Viewer.cookieSource should be set', - ); - return cookieSource; - } - get cookieID(): string { const { cookieID } = this.data; invariant( cookieID !== null && cookieID !== undefined, 'Viewer.cookieID should be set', ); return cookieID; } get cookiePassword(): string { const { cookiePassword } = this.data; invariant( cookiePassword !== null && cookiePassword !== undefined, 'Viewer.cookieID should be set', ); return cookiePassword; } get cookieHash(): string { const { cookieHash } = this.data; invariant( cookieHash !== null && cookieHash !== undefined, 'Viewer.cookieHash should be set', ); return cookieHash; } get sessionIdentifierType(): SessionIdentifierType { const { sessionIdentifierType } = this.data; invariant( sessionIdentifierType !== null && sessionIdentifierType !== undefined, 'Viewer.sessionIdentifierType should be set', ); return sessionIdentifierType; } // This is used in the case of sessionIdentifierTypes.BODY_SESSION_ID only. // It will be falsey otherwise. Use session below if you want the actual // session identifier in all cases. get sessionID(): ?string { return this.data.sessionID; } get session(): string { if (this.sessionIdentifierType === sessionIdentifierTypes.COOKIE_ID) { return this.cookieID; } else if (this.sessionID) { return this.sessionID; } else if (!this.loggedIn) { throw new ServerError('not_logged_in'); } else { // If the session identifier is sessionIdentifierTypes.BODY_SESSION_ID and // the user is logged in, then the sessionID should be set. throw new ServerError('unknown_error'); } } get hasSessionInfo(): boolean { const { sessionInfo } = this.data; return !!sessionInfo; } get sessionLastValidated(): number { const { sessionInfo } = this.data; invariant( sessionInfo !== null && sessionInfo !== undefined, 'Viewer.sessionInfo should be set', ); return sessionInfo.lastValidated; } get sessionLastUpdated(): number { const { sessionInfo } = this.data; invariant( sessionInfo !== null && sessionInfo !== undefined, 'Viewer.sessionInfo should be set', ); return sessionInfo.lastUpdate; } get calendarQuery(): CalendarQuery { const { sessionInfo } = this.data; invariant( sessionInfo !== null && sessionInfo !== undefined, 'Viewer.sessionInfo should be set', ); return sessionInfo.calendarQuery; } get userID(): string { if (!this.data.userID) { throw new ServerError('not_logged_in'); } return this.data.userID; } get cookieName(): string { return Viewer.cookieNameFromViewerData(this.data); } get cookieString(): string { return `${this.cookieID}:${this.cookiePassword}`; } get cookiePairString(): string { return `${this.cookieName}=${this.cookieString}`; } get platformDetails(): ?PlatformDetails { return this.data.platformDetails; } get platform(): ?Platform { return this.data.platformDetails ? this.data.platformDetails.platform : null; } get deviceToken(): ?string { return this.data.deviceToken; } get isScriptViewer(): boolean { return this.data.isScriptViewer; } get isSocket(): boolean { invariant( this.data.isSocket !== null && this.data.isSocket !== undefined, 'isSocket should be set', ); return this.data.isSocket; } get ipAddress(): string { invariant( this.data.ipAddress !== null && this.data.ipAddress !== undefined, 'ipAddress should be set', ); return this.data.ipAddress; } get userAgent(): ?string { return this.data.userAgent; } get timeZone(): ?string { if (this.cachedTimeZone === undefined) { const geoData = geoip.lookup(this.ipAddress); this.cachedTimeZone = geoData ? geoData.timezone : null; } return this.cachedTimeZone; } } export { Viewer }; diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js index 21568eb10..0564fe39f 100644 --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -1,884 +1,875 @@ // @flow import type { $Request } from 'express'; import invariant from 'invariant'; import _debounce from 'lodash/debounce.js'; import t from 'tcomb'; import type { TUnion } from 'tcomb'; import WebSocket from 'ws'; import { baseLegalPolicies } from 'lib/facts/policies.js'; import { mostRecentMessageTimestamp } from 'lib/shared/message-utils.js'; import { isStaff } from 'lib/shared/staff-utils.js'; import { serverRequestSocketTimeout, serverResponseTimeout, } from 'lib/shared/timeouts.js'; import { mostRecentUpdateTimestamp } from 'lib/shared/update-utils.js'; import { hasMinCodeVersion } from 'lib/shared/version-utils.js'; import type { Shape } from 'lib/types/core.js'; import { endpointIsSocketSafe } from 'lib/types/endpoints.js'; import { defaultNumberPerThread } from 'lib/types/message-types.js'; import { redisMessageTypes, type RedisMessage } from 'lib/types/redis-types.js'; import { serverRequestTypes } from 'lib/types/request-types.js'; import { - cookieSources, sessionCheckFrequency, stateCheckInactivityActivationInterval, } from 'lib/types/session-types.js'; import { type ClientSocketMessage, type InitialClientSocketMessage, type ResponsesClientSocketMessage, type ServerStateSyncFullSocketPayload, type ServerServerSocketMessage, type ErrorServerSocketMessage, type AuthErrorServerSocketMessage, type PingClientSocketMessage, type AckUpdatesClientSocketMessage, type APIRequestClientSocketMessage, clientSocketMessageTypes, stateSyncPayloadTypes, serverSocketMessageTypes, serverServerSocketMessageValidator, } from 'lib/types/socket-types.js'; import { ServerError } from 'lib/utils/errors.js'; import { values } from 'lib/utils/objects.js'; import { promiseAll } from 'lib/utils/promises.js'; import SequentialPromiseResolver from 'lib/utils/sequential-promise-resolver.js'; import sleep from 'lib/utils/sleep.js'; import { tShape, tCookie } from 'lib/utils/validation-utils.js'; import { RedisSubscriber } from './redis.js'; import { clientResponseInputValidator, processClientResponses, initializeSession, checkState, } from './session-utils.js'; import { fetchUpdateInfosWithRawUpdateInfos } from '../creators/update-creator.js'; import { deleteActivityForViewerSession } from '../deleters/activity-deleters.js'; import { deleteCookie } from '../deleters/cookie-deleters.js'; import { deleteUpdatesBeforeTimeTargetingSession } from '../deleters/update-deleters.js'; import { jsonEndpoints } from '../endpoints.js'; import { fetchMessageInfosSince, getMessageFetchResultFromRedisMessages, } from '../fetchers/message-fetchers.js'; import { fetchUpdateInfos } from '../fetchers/update-fetchers.js'; import { newEntryQueryInputValidator, verifyCalendarQueryThreadIDs, } from '../responders/entry-responders.js'; import { handleAsyncPromise } from '../responders/handlers.js'; import { fetchViewerForSocket, updateCookie, - createNewAnonymousCookie, isCookieMissingSignedIdentityKeysBlob, isCookieMissingOlmNotificationsSession, + createNewAnonymousCookie, } from '../session/cookies.js'; import { Viewer } from '../session/viewer.js'; +import type { AnonymousViewerData } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { commitSessionUpdate } from '../updaters/session-updaters.js'; import { compressMessage } from '../utils/compress.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { checkInputValidator, checkClientSupported, policiesValidator, validateOutput, } from '../utils/validation-utils.js'; const clientSocketMessageInputValidator: TUnion = t.union([ tShape({ type: t.irreducible( 'clientSocketMessageTypes.INITIAL', x => x === clientSocketMessageTypes.INITIAL, ), id: t.Number, payload: tShape({ sessionIdentification: tShape({ cookie: t.maybe(tCookie), sessionID: t.maybe(t.String), }), sessionState: tShape({ calendarQuery: newEntryQueryInputValidator, messagesCurrentAsOf: t.Number, updatesCurrentAsOf: t.Number, watchedIDs: t.list(t.String), }), clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.RESPONSES', x => x === clientSocketMessageTypes.RESPONSES, ), id: t.Number, payload: tShape({ clientResponses: t.list(clientResponseInputValidator), }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.PING', x => x === clientSocketMessageTypes.PING, ), id: t.Number, }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.ACK_UPDATES', x => x === clientSocketMessageTypes.ACK_UPDATES, ), id: t.Number, payload: tShape({ currentAsOf: t.Number, }), }), tShape({ type: t.irreducible( 'clientSocketMessageTypes.API_REQUEST', x => x === clientSocketMessageTypes.API_REQUEST, ), id: t.Number, payload: tShape({ endpoint: t.String, input: t.maybe(t.Object), }), }), ]); function onConnection(ws: WebSocket, req: $Request) { assertSecureRequest(req); new Socket(ws, req); } type StateCheckConditions = { activityRecentlyOccurred: boolean, stateCheckOngoing: boolean, }; const minVersionsForCompression = { native: 265, web: 30, }; class Socket { ws: WebSocket; httpRequest: $Request; viewer: ?Viewer; redis: ?RedisSubscriber; redisPromiseResolver: SequentialPromiseResolver; stateCheckConditions: StateCheckConditions = { activityRecentlyOccurred: true, stateCheckOngoing: false, }; stateCheckTimeoutID: ?TimeoutID; constructor(ws: WebSocket, httpRequest: $Request) { this.ws = ws; this.httpRequest = httpRequest; ws.on('message', this.onMessage); ws.on('close', this.onClose); this.resetTimeout(); this.redisPromiseResolver = new SequentialPromiseResolver(this.sendMessage); } onMessage = async ( messageString: string | Buffer | ArrayBuffer | Array, ) => { invariant(typeof messageString === 'string', 'message should be string'); let clientSocketMessage: ?ClientSocketMessage; try { this.resetTimeout(); const messageObject = JSON.parse(messageString); clientSocketMessage = checkInputValidator( clientSocketMessageInputValidator, messageObject, ); if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { if (this.viewer) { // This indicates that the user sent multiple INITIAL messages. throw new ServerError('socket_already_initialized'); } this.viewer = await fetchViewerForSocket( this.httpRequest, clientSocketMessage, ); - if (!this.viewer) { - // This indicates that the cookie was invalid, but the client is using - // cookieSources.HEADER and thus can't accept a new cookie over - // WebSockets. See comment under catch block for socket_deauthorized. - throw new ServerError('socket_deauthorized'); - } } const { viewer } = this; if (!viewer) { // This indicates a non-INITIAL message was sent by the client before // the INITIAL message. throw new ServerError('socket_uninitialized'); } if (viewer.sessionChanged) { // This indicates that the cookie was invalid, and we've assigned a new // anonymous one. throw new ServerError('socket_deauthorized'); } if (!viewer.loggedIn) { // This indicates that the specified cookie was an anonymous one. throw new ServerError('not_logged_in'); } await checkClientSupported( viewer, clientSocketMessageInputValidator, clientSocketMessage, ); await policiesValidator(viewer, baseLegalPolicies); const serverResponses = await this.handleClientSocketMessage( clientSocketMessage, ); if (!this.redis) { this.redis = new RedisSubscriber( { userID: viewer.userID, sessionID: viewer.session }, this.onRedisMessage, ); } if (viewer.sessionChanged) { // This indicates that something has caused the session to change, which // shouldn't happen from inside a WebSocket since we can't handle cookie // invalidation. throw new ServerError('session_mutated_from_socket'); } if (clientSocketMessage.type !== clientSocketMessageTypes.PING) { handleAsyncPromise(updateCookie(viewer)); } for (const response of serverResponses) { // Normally it's an anti-pattern to await in sequence like this. But in // this case, we have a requirement that this array of serverResponses // is delivered in order. See here: // https://github.com/CommE2E/comm/blob/101eb34481deb49c609bfd2c785f375886e52666/keyserver/src/socket/socket.js#L566-L568 await this.sendMessage(response); } if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { this.onSuccessfulConnection(); } } catch (error) { console.warn(error); if (!(error instanceof ServerError)) { const errorMessage: ErrorServerSocketMessage = { type: serverSocketMessageTypes.ERROR, message: error.message, }; const responseTo = clientSocketMessage ? clientSocketMessage.id : null; if (responseTo !== null) { errorMessage.responseTo = responseTo; } this.markActivityOccurred(); await this.sendMessage(errorMessage); return; } invariant(clientSocketMessage, 'should be set'); const responseTo = clientSocketMessage.id; if (error.message === 'socket_deauthorized') { + invariant(this.viewer, 'should be set'); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, - }; - if (this.viewer) { - // viewer should only be falsey for cookieSources.HEADER (web) - // clients. Usually if the cookie is invalid we construct a new - // anonymous Viewer with a new cookie, and then pass the cookie down - // in the error. But we can't pass HTTP cookies in WebSocket messages. - authErrorMessage.sessionChange = { + sessionChange: { cookie: this.viewer.cookiePairString, currentUserInfo: { anonymous: true, }, - }; - } + }, + }; + await this.sendMessage(authErrorMessage); this.ws.close(4100, error.message); return; } else if (error.message === 'client_version_unsupported') { const { viewer } = this; invariant(viewer, 'should be set'); - const promises = {}; - promises.deleteCookie = deleteCookie(viewer.cookieID); - if (viewer.cookieSource !== cookieSources.BODY) { - promises.anonymousViewerData = createNewAnonymousCookie({ + + const anonymousViewerDataPromise: Promise = + createNewAnonymousCookie({ platformDetails: error.platformDetails, deviceToken: viewer.deviceToken, }); - } - const { anonymousViewerData } = await promiseAll(promises); + const deleteCookiePromise = deleteCookie(viewer.cookieID); + const [anonymousViewerData] = await Promise.all([ + anonymousViewerDataPromise, + deleteCookiePromise, + ]); + + // It is normally not safe to pass the result of + // createNewAnonymousCookie to the Viewer constructor. That is because + // createNewAnonymousCookie leaves several fields of + // AnonymousViewerData unset, and consequently Viewer will throw when + // access is attempted. It is only safe here because we can guarantee + // that only cookiePairString and cookieID are accessed on anonViewer + // below. + const anonViewer = new Viewer(anonymousViewerData); const authErrorMessage: AuthErrorServerSocketMessage = { type: serverSocketMessageTypes.AUTH_ERROR, responseTo, message: error.message, - }; - if (anonymousViewerData) { - // It is normally not safe to pass the result of - // createNewAnonymousCookie to the Viewer constructor. That is because - // createNewAnonymousCookie leaves several fields of - // AnonymousViewerData unset, and consequently Viewer will throw when - // access is attempted. It is only safe here because we can guarantee - // that only cookiePairString and cookieID are accessed on anonViewer - // below. - const anonViewer = new Viewer(anonymousViewerData); - authErrorMessage.sessionChange = { + sessionChange: { cookie: anonViewer.cookiePairString, currentUserInfo: { anonymous: true, }, - }; - } + }, + }; await this.sendMessage(authErrorMessage); this.ws.close(4101, error.message); return; } if (error.payload) { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, payload: error.payload, }); } else { await this.sendMessage({ type: serverSocketMessageTypes.ERROR, responseTo, message: error.message, }); } if (error.message === 'not_logged_in') { this.ws.close(4102, error.message); } else if (error.message === 'session_mutated_from_socket') { this.ws.close(4103, error.message); } else { this.markActivityOccurred(); } } }; onClose = async () => { this.clearStateCheckTimeout(); this.resetTimeout.cancel(); this.debouncedAfterActivity.cancel(); if (this.viewer && this.viewer.hasSessionInfo) { await deleteActivityForViewerSession(this.viewer); } if (this.redis) { this.redis.quit(); this.redis = null; } }; sendMessage = async (message: ServerServerSocketMessage) => { invariant( this.ws.readyState > 0, "shouldn't send message until connection established", ); if (this.ws.readyState !== 1) { return; } const { viewer } = this; const validatedMessage = validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, message, ); const stringMessage = JSON.stringify(validatedMessage); if ( !viewer?.platformDetails || !hasMinCodeVersion(viewer.platformDetails, minVersionsForCompression) || !isStaff(viewer.id) ) { this.ws.send(stringMessage); return; } const compressionResult = await compressMessage(stringMessage); if (this.ws.readyState !== 1) { return; } if (!compressionResult.compressed) { this.ws.send(stringMessage); return; } const compressedMessage = { type: serverSocketMessageTypes.COMPRESSED_MESSAGE, payload: compressionResult.result, }; const validatedCompressedMessage = validateOutput( viewer?.platformDetails, serverServerSocketMessageValidator, compressedMessage, ); const stringCompressedMessage = JSON.stringify(validatedCompressedMessage); this.ws.send(stringCompressedMessage); }; async handleClientSocketMessage( message: ClientSocketMessage, ): Promise { const resultPromise = (async () => { if (message.type === clientSocketMessageTypes.INITIAL) { this.markActivityOccurred(); return await this.handleInitialClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.RESPONSES) { this.markActivityOccurred(); return await this.handleResponsesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.PING) { return this.handlePingClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.ACK_UPDATES) { this.markActivityOccurred(); return await this.handleAckUpdatesClientSocketMessage(message); } else if (message.type === clientSocketMessageTypes.API_REQUEST) { this.markActivityOccurred(); return await this.handleAPIRequestClientSocketMessage(message); } return []; })(); const timeoutPromise = (async () => { await sleep(serverResponseTimeout); throw new ServerError('socket_response_timeout'); })(); return await Promise.race([resultPromise, timeoutPromise]); } async handleInitialClientSocketMessage( message: InitialClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const responses = []; const { sessionState, clientResponses } = message.payload; const { calendarQuery, updatesCurrentAsOf: oldUpdatesCurrentAsOf, messagesCurrentAsOf: oldMessagesCurrentAsOf, watchedIDs, } = sessionState; await verifyCalendarQueryThreadIDs(calendarQuery); const sessionInitializationResult = await initializeSession( viewer, calendarQuery, oldUpdatesCurrentAsOf, ); const threadCursors = {}; for (const watchedThreadID of watchedIDs) { threadCursors[watchedThreadID] = null; } const messageSelectionCriteria = { threadCursors, joinedThreads: true, newerThan: oldMessagesCurrentAsOf, }; const [fetchMessagesResult, { serverRequests, activityUpdateResult }] = await Promise.all([ fetchMessageInfosSince( viewer, messageSelectionCriteria, defaultNumberPerThread, ), processClientResponses(viewer, clientResponses), ]); const messagesResult = { rawMessageInfos: fetchMessagesResult.rawMessageInfos, truncationStatuses: fetchMessagesResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( fetchMessagesResult.rawMessageInfos, oldMessagesCurrentAsOf, ), }; const isCookieMissingSignedIdentityKeysBlobPromise = isCookieMissingSignedIdentityKeysBlob(viewer.cookieID); const isCookieMissingOlmNotificationsSessionPromise = isCookieMissingOlmNotificationsSession(viewer); if (!sessionInitializationResult.sessionContinued) { const promises = Object.fromEntries( values(serverStateSyncSpecs).map(spec => [ spec.hashKey, spec.fetchFullSocketSyncPayload(viewer, [calendarQuery]), ]), ); const results = await promiseAll(promises); const payload: ServerStateSyncFullSocketPayload = { type: stateSyncPayloadTypes.FULL, messagesResult, threadInfos: results.threadInfos, currentUserInfo: results.currentUserInfo, rawEntryInfos: results.entryInfos, userInfos: results.userInfos, updatesCurrentAsOf: oldUpdatesCurrentAsOf, }; if (viewer.sessionChanged) { // If initializeSession encounters, // sessionIdentifierTypes.BODY_SESSION_ID but the session // is unspecified or expired, // it will set a new sessionID and specify viewer.sessionChanged const { sessionID } = viewer; invariant( sessionID !== null && sessionID !== undefined, 'should be set', ); payload.sessionID = sessionID; viewer.sessionChanged = false; } responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload, }); } else { const { sessionUpdate, deltaEntryInfoResult } = sessionInitializationResult; const promises = {}; promises.deleteExpiredUpdates = deleteUpdatesBeforeTimeTargetingSession( viewer, oldUpdatesCurrentAsOf, ); promises.fetchUpdateResult = fetchUpdateInfos( viewer, oldUpdatesCurrentAsOf, calendarQuery, ); promises.sessionUpdate = commitSessionUpdate(viewer, sessionUpdate); const { fetchUpdateResult } = await promiseAll(promises); const { updateInfos, userInfos } = fetchUpdateResult; const newUpdatesCurrentAsOf = mostRecentUpdateTimestamp( [...updateInfos], oldUpdatesCurrentAsOf, ); const updatesResult = { newUpdates: updateInfos, currentAsOf: newUpdatesCurrentAsOf, }; responses.push({ type: serverSocketMessageTypes.STATE_SYNC, responseTo: message.id, payload: { type: stateSyncPayloadTypes.INCREMENTAL, messagesResult, updatesResult, deltaEntryInfos: deltaEntryInfoResult.rawEntryInfos, deletedEntryIDs: deltaEntryInfoResult.deletedEntryIDs, userInfos: values(userInfos), }, }); } const [signedIdentityKeysBlobMissing, olmNotificationsSessionMissing] = await Promise.all([ isCookieMissingSignedIdentityKeysBlobPromise, isCookieMissingOlmNotificationsSessionPromise, ]); if (signedIdentityKeysBlobMissing) { serverRequests.push({ type: serverRequestTypes.SIGNED_IDENTITY_KEYS_BLOB, }); } if (olmNotificationsSessionMissing) { serverRequests.push({ type: serverRequestTypes.INITIAL_NOTIFICATIONS_ENCRYPTED_MESSAGE, }); } if (serverRequests.length > 0 || clientResponses.length > 0) { // We send this message first since the STATE_SYNC triggers the client's // connection status to shift to "connected", and we want to make sure the // client responses are cleared from Redux before that happens responses.unshift({ type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }); } if (activityUpdateResult) { // Same reason for unshifting as above responses.unshift({ type: serverSocketMessageTypes.ACTIVITY_UPDATE_RESPONSE, responseTo: message.id, payload: activityUpdateResult, }); } return responses; } async handleResponsesClientSocketMessage( message: ResponsesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { clientResponses } = message.payload; const { stateCheckStatus } = await processClientResponses( viewer, clientResponses, ); const serverRequests = []; if (stateCheckStatus && stateCheckStatus.status !== 'state_check') { const { sessionUpdate, checkStateRequest } = await checkState( viewer, stateCheckStatus, ); if (sessionUpdate) { await commitSessionUpdate(viewer, sessionUpdate); this.setStateCheckConditions({ stateCheckOngoing: false }); } if (checkStateRequest) { serverRequests.push(checkStateRequest); } } // We send a response message regardless of whether we have any requests, // since we need to ack the client's responses return [ { type: serverSocketMessageTypes.REQUESTS, responseTo: message.id, payload: { serverRequests }, }, ]; } handlePingClientSocketMessage( message: PingClientSocketMessage, ): ServerServerSocketMessage[] { return [ { type: serverSocketMessageTypes.PONG, responseTo: message.id, }, ]; } async handleAckUpdatesClientSocketMessage( message: AckUpdatesClientSocketMessage, ): Promise { const { viewer } = this; invariant(viewer, 'should be set'); const { currentAsOf } = message.payload; await Promise.all([ deleteUpdatesBeforeTimeTargetingSession(viewer, currentAsOf), commitSessionUpdate(viewer, { lastUpdate: currentAsOf }), ]); return []; } async handleAPIRequestClientSocketMessage( message: APIRequestClientSocketMessage, ): Promise { if (!endpointIsSocketSafe(message.payload.endpoint)) { throw new ServerError('endpoint_unsafe_for_socket'); } const { viewer } = this; invariant(viewer, 'should be set'); const responder = jsonEndpoints[message.payload.endpoint]; await policiesValidator(viewer, responder.requiredPolicies); const response = await responder.responder(viewer, message.payload.input); return [ { type: serverSocketMessageTypes.API_RESPONSE, responseTo: message.id, payload: response, }, ]; } onRedisMessage = async (message: RedisMessage) => { try { await this.processRedisMessage(message); } catch (e) { console.warn(e); } }; async processRedisMessage(message: RedisMessage) { if (message.type === redisMessageTypes.START_SUBSCRIPTION) { this.ws.terminate(); } else if (message.type === redisMessageTypes.NEW_UPDATES) { const { viewer } = this; invariant(viewer, 'should be set'); if (message.ignoreSession && message.ignoreSession === viewer.session) { return; } const rawUpdateInfos = message.updates; this.redisPromiseResolver.add( (async () => { const { updateInfos, userInfos } = await fetchUpdateInfosWithRawUpdateInfos(rawUpdateInfos, { viewer, }); if (updateInfos.length === 0) { console.warn( 'could not get any UpdateInfos from redisMessageTypes.NEW_UPDATES', ); return null; } this.markActivityOccurred(); return { type: serverSocketMessageTypes.UPDATES, payload: { updatesResult: { currentAsOf: mostRecentUpdateTimestamp([...updateInfos], 0), newUpdates: updateInfos, }, userInfos: values(userInfos), }, }; })(), ); } else if (message.type === redisMessageTypes.NEW_MESSAGES) { const { viewer } = this; invariant(viewer, 'should be set'); const rawMessageInfos = message.messages; const messageFetchResult = getMessageFetchResultFromRedisMessages( viewer, rawMessageInfos, ); if (messageFetchResult.rawMessageInfos.length === 0) { console.warn( 'could not get any rawMessageInfos from ' + 'redisMessageTypes.NEW_MESSAGES', ); return; } this.redisPromiseResolver.add( (async () => { this.markActivityOccurred(); return { type: serverSocketMessageTypes.MESSAGES, payload: { messagesResult: { rawMessageInfos: messageFetchResult.rawMessageInfos, truncationStatuses: messageFetchResult.truncationStatuses, currentAsOf: mostRecentMessageTimestamp( messageFetchResult.rawMessageInfos, 0, ), }, }, }; })(), ); } } onSuccessfulConnection() { if (this.ws.readyState !== 1) { return; } this.handleStateCheckConditionsUpdate(); } // The Socket will timeout by calling this.ws.terminate() // serverRequestSocketTimeout milliseconds after the last // time resetTimeout is called resetTimeout = _debounce( () => this.ws.terminate(), serverRequestSocketTimeout, ); debouncedAfterActivity = _debounce( () => this.setStateCheckConditions({ activityRecentlyOccurred: false }), stateCheckInactivityActivationInterval, ); markActivityOccurred = () => { if (this.ws.readyState !== 1) { return; } this.setStateCheckConditions({ activityRecentlyOccurred: true }); this.debouncedAfterActivity(); }; clearStateCheckTimeout() { const { stateCheckTimeoutID } = this; if (stateCheckTimeoutID) { clearTimeout(stateCheckTimeoutID); this.stateCheckTimeoutID = null; } } setStateCheckConditions(newConditions: Shape) { this.stateCheckConditions = { ...this.stateCheckConditions, ...newConditions, }; this.handleStateCheckConditionsUpdate(); } get stateCheckCanStart() { return Object.values(this.stateCheckConditions).every(cond => !cond); } handleStateCheckConditionsUpdate() { if (!this.stateCheckCanStart) { this.clearStateCheckTimeout(); return; } if (this.stateCheckTimeoutID) { return; } const { viewer } = this; if (!viewer) { return; } const timeUntilStateCheck = viewer.sessionLastValidated + sessionCheckFrequency - Date.now(); if (timeUntilStateCheck <= 0) { this.initiateStateCheck(); } else { this.stateCheckTimeoutID = setTimeout( this.initiateStateCheck, timeUntilStateCheck, ); } } initiateStateCheck = async () => { this.setStateCheckConditions({ stateCheckOngoing: true }); const { viewer } = this; invariant(viewer, 'should be set'); const { checkStateRequest } = await checkState(viewer, { status: 'state_check', }); invariant(checkStateRequest, 'should be set'); await this.sendMessage({ type: serverSocketMessageTypes.REQUESTS, payload: { serverRequests: [checkStateRequest] }, }); }; } export { onConnection }; diff --git a/lib/socket/socket.react.js b/lib/socket/socket.react.js index 81836da92..85e515666 100644 --- a/lib/socket/socket.react.js +++ b/lib/socket/socket.react.js @@ -1,815 +1,813 @@ // @flow import invariant from 'invariant'; import _isEqual from 'lodash/fp/isEqual.js'; import _throttle from 'lodash/throttle.js'; import * as React from 'react'; import ActivityHandler from './activity-handler.react.js'; import APIRequestHandler from './api-request-handler.react.js'; import CalendarQueryHandler from './calendar-query-handler.react.js'; import { InflightRequests } from './inflight-requests.js'; import MessageHandler from './message-handler.react.js'; import ReportHandler from './report-handler.react.js'; import RequestResponseHandler from './request-response-handler.react.js'; import UpdateHandler from './update-handler.react.js'; import { updateActivityActionTypes } from '../actions/activity-actions.js'; import { updateLastCommunicatedPlatformDetailsActionType } from '../actions/device-actions.js'; import { logOutActionTypes } from '../actions/user-actions.js'; import { unsupervisedBackgroundActionType } from '../reducers/lifecycle-state-reducer.js'; import { pingFrequency, serverRequestSocketTimeout, clientRequestVisualTimeout, clientRequestSocketTimeout, } from '../shared/timeouts.js'; import { logInActionSources, type LogOutResult, } from '../types/account-types.js'; import type { CompressedData } from '../types/compression-types.js'; import { type PlatformDetails } from '../types/device-types.js'; import type { CalendarQuery } from '../types/entry-types.js'; import { forcePolicyAcknowledgmentActionType } from '../types/policy-types.js'; import type { Dispatch } from '../types/redux-types.js'; import { serverRequestTypes, type ClientClientResponse, type ClientServerRequest, } from '../types/request-types.js'; import { type SessionState, type SessionIdentification, type PreRequestUserState, } from '../types/session-types.js'; import { clientSocketMessageTypes, type ClientClientSocketMessage, serverSocketMessageTypes, type ClientServerSocketMessage, stateSyncPayloadTypes, fullStateSyncActionType, incrementalStateSyncActionType, updateConnectionStatusActionType, type ConnectionInfo, type ClientInitialClientSocketMessage, type ClientResponsesClientSocketMessage, type PingClientSocketMessage, type AckUpdatesClientSocketMessage, type APIRequestClientSocketMessage, type ClientSocketMessageWithoutID, type SocketListener, type ConnectionStatus, setLateResponseActionType, type CommTransportLayer, } from '../types/socket-types.js'; import { actionLogger } from '../utils/action-logger.js'; import type { DispatchActionPromise } from '../utils/action-utils.js'; import { setNewSessionActionType, fetchNewCookieFromNativeCredentials, } from '../utils/action-utils.js'; import { getConfig } from '../utils/config.js'; import { ServerError, SocketTimeout, SocketOffline } from '../utils/errors.js'; import { promiseAll } from '../utils/promises.js'; import sleep from '../utils/sleep.js'; import { ashoatKeyserverID } from '../utils/validation-utils.js'; const remainingTimeAfterVisualTimeout = clientRequestSocketTimeout - clientRequestVisualTimeout; export type BaseSocketProps = { +detectUnsupervisedBackgroundRef?: ( detectUnsupervisedBackground: (alreadyClosed: boolean) => boolean, ) => void, }; type Props = { ...BaseSocketProps, // Redux state +active: boolean, +openSocket: () => CommTransportLayer, +getClientResponses: ( activeServerRequests: $ReadOnlyArray, ) => Promise<$ReadOnlyArray>, +activeThread: ?string, +sessionStateFunc: () => SessionState, +sessionIdentification: SessionIdentification, +cookie: ?string, +urlPrefix: string, +connection: ConnectionInfo, +currentCalendarQuery: () => CalendarQuery, +canSendReports: boolean, +frozen: boolean, +preRequestUserState: PreRequestUserState, +noDataAfterPolicyAcknowledgment?: boolean, +lastCommunicatedPlatformDetails: ?PlatformDetails, +decompressSocketMessage: CompressedData => string, // Redux dispatch functions +dispatch: Dispatch, +dispatchActionPromise: DispatchActionPromise, // async functions that hit server APIs +logOut: (preRequestUserState: PreRequestUserState) => Promise, +socketCrashLoopRecovery?: () => Promise, // keyserver olm sessions specific props +getInitialNotificationsEncryptedMessage?: () => Promise, }; type State = { +inflightRequests: ?InflightRequests, }; class Socket extends React.PureComponent { state: State = { inflightRequests: null, }; socket: ?CommTransportLayer; nextClientMessageID: number = 0; listeners: Set = new Set(); pingTimeoutID: ?TimeoutID; messageLastReceived: ?number; reopenConnectionAfterClosing: boolean = false; invalidationRecoveryInProgress: boolean = false; initializedWithUserState: ?PreRequestUserState; failuresAfterPolicyAcknowledgment: number = 0; openSocket(newStatus: ConnectionStatus) { if ( this.props.frozen || !this.props.cookie || !this.props.cookie.startsWith('user=') ) { return; } if (this.socket) { const { status } = this.props.connection; if (status === 'forcedDisconnecting') { this.reopenConnectionAfterClosing = true; return; } else if (status === 'disconnecting' && this.socket.readyState === 1) { this.markSocketInitialized(); return; } else if ( status === 'connected' || status === 'connecting' || status === 'reconnecting' ) { return; } if (this.socket.readyState < 2) { this.socket.close(); console.log(`this.socket seems open, but Redux thinks it's ${status}`); } } this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: newStatus, keyserverID: ashoatKeyserverID }, }); const socket = this.props.openSocket(); const openObject = {}; socket.onopen = () => { if (this.socket === socket) { this.initializeSocket(); openObject.initializeMessageSent = true; } }; socket.onmessage = this.receiveMessage; socket.onclose = () => { if (this.socket === socket) { this.onClose(); } }; this.socket = socket; (async () => { await sleep(clientRequestVisualTimeout); if (this.socket !== socket || openObject.initializeMessageSent) { return; } this.setLateResponse(-1, true); await sleep(remainingTimeAfterVisualTimeout); if (this.socket !== socket || openObject.initializeMessageSent) { return; } this.finishClosingSocket(); })(); this.setState({ inflightRequests: new InflightRequests({ timeout: () => { if (this.socket === socket) { this.finishClosingSocket(); } }, setLateResponse: (messageID: number, isLate: boolean) => { if (this.socket === socket) { this.setLateResponse(messageID, isLate); } }, }), }); } markSocketInitialized() { this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: 'connected', keyserverID: ashoatKeyserverID }, }); this.resetPing(); } closeSocket( // This param is a hack. When closing a socket there is a race between this // function and the one to propagate the activity update. We make sure that // the activity update wins the race by passing in this param. activityUpdatePending: boolean, ) { const { status } = this.props.connection; if (status === 'disconnected') { return; } else if (status === 'disconnecting' || status === 'forcedDisconnecting') { this.reopenConnectionAfterClosing = false; return; } this.stopPing(); this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: 'disconnecting', keyserverID: ashoatKeyserverID }, }); if (!activityUpdatePending) { this.finishClosingSocket(); } } forceCloseSocket() { this.stopPing(); const { status } = this.props.connection; if (status !== 'forcedDisconnecting' && status !== 'disconnected') { this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: 'forcedDisconnecting', keyserverID: ashoatKeyserverID, }, }); } this.finishClosingSocket(); } finishClosingSocket(receivedResponseTo?: ?number) { const { inflightRequests } = this.state; if ( inflightRequests && !inflightRequests.allRequestsResolvedExcept(receivedResponseTo) ) { return; } if (this.socket && this.socket.readyState < 2) { // If it's not closing already, close it this.socket.close(); } this.socket = null; this.stopPing(); this.setState({ inflightRequests: null }); if (this.props.connection.status !== 'disconnected') { this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: 'disconnected', keyserverID: ashoatKeyserverID }, }); } if (this.reopenConnectionAfterClosing) { this.reopenConnectionAfterClosing = false; if (this.props.active) { this.openSocket('connecting'); } } } reconnect: $Call void, number> = _throttle( () => this.openSocket('reconnecting'), 2000, ); componentDidMount() { if (this.props.detectUnsupervisedBackgroundRef) { this.props.detectUnsupervisedBackgroundRef( this.detectUnsupervisedBackground, ); } if (this.props.active) { this.openSocket('connecting'); } } componentWillUnmount() { this.closeSocket(false); this.reconnect.cancel(); } componentDidUpdate(prevProps: Props) { if (this.props.active && !prevProps.active) { this.openSocket('connecting'); } else if (!this.props.active && prevProps.active) { this.closeSocket(!!prevProps.activeThread); } else if ( this.props.active && prevProps.openSocket !== this.props.openSocket ) { // This case happens when the baseURL/urlPrefix is changed this.reopenConnectionAfterClosing = true; this.forceCloseSocket(); } else if ( this.props.active && this.props.connection.status === 'disconnected' && prevProps.connection.status !== 'disconnected' && !this.invalidationRecoveryInProgress ) { this.reconnect(); } } render(): React.Node { // It's important that APIRequestHandler get rendered first here. This is so // that it is registered with Redux first, so that its componentDidUpdate // processes before the other Handlers. This allows APIRequestHandler to // register itself with action-utils before other Handlers call // dispatchActionPromise in response to the componentDidUpdate triggered by // the same Redux change (state.connection.status). return ( ); } sendMessageWithoutID: (message: ClientSocketMessageWithoutID) => number = message => { const id = this.nextClientMessageID++; // These conditions all do the same thing and the runtime checks are only // necessary for Flow if (message.type === clientSocketMessageTypes.INITIAL) { this.sendMessage( ({ ...message, id }: ClientInitialClientSocketMessage), ); } else if (message.type === clientSocketMessageTypes.RESPONSES) { this.sendMessage( ({ ...message, id }: ClientResponsesClientSocketMessage), ); } else if (message.type === clientSocketMessageTypes.PING) { this.sendMessage(({ ...message, id }: PingClientSocketMessage)); } else if (message.type === clientSocketMessageTypes.ACK_UPDATES) { this.sendMessage(({ ...message, id }: AckUpdatesClientSocketMessage)); } else if (message.type === clientSocketMessageTypes.API_REQUEST) { this.sendMessage(({ ...message, id }: APIRequestClientSocketMessage)); } return id; }; sendMessage(message: ClientClientSocketMessage) { const socket = this.socket; invariant(socket, 'should be set'); socket.send(JSON.stringify(message)); } messageFromEvent(event: MessageEvent): ?ClientServerSocketMessage { if (typeof event.data !== 'string') { console.log('socket received a non-string message'); return null; } let rawMessage; try { rawMessage = JSON.parse(event.data); } catch (e) { console.log(e); return null; } if (rawMessage.type !== serverSocketMessageTypes.COMPRESSED_MESSAGE) { return rawMessage; } const result = this.props.decompressSocketMessage(rawMessage.payload); try { return JSON.parse(result); } catch (e) { console.log(e); return null; } } receiveMessage: (event: MessageEvent) => Promise = async event => { const message = this.messageFromEvent(event); if (!message) { return; } this.failuresAfterPolicyAcknowledgment = 0; const { inflightRequests } = this.state; if (!inflightRequests) { // inflightRequests can be falsey here if we receive a message after we've // begun shutting down the socket. It's possible for a React Native // WebSocket to deliver a message even after close() is called on it. In // this case the message is probably a PONG, which we can safely ignore. // If it's not a PONG, it has to be something server-initiated (like // UPDATES or MESSAGES), since InflightRequests.allRequestsResolvedExcept // will wait for all responses to client-initiated requests to be // delivered before closing a socket. UPDATES and MESSAGES are both // checkpointed on the client, so should be okay to just ignore here and // redownload them later, probably in an incremental STATE_SYNC. return; } // If we receive any message, that indicates that our connection is healthy, // so we can reset the ping timeout. this.resetPing(); inflightRequests.resolveRequestsForMessage(message); const { status } = this.props.connection; if (status === 'disconnecting' || status === 'forcedDisconnecting') { this.finishClosingSocket( // We do this for Flow message.responseTo !== undefined ? message.responseTo : null, ); } for (const listener of this.listeners) { listener(message); } if (message.type === serverSocketMessageTypes.ERROR) { const { message: errorMessage, payload } = message; if (payload) { console.log(`socket sent error ${errorMessage} with payload`, payload); } else { console.log(`socket sent error ${errorMessage}`); } if (errorMessage === 'policies_not_accepted' && this.props.active) { this.props.dispatch({ type: forcePolicyAcknowledgmentActionType, payload, }); } } else if (message.type === serverSocketMessageTypes.AUTH_ERROR) { const { sessionChange } = message; const cookie = sessionChange ? sessionChange.cookie : this.props.cookie; this.invalidationRecoveryInProgress = true; const recoverySessionChange = await fetchNewCookieFromNativeCredentials( this.props.dispatch, cookie, this.props.urlPrefix, logInActionSources.socketAuthErrorResolutionAttempt, ashoatKeyserverID, this.props.getInitialNotificationsEncryptedMessage, ); if (!recoverySessionChange && sessionChange) { - // This should only happen in the cookieSources.BODY (native) case when - // the resolution attempt failed const { cookie: newerCookie, currentUserInfo } = sessionChange; this.props.dispatch({ type: setNewSessionActionType, payload: { sessionChange: { cookieInvalidated: true, currentUserInfo, cookie: newerCookie, }, preRequestUserState: this.initializedWithUserState, error: null, logInActionSource: logInActionSources.socketAuthErrorResolutionAttempt, }, }); } else if (!recoverySessionChange) { this.props.dispatchActionPromise( logOutActionTypes, this.props.logOut(this.props.preRequestUserState), ); } this.invalidationRecoveryInProgress = false; } }; addListener: (listener: SocketListener) => void = listener => { this.listeners.add(listener); }; removeListener: (listener: SocketListener) => void = listener => { this.listeners.delete(listener); }; onClose: () => void = () => { const { status } = this.props.connection; this.socket = null; this.stopPing(); if (this.state.inflightRequests) { this.state.inflightRequests.rejectAll(new Error('socket closed')); this.setState({ inflightRequests: null }); } const handled = this.detectUnsupervisedBackground(true); if (!handled && status !== 'disconnected') { this.props.dispatch({ type: updateConnectionStatusActionType, payload: { status: 'disconnected', keyserverID: ashoatKeyserverID }, }); } }; async sendInitialMessage() { const { inflightRequests } = this.state; invariant( inflightRequests, 'inflightRequests falsey inside sendInitialMessage', ); const messageID = this.nextClientMessageID++; const promises = {}; const shouldSendInitialPlatformDetails = !_isEqual( this.props.lastCommunicatedPlatformDetails, )(getConfig().platformDetails); const clientResponses = []; if (shouldSendInitialPlatformDetails) { clientResponses.push({ type: serverRequestTypes.PLATFORM_DETAILS, platformDetails: getConfig().platformDetails, }); } const { queuedActivityUpdates } = this.props.connection; if (queuedActivityUpdates.length > 0) { clientResponses.push({ type: serverRequestTypes.INITIAL_ACTIVITY_UPDATES, activityUpdates: queuedActivityUpdates, }); promises.activityUpdateMessage = inflightRequests.fetchResponse( messageID, serverSocketMessageTypes.ACTIVITY_UPDATE_RESPONSE, ); } const sessionState = this.props.sessionStateFunc(); const { sessionIdentification } = this.props; const initialMessage = { type: clientSocketMessageTypes.INITIAL, id: messageID, payload: { clientResponses, sessionState, sessionIdentification, }, }; this.initializedWithUserState = this.props.preRequestUserState; this.sendMessage(initialMessage); promises.stateSyncMessage = inflightRequests.fetchResponse( messageID, serverSocketMessageTypes.STATE_SYNC, ); const { stateSyncMessage, activityUpdateMessage } = await promiseAll( promises, ); if (shouldSendInitialPlatformDetails) { this.props.dispatch({ type: updateLastCommunicatedPlatformDetailsActionType, payload: { platformDetails: getConfig().platformDetails, keyserverID: ashoatKeyserverID, }, }); } if (activityUpdateMessage) { this.props.dispatch({ type: updateActivityActionTypes.success, payload: { activityUpdates: { [ashoatKeyserverID]: queuedActivityUpdates }, result: activityUpdateMessage.payload, }, }); } if (stateSyncMessage.payload.type === stateSyncPayloadTypes.FULL) { const { sessionID, type, ...actionPayload } = stateSyncMessage.payload; this.props.dispatch({ type: fullStateSyncActionType, payload: { ...actionPayload, calendarQuery: sessionState.calendarQuery, keyserverID: ashoatKeyserverID, }, }); if (sessionID !== null && sessionID !== undefined) { invariant( this.initializedWithUserState, 'initializedWithUserState should be set when state sync received', ); this.props.dispatch({ type: setNewSessionActionType, payload: { sessionChange: { cookieInvalidated: false, sessionID }, preRequestUserState: this.initializedWithUserState, error: null, logInActionSource: undefined, keyserverID: ashoatKeyserverID, }, }); } } else { const { type, ...actionPayload } = stateSyncMessage.payload; this.props.dispatch({ type: incrementalStateSyncActionType, payload: { ...actionPayload, calendarQuery: sessionState.calendarQuery, keyserverID: ashoatKeyserverID, }, }); } const currentAsOf = stateSyncMessage.payload.type === stateSyncPayloadTypes.FULL ? stateSyncMessage.payload.updatesCurrentAsOf : stateSyncMessage.payload.updatesResult.currentAsOf; this.sendMessageWithoutID({ type: clientSocketMessageTypes.ACK_UPDATES, payload: { currentAsOf }, }); this.markSocketInitialized(); } initializeSocket: (retriesLeft?: number) => Promise = async ( retriesLeft = 1, ) => { try { await this.sendInitialMessage(); } catch (e) { if (this.props.noDataAfterPolicyAcknowledgment) { this.failuresAfterPolicyAcknowledgment++; } else { this.failuresAfterPolicyAcknowledgment = 0; } if ( this.failuresAfterPolicyAcknowledgment >= 2 && this.props.socketCrashLoopRecovery ) { this.failuresAfterPolicyAcknowledgment = 0; try { await this.props.socketCrashLoopRecovery(); } catch (error) { console.log(error); this.props.dispatchActionPromise( logOutActionTypes, this.props.logOut(this.props.preRequestUserState), ); } return; } console.log(e); const { status } = this.props.connection; if ( e instanceof SocketTimeout || e instanceof SocketOffline || (status !== 'connecting' && status !== 'reconnecting') ) { // This indicates that the socket will be closed. Do nothing, since the // connection status update will trigger a reconnect. } else if ( retriesLeft === 0 || (e instanceof ServerError && e.message !== 'unknown_error') ) { if (e.message === 'not_logged_in') { this.props.dispatchActionPromise( logOutActionTypes, this.props.logOut(this.props.preRequestUserState), ); } else if (this.socket) { this.socket.close(); } } else { await this.initializeSocket(retriesLeft - 1); } } }; stopPing() { if (this.pingTimeoutID) { clearTimeout(this.pingTimeoutID); this.pingTimeoutID = null; } } resetPing() { this.stopPing(); const socket = this.socket; this.messageLastReceived = Date.now(); this.pingTimeoutID = setTimeout(() => { if (this.socket === socket) { this.sendPing(); } }, pingFrequency); } async sendPing() { if (this.props.connection.status !== 'connected') { // This generally shouldn't happen because anything that changes the // connection status should call stopPing(), but it's good to make sure return; } const messageID = this.sendMessageWithoutID({ type: clientSocketMessageTypes.PING, }); try { invariant( this.state.inflightRequests, 'inflightRequests falsey inside sendPing', ); await this.state.inflightRequests.fetchResponse( messageID, serverSocketMessageTypes.PONG, ); } catch (e) {} } setLateResponse: (messageID: number, isLate: boolean) => void = ( messageID, isLate, ) => { this.props.dispatch({ type: setLateResponseActionType, payload: { messageID, isLate, keyserverID: ashoatKeyserverID }, }); }; cleanUpServerTerminatedSocket() { if (this.socket && this.socket.readyState < 2) { this.socket.close(); } else { this.onClose(); } } detectUnsupervisedBackground: (alreadyClosed: boolean) => boolean = alreadyClosed => { // On native, sometimes the app is backgrounded without the proper // callbacks getting triggered. This leaves us in an incorrect state for // two reasons: // (1) The connection is still considered to be active, causing API // requests to be processed via socket and failing. // (2) We rely on flipping foreground state in Redux to detect activity // changes, and thus won't think we need to update activity. if ( this.props.connection.status !== 'connected' || !this.messageLastReceived || this.messageLastReceived + serverRequestSocketTimeout >= Date.now() || (actionLogger.mostRecentActionTime && actionLogger.mostRecentActionTime + 3000 < Date.now()) ) { return false; } if (!alreadyClosed) { this.cleanUpServerTerminatedSocket(); } this.props.dispatch({ type: unsupervisedBackgroundActionType, payload: { keyserverID: ashoatKeyserverID }, }); return true; }; } export default Socket; diff --git a/lib/types/session-types.js b/lib/types/session-types.js index 14e06862f..420c2113e 100644 --- a/lib/types/session-types.js +++ b/lib/types/session-types.js @@ -1,125 +1,110 @@ // @flow import t, { type TInterface } from 'tcomb'; import type { LogInActionSource } from './account-types.js'; import type { Shape } from './core.js'; import type { CalendarQuery } from './entry-types.js'; import type { RawThreadInfos } from './thread-types.js'; import { type UserInfo, type CurrentUserInfo, type LoggedOutUserInfo, } from './user-types.js'; import { tShape } from '../utils/validation-utils.js'; export const cookieLifetime = 30 * 24 * 60 * 60 * 1000; // in milliseconds // Interval the server waits after a state check before starting a new one export const sessionCheckFrequency = 3 * 60 * 1000; // in milliseconds // How long the server debounces after activity before initiating a state check export const stateCheckInactivityActivationInterval = 3 * 1000; // in milliseconds -// On native, we specify the cookie directly in the request and response body. -// We do this because: -// (1) We don't have the same XSS risks as we do on web, so there is no need to -// prevent JavaScript from knowing the cookie password. -// (2) In the past the internal cookie logic on Android has been buggy. -// https://github.com/facebook/react-native/issues/12956 is an example -// issue. By specifying the cookie in the body we retain full control of how -// that data is passed, without necessitating any native modules like -// react-native-cookies. -export const cookieSources = Object.freeze({ - BODY: 0, - HEADER: 1, -}); -export type CookieSource = $Values; - // On native, we use the cookieID as a unique session identifier. This is // because there is no way to have two instances of an app running. On the other // hand, on web it is possible to have two sessions open using the same cookie, // so we have a unique sessionID specified in the request body. export const sessionIdentifierTypes = Object.freeze({ COOKIE_ID: 0, BODY_SESSION_ID: 1, }); export type SessionIdentifierType = $Values; export const cookieTypes = Object.freeze({ USER: 'user', ANONYMOUS: 'anonymous', }); export type CookieType = $Values; export type ServerSessionChange = | { cookieInvalidated: false, threadInfos: RawThreadInfos, userInfos: $ReadOnlyArray, sessionID?: null | string, cookie?: string, } | { cookieInvalidated: true, threadInfos: RawThreadInfos, userInfos: $ReadOnlyArray, currentUserInfo: LoggedOutUserInfo, sessionID?: null | string, cookie?: string, }; export type ClientSessionChange = | { +cookieInvalidated: false, +currentUserInfo?: ?CurrentUserInfo, +sessionID?: null | string, +cookie?: string, } | { +cookieInvalidated: true, +currentUserInfo: LoggedOutUserInfo, +sessionID?: null | string, +cookie?: string, }; export type PreRequestUserKeyserverSessionInfo = { +cookie: ?string, +sessionID: ?string, }; export type PreRequestUserState = { +currentUserInfo: ?CurrentUserInfo, +cookiesAndSessions: { +[keyserverID: string]: PreRequestUserKeyserverSessionInfo, }, }; export type SetSessionPayload = { +sessionChange: ClientSessionChange, +preRequestUserState: ?PreRequestUserState, +error: ?string, +logInActionSource: ?LogInActionSource, +keyserverID: string, }; export type SessionState = { calendarQuery: CalendarQuery, messagesCurrentAsOf: number, updatesCurrentAsOf: number, watchedIDs: $ReadOnlyArray, }; export type SessionIdentification = Shape<{ cookie: ?string, sessionID: ?string, }>; export type SessionPublicKeys = { +identityKey: string, +oneTimeKey?: string, }; export const sessionPublicKeysValidator: TInterface = tShape({ identityKey: t.String, oneTimeKey: t.maybe(t.String), }); diff --git a/lib/types/socket-types.js b/lib/types/socket-types.js index 6a6e7cfd4..aaf1c7fb3 100644 --- a/lib/types/socket-types.js +++ b/lib/types/socket-types.js @@ -1,540 +1,538 @@ // @flow import invariant from 'invariant'; import t, { type TInterface, type TUnion } from 'tcomb'; import { type ActivityUpdate, activityUpdateValidator, type UpdateActivityResult, updateActivityResultValidator, } from './activity-types.js'; import { type CompressedData, compressedDataValidator, } from './compression-types.js'; import type { APIRequest } from './endpoints.js'; import { type RawEntryInfo, rawEntryInfoValidator, type CalendarQuery, } from './entry-types.js'; import { type MessagesResponse, messagesResponseValidator, type NewMessagesPayload, newMessagesPayloadValidator, } from './message-types.js'; import { type ServerServerRequest, serverServerRequestValidator, type ClientServerRequest, type ClientResponse, type ClientClientResponse, } from './request-types.js'; import type { SessionState, SessionIdentification } from './session-types.js'; import { rawThreadInfoValidator, type RawThreadInfos } from './thread-types.js'; import { type ClientUpdatesResult, type ClientUpdatesResultWithUserInfos, type ServerUpdatesResult, serverUpdatesResultValidator, type ServerUpdatesResultWithUserInfos, serverUpdatesResultWithUserInfosValidator, } from './update-types.js'; import { type UserInfo, userInfoValidator, type CurrentUserInfo, currentUserInfoValidator, type LoggedOutUserInfo, loggedOutUserInfoValidator, } from './user-types.js'; import { tShape, tNumber, tID } from '../utils/validation-utils.js'; // The types of messages that the client sends across the socket export const clientSocketMessageTypes = Object.freeze({ INITIAL: 0, RESPONSES: 1, //ACTIVITY_UPDATES: 2, (DEPRECATED) PING: 3, ACK_UPDATES: 4, API_REQUEST: 5, }); export type ClientSocketMessageType = $Values; export function assertClientSocketMessageType( ourClientSocketMessageType: number, ): ClientSocketMessageType { invariant( ourClientSocketMessageType === 0 || ourClientSocketMessageType === 1 || ourClientSocketMessageType === 3 || ourClientSocketMessageType === 4 || ourClientSocketMessageType === 5, 'number is not ClientSocketMessageType enum', ); return ourClientSocketMessageType; } export type InitialClientSocketMessage = { +type: 0, +id: number, +payload: { +sessionIdentification: SessionIdentification, +sessionState: SessionState, +clientResponses: $ReadOnlyArray, }, }; export type ResponsesClientSocketMessage = { +type: 1, +id: number, +payload: { +clientResponses: $ReadOnlyArray, }, }; export type PingClientSocketMessage = { +type: 3, +id: number, }; export type AckUpdatesClientSocketMessage = { +type: 4, +id: number, +payload: { +currentAsOf: number, }, }; export type APIRequestClientSocketMessage = { +type: 5, +id: number, +payload: APIRequest, }; export type ClientSocketMessage = | InitialClientSocketMessage | ResponsesClientSocketMessage | PingClientSocketMessage | AckUpdatesClientSocketMessage | APIRequestClientSocketMessage; export type ClientInitialClientSocketMessage = { +type: 0, +id: number, +payload: { +sessionIdentification: SessionIdentification, +sessionState: SessionState, +clientResponses: $ReadOnlyArray, }, }; export type ClientResponsesClientSocketMessage = { +type: 1, +id: number, +payload: { +clientResponses: $ReadOnlyArray, }, }; export type ClientClientSocketMessage = | ClientInitialClientSocketMessage | ClientResponsesClientSocketMessage | PingClientSocketMessage | AckUpdatesClientSocketMessage | APIRequestClientSocketMessage; export type ClientSocketMessageWithoutID = $Diff< ClientClientSocketMessage, { id: number }, >; // The types of messages that the server sends across the socket export const serverSocketMessageTypes = Object.freeze({ STATE_SYNC: 0, REQUESTS: 1, ERROR: 2, AUTH_ERROR: 3, ACTIVITY_UPDATE_RESPONSE: 4, PONG: 5, UPDATES: 6, MESSAGES: 7, API_RESPONSE: 8, COMPRESSED_MESSAGE: 9, }); export type ServerSocketMessageType = $Values; export function assertServerSocketMessageType( ourServerSocketMessageType: number, ): ServerSocketMessageType { invariant( ourServerSocketMessageType === 0 || ourServerSocketMessageType === 1 || ourServerSocketMessageType === 2 || ourServerSocketMessageType === 3 || ourServerSocketMessageType === 4 || ourServerSocketMessageType === 5 || ourServerSocketMessageType === 6 || ourServerSocketMessageType === 7 || ourServerSocketMessageType === 8 || ourServerSocketMessageType === 9, 'number is not ServerSocketMessageType enum', ); return ourServerSocketMessageType; } export const stateSyncPayloadTypes = Object.freeze({ FULL: 0, INCREMENTAL: 1, }); export const fullStateSyncActionType = 'FULL_STATE_SYNC'; export type BaseFullStateSync = { +messagesResult: MessagesResponse, +threadInfos: RawThreadInfos, +rawEntryInfos: $ReadOnlyArray, +userInfos: $ReadOnlyArray, +updatesCurrentAsOf: number, }; const baseFullStateSyncValidator = tShape({ messagesResult: messagesResponseValidator, threadInfos: t.dict(tID, rawThreadInfoValidator), rawEntryInfos: t.list(rawEntryInfoValidator), userInfos: t.list(userInfoValidator), updatesCurrentAsOf: t.Number, }); export type ClientFullStateSync = { ...BaseFullStateSync, +currentUserInfo: CurrentUserInfo, }; export type StateSyncFullActionPayload = { ...ClientFullStateSync, +calendarQuery: CalendarQuery, +keyserverID: string, }; export type ClientStateSyncFullSocketPayload = { ...ClientFullStateSync, +type: 0, // Included iff client is using sessionIdentifierTypes.BODY_SESSION_ID +sessionID?: string, }; export type ServerFullStateSync = { ...BaseFullStateSync, +currentUserInfo: CurrentUserInfo, }; const serverFullStateSyncValidator = tShape({ ...baseFullStateSyncValidator.meta.props, currentUserInfo: currentUserInfoValidator, }); export type ServerStateSyncFullSocketPayload = { ...ServerFullStateSync, +type: 0, // Included iff client is using sessionIdentifierTypes.BODY_SESSION_ID +sessionID?: string, }; const serverStateSyncFullSocketPayloadValidator = tShape({ ...serverFullStateSyncValidator.meta.props, type: tNumber(stateSyncPayloadTypes.FULL), sessionID: t.maybe(t.String), }); export const incrementalStateSyncActionType = 'INCREMENTAL_STATE_SYNC'; export type BaseIncrementalStateSync = { +messagesResult: MessagesResponse, +deltaEntryInfos: $ReadOnlyArray, +deletedEntryIDs: $ReadOnlyArray, +userInfos: $ReadOnlyArray, }; const baseIncrementalStateSyncValidator = tShape({ messagesResult: messagesResponseValidator, deltaEntryInfos: t.list(rawEntryInfoValidator), deletedEntryIDs: t.list(tID), userInfos: t.list(userInfoValidator), }); export type ClientIncrementalStateSync = { ...BaseIncrementalStateSync, +updatesResult: ClientUpdatesResult, }; export type StateSyncIncrementalActionPayload = { ...ClientIncrementalStateSync, +calendarQuery: CalendarQuery, +keyserverID: string, }; type ClientStateSyncIncrementalSocketPayload = { +type: 1, ...ClientIncrementalStateSync, }; export type ServerIncrementalStateSync = { ...BaseIncrementalStateSync, +updatesResult: ServerUpdatesResult, }; const serverIncrementalStateSyncValidator = tShape({ ...baseIncrementalStateSyncValidator.meta.props, updatesResult: serverUpdatesResultValidator, }); type ServerStateSyncIncrementalSocketPayload = { +type: 1, ...ServerIncrementalStateSync, }; const serverStateSyncIncrementalSocketPayloadValidator = tShape({ type: tNumber(stateSyncPayloadTypes.INCREMENTAL), ...serverIncrementalStateSyncValidator.meta.props, }); export type ClientStateSyncSocketPayload = | ClientStateSyncFullSocketPayload | ClientStateSyncIncrementalSocketPayload; export type ServerStateSyncSocketPayload = | ServerStateSyncFullSocketPayload | ServerStateSyncIncrementalSocketPayload; const serverStateSyncSocketPayloadValidator = t.union([ serverStateSyncFullSocketPayloadValidator, serverStateSyncIncrementalSocketPayloadValidator, ]); export type ServerStateSyncServerSocketMessage = { +type: 0, +responseTo: number, +payload: ServerStateSyncSocketPayload, }; export const serverStateSyncServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.STATE_SYNC), responseTo: t.Number, payload: serverStateSyncSocketPayloadValidator, }); export type ServerRequestsServerSocketMessage = { +type: 1, +responseTo?: number, +payload: { +serverRequests: $ReadOnlyArray, }, }; export const serverRequestsServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.REQUESTS), responseTo: t.maybe(t.Number), payload: tShape({ serverRequests: t.list(serverServerRequestValidator), }), }); export type ErrorServerSocketMessage = { type: 2, responseTo?: number, message: string, payload?: Object, }; export const errorServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.ERROR), responseTo: t.maybe(t.Number), message: t.String, payload: t.maybe(t.Object), }); export type AuthErrorServerSocketMessage = { - type: 3, - responseTo: number, - message: string, - // If unspecified, it is because the client is using cookieSources.HEADER, - // which means the server can't update the cookie from a socket message. - sessionChange?: { - cookie: string, - currentUserInfo: LoggedOutUserInfo, + +type: 3, + +responseTo: number, + +message: string, + +sessionChange: { + +cookie: string, + +currentUserInfo: LoggedOutUserInfo, }, }; export const authErrorServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.AUTH_ERROR), responseTo: t.Number, message: t.String, sessionChange: t.maybe( tShape({ cookie: t.String, currentUserInfo: loggedOutUserInfoValidator }), ), }); export type ActivityUpdateResponseServerSocketMessage = { +type: 4, +responseTo: number, +payload: UpdateActivityResult, }; export const activityUpdateResponseServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.ACTIVITY_UPDATE_RESPONSE), responseTo: t.Number, payload: updateActivityResultValidator, }); export type PongServerSocketMessage = { +type: 5, +responseTo: number, }; export const pongServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.PONG), responseTo: t.Number, }); export type ServerUpdatesServerSocketMessage = { +type: 6, +payload: ServerUpdatesResultWithUserInfos, }; export const serverUpdatesServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.UPDATES), payload: serverUpdatesResultWithUserInfosValidator, }); export type MessagesServerSocketMessage = { +type: 7, +payload: NewMessagesPayload, }; export const messagesServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.MESSAGES), payload: newMessagesPayloadValidator, }); export type APIResponseServerSocketMessage = { +type: 8, +responseTo: number, +payload?: Object, }; export const apiResponseServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.API_RESPONSE), responseTo: t.Number, payload: t.maybe(t.Object), }); export type CompressedMessageServerSocketMessage = { +type: 9, +payload: CompressedData, }; export const compressedMessageServerSocketMessageValidator: TInterface = tShape({ type: tNumber(serverSocketMessageTypes.COMPRESSED_MESSAGE), payload: compressedDataValidator, }); export type ServerServerSocketMessage = | ServerStateSyncServerSocketMessage | ServerRequestsServerSocketMessage | ErrorServerSocketMessage | AuthErrorServerSocketMessage | ActivityUpdateResponseServerSocketMessage | PongServerSocketMessage | ServerUpdatesServerSocketMessage | MessagesServerSocketMessage | APIResponseServerSocketMessage | CompressedMessageServerSocketMessage; export const serverServerSocketMessageValidator: TUnion = t.union([ serverStateSyncServerSocketMessageValidator, serverRequestsServerSocketMessageValidator, errorServerSocketMessageValidator, authErrorServerSocketMessageValidator, activityUpdateResponseServerSocketMessageValidator, pongServerSocketMessageValidator, serverUpdatesServerSocketMessageValidator, messagesServerSocketMessageValidator, apiResponseServerSocketMessageValidator, compressedMessageServerSocketMessageValidator, ]); export type ClientRequestsServerSocketMessage = { +type: 1, +responseTo?: number, +payload: { +serverRequests: $ReadOnlyArray, }, }; export type ClientStateSyncServerSocketMessage = { +type: 0, +responseTo: number, +payload: ClientStateSyncSocketPayload, }; export type ClientUpdatesServerSocketMessage = { +type: 6, +payload: ClientUpdatesResultWithUserInfos, }; export type ClientServerSocketMessage = | ClientStateSyncServerSocketMessage | ClientRequestsServerSocketMessage | ErrorServerSocketMessage | AuthErrorServerSocketMessage | ActivityUpdateResponseServerSocketMessage | PongServerSocketMessage | ClientUpdatesServerSocketMessage | MessagesServerSocketMessage | APIResponseServerSocketMessage | CompressedMessageServerSocketMessage; export type SocketListener = (message: ClientServerSocketMessage) => void; export type ConnectionStatus = | 'connecting' | 'connected' | 'reconnecting' | 'disconnecting' | 'forcedDisconnecting' | 'disconnected'; export type ConnectionInfo = { +status: ConnectionStatus, +queuedActivityUpdates: $ReadOnlyArray, +lateResponses: $ReadOnlyArray, +showDisconnectedBar: boolean, }; export const connectionInfoValidator: TInterface = tShape({ status: t.enums.of([ 'connecting', 'connected', 'reconnecting', 'disconnecting', 'forcedDisconnecting', 'disconnected', ]), queuedActivityUpdates: t.list(activityUpdateValidator), lateResponses: t.list(t.Number), showDisconnectedBar: t.Boolean, }); export const defaultConnectionInfo: ConnectionInfo = { status: 'connecting', queuedActivityUpdates: [], lateResponses: [], showDisconnectedBar: false, }; export const updateConnectionStatusActionType = 'UPDATE_CONNECTION_STATUS'; export type UpdateConnectionStatusPayload = { +status: ConnectionStatus, +keyserverID: string, }; export const setLateResponseActionType = 'SET_LATE_RESPONSE'; export type SetLateResponsePayload = { +messageID: number, +isLate: boolean, +keyserverID: string, }; export const updateDisconnectedBarActionType = 'UPDATE_DISCONNECTED_BAR'; export type UpdateDisconnectedBarPayload = { +visible: boolean, +keyserverID: string, }; export type OneTimeKeyGenerator = (inc: number) => string; export type GRPCStream = { readyState: number, onopen: (ev: any) => mixed, onmessage: (ev: MessageEvent) => mixed, onclose: (ev: CloseEvent) => mixed, close(code?: number, reason?: string): void, send(data: string | Blob | ArrayBuffer | $ArrayBufferView): void, }; export type CommTransportLayer = GRPCStream | WebSocket;