diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs index 2cbc83466..13fc2e071 100644 --- a/services/identity/src/websockets/auth.rs +++ b/services/identity/src/websockets/auth.rs @@ -1,79 +1,81 @@ use client_proto::VerifyUserAccessTokenRequest; use grpc_clients::identity; use grpc_clients::tonic::Request; use identity::get_unauthenticated_client; use identity::protos::unauthenticated as client_proto; use identity_search_messages::IdentitySearchAuthMessage; use tracing::{debug, error}; use crate::constants::DEFAULT_IDENTITY_ENDPOINT; use crate::websockets::errors::WebsocketError; const PLACEHOLDER_CODE_VERSION: u64 = 0; const DEVICE_TYPE: &str = "service"; +#[tracing::instrument(skip_all)] async fn verify_user_access_token( user_id: &str, device_id: &str, access_token: &str, ) -> Result { let grpc_client = get_unauthenticated_client( DEFAULT_IDENTITY_ENDPOINT, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await; let mut grpc_client = match grpc_client { Ok(grpc_client) => grpc_client, Err(e) => { error!("Failed to get unauthenticated client: {}", e); return Err(WebsocketError::AuthError); } }; let message = VerifyUserAccessTokenRequest { user_id: user_id.to_string(), device_id: device_id.to_string(), access_token: access_token.to_string(), }; let request = Request::new(message); let response = match grpc_client.verify_user_access_token(request).await { Ok(response) => response, Err(_) => { error!("Failed to verify user access token"); return Err(WebsocketError::AuthError); } }; Ok(response.into_inner().token_valid) } +#[tracing::instrument(skip_all)] pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { let auth_message = serde_json::from_str(message.trim()); let auth_message: IdentitySearchAuthMessage = match auth_message { Ok(auth_message) => auth_message, Err(_) => { error!("Failed to parse auth message"); return Err(WebsocketError::InvalidMessage); } }; let user_id = auth_message.user_id; let device_id = auth_message.device_id; let access_token = auth_message.access_token; let is_valid_token = verify_user_access_token(&user_id, &device_id, &access_token).await?; if is_valid_token { debug!("User {} authenticated", user_id); } else { debug!("User {} not authenticated", user_id); return Err(WebsocketError::UnauthorizedDevice); } Ok(()) } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index 532e9161c..6537fc08c 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,336 +1,342 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use futures::lock::Mutex; use futures_util::{SinkExt, StreamExt}; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, IdentitySearchFailure, IdentitySearchMethod, IdentitySearchResponse, IdentitySearchResult, IdentitySearchUser, MessagesToServer, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info}; mod auth; mod opensearch; mod send; use crate::config::CONFIG; use crate::constants::{ IDENTITY_SEARCH_INDEX, IDENTITY_SEARCH_RESULT_SIZE, IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT, }; use opensearch::OpenSearchResponse; use send::{send_message, WebsocketSink}; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { size: u32, query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } struct WebsocketService { addr: SocketAddr, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let future = async move { debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } +#[tracing::instrument(skip_all)] pub async fn run_server() -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening to WebSocket traffic on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let connection = http .serve_connection(stream, WebsocketService { addr }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } +#[tracing::instrument(skip_all)] async fn send_search_request( url: &str, json_body: T, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .json(&json_body) .send() .await } +#[tracing::instrument(skip_all)] async fn close_connection(outgoing: WebsocketSink) { if let Err(e) = outgoing.lock().await.close().await { error!("Error closing connection: {}", e); } } +#[tracing::instrument(skip_all)] async fn handle_prefix_search( request_id: &str, prefix_request: identity_search_messages::IdentitySearchPrefix, ) -> Result { let prefix_query = Query { size: IDENTITY_SEARCH_RESULT_SIZE, query: Prefix { prefix: Username { username: prefix_request.prefix.trim().to_string(), }, }, }; let opensearch_url = format!( "https://{}/{}/_search/", &CONFIG.opensearch_endpoint, IDENTITY_SEARCH_INDEX ); let search_response = send_search_request(&opensearch_url, prefix_query) .await? .json::>() .await?; let usernames: Vec = search_response .hits .inner .into_iter() .filter_map(|hit| hit.source) .collect(); let search_result = IdentitySearchResult { id: request_id.to_string(), hits: usernames, }; Ok(search_result) } +#[tracing::instrument(skip_all)] async fn handle_websocket_frame( text: String, outgoing: WebsocketSink, ) -> Result<(), errors::WebsocketError> { let Ok(serialized_message) = serde_json::from_str::(&text) else { return Err(errors::WebsocketError::SerializationError); }; match serialized_message { MessagesToServer::Heartbeat(Heartbeat {}) => { debug!("Received heartbeat"); Ok(()) } MessagesToServer::IdentitySearchQuery(search_query) => { let handler_result = match search_query.search_method { IdentitySearchMethod::IdentitySearchPrefix(prefix_query) => { handle_prefix_search(&search_query.id, prefix_query).await } }; let search_response = match handler_result { Ok(search_result) => IdentitySearchResponse::Success(search_result), Err(e) => IdentitySearchResponse::Error(IdentitySearchFailure { id: search_query.id, error: e.to_string(), }), }; let serialized_message = serde_json::to_string(&search_response)?; send_message(Message::Text(serialized_message), outgoing.clone()).await; Ok(()) } _ => Err(errors::WebsocketError::InvalidMessage), } } +#[tracing::instrument(skip_all)] async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!("WebSocket handshake error: {}", e); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { if let Err(auth_error) = auth::handle_auth_message(&text).await { let error_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error( auth_error.to_string(), ), }; let serialized_response = serde_json::to_string(&error_response) .expect("Error serializing auth error response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; close_connection(outgoing).await; return; } else { let success_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&success_response) .expect("Error serializing auth success response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; } } _ => { error!("Invalid authentication message from {}", addr); close_connection(outgoing).await; return; } } } else { error!("No authentication message from {}", addr); close_connection(outgoing).await; return; } let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; loop { tokio::select! { client_message = incoming.next() => { let message: Message = match client_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); send_message(Message::Pong(msg), outgoing.clone()).await; } Message::Text(text) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await { error!("Error handling WebSocket frame: {}", e); continue; }; } _ => { error!("Client sent invalid message type"); break; } } } _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died.", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); send_message(Message::text(serialized), outgoing.clone()).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; } } } info!("unregistering connection to: {}", addr); close_connection(outgoing).await; } diff --git a/services/identity/src/websockets/send.rs b/services/identity/src/websockets/send.rs index 0f8fffeed..50660e0fd 100644 --- a/services/identity/src/websockets/send.rs +++ b/services/identity/src/websockets/send.rs @@ -1,18 +1,19 @@ use std::sync::Arc; use futures::lock::Mutex; use futures_util::stream::SplitSink; use futures_util::SinkExt; use hyper::upgrade::Upgraded; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::WebSocketStream; use tracing::error; pub type WebsocketSink = Arc, Message>>>; +#[tracing::instrument(skip_all)] pub async fn send_message(message: Message, outgoing: WebsocketSink) { if let Err(e) = outgoing.lock().await.send(message).await { error!("Failed to send message to device: {}", e); } }