diff --git a/keyserver/src/socket/socket.js b/keyserver/src/socket/socket.js
--- a/keyserver/src/socket/socket.js
+++ b/keyserver/src/socket/socket.js
@@ -89,6 +89,7 @@
   checkClientSupported,
   policiesValidator,
   validateOutput,
+  validateInput,
 } from '../utils/validation-utils.js';
 
 const clientSocketMessageInputValidator: TUnion<ClientSocketMessage> = t.union([
@@ -193,22 +194,26 @@
     messageString: string | Buffer | ArrayBuffer | Array<Buffer>,
   ): Promise<void> => {
     invariant(typeof messageString === 'string', 'message should be string');
-    let clientSocketMessage: ?ClientSocketMessage;
+    let responseTo = null;
     try {
       this.resetTimeout();
       const messageObject = JSON.parse(messageString);
-      clientSocketMessage = checkInputValidator(
+      const clientSocketMessageWithClientIDs = checkInputValidator(
         clientSocketMessageInputValidator,
         messageObject,
       );
-      if (clientSocketMessage.type === clientSocketMessageTypes.INITIAL) {
+      responseTo = clientSocketMessageWithClientIDs.id;
+      if (
+        clientSocketMessageWithClientIDs.type ===
+        clientSocketMessageTypes.INITIAL
+      ) {
         if (this.viewer) {
           // This indicates that the user sent multiple INITIAL messages.
           throw new ServerError('socket_already_initialized');
         }
         this.viewer = await fetchViewerForSocket(
           this.httpRequest,
-          clientSocketMessage,
+          clientSocketMessageWithClientIDs,
         );
       }
       const { viewer } = this;
@@ -229,9 +234,14 @@
       await checkClientSupported(
         viewer,
         clientSocketMessageInputValidator,
-        clientSocketMessage,
+        clientSocketMessageWithClientIDs,
       );
       await policiesValidator(viewer, baseLegalPolicies);
+      const clientSocketMessage = await validateInput(
+        viewer,
+        clientSocketMessageInputValidator,
+        clientSocketMessageWithClientIDs,
+      );
 
       const serverResponses =
         await this.handleClientSocketMessage(clientSocketMessage);
@@ -267,7 +277,6 @@
           type: serverSocketMessageTypes.ERROR,
           message: error.message,
         };
-        const responseTo = clientSocketMessage ? clientSocketMessage.id : null;
         if (responseTo !== null) {
           errorMessage.responseTo = responseTo;
         }
@@ -275,8 +284,7 @@
         await this.sendMessage(errorMessage);
         return;
       }
-      invariant(clientSocketMessage, 'should be set');
-      const responseTo = clientSocketMessage.id;
+      invariant(responseTo, 'should be set');
       if (error.message === 'socket_deauthorized') {
         invariant(this.viewer, 'should be set');
         const authErrorMessage: AuthErrorServerSocketMessage = {