diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -9,11 +9,16 @@ import { baseLegalPolicies } from 'lib/facts/policies.js'; import { mostRecentMessageTimestamp } from 'lib/shared/message-utils.js'; +import { isStaff } from 'lib/shared/staff-utils.js'; import { serverRequestSocketTimeout, serverResponseTimeout, } from 'lib/shared/timeouts.js'; import { mostRecentUpdateTimestamp } from 'lib/shared/update-utils.js'; +import { + hasMinCodeVersion, + NEXT_CODE_VERSION, +} from 'lib/shared/version-utils.js'; import type { Shape } from 'lib/types/core.js'; import { endpointIsSocketSafe } from 'lib/types/endpoints.js'; import { defaultNumberPerThread } from 'lib/types/message-types.js'; @@ -79,6 +84,7 @@ import { Viewer } from '../session/viewer.js'; import { serverStateSyncSpecs } from '../shared/state-sync/state-sync-specs.js'; import { commitSessionUpdate } from '../updaters/session-updaters.js'; +import { compressMessage } from '../utils/compress.js'; import { assertSecureRequest } from '../utils/security-utils.js'; import { checkInputValidator, @@ -158,6 +164,11 @@ stateCheckOngoing: boolean, }; +const minVersionsForCompression = { + native: NEXT_CODE_VERSION, + web: NEXT_CODE_VERSION, +}; + class Socket { ws: WebSocket; httpRequest: $Request; @@ -373,15 +384,45 @@ this.ws.readyState > 0, "shouldn't send message until connection established", ); - if (this.ws.readyState === 1) { - const { viewer } = this; - const validatedMessage = validateOutput( - viewer?.platformDetails, - serverServerSocketMessageValidator, - message, - ); - this.ws.send(JSON.stringify(validatedMessage)); + if (this.ws.readyState !== 1) { + return; } + + const { viewer } = this; + const validatedMessage = validateOutput( + viewer?.platformDetails, + serverServerSocketMessageValidator, + message, + ); + const stringMessage = JSON.stringify(validatedMessage); + + if ( + !viewer?.platformDetails || + !hasMinCodeVersion(viewer.platformDetails, minVersionsForCompression) || + !isStaff(viewer.id) + ) { + this.ws.send(stringMessage); + return; + } + + const compressionResult = compressMessage(stringMessage); + if (!compressionResult.compressed) { + this.ws.send(stringMessage); + return; + } + + const compressedMessage = { + type: serverSocketMessageTypes.COMPRESSED_MESSAGE, + payload: compressionResult.result, + }; + + const validatedCompressedMessage = validateOutput( + viewer?.platformDetails, + serverServerSocketMessageValidator, + compressedMessage, + ); + const stringCompressedMessage = JSON.stringify(validatedCompressedMessage); + this.ws.send(stringCompressedMessage); }; async handleClientSocketMessage( diff --git a/keyserver/src/utils/compress.js b/keyserver/src/utils/compress.js new file mode 100644 --- /dev/null +++ b/keyserver/src/utils/compress.js @@ -0,0 +1,31 @@ +// @flow + +import zlib from 'zlib'; + +import type { CompressedData } from 'lib/types/compression-types.js'; + +const brotliOptions = { + params: { + [zlib.constants.BROTLI_PARAM_MODE]: zlib.constants.BROTLI_MODE_TEXT, + }, +}; +const minimumSizeForCompression = 4096; // bytes + +type CompressionResult = + | { +compressed: true, +result: CompressedData } + | { +compressed: false, +result: string }; +function compressMessage(message: string): CompressionResult { + const bytesInMessage = new Blob([message]).size; + if (bytesInMessage < minimumSizeForCompression) { + return { compressed: false, result: message }; + } + const brotliResult = zlib.brotliCompressSync(message, brotliOptions); + const base64Encoded = brotliResult.toString('base64'); + const result = { + algo: 'brotli+base64', + data: base64Encoded, + }; + return { compressed: true, result }; +} + +export { compressMessage }; diff --git a/lib/socket/socket.react.js b/lib/socket/socket.react.js --- a/lib/socket/socket.react.js +++ b/lib/socket/socket.react.js @@ -27,6 +27,7 @@ logInActionSources, type LogOutResult, } from '../types/account-types.js'; +import type { CompressedData } from '../types/compression-types.js'; import { isWebPlatform, type PlatformDetails } from '../types/device-types.js'; import type { CalendarQuery } from '../types/entry-types.js'; import { forcePolicyAcknowledgmentActionType } from '../types/policy-types.js'; @@ -101,6 +102,7 @@ +preRequestUserState: PreRequestUserState, +noDataAfterPolicyAcknowledgment?: boolean, +lastCommunicatedPlatformDetails: ?PlatformDetails, + +decompressSocketMessage: CompressedData => string, // Redux dispatch functions +dispatch: Dispatch, +dispatchActionPromise: DispatchActionPromise, @@ -392,13 +394,27 @@ socket.send(JSON.stringify(message)); } - static messageFromEvent(event: MessageEvent): ?ClientServerSocketMessage { + messageFromEvent(event: MessageEvent): ?ClientServerSocketMessage { if (typeof event.data !== 'string') { console.log('socket received a non-string message'); return null; } + + let rawMessage; + try { + rawMessage = JSON.parse(event.data); + } catch (e) { + console.log(e); + return null; + } + + if (rawMessage.type !== serverSocketMessageTypes.COMPRESSED_MESSAGE) { + return rawMessage; + } + + const result = this.props.decompressSocketMessage(rawMessage.payload); try { - return JSON.parse(event.data); + return JSON.parse(result); } catch (e) { console.log(e); return null; @@ -406,7 +422,7 @@ } receiveMessage: (event: MessageEvent) => Promise = async event => { - const message = Socket.messageFromEvent(event); + const message = this.messageFromEvent(event); if (!message) { return; } diff --git a/lib/types/compression-types.js b/lib/types/compression-types.js new file mode 100644 --- /dev/null +++ b/lib/types/compression-types.js @@ -0,0 +1,16 @@ +// @flow + +import t, { type TInterface } from 'tcomb'; + +import { tShape, tString } from '../utils/validation-utils.js'; + +export type CompressedData = { + +algo: 'brotli+base64', + +data: string, +}; + +export const compressedDataValidator: TInterface = + tShape({ + algo: tString('brotli+base64'), + data: t.String, + }); diff --git a/lib/types/socket-types.js b/lib/types/socket-types.js --- a/lib/types/socket-types.js +++ b/lib/types/socket-types.js @@ -9,6 +9,10 @@ type UpdateActivityResult, updateActivityResultValidator, } from './activity-types.js'; +import { + type CompressedData, + compressedDataValidator, +} from './compression-types.js'; import type { APIRequest } from './endpoints.js'; import { type RawEntryInfo, @@ -149,6 +153,7 @@ UPDATES: 6, MESSAGES: 7, API_RESPONSE: 8, + COMPRESSED_MESSAGE: 9, }); export type ServerSocketMessageType = $Values; export function assertServerSocketMessageType( @@ -163,7 +168,8 @@ ourServerSocketMessageType === 5 || ourServerSocketMessageType === 6 || ourServerSocketMessageType === 7 || - ourServerSocketMessageType === 8, + ourServerSocketMessageType === 8 || + ourServerSocketMessageType === 9, 'number is not ServerSocketMessageType enum', ); return ourServerSocketMessageType; @@ -401,6 +407,16 @@ payload: t.maybe(t.Object), }); +export type CompressedMessageServerSocketMessage = { + +type: 9, + +payload: CompressedData, +}; +export const compressedMessageServerSocketMessageValidator: TInterface = + tShape({ + type: tNumber(serverSocketMessageTypes.COMPRESSED_MESSAGE), + payload: compressedDataValidator, + }); + export type ServerServerSocketMessage = | ServerStateSyncServerSocketMessage | ServerRequestsServerSocketMessage @@ -410,7 +426,8 @@ | PongServerSocketMessage | ServerUpdatesServerSocketMessage | MessagesServerSocketMessage - | APIResponseServerSocketMessage; + | APIResponseServerSocketMessage + | CompressedMessageServerSocketMessage; export const serverServerSocketMessageValidator: TUnion = t.union([ serverStateSyncServerSocketMessageValidator, @@ -422,6 +439,7 @@ serverUpdatesServerSocketMessageValidator, messagesServerSocketMessageValidator, apiResponseServerSocketMessageValidator, + compressedMessageServerSocketMessageValidator, ]); export type ClientRequestsServerSocketMessage = { @@ -449,7 +467,8 @@ | PongServerSocketMessage | ClientUpdatesServerSocketMessage | MessagesServerSocketMessage - | APIResponseServerSocketMessage; + | APIResponseServerSocketMessage + | CompressedMessageServerSocketMessage; export type SocketListener = (message: ClientServerSocketMessage) => void; diff --git a/native/socket.react.js b/native/socket.react.js --- a/native/socket.react.js +++ b/native/socket.react.js @@ -38,6 +38,7 @@ } from './selectors/socket-selectors.js'; import Alert from './utils/alert.js'; import { useInitialNotificationsEncryptedMessage } from './utils/crypto-utils.js'; +import { decompressMessage } from './utils/decompress.js'; const NativeSocket: React.ComponentType = React.memo(function NativeSocket(props: BaseSocketProps) { @@ -165,6 +166,7 @@ getInitialNotificationsEncryptedMessage } lastCommunicatedPlatformDetails={lastCommunicatedPlatformDetails} + decompressSocketMessage={decompressMessage} /> ); }); diff --git a/native/utils/decompress.js b/native/utils/decompress.js new file mode 100644 --- /dev/null +++ b/native/utils/decompress.js @@ -0,0 +1,17 @@ +// @flow + +import decompress from 'brotli/decompress.js'; +import { Buffer } from 'buffer'; +import invariant from 'invariant'; + +import type { CompressedData } from 'lib/types/compression-types.js'; + +function decompressMessage(message: CompressedData): string { + invariant(message.algo === 'brotli+base64', 'only supports brotli+base64'); + const inBuffer = Buffer.from(message.data, 'base64'); + const decompressed = decompress(inBuffer); + const outBuffer = Buffer.from(decompressed); + return outBuffer.toString('utf-8'); +} + +export { decompressMessage }; diff --git a/web/.eslintrc.json b/web/.eslintrc.json --- a/web/.eslintrc.json +++ b/web/.eslintrc.json @@ -4,6 +4,7 @@ "jest": true }, "globals": { - "process": true + "process": true, + "Buffer": true } } diff --git a/web/socket.react.js b/web/socket.react.js --- a/web/socket.react.js +++ b/web/socket.react.js @@ -29,6 +29,7 @@ webGetClientResponsesSelector, webSessionStateFuncSelector, } from './selectors/socket-selectors.js'; +import { decompressMessage } from './utils/decompress.js'; const WebSocket: React.ComponentType = React.memo(function WebSocket(props) { @@ -89,6 +90,7 @@ dispatchActionPromise={dispatchActionPromise} logOut={callLogOut} lastCommunicatedPlatformDetails={lastCommunicatedPlatformDetails} + decompressSocketMessage={decompressMessage} /> ); }); diff --git a/web/utils/decompress.js b/web/utils/decompress.js new file mode 100644 --- /dev/null +++ b/web/utils/decompress.js @@ -0,0 +1,16 @@ +// @flow + +import decompress from 'brotli/decompress.js'; +import invariant from 'invariant'; + +import type { CompressedData } from 'lib/types/compression-types.js'; + +function decompressMessage(message: CompressedData): string { + invariant(message.algo === 'brotli+base64', 'only supports brotli+base64'); + const inBuffer = Buffer.from(message.data, 'base64'); + const decompressed = decompress(inBuffer); + const outBuffer = Buffer.from(decompressed); + return outBuffer.toString('utf-8'); +} + +export { decompressMessage };