diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs index 424e26931..2cbc83466 100644 --- a/services/identity/src/websockets/auth.rs +++ b/services/identity/src/websockets/auth.rs @@ -1,79 +1,79 @@ use client_proto::VerifyUserAccessTokenRequest; use grpc_clients::identity; use grpc_clients::tonic::Request; use identity::get_unauthenticated_client; use identity::protos::unauthenticated as client_proto; -use identity_search_messages::AuthMessage; +use identity_search_messages::IdentitySearchAuthMessage; use tracing::{debug, error}; use crate::constants::DEFAULT_IDENTITY_ENDPOINT; use crate::websockets::errors::WebsocketError; const PLACEHOLDER_CODE_VERSION: u64 = 0; const DEVICE_TYPE: &str = "service"; async fn verify_user_access_token( user_id: &str, device_id: &str, access_token: &str, ) -> Result { let grpc_client = get_unauthenticated_client( DEFAULT_IDENTITY_ENDPOINT, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await; let mut grpc_client = match grpc_client { Ok(grpc_client) => grpc_client, Err(e) => { error!("Failed to get unauthenticated client: {}", e); return Err(WebsocketError::AuthError); } }; let message = VerifyUserAccessTokenRequest { user_id: user_id.to_string(), device_id: device_id.to_string(), access_token: access_token.to_string(), }; let request = Request::new(message); let response = match grpc_client.verify_user_access_token(request).await { Ok(response) => response, Err(_) => { error!("Failed to verify user access token"); return Err(WebsocketError::AuthError); } }; Ok(response.into_inner().token_valid) } pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { let auth_message = serde_json::from_str(message.trim()); - let auth_message: AuthMessage = match auth_message { + let auth_message: IdentitySearchAuthMessage = match auth_message { Ok(auth_message) => auth_message, Err(_) => { error!("Failed to parse auth message"); return Err(WebsocketError::InvalidMessage); } }; let user_id = auth_message.user_id; let device_id = auth_message.device_id; let access_token = auth_message.access_token; let is_valid_token = verify_user_access_token(&user_id, &device_id, &access_token).await?; if is_valid_token { debug!("User {} authenticated", user_id); } else { debug!("User {} not authenticated", user_id); return Err(WebsocketError::UnauthorizedDevice); } Ok(()) } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index 814fa291c..8e8a7622c 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,314 +1,316 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use elastic::client::responses::SearchResponse; use futures::lock::Mutex; use futures_util::{SinkExt, StreamExt}; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, - Messages, SearchQuery, SearchResult, User, + IdentitySearchQuery, IdentitySearchResult, IdentitySearchUser, Messages, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info}; mod auth; mod send; use crate::config::CONFIG; use crate::constants::{ IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT, }; use send::{send_error_response, send_message, WebsocketSink}; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } struct WebsocketService { addr: SocketAddr, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let future = async move { tracing::debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server() -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening to WebSocket traffic on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let connection = http .serve_connection(stream, WebsocketService { addr }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_search_request( url: &str, json_body: T, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .json(&json_body) .send() .await } async fn close_connection(outgoing: WebsocketSink) { if let Err(e) = outgoing.lock().await.close().await { error!("Error closing connection: {}", e); } } async fn handle_prefix_search( prefix: &str, ) -> Result { let prefix_query = Query { query: Prefix { prefix: Username { username: prefix.trim().to_string(), }, }, }; let opensearch_url = format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); let search_response = send_search_request(&opensearch_url, prefix_query) .await? - .json::>() + .json::>() .await?; - let usernames: Vec = search_response.into_documents().collect(); + let usernames: Vec = + search_response.into_documents().collect(); - let search_result = serde_json::to_string(&SearchResult { hits: usernames })?; + let search_result = + serde_json::to_string(&IdentitySearchResult { hits: usernames })?; Ok(search_result) } async fn handle_websocket_frame( text: String, outgoing: WebsocketSink, ) -> Result<(), errors::WebsocketError> { let Ok(serialized_message) = serde_json::from_str::(&text) else { return Err(errors::WebsocketError::SerializationError); }; match serialized_message { Messages::Heartbeat(Heartbeat {}) => { debug!("Received heartbeat"); Ok(()) } - Messages::SearchQuery(search_request) => { + Messages::IdentitySearchQuery(search_request) => { let search_result = match search_request { - SearchQuery::Prefix(prefix_request) => { + IdentitySearchQuery::IdentitySearchPrefix(prefix_request) => { handle_prefix_search(&prefix_request.prefix).await } }?; send_message(Message::Text(search_result), outgoing.clone()).await; Ok(()) } _ => Err(errors::WebsocketError::InvalidMessage), } } async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!("WebSocket handshake error: {}", e); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { if let Err(auth_error) = auth::handle_auth_message(&text).await { let error_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error( auth_error.to_string(), ), }; let serialized_response = serde_json::to_string(&error_response) .expect("Error serializing auth error response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; close_connection(outgoing).await; return; } else { let success_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&success_response) .expect("Error serializing auth success response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; } } _ => { error!("Invalid authentication message from {}", addr); close_connection(outgoing).await; return; } } } else { error!("No authentication message from {}", addr); close_connection(outgoing).await; return; } let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; loop { tokio::select! { client_message = incoming.next() => { let message: Message = match client_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); send_message(Message::Pong(msg), outgoing.clone()).await; } Message::Text(text) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await { send_error_response(e, outgoing.clone()).await; continue; }; } _ => { error!("Client sent invalid message type"); break; } } } _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died.", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); send_message(Message::text(serialized), outgoing.clone()).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; } } } info!("unregistering connection to: {}", addr); close_connection(outgoing).await; } diff --git a/shared/identity_search_messages/src/messages/auth_messages.rs b/shared/identity_search_messages/src/messages/auth_messages.rs index 042d3f0a0..baf55df8a 100644 --- a/shared/identity_search_messages/src/messages/auth_messages.rs +++ b/shared/identity_search_messages/src/messages/auth_messages.rs @@ -1,13 +1,13 @@ //! Message sent by client to authenticate with Identity Search Server use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "camelCase")] -pub struct AuthMessage { +pub struct IdentitySearchAuthMessage { #[serde(rename = "userID")] pub user_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub access_token: String, } diff --git a/shared/identity_search_messages/src/messages/mod.rs b/shared/identity_search_messages/src/messages/mod.rs index d378f6661..429fc4e20 100644 --- a/shared/identity_search_messages/src/messages/mod.rs +++ b/shared/identity_search_messages/src/messages/mod.rs @@ -1,24 +1,24 @@ //! Messages sent from client to Identity Search Server pub mod auth_messages; pub mod search_query; pub mod search_result; pub use auth_messages::*; pub use search_query::*; pub use search_result::*; use serde::{Deserialize, Serialize}; pub use websocket_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, }; #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Messages { - AuthMessage(AuthMessage), - SearchQuery(SearchQuery), + IdentitySearchAuthMessage(IdentitySearchAuthMessage), + IdentitySearchQuery(IdentitySearchQuery), Heartbeat(Heartbeat), ConnectionInitializationResponse(ConnectionInitializationResponse), - SearchResult(SearchResult), + IdentitySearchResult(IdentitySearchResult), } diff --git a/shared/identity_search_messages/src/messages/search_query.rs b/shared/identity_search_messages/src/messages/search_query.rs index 7599bdf3b..9961fc297 100644 --- a/shared/identity_search_messages/src/messages/search_query.rs +++ b/shared/identity_search_messages/src/messages/search_query.rs @@ -1,14 +1,14 @@ //! Search Request Messages sent by Client to Identity Search via WebSocket. use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] -pub struct Prefix { +pub struct IdentitySearchPrefix { pub prefix: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum SearchQuery { - Prefix(Prefix), +pub enum IdentitySearchQuery { + IdentitySearchPrefix(IdentitySearchPrefix), } diff --git a/shared/identity_search_messages/src/messages/search_result.rs b/shared/identity_search_messages/src/messages/search_result.rs index 30ab7b3cd..9e208bc24 100644 --- a/shared/identity_search_messages/src/messages/search_result.rs +++ b/shared/identity_search_messages/src/messages/search_result.rs @@ -1,16 +1,16 @@ //! Search Result Messages sent by Identity Search via WebSocket. use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] -pub struct User { +pub struct IdentitySearchUser { #[serde(rename = "userID")] pub user_id: String, pub username: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] -pub struct SearchResult { - pub hits: Vec, +pub struct IdentitySearchResult { + pub hits: Vec, }