diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -1,3 +1,5 @@ +use tokio::time::Duration; + // Secrets pub const SECRETS_DIRECTORY: &str = "secrets"; @@ -133,6 +135,8 @@ pub const IDENTITY_SERVICE_SOCKET_ADDR: &str = "[::]:50054"; pub const IDENTITY_SERVICE_WEBSOCKET_ADDR: &str = "[::]:51004"; +pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); + // Token pub const ACCESS_TOKEN_LENGTH: usize = 512; diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -10,8 +10,8 @@ use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ - ConnectionInitializationResponse, ConnectionInitializationStatus, - SearchQuery, SearchResult, User, + ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, + Messages, SearchQuery, SearchResult, User, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -21,7 +21,9 @@ mod send; use crate::config::CONFIG; -use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; +use crate::constants::{ + IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT, +}; use send::{send_error_response, send_message, WebsocketSink}; pub mod errors; @@ -165,6 +167,38 @@ Ok(search_result) } +async fn handle_websocket_frame( + text: String, + outgoing: WebsocketSink, +) -> Result<(), errors::WebsocketError> { + let Ok(serialized_message) = serde_json::from_str::(&text) else { + return Err(errors::WebsocketError::SerializationError); + }; + + match serialized_message { + Messages::Heartbeat(Heartbeat {}) => { + debug!("Received heartbeat"); + Ok(()) + } + Messages::SearchQuery(search_request) => { + let search_result = match search_request { + SearchQuery::Prefix(prefix_request) => { + handle_prefix_search(&prefix_request.prefix).await + } + }?; + + send_message( + Message::Text(format!("{}", search_result.to_string())), + outgoing.clone(), + ) + .await; + + Ok(()) + } + _ => Err(errors::WebsocketError::InvalidMessage), + } +} + async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); @@ -220,61 +254,65 @@ return; } - while let Some(message) = incoming.next().await { - match message { - Ok(Message::Close(_)) => { - debug!("Connection to {} closed.", addr); - break; - } - Ok(Message::Pong(_)) => { - debug!("Received Pong message from {}", addr); - } - Ok(Message::Ping(msg)) => { - debug!("Received Ping message from {}", addr); - send_message(Message::Pong(msg), outgoing.clone()).await; - } - Ok(Message::Text(text)) => { - let Ok(search_request) = serde_json::from_str(&text) else { - send_error_response( - errors::WebsocketError::InvalidSearchQuery, - outgoing.clone(), - ) - .await; - continue; - }; - - let search_result = match search_request { - SearchQuery::Prefix(prefix_request) => { - handle_prefix_search(&prefix_request.prefix).await + let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + let mut got_heartbeat_response = true; + + loop { + tokio::select! { + client_message = incoming.next() => { + let message: Message = match client_message { + Some(Ok(msg)) => msg, + _ => { + debug!("Connection to {} closed remotely.", addr); + break; } }; - let response_msg = match search_result { - Ok(response_msg) => response_msg, - Err(e) => { - send_error_response(e, outgoing.clone()).await; - continue; + match message { + Message::Close(_) => { + debug!("Connection to {} closed.", addr); + break; } - }; + Message::Pong(_) => { + debug!("Received Pong message from {}", addr); + } + Message::Ping(msg) => { + debug!("Received Ping message from {}", addr); + send_message(Message::Pong(msg), outgoing.clone()).await; + } + Message::Text(text) => { + got_heartbeat_response = true; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); + + if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await { + send_error_response(e, outgoing.clone()).await; + continue; + }; + } + _ => { + error!("Client sent invalid message type"); + break; + } + } + } + _ = &mut ping_timeout => { + if !got_heartbeat_response { + error!("Connection to {} died.", addr); + break; + } + let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); + send_message(Message::text(serialized), outgoing.clone()).await; - send_message( - Message::Text(format!("{}", response_msg.to_string())), - outgoing.clone(), - ) - .await; + got_heartbeat_response = false; + ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } - Err(e) => { - error!("Error in WebSocket message: {}", e); - send_error_response( - errors::WebsocketError::InvalidMessage, - outgoing.clone(), - ) - .await; - continue; + else => { + debug!("Unhealthy connection for: {}", addr); + break; } - _ => {} } } + info!("unregistering connection to: {}", addr); close_connection(outgoing).await; }