diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js --- a/keyserver/src/socket/socket.js +++ b/keyserver/src/socket/socket.js @@ -89,6 +89,7 @@ checkClientSupported, policiesValidator, validateOutput, + validateInput, } from '../utils/validation-utils.js'; const clientSocketMessageInputValidator: TUnion = t.union([ @@ -193,22 +194,26 @@ messageString: string | Buffer | ArrayBuffer | Array, ): Promise => { invariant(typeof messageString === 'string', 'message should be string'); - let clientSocketMessage: ?ClientSocketMessage; + let responseTo = null; try { this.resetTimeout(); const messageObject = JSON.parse(messageString); - clientSocketMessage = checkInputValidator( + const clientSocketMessageWithClientIDs = checkInputValidator( clientSocketMessageInputValidator, messageObject, ); - if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) { + responseTo = clientSocketMessageWithClientIDs.id; + if ( + clientSocketMessageWithClientIDs.type === + clientSocketMessageTypes.INITIAL + ) { if (this.viewer) { // This indicates that the user sent multiple INITIAL messages. throw new ServerError('socket_already_initialized'); } this.viewer = await fetchViewerForSocket( this.httpRequest, - clientSocketMessage, + clientSocketMessageWithClientIDs, ); } const { viewer } = this; @@ -229,9 +234,14 @@ await checkClientSupported( viewer, clientSocketMessageInputValidator, - clientSocketMessage, + clientSocketMessageWithClientIDs, ); await policiesValidator(viewer, baseLegalPolicies); + const clientSocketMessage = await validateInput( + viewer, + clientSocketMessageInputValidator, + clientSocketMessageWithClientIDs, + ); const serverResponses = await this.handleClientSocketMessage(clientSocketMessage); @@ -267,7 +277,6 @@ type: serverSocketMessageTypes.ERROR, message: error.message, }; - const responseTo = clientSocketMessage ? clientSocketMessage.id : null; if (responseTo !== null) { errorMessage.responseTo = responseTo; } @@ -275,8 +284,7 @@ await this.sendMessage(errorMessage); return; } - invariant(clientSocketMessage, 'should be set'); - const responseTo = clientSocketMessage.id; + invariant(responseTo, 'should be set'); if (error.message === 'socket_deauthorized') { invariant(this.viewer, 'should be set'); const authErrorMessage: AuthErrorServerSocketMessage = {