diff --git a/keyserver/src/socket/tunnelbroker.js b/keyserver/src/socket/tunnelbroker.js index 455aeae5b..2f057491e 100644 --- a/keyserver/src/socket/tunnelbroker.js +++ b/keyserver/src/socket/tunnelbroker.js @@ -1,480 +1,480 @@ // @flow import invariant from 'invariant'; import _debounce from 'lodash/debounce.js'; import { getRustAPI } from 'rust-node-addon'; import uuid from 'uuid'; import WebSocket from 'ws'; import { hexToUintArray } from 'lib/media/data-utils.js'; import { clientTunnelbrokerSocketReconnectDelay, tunnelbrokerHeartbeatTimeout, } from 'lib/shared/timeouts.js'; import type { TunnelbrokerClientMessageToDevice } from 'lib/tunnelbroker/tunnelbroker-context.js'; +import type { MessageSentStatus } from 'lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js'; import type { MessageReceiveConfirmation } from 'lib/types/tunnelbroker/message-receive-confirmation-types.js'; -import type { MessageSentStatus } from 'lib/types/tunnelbroker/message-to-device-request-status-types.js'; import type { MessageToDeviceRequest } from 'lib/types/tunnelbroker/message-to-device-request-types.js'; import { deviceToTunnelbrokerMessageTypes, tunnelbrokerToDeviceMessageTypes, tunnelbrokerToDeviceMessageValidator, type TunnelbrokerToDeviceMessage, } from 'lib/types/tunnelbroker/messages.js'; import { qrCodeAuthMessageValidator, type RefreshKeyRequest, refreshKeysRequestValidator, type QRCodeAuthMessage, } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { peerToPeerMessageTypes } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { type QRCodeAuthMessagePayload, qrCodeAuthMessagePayloadValidator, qrCodeAuthMessageTypes, } from 'lib/types/tunnelbroker/qr-code-auth-message-types.js'; import type { ConnectionInitializationMessage, AnonymousInitializationMessage, } from 'lib/types/tunnelbroker/session-types.js'; import type { Heartbeat } from 'lib/types/websocket/heartbeat-types.js'; import { getCommConfig } from 'lib/utils/comm-config.js'; import { convertBytesToObj, convertObjToBytes, } from 'lib/utils/conversion-utils.js'; import { getMessageForException } from 'lib/utils/errors.js'; import sleep from 'lib/utils/sleep.js'; import { fetchOlmAccount } from '../updaters/olm-account-updater.js'; import { fetchIdentityInfo, saveIdentityInfo } from '../user/identity.js'; import type { IdentityInfo } from '../user/identity.js'; import { encrypt, decrypt } from '../utils/aes-crypto-utils.js'; import { getContentSigningKey, uploadNewOneTimeKeys, getNewDeviceKeyUpload, markPrekeysAsPublished, } from '../utils/olm-utils.js'; type TBConnectionInfo = { +url: string, }; async function getTBConnectionInfo(): Promise { const tbConfig = await getCommConfig({ folder: 'facts', name: 'tunnelbroker', }); if (tbConfig) { return tbConfig; } console.warn('Defaulting to staging Tunnelbroker'); return { url: 'wss://tunnelbroker.staging.commtechnologies.org:51001', }; } async function createAndMaintainTunnelbrokerWebsocket(encryptionKey: ?string) { const [deviceID, tbConnectionInfo] = await Promise.all([ getContentSigningKey(), getTBConnectionInfo(), ]); const createNewTunnelbrokerSocket = async ( shouldNotifyPrimaryAfterReopening: boolean, primaryDeviceID: ?string, ) => { const identityInfo = await fetchIdentityInfo(); new TunnelbrokerSocket({ socketURL: tbConnectionInfo.url, onClose: async (successfullyAuthed: boolean, primaryID: ?string) => { await sleep(clientTunnelbrokerSocketReconnectDelay); await createNewTunnelbrokerSocket(successfullyAuthed, primaryID); }, identityInfo, deviceID, qrAuthEncryptionKey: encryptionKey, primaryDeviceID, shouldNotifyPrimaryAfterReopening, }); }; await createNewTunnelbrokerSocket(false, null); } type TunnelbrokerSocketParams = { +socketURL: string, +onClose: (boolean, ?string) => mixed, +identityInfo: ?IdentityInfo, +deviceID: string, +qrAuthEncryptionKey: ?string, +primaryDeviceID: ?string, +shouldNotifyPrimaryAfterReopening: boolean, }; type PromiseCallbacks = { +resolve: () => void, +reject: (error: string) => void, }; type Promises = { [clientMessageID: string]: PromiseCallbacks }; class TunnelbrokerSocket { ws: WebSocket; connected: boolean = false; closed: boolean = false; promises: Promises = {}; heartbeatTimeoutID: ?TimeoutID; oneTimeKeysPromise: ?Promise; identityInfo: ?IdentityInfo; qrAuthEncryptionKey: ?string; primaryDeviceID: ?string; shouldNotifyPrimaryAfterReopening: boolean = false; shouldNotifyPrimary: boolean = false; constructor(tunnelbrokerSocketParams: TunnelbrokerSocketParams) { const { socketURL, onClose, identityInfo, deviceID, qrAuthEncryptionKey, primaryDeviceID, shouldNotifyPrimaryAfterReopening, } = tunnelbrokerSocketParams; this.identityInfo = identityInfo; this.qrAuthEncryptionKey = qrAuthEncryptionKey; this.primaryDeviceID = primaryDeviceID; if (shouldNotifyPrimaryAfterReopening) { this.shouldNotifyPrimary = true; } const socket = new WebSocket(socketURL); socket.on('open', () => { this.onOpen(socket, deviceID); }); socket.on('close', async () => { if (this.closed) { return; } this.closed = true; this.connected = false; this.stopHeartbeatTimeout(); console.error('Connection to Tunnelbroker closed'); onClose(this.shouldNotifyPrimaryAfterReopening, this.primaryDeviceID); }); socket.on('error', (error: Error) => { console.error('Tunnelbroker socket error:', error.message); }); socket.on('message', this.onMessage); this.ws = socket; } onOpen: (socket: WebSocket, deviceID: string) => void = ( socket, deviceID, ) => { if (this.closed) { return; } if (this.identityInfo) { const initMessage: ConnectionInitializationMessage = { type: 'ConnectionInitializationMessage', deviceID, accessToken: this.identityInfo.accessToken, userID: this.identityInfo.userId, deviceType: 'keyserver', }; socket.send(JSON.stringify(initMessage)); } else { const initMessage: AnonymousInitializationMessage = { type: 'AnonymousInitializationMessage', deviceID, deviceType: 'keyserver', }; socket.send(JSON.stringify(initMessage)); } }; onMessage: (event: ArrayBuffer) => Promise = async ( event: ArrayBuffer, ) => { let rawMessage; try { rawMessage = JSON.parse(event.toString()); } catch (e) { console.error('error while parsing Tunnelbroker message:', e.message); return; } if (!tunnelbrokerToDeviceMessageValidator.is(rawMessage)) { console.error( 'invalid tunnelbrokerToDeviceMessage: ', rawMessage.toString(), ); return; } const message: TunnelbrokerToDeviceMessage = rawMessage; this.resetHeartbeatTimeout(); if ( message.type === tunnelbrokerToDeviceMessageTypes.CONNECTION_INITIALIZATION_RESPONSE ) { if (message.status.type === 'Success' && !this.connected) { this.connected = true; console.info( this.identityInfo ? 'session with Tunnelbroker created' : 'anonymous session with Tunnelbroker created', ); if (!this.shouldNotifyPrimary) { return; } const { primaryDeviceID } = this; invariant( primaryDeviceID, 'Primary device ID is not set but should be', ); const payload = await this.encodeQRAuthMessage({ type: qrCodeAuthMessageTypes.SECONDARY_DEVICE_REGISTRATION_SUCCESS, requestBackupKeys: false, }); if (!payload) { this.closeConnection(); return; } await this.sendMessageToDevice({ deviceID: primaryDeviceID, payload: JSON.stringify(payload), }); } else if (message.status.type === 'Success' && this.connected) { console.info( 'received ConnectionInitializationResponse with status: Success for already connected socket', ); } else { this.connected = false; console.error( 'creating session with Tunnelbroker error:', message.status.data, ); } } else if ( message.type === tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE ) { const confirmation: MessageReceiveConfirmation = { type: deviceToTunnelbrokerMessageTypes.MESSAGE_RECEIVE_CONFIRMATION, messageIDs: [message.messageID], }; this.ws.send(JSON.stringify(confirmation)); const { payload } = message; try { const messageToKeyserver = JSON.parse(payload); if (qrCodeAuthMessageValidator.is(messageToKeyserver)) { const request: QRCodeAuthMessage = messageToKeyserver; const [qrCodeAuthMessage, rustAPI, accountInfo] = await Promise.all([ this.parseQRCodeAuthMessage(request), getRustAPI(), fetchOlmAccount('content'), ]); if ( !qrCodeAuthMessage || qrCodeAuthMessage.type !== qrCodeAuthMessageTypes.DEVICE_LIST_UPDATE_SUCCESS ) { return; } const { primaryDeviceID, userID } = qrCodeAuthMessage; this.primaryDeviceID = primaryDeviceID; const [nonce, deviceKeyUpload] = await Promise.all([ rustAPI.generateNonce(), getNewDeviceKeyUpload(), ]); const signedIdentityKeysBlob = { payload: deviceKeyUpload.keyPayload, signature: deviceKeyUpload.keyPayloadSignature, }; const nonceSignature = accountInfo.account.sign(nonce); const identityInfo = await rustAPI.uploadSecondaryDeviceKeysAndLogIn( userID, nonce, nonceSignature, signedIdentityKeysBlob, deviceKeyUpload.contentPrekey, deviceKeyUpload.contentPrekeySignature, deviceKeyUpload.notifPrekey, deviceKeyUpload.notifPrekeySignature, deviceKeyUpload.contentOneTimeKeys, deviceKeyUpload.notifOneTimeKeys, ); await Promise.all([ markPrekeysAsPublished(), saveIdentityInfo(identityInfo), ]); this.shouldNotifyPrimaryAfterReopening = true; this.closeConnection(); } else if (refreshKeysRequestValidator.is(messageToKeyserver)) { const request: RefreshKeyRequest = messageToKeyserver; this.debouncedRefreshOneTimeKeys(request.numberOfKeys); } } catch (e) { console.error( 'error while processing message to keyserver:', e.message, ); } } else if ( message.type === - tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS + tunnelbrokerToDeviceMessageTypes.DEVICE_TO_TUNNELBROKER_REQUEST_STATUS ) { for (const status: MessageSentStatus of message.clientMessageIDs) { if (status.type === 'Success') { if (this.promises[status.data]) { this.promises[status.data].resolve(); delete this.promises[status.data]; } else { console.log( 'received successful response for a non-existent request', ); } } else if (status.type === 'Error') { if (this.promises[status.data.id]) { this.promises[status.data.id].reject(status.data.error); delete this.promises[status.data.id]; } else { console.log('received error response for a non-existent request'); } } else if (status.type === 'SerializationError') { console.error('SerializationError for message: ', status.data); } else if (status.type === 'InvalidRequest') { console.log('Tunnelbroker recorded InvalidRequest'); } } } else if (message.type === tunnelbrokerToDeviceMessageTypes.HEARTBEAT) { const heartbeat: Heartbeat = { type: deviceToTunnelbrokerMessageTypes.HEARTBEAT, }; this.ws.send(JSON.stringify(heartbeat)); } }; refreshOneTimeKeys: (numberOfKeys: number) => void = numberOfKeys => { const oldOneTimeKeysPromise = this.oneTimeKeysPromise; this.oneTimeKeysPromise = (async () => { await oldOneTimeKeysPromise; await uploadNewOneTimeKeys(numberOfKeys); })(); }; debouncedRefreshOneTimeKeys: (numberOfKeys: number) => void = _debounce( this.refreshOneTimeKeys, 100, { leading: true, trailing: true }, ); sendMessageToDevice: ( message: TunnelbrokerClientMessageToDevice, ) => Promise = (message: TunnelbrokerClientMessageToDevice) => { if (!this.connected) { throw new Error('Tunnelbroker not connected'); } const clientMessageID = uuid.v4(); const messageToDevice: MessageToDeviceRequest = { type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, clientMessageID, deviceID: message.deviceID, payload: message.payload, }; return new Promise((resolve, reject) => { this.promises[clientMessageID] = { resolve, reject, }; this.ws.send(JSON.stringify(messageToDevice)); }); }; stopHeartbeatTimeout() { if (this.heartbeatTimeoutID) { clearTimeout(this.heartbeatTimeoutID); this.heartbeatTimeoutID = null; } } resetHeartbeatTimeout() { this.stopHeartbeatTimeout(); this.heartbeatTimeoutID = setTimeout(() => { this.ws.close(); this.connected = false; }, tunnelbrokerHeartbeatTimeout); } closeConnection() { this.ws.close(); this.connected = false; } parseQRCodeAuthMessage: ( message: QRCodeAuthMessage, ) => Promise = async message => { const encryptionKey = this.qrAuthEncryptionKey; if (!encryptionKey) { return null; } const encryptedData = Buffer.from(message.encryptedContent, 'base64'); const decryptedData = await decrypt( hexToUintArray(encryptionKey), new Uint8Array(encryptedData), ); const payload = convertBytesToObj(decryptedData); if (!qrCodeAuthMessagePayloadValidator.is(payload)) { return null; } return payload; }; encodeQRAuthMessage: ( payload: QRCodeAuthMessagePayload, ) => Promise = async payload => { const encryptionKey = this.qrAuthEncryptionKey; if (!encryptionKey) { console.error('Encryption key missing - cannot send QR auth message.'); return null; } let encryptedContent; try { const payloadBytes = convertObjToBytes(payload); const keyBytes = hexToUintArray(encryptionKey); const encryptedBytes = await encrypt(keyBytes, payloadBytes); encryptedContent = Buffer.from(encryptedBytes).toString('base64'); } catch (e) { console.error( 'Error encoding QRCodeAuthMessagePayload:', getMessageForException(e), ); return null; } return { type: peerToPeerMessageTypes.QR_CODE_AUTH_MESSAGE, encryptedContent, }; }; } export { createAndMaintainTunnelbrokerWebsocket }; diff --git a/lib/tunnelbroker/tunnelbroker-context.js b/lib/tunnelbroker/tunnelbroker-context.js index 39d6b3e03..48889424e 100644 --- a/lib/tunnelbroker/tunnelbroker-context.js +++ b/lib/tunnelbroker/tunnelbroker-context.js @@ -1,480 +1,480 @@ // @flow import invariant from 'invariant'; import _isEqual from 'lodash/fp/isEqual.js'; import * as React from 'react'; import uuid from 'uuid'; import { PeerToPeerProvider } from './peer-to-peer-context.js'; import { PeerToPeerMessageHandler } from './peer-to-peer-message-handler.js'; import type { SecondaryTunnelbrokerConnection } from './secondary-tunnelbroker-connection.js'; import { tunnnelbrokerURL } from '../facts/tunnelbroker.js'; import { IdentityClientContext } from '../shared/identity-client-context.js'; import { tunnelbrokerHeartbeatTimeout } from '../shared/timeouts.js'; import { isWebPlatform } from '../types/device-types.js'; -import type { MessageSentStatus } from '../types/tunnelbroker/message-to-device-request-status-types.js'; +import type { MessageSentStatus } from '../types/tunnelbroker/device-to-tunnelbroker-request-status-types.js'; import type { MessageToDeviceRequest } from '../types/tunnelbroker/message-to-device-request-types.js'; import type { MessageToTunnelbrokerRequest } from '../types/tunnelbroker/message-to-tunnelbroker-request-types.js'; import { deviceToTunnelbrokerMessageTypes, tunnelbrokerToDeviceMessageTypes, tunnelbrokerToDeviceMessageValidator, type TunnelbrokerToDeviceMessage, type DeviceToTunnelbrokerRequest, } from '../types/tunnelbroker/messages.js'; import type { TunnelbrokerAPNsNotif, TunnelbrokerFCMNotif, } from '../types/tunnelbroker/notif-types.js'; import type { AnonymousInitializationMessage, ConnectionInitializationMessage, TunnelbrokerInitializationMessage, TunnelbrokerDeviceTypes, } from '../types/tunnelbroker/session-types.js'; import type { Heartbeat } from '../types/websocket/heartbeat-types.js'; import { getConfig } from '../utils/config.js'; import { getContentSigningKey } from '../utils/crypto-utils.js'; import { useSelector } from '../utils/redux-utils.js'; export type TunnelbrokerClientMessageToDevice = { +deviceID: string, +payload: string, }; export type TunnelbrokerSocketListener = ( message: TunnelbrokerToDeviceMessage, ) => mixed; type PromiseCallbacks = { +resolve: () => void, +reject: (error: string) => void, }; type Promises = { [clientMessageID: string]: PromiseCallbacks }; type TunnelbrokerSocketState = | { +connected: true, +isAuthorized: boolean, } | { +connected: false, }; type TunnelbrokerContextType = { +sendMessageToDevice: ( message: TunnelbrokerClientMessageToDevice, messageID: ?string, ) => Promise, +sendNotif: ( notif: TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif, ) => Promise, +sendMessageToTunnelbroker: (payload: string) => Promise, +addListener: (listener: TunnelbrokerSocketListener) => void, +removeListener: (listener: TunnelbrokerSocketListener) => void, +socketState: TunnelbrokerSocketState, +setUnauthorizedDeviceID: (unauthorizedDeviceID: ?string) => void, }; const TunnelbrokerContext: React.Context = React.createContext(); type Props = { +children: React.Node, +shouldBeClosed?: boolean, +onClose?: () => mixed, +secondaryTunnelbrokerConnection?: SecondaryTunnelbrokerConnection, }; function getTunnelbrokerDeviceType(): TunnelbrokerDeviceTypes { return isWebPlatform(getConfig().platformDetails.platform) ? 'web' : 'mobile'; } function createAnonymousInitMessage( deviceID: string, ): AnonymousInitializationMessage { return ({ type: 'AnonymousInitializationMessage', deviceID, deviceType: getTunnelbrokerDeviceType(), }: AnonymousInitializationMessage); } function TunnelbrokerProvider(props: Props): React.Node { const { children, shouldBeClosed, onClose, secondaryTunnelbrokerConnection } = props; const accessToken = useSelector(state => state.commServicesAccessToken); const userID = useSelector(state => state.currentUserInfo?.id); const [unauthorizedDeviceID, setUnauthorizedDeviceID] = React.useState(null); const isAuthorized = !unauthorizedDeviceID; const createInitMessage = React.useCallback(async () => { if (shouldBeClosed) { return null; } if (unauthorizedDeviceID) { return createAnonymousInitMessage(unauthorizedDeviceID); } if (!accessToken || !userID) { return null; } const deviceID = await getContentSigningKey(); if (!deviceID) { return null; } return ({ type: 'ConnectionInitializationMessage', deviceID, accessToken, userID, deviceType: getTunnelbrokerDeviceType(), }: ConnectionInitializationMessage); }, [accessToken, shouldBeClosed, unauthorizedDeviceID, userID]); const previousInitMessage = React.useRef(null); const [socketState, setSocketState] = React.useState( { connected: false }, ); const listeners = React.useRef>(new Set()); const socket = React.useRef(null); const socketSessionCounter = React.useRef(0); const promises = React.useRef({}); const heartbeatTimeoutID = React.useRef(); const identityContext = React.useContext(IdentityClientContext); invariant(identityContext, 'Identity context should be set'); const { identityClient } = identityContext; const stopHeartbeatTimeout = React.useCallback(() => { if (heartbeatTimeoutID.current) { clearTimeout(heartbeatTimeoutID.current); heartbeatTimeoutID.current = null; } }, []); const resetHeartbeatTimeout = React.useCallback(() => { stopHeartbeatTimeout(); heartbeatTimeoutID.current = setTimeout(() => { socket.current?.close(); setSocketState({ connected: false }); }, tunnelbrokerHeartbeatTimeout); }, [stopHeartbeatTimeout]); // determine if the socket is active (not closed or closing) const isSocketActive = socket.current?.readyState === WebSocket.OPEN || socket.current?.readyState === WebSocket.CONNECTING; const connectionChangePromise = React.useRef>(null); // The Tunnelbroker connection can have 4 states: // - DISCONNECTED: isSocketActive = false, connected = false // Should be in this state when initMessage is null // - CONNECTING: isSocketActive = true, connected = false // This lasts until Tunnelbroker sends ConnectionInitializationResponse // - CONNECTED: isSocketActive = true, connected = true // - DISCONNECTING: isSocketActive = false, connected = true // This lasts between socket.close() and socket.onclose() React.useEffect(() => { connectionChangePromise.current = (async () => { await connectionChangePromise.current; try { const initMessage = await createInitMessage(); const initMessageChanged = !_isEqual( previousInitMessage.current, initMessage, ); previousInitMessage.current = initMessage; // when initMessage changes, we need to close the socket // and open a new one if ( (!initMessage || initMessageChanged) && isSocketActive && socket.current ) { socket.current?.close(); return; } // when we're already connected (or pending disconnection), // or there's no init message to start with, we don't need // to do anything if (socketState.connected || !initMessage || socket.current) { return; } const tunnelbrokerSocket = new WebSocket(tunnnelbrokerURL); tunnelbrokerSocket.onopen = () => { tunnelbrokerSocket.send(JSON.stringify(initMessage)); }; tunnelbrokerSocket.onclose = () => { // this triggers the effect hook again and reconnect setSocketState({ connected: false }); onClose?.(); socket.current = null; console.log('Connection to Tunnelbroker closed'); }; tunnelbrokerSocket.onerror = e => { console.log('Tunnelbroker socket error:', e.message); }; tunnelbrokerSocket.onmessage = (event: MessageEvent) => { if (typeof event.data !== 'string') { console.log('socket received a non-string message'); return; } let rawMessage; try { rawMessage = JSON.parse(event.data); } catch (e) { console.log('error while parsing Tunnelbroker message:', e.message); return; } if (!tunnelbrokerToDeviceMessageValidator.is(rawMessage)) { console.log('invalid TunnelbrokerMessage'); return; } const message: TunnelbrokerToDeviceMessage = rawMessage; resetHeartbeatTimeout(); for (const listener of listeners.current) { listener(message); } // MESSAGE_TO_DEVICE is handled in PeerToPeerMessageHandler if ( message.type === tunnelbrokerToDeviceMessageTypes.CONNECTION_INITIALIZATION_RESPONSE ) { if (message.status.type === 'Success' && !socketState.connected) { setSocketState({ connected: true, isAuthorized }); console.log( 'session with Tunnelbroker created. isAuthorized:', isAuthorized, ); } else if ( message.status.type === 'Success' && socketState.connected ) { console.log( 'received ConnectionInitializationResponse with status: Success for already connected socket', ); } else { setSocketState({ connected: false }); console.log( 'creating session with Tunnelbroker error:', message.status.data, ); } } else if ( message.type === - tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS + tunnelbrokerToDeviceMessageTypes.DEVICE_TO_TUNNELBROKER_REQUEST_STATUS ) { for (const status: MessageSentStatus of message.clientMessageIDs) { if (status.type === 'Success') { promises.current[status.data]?.resolve(); delete promises.current[status.data]; } else if (status.type === 'Error') { promises.current[status.data.id]?.reject(status.data.error); delete promises.current[status.data.id]; } else if (status.type === 'SerializationError') { console.log('SerializationError for message: ', status.data); } else if (status.type === 'InvalidRequest') { console.log('Tunnelbroker recorded InvalidRequest'); } } } else if ( message.type === tunnelbrokerToDeviceMessageTypes.HEARTBEAT ) { const heartbeat: Heartbeat = { type: deviceToTunnelbrokerMessageTypes.HEARTBEAT, }; socket.current?.send(JSON.stringify(heartbeat)); } }; socket.current = tunnelbrokerSocket; socketSessionCounter.current = socketSessionCounter.current + 1; } catch (err) { console.log('Tunnelbroker connection error:', err); } })(); }, [ isSocketActive, isAuthorized, resetHeartbeatTimeout, stopHeartbeatTimeout, identityClient, onClose, createInitMessage, socketState.connected, ]); const sendMessage: (request: DeviceToTunnelbrokerRequest) => Promise = React.useCallback( request => { return new Promise((resolve, reject) => { const socketActive = socketState.connected && socket.current; if (!shouldBeClosed && !socketActive) { throw new Error('Tunnelbroker not connected'); } promises.current[request.clientMessageID] = { resolve, reject, }; if (socketActive) { socket.current?.send(JSON.stringify(request)); } else { secondaryTunnelbrokerConnection?.sendMessage(request); } }); }, [socketState, secondaryTunnelbrokerConnection, shouldBeClosed], ); const sendMessageToDevice: ( message: TunnelbrokerClientMessageToDevice, messageID: ?string, ) => Promise = React.useCallback( (message: TunnelbrokerClientMessageToDevice, messageID: ?string) => { const clientMessageID = messageID ?? uuid.v4(); const messageToDevice: MessageToDeviceRequest = { type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, clientMessageID, deviceID: message.deviceID, payload: message.payload, }; return sendMessage(messageToDevice); }, [sendMessage], ); const sendMessageToTunnelbroker: (payload: string) => Promise = React.useCallback( (payload: string) => { const clientMessageID = uuid.v4(); const messageToTunnelbroker: MessageToTunnelbrokerRequest = { type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_TUNNELBROKER_REQUEST, clientMessageID, payload, }; return sendMessage(messageToTunnelbroker); }, [sendMessage], ); React.useEffect( () => secondaryTunnelbrokerConnection?.onSendMessage(message => { if (shouldBeClosed) { // We aren't supposed to be handling it return; } void (async () => { try { await sendMessage(message); secondaryTunnelbrokerConnection.setMessageStatus( message.clientMessageID, ); } catch (error) { secondaryTunnelbrokerConnection.setMessageStatus( message.clientMessageID, error, ); } })(); }), [secondaryTunnelbrokerConnection, sendMessage, shouldBeClosed], ); React.useEffect( () => secondaryTunnelbrokerConnection?.onMessageStatus((messageID, error) => { if (error) { promises.current[messageID].reject(error); } else { promises.current[messageID].resolve(); } delete promises.current[messageID]; }), [secondaryTunnelbrokerConnection], ); const addListener = React.useCallback( (listener: TunnelbrokerSocketListener) => { listeners.current.add(listener); }, [], ); const removeListener = React.useCallback( (listener: TunnelbrokerSocketListener) => { listeners.current.delete(listener); }, [], ); const getSessionCounter = React.useCallback( () => socketSessionCounter.current, [], ); const doesSocketExist = React.useCallback(() => !!socket.current, []); const socketSend = React.useCallback((message: string) => { socket.current?.send(message); }, []); const value: TunnelbrokerContextType = React.useMemo( () => ({ sendMessageToDevice, sendMessageToTunnelbroker, sendNotif: sendMessage, socketState, addListener, removeListener, setUnauthorizedDeviceID, }), [ sendMessageToDevice, sendMessage, sendMessageToTunnelbroker, socketState, addListener, removeListener, ], ); return ( {children} ); } function useTunnelbroker(): TunnelbrokerContextType { const context = React.useContext(TunnelbrokerContext); invariant(context, 'TunnelbrokerContext not found'); return context; } export { TunnelbrokerProvider, useTunnelbroker }; diff --git a/lib/types/tunnelbroker/message-to-device-request-status-types.js b/lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js similarity index 88% rename from lib/types/tunnelbroker/message-to-device-request-status-types.js rename to lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js index 8d8410a10..c9b82f3db 100644 --- a/lib/types/tunnelbroker/message-to-device-request-status-types.js +++ b/lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js @@ -1,60 +1,60 @@ // @flow import type { TInterface } from 'tcomb'; import t from 'tcomb'; import { tShape, tString } from '../../utils/validation-utils.js'; export type Failure = { +id: string, +error: string, }; const failureValidator: TInterface = tShape({ id: t.String, error: t.String, }); type MessageSentSuccessStatus = { +type: 'Success', +data: string }; type MessageSentErrorStatus = { +type: 'Error', +data: Failure }; type MessageSentInvalidRequestStatus = { +type: 'InvalidRequest' }; type MessageSentUnauthenticatedStatus = { +type: 'Unauthenticated' }; type MessageSentSerializationErrorStatus = { +type: 'SerializationError', +data: string, }; export type MessageSentStatus = | MessageSentSuccessStatus | MessageSentErrorStatus | MessageSentInvalidRequestStatus | MessageSentSerializationErrorStatus; const messageSentStatusValidator = t.union([ tShape({ type: tString('Success'), data: t.String, }), tShape({ type: tString('Error'), data: failureValidator, }), tShape({ type: tString('InvalidRequest') }), tShape({ type: tString('Unauthenticated'), }), tShape({ type: tString('SerializationError'), data: t.String, }), ]); -export type MessageToDeviceRequestStatus = { +export type DeviceToTunnelbrokerRequestStatus = { +type: 'MessageToDeviceRequestStatus', +clientMessageIDs: $ReadOnlyArray, }; -export const messageToDeviceRequestStatusValidator: TInterface = - tShape({ +export const messageToDeviceRequestStatusValidator: TInterface = + tShape({ type: tString('MessageToDeviceRequestStatus'), clientMessageIDs: t.list(messageSentStatusValidator), }); diff --git a/lib/types/tunnelbroker/messages.js b/lib/types/tunnelbroker/messages.js index 6e483bc7d..21ad01477 100644 --- a/lib/types/tunnelbroker/messages.js +++ b/lib/types/tunnelbroker/messages.js @@ -1,97 +1,97 @@ // @flow import type { TUnion } from 'tcomb'; import t from 'tcomb'; -import { type MessageReceiveConfirmation } from './message-receive-confirmation-types.js'; import { - type MessageToDeviceRequestStatus, messageToDeviceRequestStatusValidator, -} from './message-to-device-request-status-types.js'; + type DeviceToTunnelbrokerRequestStatus, +} from './device-to-tunnelbroker-request-status-types.js'; +import { type MessageReceiveConfirmation } from './message-receive-confirmation-types.js'; import { type MessageToDeviceRequest } from './message-to-device-request-types.js'; import { type MessageToDevice, messageToDeviceValidator, } from './message-to-device-types.js'; import { type MessageToTunnelbrokerRequest } from './message-to-tunnelbroker-request-types.js'; import { type TunnelbrokerAPNsNotif, type TunnelbrokerFCMNotif, } from './notif-types.js'; import { type AnonymousInitializationMessage, type ConnectionInitializationMessage, } from './session-types.js'; import { type ConnectionInitializationResponse, connectionInitializationResponseValidator, } from '../websocket/connection-initialization-response-types.js'; import { type Heartbeat, heartbeatValidator, } from '../websocket/heartbeat-types.js'; /* * This file defines types and validation for messages exchanged * with the Tunnelbroker. The definitions in this file should remain in sync * with the structures defined in the corresponding * Rust file at `shared/tunnelbroker_messages/src/messages/mod.rs`. * * If you edit the definitions in one file, * please make sure to update the corresponding definitions in the other. * */ // Messages sent from Device to Tunnelbroker. export const deviceToTunnelbrokerMessageTypes = Object.freeze({ CONNECTION_INITIALIZATION_MESSAGE: 'ConnectionInitializationMessage', ANONYMOUS_INITIALIZATION_MESSAGE: 'AnonymousInitializationMessage', TUNNELBROKER_APNS_NOTIF: 'APNsNotif', TUNNELBROKER_FCM_NOTIF: 'FCMNotif', MESSAGE_TO_DEVICE_REQUEST: 'MessageToDeviceRequest', MESSAGE_RECEIVE_CONFIRMATION: 'MessageReceiveConfirmation', MESSAGE_TO_TUNNELBROKER_REQUEST: 'MessageToTunnelbrokerRequest', HEARTBEAT: 'Heartbeat', }); export type DeviceToTunnelbrokerMessage = | ConnectionInitializationMessage | AnonymousInitializationMessage | TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif | MessageToDeviceRequest | MessageReceiveConfirmation | MessageToTunnelbrokerRequest | Heartbeat; // Types having `clientMessageID` prop. // When using this type, it is possible to use Promise abstraction, // and await sending a message until Tunnelbroker responds that // the request was processed. export type DeviceToTunnelbrokerRequest = | TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif | MessageToDeviceRequest | MessageToTunnelbrokerRequest; // Messages sent from Tunnelbroker to Device. export const tunnelbrokerToDeviceMessageTypes = Object.freeze({ CONNECTION_INITIALIZATION_RESPONSE: 'ConnectionInitializationResponse', - MESSAGE_TO_DEVICE_REQUEST_STATUS: 'MessageToDeviceRequestStatus', + DEVICE_TO_TUNNELBROKER_REQUEST_STATUS: 'MessageToDeviceRequestStatus', MESSAGE_TO_DEVICE: 'MessageToDevice', HEARTBEAT: 'Heartbeat', }); export type TunnelbrokerToDeviceMessage = | ConnectionInitializationResponse - | MessageToDeviceRequestStatus + | DeviceToTunnelbrokerRequestStatus | MessageToDevice | Heartbeat; export const tunnelbrokerToDeviceMessageValidator: TUnion = t.union([ connectionInitializationResponseValidator, messageToDeviceRequestStatusValidator, messageToDeviceValidator, heartbeatValidator, ]); diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index ae65ab8d3..0cce246be 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,119 +1,119 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, ConnectionInitializationResponse, - ConnectionInitializationStatus, DeviceTypes, Heartbeat, MessageSentStatus, - MessageToDeviceRequest, MessageToDeviceRequestStatus, + ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, + DeviceTypes, Heartbeat, MessageSentStatus, MessageToDeviceRequest, TunnelbrokerToDeviceMessage, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct WebSocketMessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } pub async fn create_socket( device_info: &DeviceInfo, ) -> Result< WebSocketStream>, Box, > { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = socket.next().await { let response: ConnectionInitializationResponse = serde_json::from_str(response.to_text().unwrap())?; return match response.status { ConnectionInitializationStatus::Success => Ok(socket), ConnectionInitializationStatus::Error(err) => Err(err.into()), }; } Err("Failed to get response from Tunnelbroker".into()) } pub async fn send_message( socket: &mut WebSocketStream>, message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { - let confirmation: MessageToDeviceRequestStatus = + let confirmation: DeviceToTunnelbrokerRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } pub async fn receive_message( socket: &mut WebSocketStream>, ) -> Result> { while let Some(Ok(response)) = socket.next().await { let message_str = response.to_text().expect("Failed to get response content"); let message = serde_json::from_str::(message_str).unwrap(); match message { TunnelbrokerToDeviceMessage::MessageToDevice(msg) => { let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { message_ids: vec![msg.message_id], }; let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); socket.send(Message::Text(serialized_confirmation)).await?; return Ok(msg.payload); } TunnelbrokerToDeviceMessage::Heartbeat(Heartbeat {}) => { let msg = Heartbeat {}; let serialized = serde_json::to_string(&msg).unwrap(); socket.send(Message::Text(serialized)).await?; } _ => return Err(format!("Unexpected message type {message:?}").into()), } } Err("Failed to receive message".into()) } diff --git a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs index 26e29cd18..b8aeccfe9 100644 --- a/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs +++ b/services/commtest/tests/tunnelbroker_sender_confirmation_tests.rs @@ -1,90 +1,90 @@ use commtest::identity::device::register_user_device; use commtest::tunnelbroker::socket::{create_socket, receive_message}; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite::Message; use tunnelbroker_messages::{ - MessageSentStatus, MessageToDeviceRequest, MessageToDeviceRequestStatus, + DeviceToTunnelbrokerRequestStatus, MessageSentStatus, MessageToDeviceRequest, }; /// Tests of responses sent from Tunnelberoker to client /// trying to send message to other device #[tokio::test] async fn get_confirmation() { let sender = register_user_device(None, None).await; let receiver = register_user_device(None, None).await; let client_message_id = "mockID".to_string(); // Send message to not connected client let payload = "persisted message"; let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: receiver.device_id.clone(), payload: payload.to_string(), }; let serialized_request = serde_json::to_string(&request) .expect("Failed to serialize message to device"); let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { - let expected_response = MessageToDeviceRequestStatus { + let expected_response = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![MessageSentStatus::Success(client_message_id)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; // Connect receiver to flush DDB and avoid polluting other tests let mut receiver_socket = create_socket(&receiver).await.unwrap(); let receiver_response = receive_message(&mut receiver_socket).await.unwrap(); assert_eq!(payload, receiver_response); } #[tokio::test] async fn get_serialization_error() { let sender = register_user_device(None, None).await; let message = "some bad json".to_string(); let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Text(message.clone())) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { - let expected_response = MessageToDeviceRequestStatus { + let expected_response = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![MessageSentStatus::SerializationError(message)], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } #[tokio::test] async fn get_invalid_request_error() { let sender = register_user_device(None, None).await; let mut sender_socket = create_socket(&sender).await.unwrap(); sender_socket .send(Message::Binary(vec![])) .await .expect("Failed to send message"); if let Some(Ok(response)) = sender_socket.next().await { - let expected_response = MessageToDeviceRequestStatus { + let expected_response = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![MessageSentStatus::InvalidRequest], }; let expected_payload = serde_json::to_string(&expected_response).unwrap(); let received_payload = response.to_text().unwrap(); assert_eq!(received_payload, expected_payload); }; } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index cf9479e59..4f6ec64e7 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,331 +1,331 @@ pub mod session; use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; use crate::database::DatabaseClient; use crate::notifs::NotifClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::{ - ConnectionInitializationStatus, Heartbeat, MessageSentStatus, - MessageToDeviceRequestStatus, + ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, Heartbeat, + MessageSentStatus, }; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, notif_client: NotifClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let notif_client = self.notif_client.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel, notif_client) .await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, notif_client: notif_client.clone(), }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_error_init_response( error: SessionError, mut outgoing: SplitSink, Message>, ) { let error_response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error(error.to_string()), }; match serde_json::to_string(&error_response) { Ok(serialized_response) => { if let Err(send_error) = outgoing.send(Message::Text(serialized_response)).await { error!("Failed to send init error response: {:?}", send_error); } } Err(ser_error) => { error!("Failed to serialize the error response: {:?}", ser_error); } } } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session( outgoing, first_msg, db_client, amqp_channel, notif_client, ) .await { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&response).unwrap(); session .send_message_to_device(Message::Text(serialized_response)) .await; session } Err((err, outgoing)) => { error!("Failed to create session with device"); send_error_init_response(err, outgoing).await; return; } } } else { error!("Failed to create session with device"); send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { trace!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; - let request_status = MessageToDeviceRequestStatus { + let request_status = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); - let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; + let confirmation = DeviceToTunnelbrokerRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); session.send_message_to_device(Message::text(serialized)).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, notif_client, )), Err(e) => Err((e, outgoing)), } } diff --git a/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs b/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs index a9a51ef3c..dd363d0b0 100644 --- a/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs +++ b/shared/tunnelbroker_messages/src/messages/message_to_device_request_status.rs @@ -1,92 +1,98 @@ //! Message sent from Tunnelbroker to WebSocket clients to inform that message //! was processed, saved in DDB, and will be delivered. use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct Failure { pub id: String, pub error: String, } // NOTE: Keep this in sync with -// lib/types/tunnelbroker/message-to-device-request-status-types.js +// lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js #[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", content = "data")] pub enum MessageSentStatus { /// The message with the provided ID (String) has been processed /// by the Tunnelbroker and is queued for delivery. Success(String), /// 'Failure' contains information about the message ID /// along with the specific error message. Error(Failure), /// The request was invalid (e.g., Bytes instead of Text). /// In this case, the ID cannot be retrieved. InvalidRequest, /// Unauthenticated client tried to send a message. Unauthenticated, /// The JSON could not be serialized, which is why the entire message is /// returned back. /// It becomes impossible to retrieve the message ID in such circumstances. SerializationError(String), } // NOTE: Keep this in sync with -// lib/types/tunnelbroker/message-to-device-request-status-types.js +// lib/types/tunnelbroker/device-to-tunnelbroker-request-status-types.js #[derive(Serialize, Deserialize, PartialEq, Debug)] -#[serde(tag = "type", rename_all = "camelCase")] -pub struct MessageToDeviceRequestStatus { +#[serde( + tag = "type", + rename = "MessageToDeviceRequestStatus", + rename_all = "camelCase" +)] +pub struct DeviceToTunnelbrokerRequestStatus { #[serde(rename = "clientMessageIDs")] pub client_message_ids: Vec, } #[cfg(test)] mod send_confirmation_tests { use super::*; #[test] fn test_send_confirmation_deserialization() { let example_payload = r#"{ "type": "MessageToDeviceRequestStatus", "clientMessageIDs": [ {"type": "Success", "data": "id123"}, {"type": "Success", "data": "id456"}, {"type": "Error", "data": {"id": "id789", "error": "Something went wrong"}}, {"type": "SerializationError", "data": "message"}, {"type": "InvalidRequest"} ] }"#; - let request = - serde_json::from_str::(example_payload) - .unwrap(); + let request = serde_json::from_str::( + example_payload, + ) + .unwrap(); let expected_client_message_ids = vec![ MessageSentStatus::Success("id123".to_string()), MessageSentStatus::Success("id456".to_string()), MessageSentStatus::Error(Failure { id: String::from("id789"), error: String::from("Something went wrong"), }), MessageSentStatus::SerializationError("message".to_string()), MessageSentStatus::InvalidRequest, ]; assert_eq!(request.client_message_ids, expected_client_message_ids); } #[test] fn test_send_confirmation_deserialization_empty_vec() { let example_payload = r#"{ "type": "MessageToDeviceRequestStatus", "clientMessageIDs": [] }"#; - let request = - serde_json::from_str::(example_payload) - .unwrap(); + let request = serde_json::from_str::( + example_payload, + ) + .unwrap(); let expected_client_message_ids: Vec = Vec::new(); assert_eq!(request.client_message_ids, expected_client_message_ids); } } diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index 095e2ddf9..d54ebe822 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,77 +1,77 @@ //! Messages sent between Tunnelbroker and a device. pub mod device_list_updated; pub mod keys; pub mod message_receive_confirmation; pub mod message_to_device; pub mod message_to_device_request; pub mod message_to_device_request_status; pub mod message_to_tunnelbroker; pub mod message_to_tunnelbroker_request; pub mod notif; pub mod session; pub use device_list_updated::*; pub use keys::*; pub use message_receive_confirmation::*; pub use message_to_device::*; pub use message_to_device_request::*; pub use message_to_device_request_status::*; pub use message_to_tunnelbroker::*; pub use message_to_tunnelbroker_request::*; pub use session::*; pub use websocket_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, }; use crate::notif::*; use serde::{Deserialize, Serialize}; // This file defines types and validation for messages exchanged // with the Tunnelbroker. The definitions in this file should remain in sync // with the structures defined in the corresponding // JavaScript file at `lib/types/tunnelbroker/messages.js`. // If you edit the definitions in one file, // please make sure to update the corresponding definitions in the other. // Messages sent from Device to Tunnelbroker. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum DeviceToTunnelbrokerMessage { ConnectionInitializationMessage(ConnectionInitializationMessage), AnonymousInitializationMessage(AnonymousInitializationMessage), APNsNotif(APNsNotif), FCMNotif(FCMNotif), MessageToDeviceRequest(MessageToDeviceRequest), MessageReceiveConfirmation(MessageReceiveConfirmation), MessageToTunnelbrokerRequest(MessageToTunnelbrokerRequest), Heartbeat(Heartbeat), } // Messages sent from Tunnelbroker to Device. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum TunnelbrokerToDeviceMessage { ConnectionInitializationResponse(ConnectionInitializationResponse), - MessageToDeviceRequestStatus(MessageToDeviceRequestStatus), + DeviceToTunnelbrokerRequestStatus(DeviceToTunnelbrokerRequestStatus), MessageToDevice(MessageToDevice), Heartbeat(Heartbeat), } // Messages sent from Services (e.g. Identity) to Device. // This type is sent to a Device as a payload of MessageToDevice. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ServiceToDeviceMessages { RefreshKeysRequest(RefreshKeyRequest), IdentityDeviceListUpdated(IdentityDeviceListUpdated), } // Messages sent from Device to Tunnelbroker which Tunnelbroker itself should handle. // This type is sent to a Tunnelbroker as a payload of MessageToTunnelbrokerRequest. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum MessageToTunnelbroker { SetDeviceToken(SetDeviceToken), }