diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index 6de172ac1..b470a77b5 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,316 +1,330 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; -use elastic::client::responses::SearchResponse; +use elastic::client::responses::SearchResponse as ElasticSearchResponse; use futures::lock::Mutex; use futures_util::{SinkExt, StreamExt}; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, - IdentitySearchMethod, IdentitySearchResult, IdentitySearchUser, Messages, + IdentitySearchFailure, IdentitySearchMethod, IdentitySearchResponse, + IdentitySearchResult, IdentitySearchUser, Messages, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info}; mod auth; mod send; use crate::config::CONFIG; use crate::constants::{ IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT, }; use send::{send_error_response, send_message, WebsocketSink}; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } struct WebsocketService { addr: SocketAddr, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let future = async move { tracing::debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server() -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening to WebSocket traffic on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let connection = http .serve_connection(stream, WebsocketService { addr }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_search_request( url: &str, json_body: T, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .json(&json_body) .send() .await } async fn close_connection(outgoing: WebsocketSink) { if let Err(e) = outgoing.lock().await.close().await { error!("Error closing connection: {}", e); } } async fn handle_prefix_search( - prefix: &str, -) -> Result { + request_id: &str, + prefix_request: identity_search_messages::IdentitySearchPrefix, +) -> Result { let prefix_query = Query { query: Prefix { prefix: Username { - username: prefix.trim().to_string(), + username: prefix_request.prefix.trim().to_string(), }, }, }; let opensearch_url = format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); let search_response = send_search_request(&opensearch_url, prefix_query) .await? - .json::>() + .json::>() .await?; let usernames: Vec = search_response.into_documents().collect(); - let search_result = - serde_json::to_string(&IdentitySearchResult { hits: usernames })?; + let search_result = IdentitySearchResult { + id: request_id.to_string(), + hits: usernames, + }; Ok(search_result) } async fn handle_websocket_frame( text: String, outgoing: WebsocketSink, ) -> Result<(), errors::WebsocketError> { let Ok(serialized_message) = serde_json::from_str::(&text) else { return Err(errors::WebsocketError::SerializationError); }; match serialized_message { Messages::Heartbeat(Heartbeat {}) => { debug!("Received heartbeat"); Ok(()) } - Messages::IdentitySearchQuery(search_request) => { - let search_result = match search_request.search_method { - IdentitySearchMethod::IdentitySearchPrefix(prefix_request) => { - handle_prefix_search(&prefix_request.prefix).await + Messages::IdentitySearchQuery(search_query) => { + let handler_result = match search_query.search_method { + IdentitySearchMethod::IdentitySearchPrefix(prefix_query) => { + handle_prefix_search(&search_query.id, prefix_query).await } - }?; + }; + + let search_response = match handler_result { + Ok(search_result) => IdentitySearchResponse::Success(search_result), + Err(e) => IdentitySearchResponse::Error(IdentitySearchFailure { + id: search_query.id, + error: e.to_string(), + }), + }; + + let serialized_message = serde_json::to_string(&search_response)?; - send_message(Message::Text(search_result), outgoing.clone()).await; + send_message(Message::Text(serialized_message), outgoing.clone()).await; Ok(()) } _ => Err(errors::WebsocketError::InvalidMessage), } } async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!("WebSocket handshake error: {}", e); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { if let Err(auth_error) = auth::handle_auth_message(&text).await { let error_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error( auth_error.to_string(), ), }; let serialized_response = serde_json::to_string(&error_response) .expect("Error serializing auth error response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; close_connection(outgoing).await; return; } else { let success_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&success_response) .expect("Error serializing auth success response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; } } _ => { error!("Invalid authentication message from {}", addr); close_connection(outgoing).await; return; } } } else { error!("No authentication message from {}", addr); close_connection(outgoing).await; return; } let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; loop { tokio::select! { client_message = incoming.next() => { let message: Message = match client_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); send_message(Message::Pong(msg), outgoing.clone()).await; } Message::Text(text) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await { send_error_response(e, outgoing.clone()).await; continue; }; } _ => { error!("Client sent invalid message type"); break; } } } _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died.", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); send_message(Message::text(serialized), outgoing.clone()).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; } } } info!("unregistering connection to: {}", addr); close_connection(outgoing).await; } diff --git a/shared/identity_search_messages/src/messages/mod.rs b/shared/identity_search_messages/src/messages/mod.rs index 429fc4e20..bad1d091e 100644 --- a/shared/identity_search_messages/src/messages/mod.rs +++ b/shared/identity_search_messages/src/messages/mod.rs @@ -1,24 +1,24 @@ //! Messages sent from client to Identity Search Server pub mod auth_messages; pub mod search_query; -pub mod search_result; +pub mod search_response; pub use auth_messages::*; pub use search_query::*; -pub use search_result::*; +pub use search_response::*; use serde::{Deserialize, Serialize}; pub use websocket_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, }; #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Messages { IdentitySearchAuthMessage(IdentitySearchAuthMessage), IdentitySearchQuery(IdentitySearchQuery), Heartbeat(Heartbeat), ConnectionInitializationResponse(ConnectionInitializationResponse), - IdentitySearchResult(IdentitySearchResult), + IdentitySearchResponse(IdentitySearchResponse), } diff --git a/shared/identity_search_messages/src/messages/search_query.rs b/shared/identity_search_messages/src/messages/search_query.rs index 75cc95bb3..26f7a3b69 100644 --- a/shared/identity_search_messages/src/messages/search_query.rs +++ b/shared/identity_search_messages/src/messages/search_query.rs @@ -1,20 +1,21 @@ //! Search Request Messages sent by Client to Identity Search via WebSocket. use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] pub struct IdentitySearchPrefix { pub prefix: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum IdentitySearchMethod { IdentitySearchPrefix(IdentitySearchPrefix), } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub struct IdentitySearchQuery { + pub id: String, pub search_method: IdentitySearchMethod, } diff --git a/shared/identity_search_messages/src/messages/search_result.rs b/shared/identity_search_messages/src/messages/search_response.rs similarity index 53% rename from shared/identity_search_messages/src/messages/search_result.rs rename to shared/identity_search_messages/src/messages/search_response.rs index 9e208bc24..66e57bcfc 100644 --- a/shared/identity_search_messages/src/messages/search_result.rs +++ b/shared/identity_search_messages/src/messages/search_response.rs @@ -1,16 +1,29 @@ //! Search Result Messages sent by Identity Search via WebSocket. use serde::{Deserialize, Serialize}; +#[derive(Debug, Serialize, Deserialize)] +pub struct IdentitySearchFailure { + pub id: String, + pub error: String, +} + #[derive(Debug, Serialize, Deserialize)] pub struct IdentitySearchUser { #[serde(rename = "userID")] pub user_id: String, pub username: String, } #[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] pub struct IdentitySearchResult { + pub id: String, pub hits: Vec, } + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum IdentitySearchResponse { + Success(IdentitySearchResult), + Error(IdentitySearchFailure), +}