diff --git a/services/identity/Cargo.toml b/services/identity/Cargo.toml --- a/services/identity/Cargo.toml +++ b/services/identity/Cargo.toml @@ -37,7 +37,7 @@ tower-http = { version = "0.4", features = ["cors"] } http = "0.2" elastic = "0.21.0-pre.5" -reqwest = "0.11" +reqwest = { version = "0.11", features = ["json"] } futures = "0.3.30" [build-dependencies] diff --git a/services/identity/src/websockets/errors.rs b/services/identity/src/websockets/errors.rs --- a/services/identity/src/websockets/errors.rs +++ b/services/identity/src/websockets/errors.rs @@ -1,3 +1,5 @@ +use tracing::error; + pub type BoxedError = Box; #[derive( @@ -5,9 +7,24 @@ )] pub enum WebsocketError { InvalidMessage, + InvalidSearchQuery, UnauthorizedDevice, SendError, SearchError, AuthError, SerializationError, } + +impl From for WebsocketError { + fn from(err: serde_json::Error) -> Self { + tracing::error!("Error serializing: {}", err); + WebsocketError::SerializationError + } +} + +impl From for WebsocketError { + fn from(err: reqwest::Error) -> Self { + tracing::error!("Error with search request: {}", err); + WebsocketError::SearchError + } +} diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -11,7 +11,7 @@ use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, - SearchResult, User, + SearchQuery, SearchResult, User, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -23,7 +23,6 @@ use crate::config::CONFIG; use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; use send::{send_error_response, send_message, WebsocketSink}; - pub mod errors; #[derive(Serialize, Deserialize)] @@ -120,16 +119,16 @@ Ok(()) } -async fn send_search_request( +async fn send_search_request( url: &str, - json_body: String, + json_body: T, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(json_body) + .json(&json_body) .send() .await } @@ -140,6 +139,32 @@ } } +async fn handle_prefix_search( + prefix: &str, +) -> Result { + let prefix_query = Query { + query: Prefix { + prefix: Username { + username: prefix.trim().to_string(), + }, + }, + }; + + let opensearch_url = + format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); + + let search_response = send_search_request(&opensearch_url, prefix_query) + .await? + .json::>() + .await?; + + let usernames: Vec = search_response.into_documents().collect(); + + let search_result = serde_json::to_string(&SearchResult { hits: usernames })?; + + Ok(search_result) +} + async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); @@ -155,9 +180,6 @@ let outgoing = Arc::new(Mutex::new(outgoing)); - let opensearch_url = - format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); - if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { @@ -214,71 +236,29 @@ } } Ok(Message::Text(text)) => { - let prefix_query = Query { - query: Prefix { - prefix: Username { - username: text.trim().to_string(), - }, - }, + let Ok(search_request) = serde_json::from_str(&text) else { + send_error_response( + errors::WebsocketError::InvalidSearchQuery, + outgoing.clone(), + ) + .await; + continue; }; - let json_body = match serde_json::to_string(&prefix_query) { - Ok(json_body) => json_body, - Err(e) => { - error!("Error serializing prefix query: {}", e); - send_error_response( - errors::WebsocketError::SerializationError, - outgoing.clone(), - ) - .await; - continue; + let search_result = match search_request { + SearchQuery::Prefix(prefix_request) => { + handle_prefix_search(&prefix_request.prefix).await } }; - let response = send_search_request(&opensearch_url, json_body).await; - - let response_text = match response { - Ok(response) => match response.text().await { - Ok(text) => text, - Err(e) => { - error!("Error getting response text: {}", e); - send_error_response( - errors::WebsocketError::SearchError, - outgoing.clone(), - ) - .await; - continue; - } - }, + let response_msg = match search_result { + Ok(response_msg) => response_msg, Err(e) => { - error!("Error getting search response: {}", e); - send_error_response( - errors::WebsocketError::SearchError, - outgoing.clone(), - ) - .await; + send_error_response(e, outgoing.clone()).await; continue; } }; - let search_response: SearchResponse = - match serde_json::from_str(&response_text) { - Ok(search_response) => search_response, - Err(e) => { - error!("Error deserializing search response: {}", e); - send_error_response( - errors::WebsocketError::SerializationError, - outgoing.clone(), - ) - .await; - continue; - } - }; - - let usernames: Vec = search_response.into_documents().collect(); - - let response_msg = serde_json::json!(SearchResult { hits: usernames }); - if let Err(e) = outgoing .lock() .await diff --git a/shared/identity_search_messages/src/messages/mod.rs b/shared/identity_search_messages/src/messages/mod.rs --- a/shared/identity_search_messages/src/messages/mod.rs +++ b/shared/identity_search_messages/src/messages/mod.rs @@ -1,9 +1,11 @@ //! Messages sent from client to Identity Search Server pub mod auth_messages; +pub mod search_query; pub mod search_result; pub use auth_messages::*; +pub use search_query::*; pub use search_result::*; use serde::{Deserialize, Serialize}; @@ -15,6 +17,7 @@ #[serde(untagged)] pub enum Messages { AuthMessage(AuthMessage), + SearchQuery(SearchQuery), Heartbeat(Heartbeat), ConnectionInitializationResponse(ConnectionInitializationResponse), SearchResult(SearchResult), diff --git a/shared/identity_search_messages/src/messages/search_query.rs b/shared/identity_search_messages/src/messages/search_query.rs new file mode 100644 --- /dev/null +++ b/shared/identity_search_messages/src/messages/search_query.rs @@ -0,0 +1,14 @@ +//! Search Request Messages sent by Client to Identity Search via WebSocket. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Prefix { + pub prefix: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum SearchQuery { + Prefix(Prefix), +}