diff --git a/services/identity/src/websockets/errors.rs b/services/identity/src/websockets/errors.rs --- a/services/identity/src/websockets/errors.rs +++ b/services/identity/src/websockets/errors.rs @@ -5,6 +5,7 @@ )] pub enum WebsocketError { InvalidMessage, + InvalidSearchQuery, UnauthorizedDevice, SendError, SearchError, diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -11,7 +11,7 @@ use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, - SearchResult, User, + SearchQuery, SearchResult, User, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -23,7 +23,6 @@ use crate::config::CONFIG; use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; use send::{send_error_response, send_message, WebsocketSink}; - pub mod errors; #[derive(Serialize, Deserialize)] @@ -141,6 +140,48 @@ } } +async fn handle_prefix_search( + prefix: &str, +) -> Result { + let prefix_query = Query { + query: Prefix { + prefix: Username { + username: prefix.trim().to_string(), + }, + }, + }; + + let json_body = match serde_json::to_string(&prefix_query) { + Ok(json_body) => json_body, + Err(_) => { + return Err(errors::WebsocketError::SerializationError); + } + }; + + let opensearch_url = + format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); + + let response = send_search_request(&opensearch_url, json_body).await; + + let response_text = response + .map_err(|_| errors::WebsocketError::SearchError)? + .text() + .await + .map_err(|_| errors::WebsocketError::SearchError)?; + + let search_response: SearchResponse = + serde_json::from_str(&response_text) + .map_err(|_| errors::WebsocketError::SerializationError)?; + + let usernames: Vec = search_response.into_documents().collect(); + + let search_result = + serde_json::to_string(&SearchResult { payload: usernames }) + .map_err(|_| errors::WebsocketError::SerializationError); + + search_result +} + async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); @@ -156,9 +197,6 @@ let outgoing = Arc::new(Mutex::new(outgoing)); - let opensearch_url = - format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); - if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { @@ -215,20 +253,11 @@ } } Ok(Message::Text(text)) => { - let prefix_query = Query { - query: Prefix { - prefix: Username { - username: text.trim().to_string(), - }, - }, - }; - - let json_body = match serde_json::to_string(&prefix_query) { - Ok(json_body) => json_body, - Err(e) => { - error!("Error serializing prefix query: {}", e); + let search_request: SearchQuery = match serde_json::from_str(&text) { + Ok(search_request) => search_request, + Err(_) => { send_error_response( - errors::WebsocketError::SerializationError, + errors::WebsocketError::InvalidSearchQuery, outgoing.clone(), ) .await; @@ -236,51 +265,20 @@ } }; - let response = send_search_request(&opensearch_url, json_body).await; - - let response_text = match response { - Ok(response) => match response.text().await { - Ok(text) => text, - Err(e) => { - error!("Error getting response text: {}", e); - send_error_response( - errors::WebsocketError::SearchError, - outgoing.clone(), - ) - .await; - continue; - } - }, + let search_result = match search_request { + SearchQuery::Prefix(prefix_request) => { + handle_prefix_search(&prefix_request.prefix).await + } + }; + + let response_msg = match search_result { + Ok(response_msg) => response_msg, Err(e) => { - error!("Error getting search response: {}", e); - send_error_response( - errors::WebsocketError::SearchError, - outgoing.clone(), - ) - .await; + send_error_response(e, outgoing.clone()).await; continue; } }; - let search_response: SearchResponse = - match serde_json::from_str(&response_text) { - Ok(search_response) => search_response, - Err(e) => { - error!("Error deserializing search response: {}", e); - send_error_response( - errors::WebsocketError::SerializationError, - outgoing.clone(), - ) - .await; - continue; - } - }; - - let usernames: Vec = search_response.into_documents().collect(); - - let response_msg = - serde_json::json!(SearchResult { payload: usernames }); - if let Err(e) = outgoing .lock() .await diff --git a/shared/identity_search_messages/src/messages/mod.rs b/shared/identity_search_messages/src/messages/mod.rs --- a/shared/identity_search_messages/src/messages/mod.rs +++ b/shared/identity_search_messages/src/messages/mod.rs @@ -1,9 +1,11 @@ //! Messages sent from Identity Search server to client pub mod auth_messages; +pub mod search_query; pub mod search_result; pub use auth_messages::*; +pub use search_query::*; pub use search_result::*; use serde::{Deserialize, Serialize}; @@ -15,6 +17,7 @@ #[serde(untagged)] pub enum Messages { AuthMessage(AuthMessage), + SearchQuery(SearchQuery), Heartbeat(Heartbeat), ConnectionInitializationStatus(ConnectionInitializationStatus), ConnectionInitializationResponse(ConnectionInitializationResponse), diff --git a/shared/identity_search_messages/src/messages/search_query.rs b/shared/identity_search_messages/src/messages/search_query.rs new file mode 100644 --- /dev/null +++ b/shared/identity_search_messages/src/messages/search_query.rs @@ -0,0 +1,14 @@ +//! Search Request Messages sent by Client to Identity Search via WebSocket. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Prefix { + pub prefix: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum SearchQuery { + Prefix(Prefix), +}