diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs index 98db38d36..cae8001a7 100644 --- a/services/identity/src/websockets/auth.rs +++ b/services/identity/src/websockets/auth.rs @@ -1,90 +1,89 @@ use client_proto::VerifyUserAccessTokenRequest; use grpc_clients::identity; use grpc_clients::tonic::Request; use identity::get_unauthenticated_client; use identity::protos::unauthenticated as client_proto; use serde::{Deserialize, Serialize}; use tracing::{debug, error}; use crate::constants::DEFAULT_IDENTITY_ENDPOINT; use crate::websockets::errors::WebsocketError; #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct AuthMessage { #[serde(rename = "userID")] pub user_id: String, #[serde(rename = "deviceID")] pub device_id: String, pub access_token: String, } const PLACEHOLDER_CODE_VERSION: u64 = 0; const DEVICE_TYPE: &str = "service"; async fn verify_user_access_token( user_id: &str, device_id: &str, access_token: &str, ) -> Result { let grpc_client = get_unauthenticated_client( DEFAULT_IDENTITY_ENDPOINT, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await; let mut grpc_client = match grpc_client { Ok(grpc_client) => grpc_client, Err(e) => { error!("Failed to get unauthenticated client: {}", e); return Err(WebsocketError::AuthError); } }; let message = VerifyUserAccessTokenRequest { user_id: user_id.to_string(), device_id: device_id.to_string(), access_token: access_token.to_string(), }; let request = Request::new(message); let response = match grpc_client.verify_user_access_token(request).await { Ok(response) => response, - Err(e) => { - error!("Failed to verify user access token: {}", e); + Err(_) => { + error!("Failed to verify user access token"); return Err(WebsocketError::AuthError); } }; Ok(response.into_inner().token_valid) } pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { - error!("Handling auth message: {}", message); let auth_message = serde_json::from_str(message.trim()); let auth_message: AuthMessage = match auth_message { Ok(auth_message) => auth_message, - Err(e) => { - error!("Failed to parse auth message: {}", e); + Err(_) => { + error!("Failed to parse auth message"); return Err(WebsocketError::InvalidMessage); } }; let user_id = auth_message.user_id; let device_id = auth_message.device_id; let access_token = auth_message.access_token; let is_valid_token = verify_user_access_token(&user_id, &device_id, &access_token).await?; if is_valid_token { debug!("User {} authenticated", user_id); } else { debug!("User {} not authenticated", user_id); return Err(WebsocketError::UnauthorizedDevice); } Ok(()) } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index c0b984261..a0a130c89 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,331 +1,330 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use elastic::client::responses::SearchResponse; use futures::lock::Mutex; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info}; mod auth; use crate::config::CONFIG; use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } #[derive(Serialize, Deserialize)] struct User { #[serde(rename = "userID")] user_id: String, username: String, } struct WebsocketService { addr: SocketAddr, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let future = async move { - tracing::info!( + tracing::debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; - debug!("Upgraded WebSocket connection from {}", addr); tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server() -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); - info!("WebSocket Listening on {}", addr); + info!("Listening to WebSocket traffic on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let connection = http .serve_connection(stream, WebsocketService { addr }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_search_request( url: &str, json_body: String, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .body(json_body) .send() .await } async fn send_error_response( error: errors::WebsocketError, outgoing: Arc, Message>>>, ) { let response_msg = serde_json::json!({ "action": "errorMessage", "error": format!("{}", error) }); match serde_json::to_string(&response_msg) { Ok(serialized_response) => { if let Err(send_error) = outgoing .lock() .await .send(Message::Text(serialized_response)) .await { error!("Failed to send error response: {:?}", send_error); } } Err(serialize_error) => { error!( "Failed to serialize the error response: {:?}", serialize_error ); } } } async fn close_connection( outgoing: Arc, Message>>>, ) { if let Err(e) = outgoing.lock().await.close().await { error!("Error closing connection: {}", e); } } async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!("WebSocket handshake error: {}", e); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); let opensearch_url = format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { if let Err(auth_error) = auth::handle_auth_message(&text).await { send_error_response(auth_error, outgoing.clone()).await; close_connection(outgoing).await; return; } } _ => { error!("Invalid authentication message from {}", addr); close_connection(outgoing).await; return; } } } else { error!("No authentication message from {}", addr); close_connection(outgoing).await; return; } while let Some(message) = incoming.next().await { match message { Ok(Message::Close(_)) => { debug!("Connection to {} closed.", addr); break; } Ok(Message::Pong(_)) => { debug!("Received Pong message from {}", addr); } Ok(Message::Ping(msg)) => { debug!("Received Ping message from {}", addr); if let Err(e) = outgoing.lock().await.send(Message::Pong(msg)).await { error!("Error sending message: {}", e); } } Ok(Message::Text(text)) => { let prefix_query = Query { query: Prefix { prefix: Username { username: text.trim().to_string(), }, }, }; let json_body = match serde_json::to_string(&prefix_query) { Ok(json_body) => json_body, Err(e) => { error!("Error serializing prefix query: {}", e); send_error_response( errors::WebsocketError::SerializationError, outgoing.clone(), ) .await; continue; } }; let response = send_search_request(&opensearch_url, json_body).await; let response_text = match response { Ok(response) => match response.text().await { Ok(text) => text, Err(e) => { error!("Error getting response text: {}", e); send_error_response( errors::WebsocketError::SearchError, outgoing.clone(), ) .await; continue; } }, Err(e) => { error!("Error getting search response: {}", e); send_error_response( errors::WebsocketError::SearchError, outgoing.clone(), ) .await; continue; } }; let search_response: SearchResponse = match serde_json::from_str(&response_text) { Ok(search_response) => search_response, Err(e) => { error!("Error deserializing search response: {}", e); send_error_response( errors::WebsocketError::SerializationError, outgoing.clone(), ) .await; continue; } }; let usernames: Vec<&User> = search_response.documents().collect(); let response_msg = serde_json::json!({ "action": "searchResults", "results": usernames }); if let Err(e) = outgoing .lock() .await .send(Message::Text(format!("{}", response_msg.to_string()))) .await { error!("Error sending message: {}", e); send_error_response( errors::WebsocketError::SendError, outgoing.clone(), ) .await; continue; } } Err(e) => { error!("Error in WebSocket message: {}", e); send_error_response( errors::WebsocketError::InvalidMessage, outgoing.clone(), ) .await; continue; } _ => {} } } close_connection(outgoing).await; }