diff --git a/keyserver/src/socket/tunnelbroker.js b/keyserver/src/socket/tunnelbroker.js index a79c88457..208aa05b9 100644 --- a/keyserver/src/socket/tunnelbroker.js +++ b/keyserver/src/socket/tunnelbroker.js @@ -1,473 +1,480 @@ // @flow import invariant from 'invariant'; import _debounce from 'lodash/debounce.js'; import { getRustAPI } from 'rust-node-addon'; import uuid from 'uuid'; import WebSocket from 'ws'; import { hexToUintArray } from 'lib/media/data-utils.js'; import { clientTunnelbrokerSocketReconnectDelay, tunnelbrokerHeartbeatTimeout, } from 'lib/shared/timeouts.js'; import type { TunnelbrokerClientMessageToDevice } from 'lib/tunnelbroker/tunnelbroker-context.js'; import type { MessageReceiveConfirmation } from 'lib/types/tunnelbroker/message-receive-confirmation-types.js'; import type { MessageSentStatus } from 'lib/types/tunnelbroker/message-to-device-request-status-types.js'; import type { MessageToDeviceRequest } from 'lib/types/tunnelbroker/message-to-device-request-types.js'; import { - type TunnelbrokerMessage, - tunnelbrokerMessageTypes, - tunnelbrokerMessageValidator, + deviceToTunnelbrokerMessageTypes, + tunnelbrokerToDeviceMessageTypes, + tunnelbrokerToDeviceMessageValidator, + type TunnelbrokerToDeviceMessage, } from 'lib/types/tunnelbroker/messages.js'; import { qrCodeAuthMessageValidator, type RefreshKeyRequest, refreshKeysRequestValidator, type QRCodeAuthMessage, } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { peerToPeerMessageTypes } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { type QRCodeAuthMessagePayload, qrCodeAuthMessagePayloadValidator, qrCodeAuthMessageTypes, } from 'lib/types/tunnelbroker/qr-code-auth-message-types.js'; import type { ConnectionInitializationMessage, AnonymousInitializationMessage, } from 'lib/types/tunnelbroker/session-types.js'; import type { Heartbeat } from 'lib/types/websocket/heartbeat-types.js'; import { getCommConfig } from 'lib/utils/comm-config.js'; import { convertBytesToObj, convertObjToBytes, } from 'lib/utils/conversion-utils.js'; import { getMessageForException } from 'lib/utils/errors.js'; import sleep from 'lib/utils/sleep.js'; import { fetchOlmAccount } from '../updaters/olm-account-updater.js'; import { fetchIdentityInfo, saveIdentityInfo } from '../user/identity.js'; import type { IdentityInfo } from '../user/identity.js'; import { encrypt, decrypt } from '../utils/aes-crypto-utils.js'; import { getContentSigningKey, uploadNewOneTimeKeys, getNewDeviceKeyUpload, markPrekeysAsPublished, } from '../utils/olm-utils.js'; type TBConnectionInfo = { +url: string, }; async function getTBConnectionInfo(): Promise { const tbConfig = await getCommConfig({ folder: 'facts', name: 'tunnelbroker', }); if (tbConfig) { return tbConfig; } console.warn('Defaulting to staging Tunnelbroker'); return { url: 'wss://tunnelbroker.staging.commtechnologies.org:51001', }; } async function createAndMaintainTunnelbrokerWebsocket(encryptionKey: ?string) { const [deviceID, tbConnectionInfo] = await Promise.all([ getContentSigningKey(), getTBConnectionInfo(), ]); const createNewTunnelbrokerSocket = async ( shouldNotifyPrimaryAfterReopening: boolean, primaryDeviceID: ?string, ) => { const identityInfo = await fetchIdentityInfo(); new TunnelbrokerSocket({ socketURL: tbConnectionInfo.url, onClose: async (successfullyAuthed: boolean, primaryID: ?string) => { await sleep(clientTunnelbrokerSocketReconnectDelay); await createNewTunnelbrokerSocket(successfullyAuthed, primaryID); }, identityInfo, deviceID, qrAuthEncryptionKey: encryptionKey, primaryDeviceID, shouldNotifyPrimaryAfterReopening, }); }; await createNewTunnelbrokerSocket(false, null); } type TunnelbrokerSocketParams = { +socketURL: string, +onClose: (boolean, ?string) => mixed, +identityInfo: ?IdentityInfo, +deviceID: string, +qrAuthEncryptionKey: ?string, +primaryDeviceID: ?string, +shouldNotifyPrimaryAfterReopening: boolean, }; type PromiseCallbacks = { +resolve: () => void, +reject: (error: string) => void, }; type Promises = { [clientMessageID: string]: PromiseCallbacks }; class TunnelbrokerSocket { ws: WebSocket; connected: boolean = false; closed: boolean = false; promises: Promises = {}; heartbeatTimeoutID: ?TimeoutID; oneTimeKeysPromise: ?Promise; identityInfo: ?IdentityInfo; qrAuthEncryptionKey: ?string; primaryDeviceID: ?string; shouldNotifyPrimaryAfterReopening: boolean = false; shouldNotifyPrimary: boolean = false; constructor(tunnelbrokerSocketParams: TunnelbrokerSocketParams) { const { socketURL, onClose, identityInfo, deviceID, qrAuthEncryptionKey, primaryDeviceID, shouldNotifyPrimaryAfterReopening, } = tunnelbrokerSocketParams; this.identityInfo = identityInfo; this.qrAuthEncryptionKey = qrAuthEncryptionKey; this.primaryDeviceID = primaryDeviceID; if (shouldNotifyPrimaryAfterReopening) { this.shouldNotifyPrimary = true; } const socket = new WebSocket(socketURL); socket.on('open', () => { this.onOpen(socket, deviceID); }); socket.on('close', async () => { if (this.closed) { return; } this.closed = true; this.connected = false; this.stopHeartbeatTimeout(); console.error('Connection to Tunnelbroker closed'); onClose(this.shouldNotifyPrimaryAfterReopening, this.primaryDeviceID); }); socket.on('error', (error: Error) => { console.error('Tunnelbroker socket error:', error.message); }); socket.on('message', this.onMessage); this.ws = socket; } onOpen: (socket: WebSocket, deviceID: string) => void = ( socket, deviceID, ) => { if (this.closed) { return; } if (this.identityInfo) { const initMessage: ConnectionInitializationMessage = { type: 'ConnectionInitializationMessage', deviceID, accessToken: this.identityInfo.accessToken, userID: this.identityInfo.userId, deviceType: 'keyserver', }; socket.send(JSON.stringify(initMessage)); } else { const initMessage: AnonymousInitializationMessage = { type: 'AnonymousInitializationMessage', deviceID, deviceType: 'keyserver', }; socket.send(JSON.stringify(initMessage)); } }; onMessage: (event: ArrayBuffer) => Promise = async ( event: ArrayBuffer, ) => { let rawMessage; try { rawMessage = JSON.parse(event.toString()); } catch (e) { console.error('error while parsing Tunnelbroker message:', e.message); return; } - if (!tunnelbrokerMessageValidator.is(rawMessage)) { - console.error('invalid TunnelbrokerMessage: ', rawMessage.toString()); + if (!tunnelbrokerToDeviceMessageValidator.is(rawMessage)) { + console.error( + 'invalid tunnelbrokerToDeviceMessage: ', + rawMessage.toString(), + ); return; } - const message: TunnelbrokerMessage = rawMessage; + const message: TunnelbrokerToDeviceMessage = rawMessage; this.resetHeartbeatTimeout(); if ( message.type === - tunnelbrokerMessageTypes.CONNECTION_INITIALIZATION_RESPONSE + tunnelbrokerToDeviceMessageTypes.CONNECTION_INITIALIZATION_RESPONSE ) { if (message.status.type === 'Success' && !this.connected) { this.connected = true; console.info( this.identityInfo ? 'session with Tunnelbroker created' : 'anonymous session with Tunnelbroker created', ); if (!this.shouldNotifyPrimary) { return; } const { primaryDeviceID } = this; invariant( primaryDeviceID, 'Primary device ID is not set but should be', ); const payload = await this.encodeQRAuthMessage({ type: qrCodeAuthMessageTypes.SECONDARY_DEVICE_REGISTRATION_SUCCESS, requestBackupKeys: false, }); if (!payload) { this.closeConnection(); return; } await this.sendMessage({ deviceID: primaryDeviceID, payload: JSON.stringify(payload), }); } else if (message.status.type === 'Success' && this.connected) { console.info( 'received ConnectionInitializationResponse with status: Success for already connected socket', ); } else { this.connected = false; console.error( 'creating session with Tunnelbroker error:', message.status.data, ); } - } else if (message.type === tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE) { + } else if ( + message.type === tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE + ) { const confirmation: MessageReceiveConfirmation = { - type: tunnelbrokerMessageTypes.MESSAGE_RECEIVE_CONFIRMATION, + type: deviceToTunnelbrokerMessageTypes.MESSAGE_RECEIVE_CONFIRMATION, messageIDs: [message.messageID], }; this.ws.send(JSON.stringify(confirmation)); const { payload } = message; try { const messageToKeyserver = JSON.parse(payload); if (qrCodeAuthMessageValidator.is(messageToKeyserver)) { const request: QRCodeAuthMessage = messageToKeyserver; const [qrCodeAuthMessage, rustAPI, accountInfo] = await Promise.all([ this.parseQRCodeAuthMessage(request), getRustAPI(), fetchOlmAccount('content'), ]); if ( !qrCodeAuthMessage || qrCodeAuthMessage.type !== qrCodeAuthMessageTypes.DEVICE_LIST_UPDATE_SUCCESS ) { return; } const { primaryDeviceID, userID } = qrCodeAuthMessage; this.primaryDeviceID = primaryDeviceID; const [nonce, deviceKeyUpload] = await Promise.all([ rustAPI.generateNonce(), getNewDeviceKeyUpload(), ]); const signedIdentityKeysBlob = { payload: deviceKeyUpload.keyPayload, signature: deviceKeyUpload.keyPayloadSignature, }; const nonceSignature = accountInfo.account.sign(nonce); const identityInfo = await rustAPI.uploadSecondaryDeviceKeysAndLogIn( userID, nonce, nonceSignature, signedIdentityKeysBlob, deviceKeyUpload.contentPrekey, deviceKeyUpload.contentPrekeySignature, deviceKeyUpload.notifPrekey, deviceKeyUpload.notifPrekeySignature, deviceKeyUpload.contentOneTimeKeys, deviceKeyUpload.notifOneTimeKeys, ); await Promise.all([ markPrekeysAsPublished(), saveIdentityInfo(identityInfo), ]); this.shouldNotifyPrimaryAfterReopening = true; this.closeConnection(); } else if (refreshKeysRequestValidator.is(messageToKeyserver)) { const request: RefreshKeyRequest = messageToKeyserver; this.debouncedRefreshOneTimeKeys(request.numberOfKeys); } } catch (e) { console.error( 'error while processing message to keyserver:', e.message, ); } } else if ( - message.type === tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS + message.type === + tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS ) { for (const status: MessageSentStatus of message.clientMessageIDs) { if (status.type === 'Success') { if (this.promises[status.data]) { this.promises[status.data].resolve(); delete this.promises[status.data]; } else { console.log( 'received successful response for a non-existent request', ); } } else if (status.type === 'Error') { if (this.promises[status.data.id]) { this.promises[status.data.id].reject(status.data.error); delete this.promises[status.data.id]; } else { console.log('received error response for a non-existent request'); } } else if (status.type === 'SerializationError') { console.error('SerializationError for message: ', status.data); } else if (status.type === 'InvalidRequest') { console.log('Tunnelbroker recorded InvalidRequest'); } } - } else if (message.type === tunnelbrokerMessageTypes.HEARTBEAT) { + } else if (message.type === tunnelbrokerToDeviceMessageTypes.HEARTBEAT) { const heartbeat: Heartbeat = { - type: tunnelbrokerMessageTypes.HEARTBEAT, + type: deviceToTunnelbrokerMessageTypes.HEARTBEAT, }; this.ws.send(JSON.stringify(heartbeat)); } }; refreshOneTimeKeys: (numberOfKeys: number) => void = numberOfKeys => { const oldOneTimeKeysPromise = this.oneTimeKeysPromise; this.oneTimeKeysPromise = (async () => { await oldOneTimeKeysPromise; await uploadNewOneTimeKeys(numberOfKeys); })(); }; debouncedRefreshOneTimeKeys: (numberOfKeys: number) => void = _debounce( this.refreshOneTimeKeys, 100, { leading: true, trailing: true }, ); sendMessage: (message: TunnelbrokerClientMessageToDevice) => Promise = ( message: TunnelbrokerClientMessageToDevice, ) => { if (!this.connected) { throw new Error('Tunnelbroker not connected'); } const clientMessageID = uuid.v4(); const messageToDevice: MessageToDeviceRequest = { - type: tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, + type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, clientMessageID, deviceID: message.deviceID, payload: message.payload, }; return new Promise((resolve, reject) => { this.promises[clientMessageID] = { resolve, reject, }; this.ws.send(JSON.stringify(messageToDevice)); }); }; stopHeartbeatTimeout() { if (this.heartbeatTimeoutID) { clearTimeout(this.heartbeatTimeoutID); this.heartbeatTimeoutID = null; } } resetHeartbeatTimeout() { this.stopHeartbeatTimeout(); this.heartbeatTimeoutID = setTimeout(() => { this.ws.close(); this.connected = false; }, tunnelbrokerHeartbeatTimeout); } closeConnection() { this.ws.close(); this.connected = false; } parseQRCodeAuthMessage: ( message: QRCodeAuthMessage, ) => Promise = async message => { const encryptionKey = this.qrAuthEncryptionKey; if (!encryptionKey) { return null; } const encryptedData = Buffer.from(message.encryptedContent, 'base64'); const decryptedData = await decrypt( hexToUintArray(encryptionKey), new Uint8Array(encryptedData), ); const payload = convertBytesToObj(decryptedData); if (!qrCodeAuthMessagePayloadValidator.is(payload)) { return null; } return payload; }; encodeQRAuthMessage: ( payload: QRCodeAuthMessagePayload, ) => Promise = async payload => { const encryptionKey = this.qrAuthEncryptionKey; if (!encryptionKey) { console.error('Encryption key missing - cannot send QR auth message.'); return null; } let encryptedContent; try { const payloadBytes = convertObjToBytes(payload); const keyBytes = hexToUintArray(encryptionKey); const encryptedBytes = await encrypt(keyBytes, payloadBytes); encryptedContent = Buffer.from(encryptedBytes).toString('base64'); } catch (e) { console.error( 'Error encoding QRCodeAuthMessagePayload:', getMessageForException(e), ); return null; } return { type: peerToPeerMessageTypes.QR_CODE_AUTH_MESSAGE, encryptedContent, }; }; } export { createAndMaintainTunnelbrokerWebsocket }; diff --git a/lib/components/qr-auth-provider.react.js b/lib/components/qr-auth-provider.react.js index 78adb95f0..bb46a24e5 100644 --- a/lib/components/qr-auth-provider.react.js +++ b/lib/components/qr-auth-provider.react.js @@ -1,247 +1,247 @@ // @flow import invariant from 'invariant'; import * as React from 'react'; import { useSecondaryDeviceLogIn } from '../hooks/login-hooks.js'; import { uintArrayToHexString } from '../media/data-utils.js'; import { IdentityClientContext } from '../shared/identity-client-context.js'; import { useTunnelbroker } from '../tunnelbroker/tunnelbroker-context.js'; import type { BackupKeys } from '../types/backup-types.js'; import { - tunnelbrokerMessageTypes, - type TunnelbrokerMessage, + tunnelbrokerToDeviceMessageTypes, + type TunnelbrokerToDeviceMessage, } from '../types/tunnelbroker/messages.js'; import { type QRCodeAuthMessage, peerToPeerMessageTypes, peerToPeerMessageValidator, } from '../types/tunnelbroker/peer-to-peer-message-types.js'; import { qrCodeAuthMessageTypes, type QRCodeAuthMessagePayload, } from '../types/tunnelbroker/qr-code-auth-message-types.js'; import { getContentSigningKey } from '../utils/crypto-utils.js'; type Props = { +children: React.Node, +onLogInError: (error: mixed) => void, +generateAESKey: () => Promise, +composeTunnelbrokerQRAuthMessage: ( encryptionKey: string, payload: QRCodeAuthMessagePayload, ) => Promise, +parseTunnelbrokerQRAuthMessage: ( encryptionKey: string, message: QRCodeAuthMessage, ) => Promise, +performBackupRestore?: (backupKeys: BackupKeys) => Promise, }; type QRData = ?{ +deviceID: string, +aesKey: string }; type QRAuthContextType = { +qrData: QRData, +generateQRCode: () => Promise, }; const QRAuthContext: React.Context = React.createContext({ qrData: null, generateQRCode: async () => {}, }); function QRAuthProvider(props: Props): React.Node { const { children, onLogInError, generateAESKey, composeTunnelbrokerQRAuthMessage, parseTunnelbrokerQRAuthMessage, performBackupRestore, } = props; const [primaryDeviceID, setPrimaryDeviceID] = React.useState(); const [qrData, setQRData] = React.useState(); const [qrAuthFinished, setQRAuthFinished] = React.useState(false); const { setUnauthorizedDeviceID, addListener, removeListener, socketState, sendMessage, } = useTunnelbroker(); const identityContext = React.useContext(IdentityClientContext); const identityClient = identityContext?.identityClient; const generateQRCode = React.useCallback(async () => { try { const [ed25519, rawAESKey] = await Promise.all([ getContentSigningKey(), generateAESKey(), ]); const aesKeyAsHexString: string = uintArrayToHexString(rawAESKey); setUnauthorizedDeviceID(ed25519); setQRData({ deviceID: ed25519, aesKey: aesKeyAsHexString }); setQRAuthFinished(false); } catch (err) { console.error('Failed to generate QR Code:', err); } }, [generateAESKey, setUnauthorizedDeviceID]); const logInSecondaryDevice = useSecondaryDeviceLogIn(); const performLogIn = React.useCallback( async (userID: string) => { try { await logInSecondaryDevice(userID); } catch (err) { onLogInError(err); void generateQRCode(); } }, [logInSecondaryDevice, onLogInError, generateQRCode], ); React.useEffect(() => { if ( !qrData || !socketState.isAuthorized || !primaryDeviceID || qrAuthFinished ) { return; } void (async () => { const message = await composeTunnelbrokerQRAuthMessage(qrData?.aesKey, { type: qrCodeAuthMessageTypes.SECONDARY_DEVICE_REGISTRATION_SUCCESS, requestBackupKeys: true, }); await sendMessage({ deviceID: primaryDeviceID, payload: JSON.stringify(message), }); })(); }, [ sendMessage, primaryDeviceID, qrData, socketState, composeTunnelbrokerQRAuthMessage, qrAuthFinished, ]); const tunnelbrokerMessageListener = React.useCallback( - async (message: TunnelbrokerMessage) => { + async (message: TunnelbrokerToDeviceMessage) => { invariant(identityClient, 'identity context not set'); if ( !qrData?.aesKey || - message.type !== tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE + message.type !== tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE ) { return; } let innerMessage; try { innerMessage = JSON.parse(message.payload); } catch { return; } if ( !peerToPeerMessageValidator.is(innerMessage) || innerMessage.type !== peerToPeerMessageTypes.QR_CODE_AUTH_MESSAGE ) { return; } let qrCodeAuthMessage; try { qrCodeAuthMessage = await parseTunnelbrokerQRAuthMessage( qrData?.aesKey, innerMessage, ); } catch (err) { console.warn('Failed to decrypt Tunnelbroker QR auth message:', err); return; } if ( qrCodeAuthMessage && qrCodeAuthMessage.type === qrCodeAuthMessageTypes.BACKUP_DATA_KEY_MESSAGE ) { const { backupID, backupDataKey, backupLogDataKey } = qrCodeAuthMessage; void performBackupRestore?.({ backupID, backupDataKey, backupLogDataKey, }); setQRAuthFinished(true); return; } if ( !qrCodeAuthMessage || qrCodeAuthMessage.type !== qrCodeAuthMessageTypes.DEVICE_LIST_UPDATE_SUCCESS ) { return; } const { primaryDeviceID: receivedPrimaryDeviceID, userID } = qrCodeAuthMessage; setPrimaryDeviceID(receivedPrimaryDeviceID); await performLogIn(userID); setUnauthorizedDeviceID(null); }, [ identityClient, qrData?.aesKey, performLogIn, setUnauthorizedDeviceID, parseTunnelbrokerQRAuthMessage, performBackupRestore, ], ); React.useEffect(() => { if (!qrData?.deviceID || qrAuthFinished) { return undefined; } addListener(tunnelbrokerMessageListener); return () => { removeListener(tunnelbrokerMessageListener); }; }, [ addListener, removeListener, tunnelbrokerMessageListener, qrData?.deviceID, qrAuthFinished, ]); const value = React.useMemo( () => ({ qrData, generateQRCode, }), [qrData, generateQRCode], ); return ( {children} ); } function useQRAuthContext(): QRAuthContextType { const context = React.useContext(QRAuthContext); invariant(context, 'QRAuthContext not found'); return context; } export { QRAuthProvider, useQRAuthContext }; diff --git a/lib/tunnelbroker/peer-to-peer-message-handler.js b/lib/tunnelbroker/peer-to-peer-message-handler.js index 9c381f9c7..f84792d69 100644 --- a/lib/tunnelbroker/peer-to-peer-message-handler.js +++ b/lib/tunnelbroker/peer-to-peer-message-handler.js @@ -1,102 +1,103 @@ // @flow import * as React from 'react'; import { useTunnelbroker } from './tunnelbroker-context.js'; import { usePeerToPeerMessageHandler } from './use-peer-to-peer-message-handler.js'; import type { MessageReceiveConfirmation } from '../types/tunnelbroker/message-receive-confirmation-types.js'; import { - tunnelbrokerMessageTypes, - type TunnelbrokerMessage, + deviceToTunnelbrokerMessageTypes, + type TunnelbrokerToDeviceMessage, + tunnelbrokerToDeviceMessageTypes, } from '../types/tunnelbroker/messages.js'; import { peerToPeerMessageValidator, type PeerToPeerMessage, } from '../types/tunnelbroker/peer-to-peer-message-types.js'; type Props = { +socketSend: (message: string) => void, +getSessionCounter: () => number, +doesSocketExist: () => boolean, }; function PeerToPeerMessageHandler(props: Props): React.Node { const { socketSend, getSessionCounter, doesSocketExist } = props; const { addListener, removeListener } = useTunnelbroker(); const peerToPeerMessageHandler = usePeerToPeerMessageHandler(); const currentlyProcessedMessage = React.useRef>(null); const tunnelbrokerMessageListener = React.useCallback( - async (message: TunnelbrokerMessage) => { - if (message.type !== tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE) { + async (message: TunnelbrokerToDeviceMessage) => { + if (message.type !== tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE) { return; } const confirmation: MessageReceiveConfirmation = { - type: tunnelbrokerMessageTypes.MESSAGE_RECEIVE_CONFIRMATION, + type: deviceToTunnelbrokerMessageTypes.MESSAGE_RECEIVE_CONFIRMATION, messageIDs: [message.messageID], }; let rawPeerToPeerMessage; try { rawPeerToPeerMessage = JSON.parse(message.payload); } catch (e) { console.log( 'error while parsing Tunnelbroker peer-to-peer message:', e.message, ); // Client received incorrect message, confirm to remove from // Tunnelbroker queue. socketSend(JSON.stringify(confirmation)); return; } if (!peerToPeerMessageValidator.is(rawPeerToPeerMessage)) { console.log('invalid Tunnelbroker PeerToPeerMessage'); // The client received an invalid Tunnelbroker message, // and cannot process this type of request. socketSend(JSON.stringify(confirmation)); return; } const peerToPeerMessage: PeerToPeerMessage = rawPeerToPeerMessage; currentlyProcessedMessage.current = (async () => { const localSocketSessionCounter = getSessionCounter(); await currentlyProcessedMessage.current; // Since scheduling processing this message socket is closed // or was closed and reopened, we have to stop processing // because Tunnelbroker flushes the message again when opening // the socket, and we want to process this only once // to maintain order. if ( localSocketSessionCounter !== getSessionCounter() || !doesSocketExist() ) { return; } try { await peerToPeerMessageHandler(peerToPeerMessage, message.messageID); } catch (e) { console.log(e.message); } finally { if ( localSocketSessionCounter === getSessionCounter() && doesSocketExist() ) { // We confirm regardless of success or error while processing. socketSend(JSON.stringify(confirmation)); } } })(); }, [getSessionCounter, peerToPeerMessageHandler, doesSocketExist, socketSend], ); React.useEffect(() => { addListener(tunnelbrokerMessageListener); return () => { removeListener(tunnelbrokerMessageListener); }; }, [addListener, removeListener, tunnelbrokerMessageListener]); } export { PeerToPeerMessageHandler }; diff --git a/lib/tunnelbroker/tunnelbroker-context.js b/lib/tunnelbroker/tunnelbroker-context.js index ae652321f..ce3293b7b 100644 --- a/lib/tunnelbroker/tunnelbroker-context.js +++ b/lib/tunnelbroker/tunnelbroker-context.js @@ -1,485 +1,488 @@ // @flow import invariant from 'invariant'; import _isEqual from 'lodash/fp/isEqual.js'; import * as React from 'react'; import uuid from 'uuid'; import { PeerToPeerProvider } from './peer-to-peer-context.js'; import { PeerToPeerMessageHandler } from './peer-to-peer-message-handler.js'; import type { SecondaryTunnelbrokerConnection } from './secondary-tunnelbroker-connection.js'; import { tunnnelbrokerURL } from '../facts/tunnelbroker.js'; import { IdentityClientContext } from '../shared/identity-client-context.js'; import { tunnelbrokerHeartbeatTimeout } from '../shared/timeouts.js'; import { isWebPlatform } from '../types/device-types.js'; import type { MessageSentStatus } from '../types/tunnelbroker/message-to-device-request-status-types.js'; import type { MessageToDeviceRequest } from '../types/tunnelbroker/message-to-device-request-types.js'; import type { MessageToTunnelbrokerRequest } from '../types/tunnelbroker/message-to-tunnelbroker-request-types.js'; import { - type TunnelbrokerMessage, - tunnelbrokerMessageTypes, - tunnelbrokerMessageValidator, + deviceToTunnelbrokerMessageTypes, + tunnelbrokerToDeviceMessageTypes, + tunnelbrokerToDeviceMessageValidator, + type TunnelbrokerToDeviceMessage, } from '../types/tunnelbroker/messages.js'; import type { TunnelbrokerAPNsNotif, TunnelbrokerFCMNotif, } from '../types/tunnelbroker/notif-types.js'; import type { AnonymousInitializationMessage, ConnectionInitializationMessage, TunnelbrokerInitializationMessage, TunnelbrokerDeviceTypes, } from '../types/tunnelbroker/session-types.js'; import type { Heartbeat } from '../types/websocket/heartbeat-types.js'; import { getConfig } from '../utils/config.js'; import { getContentSigningKey } from '../utils/crypto-utils.js'; import { useSelector } from '../utils/redux-utils.js'; export type TunnelbrokerClientMessageToDevice = { +deviceID: string, +payload: string, }; export type TunnelbrokerSocketListener = ( - message: TunnelbrokerMessage, + message: TunnelbrokerToDeviceMessage, ) => mixed; type PromiseCallbacks = { +resolve: () => void, +reject: (error: string) => void, }; type Promises = { [clientMessageID: string]: PromiseCallbacks }; type TunnelbrokerSocketState = | { +connected: true, +isAuthorized: boolean, } | { +connected: false, }; type TunnelbrokerContextType = { +sendMessage: ( message: TunnelbrokerClientMessageToDevice, messageID: ?string, ) => Promise, +sendNotif: ( notif: TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif, ) => Promise, +sendMessageToTunnelbroker: (payload: string) => Promise, +addListener: (listener: TunnelbrokerSocketListener) => void, +removeListener: (listener: TunnelbrokerSocketListener) => void, +socketState: TunnelbrokerSocketState, +setUnauthorizedDeviceID: (unauthorizedDeviceID: ?string) => void, }; const TunnelbrokerContext: React.Context = React.createContext(); type Props = { +children: React.Node, +shouldBeClosed?: boolean, +onClose?: () => mixed, +secondaryTunnelbrokerConnection?: SecondaryTunnelbrokerConnection, }; function getTunnelbrokerDeviceType(): TunnelbrokerDeviceTypes { return isWebPlatform(getConfig().platformDetails.platform) ? 'web' : 'mobile'; } function createAnonymousInitMessage( deviceID: string, ): AnonymousInitializationMessage { return ({ type: 'AnonymousInitializationMessage', deviceID, deviceType: getTunnelbrokerDeviceType(), }: AnonymousInitializationMessage); } function TunnelbrokerProvider(props: Props): React.Node { const { children, shouldBeClosed, onClose, secondaryTunnelbrokerConnection } = props; const accessToken = useSelector(state => state.commServicesAccessToken); const userID = useSelector(state => state.currentUserInfo?.id); const [unauthorizedDeviceID, setUnauthorizedDeviceID] = React.useState(null); const isAuthorized = !unauthorizedDeviceID; const createInitMessage = React.useCallback(async () => { if (shouldBeClosed) { return null; } if (unauthorizedDeviceID) { return createAnonymousInitMessage(unauthorizedDeviceID); } if (!accessToken || !userID) { return null; } const deviceID = await getContentSigningKey(); if (!deviceID) { return null; } return ({ type: 'ConnectionInitializationMessage', deviceID, accessToken, userID, deviceType: getTunnelbrokerDeviceType(), }: ConnectionInitializationMessage); }, [accessToken, shouldBeClosed, unauthorizedDeviceID, userID]); const previousInitMessage = React.useRef(null); const [socketState, setSocketState] = React.useState( { connected: false }, ); const listeners = React.useRef>(new Set()); const socket = React.useRef(null); const socketSessionCounter = React.useRef(0); const promises = React.useRef({}); const heartbeatTimeoutID = React.useRef(); const identityContext = React.useContext(IdentityClientContext); invariant(identityContext, 'Identity context should be set'); const { identityClient } = identityContext; const stopHeartbeatTimeout = React.useCallback(() => { if (heartbeatTimeoutID.current) { clearTimeout(heartbeatTimeoutID.current); heartbeatTimeoutID.current = null; } }, []); const resetHeartbeatTimeout = React.useCallback(() => { stopHeartbeatTimeout(); heartbeatTimeoutID.current = setTimeout(() => { socket.current?.close(); setSocketState({ connected: false }); }, tunnelbrokerHeartbeatTimeout); }, [stopHeartbeatTimeout]); // determine if the socket is active (not closed or closing) const isSocketActive = socket.current?.readyState === WebSocket.OPEN || socket.current?.readyState === WebSocket.CONNECTING; const connectionChangePromise = React.useRef>(null); // The Tunnelbroker connection can have 4 states: // - DISCONNECTED: isSocketActive = false, connected = false // Should be in this state when initMessage is null // - CONNECTING: isSocketActive = true, connected = false // This lasts until Tunnelbroker sends ConnectionInitializationResponse // - CONNECTED: isSocketActive = true, connected = true // - DISCONNECTING: isSocketActive = false, connected = true // This lasts between socket.close() and socket.onclose() React.useEffect(() => { connectionChangePromise.current = (async () => { await connectionChangePromise.current; try { const initMessage = await createInitMessage(); const initMessageChanged = !_isEqual( previousInitMessage.current, initMessage, ); previousInitMessage.current = initMessage; // when initMessage changes, we need to close the socket // and open a new one if ( (!initMessage || initMessageChanged) && isSocketActive && socket.current ) { socket.current?.close(); return; } // when we're already connected (or pending disconnection), // or there's no init message to start with, we don't need // to do anything if (socketState.connected || !initMessage || socket.current) { return; } const tunnelbrokerSocket = new WebSocket(tunnnelbrokerURL); tunnelbrokerSocket.onopen = () => { tunnelbrokerSocket.send(JSON.stringify(initMessage)); }; tunnelbrokerSocket.onclose = () => { // this triggers the effect hook again and reconnect setSocketState({ connected: false }); onClose?.(); socket.current = null; console.log('Connection to Tunnelbroker closed'); }; tunnelbrokerSocket.onerror = e => { console.log('Tunnelbroker socket error:', e.message); }; tunnelbrokerSocket.onmessage = (event: MessageEvent) => { if (typeof event.data !== 'string') { console.log('socket received a non-string message'); return; } let rawMessage; try { rawMessage = JSON.parse(event.data); } catch (e) { console.log('error while parsing Tunnelbroker message:', e.message); return; } - if (!tunnelbrokerMessageValidator.is(rawMessage)) { + if (!tunnelbrokerToDeviceMessageValidator.is(rawMessage)) { console.log('invalid TunnelbrokerMessage'); return; } - const message: TunnelbrokerMessage = rawMessage; + const message: TunnelbrokerToDeviceMessage = rawMessage; resetHeartbeatTimeout(); for (const listener of listeners.current) { listener(message); } // MESSAGE_TO_DEVICE is handled in PeerToPeerMessageHandler if ( message.type === - tunnelbrokerMessageTypes.CONNECTION_INITIALIZATION_RESPONSE + tunnelbrokerToDeviceMessageTypes.CONNECTION_INITIALIZATION_RESPONSE ) { if (message.status.type === 'Success' && !socketState.connected) { setSocketState({ connected: true, isAuthorized }); console.log( 'session with Tunnelbroker created. isAuthorized:', isAuthorized, ); } else if ( message.status.type === 'Success' && socketState.connected ) { console.log( 'received ConnectionInitializationResponse with status: Success for already connected socket', ); } else { setSocketState({ connected: false }); console.log( 'creating session with Tunnelbroker error:', message.status.data, ); } } else if ( message.type === - tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS + tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE_REQUEST_STATUS ) { for (const status: MessageSentStatus of message.clientMessageIDs) { if (status.type === 'Success') { promises.current[status.data]?.resolve(); delete promises.current[status.data]; } else if (status.type === 'Error') { promises.current[status.data.id]?.reject(status.data.error); delete promises.current[status.data.id]; } else if (status.type === 'SerializationError') { console.log('SerializationError for message: ', status.data); } else if (status.type === 'InvalidRequest') { console.log('Tunnelbroker recorded InvalidRequest'); } } - } else if (message.type === tunnelbrokerMessageTypes.HEARTBEAT) { + } else if ( + message.type === tunnelbrokerToDeviceMessageTypes.HEARTBEAT + ) { const heartbeat: Heartbeat = { - type: tunnelbrokerMessageTypes.HEARTBEAT, + type: deviceToTunnelbrokerMessageTypes.HEARTBEAT, }; socket.current?.send(JSON.stringify(heartbeat)); } }; socket.current = tunnelbrokerSocket; socketSessionCounter.current = socketSessionCounter.current + 1; } catch (err) { console.log('Tunnelbroker connection error:', err); } })(); }, [ isSocketActive, isAuthorized, resetHeartbeatTimeout, stopHeartbeatTimeout, identityClient, onClose, createInitMessage, socketState.connected, ]); const sendMessageToDeviceRequest: ( request: | MessageToDeviceRequest | MessageToTunnelbrokerRequest | TunnelbrokerAPNsNotif | TunnelbrokerFCMNotif, ) => Promise = React.useCallback( request => { return new Promise((resolve, reject) => { const socketActive = socketState.connected && socket.current; if (!shouldBeClosed && !socketActive) { throw new Error('Tunnelbroker not connected'); } promises.current[request.clientMessageID] = { resolve, reject, }; if (socketActive) { socket.current?.send(JSON.stringify(request)); } else { secondaryTunnelbrokerConnection?.sendMessage(request); } }); }, [socketState, secondaryTunnelbrokerConnection, shouldBeClosed], ); const sendMessage: ( message: TunnelbrokerClientMessageToDevice, messageID: ?string, ) => Promise = React.useCallback( (message: TunnelbrokerClientMessageToDevice, messageID: ?string) => { const clientMessageID = messageID ?? uuid.v4(); const messageToDevice: MessageToDeviceRequest = { - type: tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, + type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_DEVICE_REQUEST, clientMessageID, deviceID: message.deviceID, payload: message.payload, }; return sendMessageToDeviceRequest(messageToDevice); }, [sendMessageToDeviceRequest], ); const sendMessageToTunnelbroker: (payload: string) => Promise = React.useCallback( (payload: string) => { const clientMessageID = uuid.v4(); const messageToTunnelbroker: MessageToTunnelbrokerRequest = { - type: tunnelbrokerMessageTypes.MESSAGE_TO_TUNNELBROKER_REQUEST, + type: deviceToTunnelbrokerMessageTypes.MESSAGE_TO_TUNNELBROKER_REQUEST, clientMessageID, payload, }; return sendMessageToDeviceRequest(messageToTunnelbroker); }, [sendMessageToDeviceRequest], ); React.useEffect( () => secondaryTunnelbrokerConnection?.onSendMessage(message => { if (shouldBeClosed) { // We aren't supposed to be handling it return; } void (async () => { try { await sendMessageToDeviceRequest(message); secondaryTunnelbrokerConnection.setMessageStatus( message.clientMessageID, ); } catch (error) { secondaryTunnelbrokerConnection.setMessageStatus( message.clientMessageID, error, ); } })(); }), [ secondaryTunnelbrokerConnection, sendMessageToDeviceRequest, shouldBeClosed, ], ); React.useEffect( () => secondaryTunnelbrokerConnection?.onMessageStatus((messageID, error) => { if (error) { promises.current[messageID].reject(error); } else { promises.current[messageID].resolve(); } delete promises.current[messageID]; }), [secondaryTunnelbrokerConnection], ); const addListener = React.useCallback( (listener: TunnelbrokerSocketListener) => { listeners.current.add(listener); }, [], ); const removeListener = React.useCallback( (listener: TunnelbrokerSocketListener) => { listeners.current.delete(listener); }, [], ); const getSessionCounter = React.useCallback( () => socketSessionCounter.current, [], ); const doesSocketExist = React.useCallback(() => !!socket.current, []); const socketSend = React.useCallback((message: string) => { socket.current?.send(message); }, []); const value: TunnelbrokerContextType = React.useMemo( () => ({ sendMessage, sendMessageToTunnelbroker, sendNotif: sendMessageToDeviceRequest, socketState, addListener, removeListener, setUnauthorizedDeviceID, }), [ sendMessage, sendMessageToDeviceRequest, sendMessageToTunnelbroker, socketState, addListener, removeListener, ], ); return ( {children} ); } function useTunnelbroker(): TunnelbrokerContextType { const context = React.useContext(TunnelbrokerContext); invariant(context, 'TunnelbrokerContext not found'); return context; } export { TunnelbrokerProvider, useTunnelbroker }; diff --git a/lib/types/tunnelbroker/messages.js b/lib/types/tunnelbroker/messages.js index fbc4d3518..828200cba 100644 --- a/lib/types/tunnelbroker/messages.js +++ b/lib/types/tunnelbroker/messages.js @@ -1,85 +1,87 @@ // @flow import type { TUnion } from 'tcomb'; import t from 'tcomb'; -import { - type MessageReceiveConfirmation, - messageReceiveConfirmationValidator, -} from './message-receive-confirmation-types.js'; +import { type MessageReceiveConfirmation } from './message-receive-confirmation-types.js'; import { type MessageToDeviceRequestStatus, messageToDeviceRequestStatusValidator, } from './message-to-device-request-status-types.js'; -import { - type MessageToDeviceRequest, - messageToDeviceRequestValidator, -} from './message-to-device-request-types.js'; +import { type MessageToDeviceRequest } from './message-to-device-request-types.js'; import { type MessageToDevice, messageToDeviceValidator, } from './message-to-device-types.js'; +import { type MessageToTunnelbrokerRequest } from './message-to-tunnelbroker-request-types.js'; import { - type MessageToTunnelbrokerRequest, - messageToTunnelbrokerRequestValidator, -} from './message-to-tunnelbroker-request-types.js'; -import { type TunnelbrokerAPNsNotif } from './notif-types.js'; + type TunnelbrokerAPNsNotif, + type TunnelbrokerFCMNotif, +} from './notif-types.js'; import { + type AnonymousInitializationMessage, type ConnectionInitializationMessage, - connectionInitializationMessageValidator, } from './session-types.js'; import { type ConnectionInitializationResponse, connectionInitializationResponseValidator, } from '../websocket/connection-initialization-response-types.js'; import { type Heartbeat, heartbeatValidator, } from '../websocket/heartbeat-types.js'; /* * This file defines types and validation for messages exchanged * with the Tunnelbroker. The definitions in this file should remain in sync * with the structures defined in the corresponding * Rust file at `shared/tunnelbroker_messages/src/messages/mod.rs`. * * If you edit the definitions in one file, * please make sure to update the corresponding definitions in the other. * */ -export const tunnelbrokerMessageTypes = Object.freeze({ +// Messages sent from Device to Tunnelbroker. +export const deviceToTunnelbrokerMessageTypes = Object.freeze({ CONNECTION_INITIALIZATION_MESSAGE: 'ConnectionInitializationMessage', - CONNECTION_INITIALIZATION_RESPONSE: 'ConnectionInitializationResponse', ANONYMOUS_INITIALIZATION_MESSAGE: 'AnonymousInitializationMessage', TUNNELBROKER_APNS_NOTIF: 'APNsNotif', - MESSAGE_TO_DEVICE_REQUEST_STATUS: 'MessageToDeviceRequestStatus', + TUNNELBROKER_FCM_NOTIF: 'FCMNotif', MESSAGE_TO_DEVICE_REQUEST: 'MessageToDeviceRequest', + MESSAGE_RECEIVE_CONFIRMATION: 'MessageReceiveConfirmation', MESSAGE_TO_TUNNELBROKER_REQUEST: 'MessageToTunnelbrokerRequest', + HEARTBEAT: 'Heartbeat', +}); + +export type DeviceToTunnelbrokerMessage = + | ConnectionInitializationMessage + | AnonymousInitializationMessage + | TunnelbrokerAPNsNotif + | TunnelbrokerFCMNotif + | MessageToDeviceRequest + | MessageReceiveConfirmation + | MessageToTunnelbrokerRequest + | Heartbeat; + +// Messages sent from Tunnelbroker to Device. +export const tunnelbrokerToDeviceMessageTypes = Object.freeze({ + CONNECTION_INITIALIZATION_RESPONSE: 'ConnectionInitializationResponse', + MESSAGE_TO_DEVICE_REQUEST_STATUS: 'MessageToDeviceRequestStatus', MESSAGE_TO_DEVICE: 'MessageToDevice', - MESSAGE_RECEIVE_CONFIRMATION: 'MessageReceiveConfirmation', HEARTBEAT: 'Heartbeat', }); -export const tunnelbrokerMessageValidator: TUnion = +export type TunnelbrokerToDeviceMessage = + | ConnectionInitializationResponse + | MessageToDeviceRequestStatus + | MessageToDevice + | Heartbeat; + +export const tunnelbrokerToDeviceMessageValidator: TUnion = t.union([ - connectionInitializationMessageValidator, connectionInitializationResponseValidator, messageToDeviceRequestStatusValidator, - messageToDeviceRequestValidator, messageToDeviceValidator, - messageReceiveConfirmationValidator, heartbeatValidator, - messageToTunnelbrokerRequestValidator, ]); - -export type TunnelbrokerMessage = - | ConnectionInitializationMessage - | ConnectionInitializationResponse - | MessageToDeviceRequestStatus - | MessageToDeviceRequest - | MessageToDevice - | MessageReceiveConfirmation - | Heartbeat - | MessageToTunnelbrokerRequest - | TunnelbrokerAPNsNotif; diff --git a/native/profile/secondary-device-qr-code-scanner.react.js b/native/profile/secondary-device-qr-code-scanner.react.js index cb41e0f93..5e9ed9704 100644 --- a/native/profile/secondary-device-qr-code-scanner.react.js +++ b/native/profile/secondary-device-qr-code-scanner.react.js @@ -1,428 +1,429 @@ // @flow import { useNavigation } from '@react-navigation/native'; import { BarCodeScanner, type BarCodeEvent } from 'expo-barcode-scanner'; import invariant from 'invariant'; import * as React from 'react'; import { View, Text } from 'react-native'; import { parseDataFromDeepLink } from 'lib/facts/links.js'; import { useBroadcastDeviceListUpdates, useGetAndUpdateDeviceListsForUsers, } from 'lib/hooks/peer-list-hooks.js'; import { getForeignPeerDevices } from 'lib/selectors/user-selectors.js'; import { addDeviceToDeviceList } from 'lib/shared/device-list-utils.js'; import { IdentityClientContext } from 'lib/shared/identity-client-context.js'; import { useTunnelbroker } from 'lib/tunnelbroker/tunnelbroker-context.js'; import { backupKeysValidator, type BackupKeys, } from 'lib/types/backup-types.js'; import { - tunnelbrokerMessageTypes, - type TunnelbrokerMessage, + tunnelbrokerToDeviceMessageTypes, + type TunnelbrokerToDeviceMessage, } from 'lib/types/tunnelbroker/messages.js'; import { peerToPeerMessageTypes, peerToPeerMessageValidator, type PeerToPeerMessage, } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { qrCodeAuthMessageTypes } from 'lib/types/tunnelbroker/qr-code-auth-message-types.js'; import { rawDeviceListFromSignedList } from 'lib/utils/device-list-utils.js'; import { assertWithValidator } from 'lib/utils/validation-utils.js'; import type { ProfileNavigationProp } from './profile.react.js'; import { getBackupSecret } from '../backup/use-client-backup.js'; import TextInput from '../components/text-input.react.js'; import { commCoreModule } from '../native-modules.js'; import HeaderRightTextButton from '../navigation/header-right-text-button.react.js'; import type { NavigationRoute } from '../navigation/route-names.js'; import { composeTunnelbrokerQRAuthMessage, parseTunnelbrokerQRAuthMessage, } from '../qr-code/qr-code-utils.js'; import { useSelector } from '../redux/redux-utils.js'; import { useStyles, useColors } from '../themes/colors.js'; import Alert from '../utils/alert.js'; import { deviceIsEmulator } from '../utils/url-utils.js'; const barCodeTypes = [BarCodeScanner.Constants.BarCodeType.qr]; type Props = { +navigation: ProfileNavigationProp<'SecondaryDeviceQRCodeScanner'>, +route: NavigationRoute<'SecondaryDeviceQRCodeScanner'>, }; + // eslint-disable-next-line no-unused-vars function SecondaryDeviceQRCodeScanner(props: Props): React.Node { const [hasPermission, setHasPermission] = React.useState(null); const [scanned, setScanned] = React.useState(false); const [urlInput, setURLInput] = React.useState(''); const styles = useStyles(unboundStyles); const { goBack, setOptions } = useNavigation(); const tunnelbrokerContext = useTunnelbroker(); const identityContext = React.useContext(IdentityClientContext); invariant(identityContext, 'identity context not set'); const aes256Key = React.useRef(null); const secondaryDeviceID = React.useRef(null); const broadcastDeviceListUpdates = useBroadcastDeviceListUpdates(); const getAndUpdateDeviceListsForUsers = useGetAndUpdateDeviceListsForUsers(); const foreignPeerDevices = useSelector(getForeignPeerDevices); const { panelForegroundTertiaryLabel } = useColors(); const tunnelbrokerMessageListener = React.useCallback( - async (message: TunnelbrokerMessage) => { + async (message: TunnelbrokerToDeviceMessage) => { const encryptionKey = aes256Key.current; const targetDeviceID = secondaryDeviceID.current; if (!encryptionKey || !targetDeviceID) { return; } - if (message.type !== tunnelbrokerMessageTypes.MESSAGE_TO_DEVICE) { + if (message.type !== tunnelbrokerToDeviceMessageTypes.MESSAGE_TO_DEVICE) { return; } let innerMessage: PeerToPeerMessage; try { innerMessage = JSON.parse(message.payload); } catch { return; } if ( !peerToPeerMessageValidator.is(innerMessage) || innerMessage.type !== peerToPeerMessageTypes.QR_CODE_AUTH_MESSAGE ) { return; } const payload = await parseTunnelbrokerQRAuthMessage( encryptionKey, innerMessage, ); if ( !payload || payload.type !== qrCodeAuthMessageTypes.SECONDARY_DEVICE_REGISTRATION_SUCCESS ) { return; } invariant(identityContext, 'identity context not set'); const { getAuthMetadata, identityClient } = identityContext; const { userID, deviceID } = await getAuthMetadata(); if (!userID || !deviceID) { throw new Error('missing auth metadata'); } const deviceLists = await identityClient.getDeviceListHistoryForUser(userID); invariant(deviceLists.length > 0, 'received empty device list history'); const lastSignedDeviceList = deviceLists[deviceLists.length - 1]; const deviceList = rawDeviceListFromSignedList(lastSignedDeviceList); const ownOtherDevices = deviceList.devices.filter(it => it !== deviceID); await Promise.all([ broadcastDeviceListUpdates( [...ownOtherDevices, ...foreignPeerDevices], lastSignedDeviceList, ), getAndUpdateDeviceListsForUsers([userID]), ]); if (!payload.requestBackupKeys) { Alert.alert('Device added', 'Device registered successfully', [ { text: 'OK', onPress: goBack }, ]); return; } const backupSecret = await getBackupSecret(); const backupKeysResponse = await commCoreModule.retrieveBackupKeys(backupSecret); const backupKeys = assertWithValidator( JSON.parse(backupKeysResponse), backupKeysValidator, ); const backupKeyMessage = await composeTunnelbrokerQRAuthMessage( encryptionKey, { type: qrCodeAuthMessageTypes.BACKUP_DATA_KEY_MESSAGE, ...backupKeys, }, ); await tunnelbrokerContext.sendMessage({ deviceID: targetDeviceID, payload: JSON.stringify(backupKeyMessage), }); Alert.alert('Device added', 'Device registered successfully', [ { text: 'OK', onPress: goBack }, ]); }, [ identityContext, broadcastDeviceListUpdates, foreignPeerDevices, getAndUpdateDeviceListsForUsers, tunnelbrokerContext, goBack, ], ); React.useEffect(() => { tunnelbrokerContext.addListener(tunnelbrokerMessageListener); return () => { tunnelbrokerContext.removeListener(tunnelbrokerMessageListener); }; }, [tunnelbrokerMessageListener, tunnelbrokerContext]); React.useEffect(() => { void (async () => { const { status } = await BarCodeScanner.requestPermissionsAsync(); setHasPermission(status === 'granted'); if (status !== 'granted') { Alert.alert( 'No access to camera', 'Please allow Comm to access your camera in order to scan the QR code.', [{ text: 'OK' }], ); goBack(); } })(); }, [goBack]); const processDeviceListUpdate = React.useCallback(async () => { try { const { deviceID: primaryDeviceID, userID } = await identityContext.getAuthMetadata(); if (!primaryDeviceID || !userID) { throw new Error('missing auth metadata'); } const encryptionKey = aes256Key.current; const targetDeviceID = secondaryDeviceID.current; if (!encryptionKey || !targetDeviceID) { throw new Error('missing tunnelbroker message data'); } await addDeviceToDeviceList( identityContext.identityClient, userID, targetDeviceID, ); const message = await composeTunnelbrokerQRAuthMessage(encryptionKey, { type: qrCodeAuthMessageTypes.DEVICE_LIST_UPDATE_SUCCESS, userID, primaryDeviceID, }); await tunnelbrokerContext.sendMessage({ deviceID: targetDeviceID, payload: JSON.stringify(message), }); } catch (err) { console.log('Primary device error:', err); Alert.alert('Adding device failed', 'Failed to update the device list', [ { text: 'OK' }, ]); goBack(); } }, [goBack, identityContext, tunnelbrokerContext]); const onPressSave = React.useCallback(async () => { if (!urlInput) { return; } const parsedData = parseDataFromDeepLink(urlInput); const keysMatch = parsedData?.data?.keys; if (!parsedData || !keysMatch) { Alert.alert( 'Scan failed', 'QR code does not contain a valid pair of keys.', [{ text: 'OK' }], ); return; } try { const keys = JSON.parse(decodeURIComponent(keysMatch)); const { aes256, ed25519 } = keys; aes256Key.current = aes256; secondaryDeviceID.current = ed25519; } catch (err) { console.log('Failed to decode URI component:', err); } await processDeviceListUpdate(); }, [processDeviceListUpdate, urlInput]); const buttonDisabled = !urlInput; React.useEffect(() => { if (!deviceIsEmulator) { return; } setOptions({ headerRight: () => ( ), }); }, [buttonDisabled, onPressSave, setOptions]); const onChangeText = React.useCallback( (text: string) => setURLInput(text), [], ); const onConnect = React.useCallback( async (barCodeEvent: BarCodeEvent) => { const { data } = barCodeEvent; const parsedData = parseDataFromDeepLink(data); const keysMatch = parsedData?.data?.keys; if (!parsedData || !keysMatch) { Alert.alert( 'Scan failed', 'QR code does not contain a valid pair of keys.', [{ text: 'OK' }], ); return; } try { const keys = JSON.parse(decodeURIComponent(keysMatch)); const { aes256, ed25519 } = keys; aes256Key.current = aes256; secondaryDeviceID.current = ed25519; } catch (err) { console.log('Failed to decode URI component:', err); } await processDeviceListUpdate(); }, [processDeviceListUpdate], ); const onCancelScan = React.useCallback(() => setScanned(false), []); const handleBarCodeScanned = React.useCallback( (barCodeEvent: BarCodeEvent) => { setScanned(true); Alert.alert( 'Connect with this device?', 'Are you sure you want to allow this device to log in to your account?', [ { text: 'Cancel', style: 'cancel', onPress: onCancelScan, }, { text: 'Connect', onPress: () => onConnect(barCodeEvent), }, ], { cancelable: false }, ); }, [onCancelScan, onConnect], ); if (hasPermission === null) { return ; } if (deviceIsEmulator) { return ( QR Code URL ); } // Note: According to the BarCodeScanner Expo docs, we should adhere to two // guidances when using the BarCodeScanner: // 1. We should specify the potential barCodeTypes we want to scan for to // minimize battery usage. // 2. We should set the onBarCodeScanned callback to undefined if it scanned // in order to 'pause' the scanner from continuing to scan while we // process the data from the scan. // See: https://docs.expo.io/versions/latest/sdk/bar-code-scanner return ( ); } const unboundStyles = { scannerContainer: { flex: 1, flexDirection: 'column', justifyContent: 'center', }, scanner: { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0, }, textInputContainer: { paddingTop: 8, }, header: { color: 'panelBackgroundLabel', fontSize: 12, fontWeight: '400', paddingBottom: 3, paddingHorizontal: 24, }, inputContainer: { backgroundColor: 'panelForeground', flexDirection: 'row', justifyContent: 'space-between', paddingHorizontal: 24, paddingVertical: 12, borderBottomWidth: 1, borderColor: 'panelForegroundBorder', borderTopWidth: 1, }, input: { color: 'panelForegroundLabel', flex: 1, fontFamily: 'Arial', fontSize: 16, paddingVertical: 0, borderBottomColor: 'transparent', }, }; export default SecondaryDeviceQRCodeScanner; diff --git a/native/profile/tunnelbroker-menu.react.js b/native/profile/tunnelbroker-menu.react.js index 233ea35eb..e7568f413 100644 --- a/native/profile/tunnelbroker-menu.react.js +++ b/native/profile/tunnelbroker-menu.react.js @@ -1,261 +1,261 @@ // @flow import * as React from 'react'; import { useState } from 'react'; import { Text, View } from 'react-native'; import { ScrollView } from 'react-native-gesture-handler'; import { IdentityClientContext } from 'lib/shared/identity-client-context.js'; import { useTunnelbroker } from 'lib/tunnelbroker/tunnelbroker-context.js'; import { - tunnelbrokerMessageTypes, - type TunnelbrokerMessage, + tunnelbrokerToDeviceMessageTypes, + type TunnelbrokerToDeviceMessage, } from 'lib/types/tunnelbroker/messages.js'; import { type EncryptedMessage, peerToPeerMessageTypes, } from 'lib/types/tunnelbroker/peer-to-peer-message-types.js'; import { createOlmSessionsWithOwnDevices, getContentSigningKey, } from 'lib/utils/crypto-utils.js'; import type { ProfileNavigationProp } from './profile.react.js'; import Button from '../components/button.react.js'; import TextInput from '../components/text-input.react.js'; import { olmAPI } from '../crypto/olm-api.js'; import type { NavigationRoute } from '../navigation/route-names.js'; import { useSelector } from '../redux/redux-utils.js'; import { useColors, useStyles } from '../themes/colors.js'; type Props = { +navigation: ProfileNavigationProp<'TunnelbrokerMenu'>, +route: NavigationRoute<'TunnelbrokerMenu'>, }; // eslint-disable-next-line no-unused-vars function TunnelbrokerMenu(props: Props): React.Node { const styles = useStyles(unboundStyles); const colors = useColors(); const currentUserID = useSelector( state => state.currentUserInfo && state.currentUserInfo.id, ); const identityContext = React.useContext(IdentityClientContext); const { socketState, addListener, sendMessage, removeListener } = useTunnelbroker(); - const [messages, setMessages] = useState([]); + const [messages, setMessages] = useState([]); const [recipient, setRecipient] = useState(''); const [message, setMessage] = useState(''); const [deviceID, setDeviceID] = React.useState(); React.useEffect(() => { void (async () => { const contentSigningKey = await getContentSigningKey(); setDeviceID(contentSigningKey); })(); }, []); - const listener = React.useCallback((msg: TunnelbrokerMessage) => { + const listener = React.useCallback((msg: TunnelbrokerToDeviceMessage) => { setMessages(prev => [...prev, msg]); }, []); React.useEffect(() => { addListener(listener); return () => removeListener(listener); }, [addListener, listener, removeListener]); const onSubmit = React.useCallback(async () => { try { await sendMessage({ deviceID: recipient, payload: message }); } catch (e) { console.log(e.message); } }, [message, recipient, sendMessage]); const onCreateSessions = React.useCallback(async () => { if (!identityContext) { return; } const authMetadata = await identityContext.getAuthMetadata(); try { await createOlmSessionsWithOwnDevices( authMetadata, identityContext.identityClient, sendMessage, ); } catch (e) { console.log(`Error creating olm sessions with own devices: ${e.message}`); } }, [identityContext, sendMessage]); const onSendEncryptedMessage = React.useCallback(async () => { try { if (!currentUserID) { return; } await olmAPI.initializeCryptoAccount(); const encryptedData = await olmAPI.encrypt(message, recipient); const signingKey = await getContentSigningKey(); const encryptedMessage: EncryptedMessage = { type: peerToPeerMessageTypes.ENCRYPTED_MESSAGE, senderInfo: { deviceID: signingKey, userID: currentUserID, }, encryptedData, }; await sendMessage({ deviceID: recipient, payload: JSON.stringify(encryptedMessage), }); } catch (e) { console.log(`Error sending encrypted content to device: ${e.message}`); } }, [message, currentUserID, recipient, sendMessage]); return ( INFO Connected {socketState.connected.toString()} DEVICE ID {deviceID} USER ID {currentUserID} SEND MESSAGE Recipient Message MESSAGES {messages - .filter(msg => msg.type !== tunnelbrokerMessageTypes.HEARTBEAT) + .filter(msg => msg.type !== tunnelbrokerToDeviceMessageTypes.HEARTBEAT) .map((msg, id) => ( {JSON.stringify(msg)} ))} ); } const unboundStyles = { scrollViewContentContainer: { paddingTop: 24, }, scrollView: { backgroundColor: 'panelBackground', }, section: { backgroundColor: 'panelForeground', borderBottomWidth: 1, borderColor: 'panelForegroundBorder', borderTopWidth: 1, marginBottom: 24, marginVertical: 2, }, header: { color: 'panelBackgroundLabel', fontSize: 12, fontWeight: '400', paddingBottom: 3, paddingHorizontal: 24, }, submenuButton: { flexDirection: 'row', paddingHorizontal: 24, paddingVertical: 10, alignItems: 'center', }, submenuText: { color: 'panelForegroundLabel', flex: 1, fontSize: 16, }, text: { color: 'panelForegroundLabel', fontSize: 16, }, row: { flexDirection: 'row', justifyContent: 'space-between', paddingHorizontal: 24, paddingVertical: 14, }, textInput: { color: 'modalBackgroundLabel', flex: 1, fontSize: 16, margin: 0, padding: 0, borderBottomColor: 'transparent', }, }; export default TunnelbrokerMenu; diff --git a/services/commtest/src/tunnelbroker/socket.rs b/services/commtest/src/tunnelbroker/socket.rs index fca00f1d1..ae65ab8d3 100644 --- a/services/commtest/src/tunnelbroker/socket.rs +++ b/services/commtest/src/tunnelbroker/socket.rs @@ -1,117 +1,119 @@ use crate::identity::device::DeviceInfo; use crate::service_addr; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tunnelbroker_messages::{ ConnectionInitializationMessage, ConnectionInitializationResponse, ConnectionInitializationStatus, DeviceTypes, Heartbeat, MessageSentStatus, - MessageToDeviceRequest, MessageToDeviceRequestStatus, Messages, + MessageToDeviceRequest, MessageToDeviceRequestStatus, + TunnelbrokerToDeviceMessage, }; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub struct WebSocketMessageToDevice { #[serde(rename = "deviceID")] pub device_id: String, pub payload: String, } pub async fn create_socket( device_info: &DeviceInfo, ) -> Result< WebSocketStream>, Box, > { let (mut socket, _) = connect_async(service_addr::TUNNELBROKER_WS) .await .expect("Can't connect"); let session_request = ConnectionInitializationMessage { device_id: device_info.device_id.to_string(), access_token: device_info.access_token.to_string(), user_id: device_info.user_id.to_string(), notify_token: None, device_type: DeviceTypes::Keyserver, device_app_version: None, device_os: None, }; let serialized_request = serde_json::to_string(&session_request) .expect("Failed to serialize connection request"); socket .send(Message::Text(serialized_request)) .await .expect("Failed to send message"); if let Some(Ok(response)) = socket.next().await { let response: ConnectionInitializationResponse = serde_json::from_str(response.to_text().unwrap())?; return match response.status { ConnectionInitializationStatus::Success => Ok(socket), ConnectionInitializationStatus::Error(err) => Err(err.into()), }; } Err("Failed to get response from Tunnelbroker".into()) } pub async fn send_message( socket: &mut WebSocketStream>, message: WebSocketMessageToDevice, ) -> Result> { let client_message_id = uuid::Uuid::new_v4().to_string(); let request = MessageToDeviceRequest { client_message_id: client_message_id.clone(), device_id: message.device_id, payload: message.payload, }; let serialized_request = serde_json::to_string(&request)?; socket.send(Message::Text(serialized_request)).await?; if let Some(Ok(response)) = socket.next().await { let confirmation: MessageToDeviceRequestStatus = serde_json::from_str(response.to_text().unwrap())?; if confirmation .client_message_ids .contains(&MessageSentStatus::Success(client_message_id.clone())) { return Ok(client_message_id); } } Err("Failed to confirm sent message".into()) } pub async fn receive_message( socket: &mut WebSocketStream>, ) -> Result> { while let Some(Ok(response)) = socket.next().await { let message_str = response.to_text().expect("Failed to get response content"); - let message = serde_json::from_str::(message_str).unwrap(); + let message = + serde_json::from_str::(message_str).unwrap(); match message { - Messages::MessageToDevice(msg) => { + TunnelbrokerToDeviceMessage::MessageToDevice(msg) => { let confirmation = tunnelbroker_messages::MessageReceiveConfirmation { message_ids: vec![msg.message_id], }; let serialized_confirmation = serde_json::to_string(&confirmation).unwrap(); socket.send(Message::Text(serialized_confirmation)).await?; return Ok(msg.payload); } - Messages::Heartbeat(Heartbeat {}) => { + TunnelbrokerToDeviceMessage::Heartbeat(Heartbeat {}) => { let msg = Heartbeat {}; let serialized = serde_json::to_string(&msg).unwrap(); socket.send(Message::Text(serialized)).await?; } _ => return Err(format!("Unexpected message type {message:?}").into()), } } Err("Failed to receive message".into()) } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 0c5be9a99..a293a2dcc 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,526 +1,535 @@ use crate::constants::{ CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, }; use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - Heartbeat, MessageToDevice, MessageToDeviceRequest, MessageToTunnelbroker, - Messages, + DeviceToTunnelbrokerMessage, Heartbeat, MessageToDevice, + MessageToDeviceRequest, MessageToTunnelbroker, }; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; use crate::notifs::apns::headers::NotificationHeaders; use crate::notifs::apns::APNsNotif; use crate::notifs::fcm::firebase_message::{ AndroidConfig, AndroidMessagePriority, FCMMessage, }; use crate::notifs::NotifClient; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, pub is_authenticated: bool, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, notif_client: NotifClient, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), DatabaseError(comm_lib::database::Error), MissingAPNsClient, MissingFCMClient, MissingDeviceToken, } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { - let serialized_message = serde_json::from_str::(message)?; + let serialized_message = + serde_json::from_str::(message)?; match serialized_message { - Messages::ConnectionInitializationMessage(mut session_info) => { + DeviceToTunnelbrokerMessage::ConnectionInitializationMessage( + mut session_info, + ) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), is_authenticated: true, }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } - Messages::AnonymousInitializationMessage(session_info) => { + DeviceToTunnelbrokerMessage::AnonymousInitializationMessage( + session_info, + ) => { debug!( "Starting unauthenticated session with device: {}", &session_info.device_id ); let device_info = DeviceInfo { device_id: session_info.device_id, device_type: session_info.device_type, device_app_version: session_info.device_app_version, device_os: session_info.device_os, is_authenticated: false, notify_token: None, }; Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage) } } } async fn publish_persisted_messages( db_client: &DatabaseClient, amqp_channel: &lapin::Channel, device_info: &DeviceInfo, ) -> Result<(), SessionError> { let messages = db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let message_to_device = MessageToDevice::from_hashmap(message)?; let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; } debug!("Flushed messages for device: {}", &device_info.device_id); Ok(()) } pub async fn initialize_amqp( db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; let mut args = FieldTable::default(); args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) .await?; publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok((device_info, amqp_consumer)) } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, notif_client: NotifClient, ) -> Self { Self { tx, db_client, device_info, amqp_channel, amqp_consumer, notif_client, } } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let message_to_device = MessageToDevice { device_id: message_request.device_id.clone(), payload: message_request.payload.clone(), message_id: message_id.clone(), }; let serialized_message = serde_json::to_string(&message_to_device)?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_message_to_tunnelbroker( &self, message_to_tunnelbroker: &MessageToTunnelbroker, ) -> Result<(), SessionError> { match message_to_tunnelbroker { MessageToTunnelbroker::SetDeviceToken(token) => { self .db_client .set_device_token(&self.device_info.device_id, &token.device_token) .await?; } } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { - let Ok(serialized_message) = serde_json::from_str::(&msg) else { + let Ok(serialized_message) = + serde_json::from_str::(&msg) + else { return Some(MessageSentStatus::SerializationError(msg)); }; match serialized_message { - Messages::Heartbeat(Heartbeat {}) => { + DeviceToTunnelbrokerMessage::Heartbeat(Heartbeat {}) => { trace!("Received heartbeat from: {}", self.device_info.device_id); None } - Messages::MessageReceiveConfirmation(confirmation) => { + DeviceToTunnelbrokerMessage::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &message_id) .await { error!("Failed to delete message: {}:", e); } } None } - Messages::MessageToDeviceRequest(message_request) => { + DeviceToTunnelbrokerMessage::MessageToDeviceRequest(message_request) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } - Messages::MessageToTunnelbrokerRequest(message_request) => { + DeviceToTunnelbrokerMessage::MessageToTunnelbrokerRequest( + message_request, + ) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for Tunnelbroker"); let Ok(message_to_tunnelbroker) = serde_json::from_str(&message_request.payload) else { return Some(MessageSentStatus::SerializationError( message_request.payload, )); }; let result = self .handle_message_to_tunnelbroker(&message_to_tunnelbroker) .await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } - Messages::APNsNotif(notif) => { + DeviceToTunnelbrokerMessage::APNsNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received APNs notif for {}", notif.device_id); let Ok(headers) = serde_json::from_str::(¬if.headers) else { return Some(MessageSentStatus::SerializationError(notif.headers)); }; let device_token = match self.get_device_token(notif.device_id).await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let apns_notif = APNsNotif { device_token, headers, payload: notif.payload, }; if let Some(apns) = self.notif_client.apns.clone() { let response = apns.send(apns_notif).await; return Some( self .get_message_to_device_status(¬if.client_message_id, response), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingAPNsClient), )) } - Messages::FCMNotif(notif) => { + DeviceToTunnelbrokerMessage::FCMNotif(notif) => { // unauthenticated clients cannot send notifs if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text notif. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received FCM notif for {}", notif.device_id); let Some(priority) = AndroidMessagePriority::from_str(¬if.priority) else { return Some(MessageSentStatus::SerializationError(notif.priority)); }; let Ok(data) = serde_json::from_str(¬if.data) else { return Some(MessageSentStatus::SerializationError(notif.data)); }; let device_token = match self.get_device_token(notif.device_id).await { Ok(token) => token, Err(e) => { return Some( self .get_message_to_device_status(¬if.client_message_id, Err(e)), ) } }; let fcm_message = FCMMessage { data, token: device_token.to_string(), android: AndroidConfig { priority }, }; if let Some(fcm) = self.notif_client.fcm.clone() { let response = fcm.send(fcm_message).await; return Some( self .get_message_to_device_status(¬if.client_message_id, response), ); } Some(self.get_message_to_device_status( ¬if.client_message_id, Err(SessionError::MissingFCMClient), )) } _ => { error!("Client sent invalid message type"); Some(MessageSentStatus::InvalidRequest) } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } pub fn get_message_to_device_status( &mut self, client_message_id: &str, result: Result<(), E>, ) -> MessageSentStatus where E: std::error::Error, { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } async fn get_device_token( &self, device_id: String, ) -> Result { let db_token = self .db_client .get_device_token(&device_id) .await .map_err(SessionError::DatabaseError)?; db_token.ok_or_else(|| SessionError::MissingDeviceToken) } } diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index 5ae3b693d..095e2ddf9 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,69 +1,77 @@ //! Messages sent between Tunnelbroker and a device. pub mod device_list_updated; pub mod keys; pub mod message_receive_confirmation; pub mod message_to_device; pub mod message_to_device_request; pub mod message_to_device_request_status; pub mod message_to_tunnelbroker; pub mod message_to_tunnelbroker_request; pub mod notif; pub mod session; pub use device_list_updated::*; pub use keys::*; pub use message_receive_confirmation::*; pub use message_to_device::*; pub use message_to_device_request::*; pub use message_to_device_request_status::*; pub use message_to_tunnelbroker::*; pub use message_to_tunnelbroker_request::*; pub use session::*; pub use websocket_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, }; use crate::notif::*; use serde::{Deserialize, Serialize}; // This file defines types and validation for messages exchanged // with the Tunnelbroker. The definitions in this file should remain in sync // with the structures defined in the corresponding // JavaScript file at `lib/types/tunnelbroker/messages.js`. // If you edit the definitions in one file, // please make sure to update the corresponding definitions in the other. +// Messages sent from Device to Tunnelbroker. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] -pub enum Messages { +pub enum DeviceToTunnelbrokerMessage { ConnectionInitializationMessage(ConnectionInitializationMessage), - ConnectionInitializationResponse(ConnectionInitializationResponse), AnonymousInitializationMessage(AnonymousInitializationMessage), - // MessageToDeviceRequestStatus must be placed before MessageToDeviceRequest. - // This is due to serde's pattern matching behavior where it prioritizes - // the first matching pattern it encounters. APNsNotif(APNsNotif), FCMNotif(FCMNotif), - MessageToDeviceRequestStatus(MessageToDeviceRequestStatus), MessageToDeviceRequest(MessageToDeviceRequest), - MessageToDevice(MessageToDevice), MessageReceiveConfirmation(MessageReceiveConfirmation), MessageToTunnelbrokerRequest(MessageToTunnelbrokerRequest), Heartbeat(Heartbeat), - IdentityDeviceListUpdated(IdentityDeviceListUpdated), } +// Messages sent from Tunnelbroker to Device. +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum TunnelbrokerToDeviceMessage { + ConnectionInitializationResponse(ConnectionInitializationResponse), + MessageToDeviceRequestStatus(MessageToDeviceRequestStatus), + MessageToDevice(MessageToDevice), + Heartbeat(Heartbeat), +} + +// Messages sent from Services (e.g. Identity) to Device. +// This type is sent to a Device as a payload of MessageToDevice. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] -pub enum PeerToPeerMessages { +pub enum ServiceToDeviceMessages { RefreshKeysRequest(RefreshKeyRequest), IdentityDeviceListUpdated(IdentityDeviceListUpdated), } +// Messages sent from Device to Tunnelbroker which Tunnelbroker itself should handle. +// This type is sent to a Tunnelbroker as a payload of MessageToTunnelbrokerRequest. #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum MessageToTunnelbroker { SetDeviceToken(SetDeviceToken), } diff --git a/web/settings/tunnelbroker-message-list.react.js b/web/settings/tunnelbroker-message-list.react.js index 1610e8ea4..8cd5bd7a3 100644 --- a/web/settings/tunnelbroker-message-list.react.js +++ b/web/settings/tunnelbroker-message-list.react.js @@ -1,55 +1,59 @@ // @flow import * as React from 'react'; import type { TunnelbrokerSocketListener } from 'lib/tunnelbroker/tunnelbroker-context.js'; import { - tunnelbrokerMessageTypes, - type TunnelbrokerMessage, + type TunnelbrokerToDeviceMessage, + tunnelbrokerToDeviceMessageTypes, } from 'lib/types/tunnelbroker/messages.js'; import css from './tunnelbroker-message-list.css'; import Modal from '../modals/modal.react.js'; type Props = { +addListener: (listener: TunnelbrokerSocketListener) => void, +removeListener: (listener: TunnelbrokerSocketListener) => void, +onClose: () => void, }; function TunnelbrokerMessagesScreen(props: Props): React.Node { const { addListener, onClose, removeListener } = props; - const [messages, setMessages] = React.useState([]); + const [messages, setMessages] = React.useState( + [], + ); - const listener = React.useCallback((msg: TunnelbrokerMessage) => { + const listener = React.useCallback((msg: TunnelbrokerToDeviceMessage) => { setMessages(prev => [...prev, msg]); }, []); React.useEffect(() => { addListener(listener); return () => removeListener(listener); }, [addListener, listener, removeListener]); let messageList: React.Node = (
No messages
); if (messages.length) { messageList = messages - .filter(message => message.type !== tunnelbrokerMessageTypes.HEARTBEAT) + .filter( + message => message.type !== tunnelbrokerToDeviceMessageTypes.HEARTBEAT, + ) .map((message, id) => (
{JSON.stringify(message)}
)); } return (
{messageList}
); } export default TunnelbrokerMessagesScreen;