diff --git a/services/identity/src/http/mod.rs b/services/identity/src/http/mod.rs new file mode 100644 index 000000000..4efa0b144 --- /dev/null +++ b/services/identity/src/http/mod.rs @@ -0,0 +1,23 @@ +use http::StatusCode; +use hyper::{Body, Request, Response}; + +type HttpRequest = Request; +type HttpResponse = Response; + +/// Main router for HTTP requests +#[tracing::instrument(skip_all, name = "http_request", fields(request_id))] +pub(super) async fn handle_http_request( + req: HttpRequest, + _db_client: crate::DatabaseClient, +) -> Result { + tracing::Span::current() + .record("request_id", uuid::Uuid::new_v4().to_string()); + + let response = match req.uri().path() { + "/health" => Response::new(Body::from("OK")), + _ => Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::from("Not found"))?, + }; + Ok(response) +} diff --git a/services/identity/src/main.rs b/services/identity/src/main.rs index 79a8c6cdf..19c1cec36 100644 --- a/services/identity/src/main.rs +++ b/services/identity/src/main.rs @@ -1,120 +1,122 @@ use comm_lib::aws; use config::Command; use database::DatabaseClient; use tonic::transport::Server; use tonic_web::GrpcWebLayer; mod client_service; mod config; pub mod constants; mod cors; mod database; pub mod ddb_utils; mod device_list; pub mod error; mod grpc_services; mod grpc_utils; +mod http; mod id; mod keygen; mod nonce; mod olm; mod regex; mod reserved_users; mod siwe; mod sync_identity_search; mod token; mod tunnelbroker; mod websockets; use constants::{COMM_SERVICES_USE_JSON_LOGS, IDENTITY_SERVICE_SOCKET_ADDR}; use cors::cors_layer; use keygen::generate_and_persist_keypair; use std::env; use sync_identity_search::sync_index; use tracing::{self, info, Level}; use tracing_subscriber::EnvFilter; use client_service::{ClientService, IdentityClientServiceServer}; use grpc_services::authenticated::AuthenticatedService; use grpc_services::protos::auth::identity_client_service_server::IdentityClientServiceServer as AuthServer; use websockets::errors::BoxedError; #[tokio::main] async fn main() -> Result<(), BoxedError> { let filter = EnvFilter::builder() .with_default_directive(Level::INFO.into()) .with_env_var(EnvFilter::DEFAULT_ENV) .from_env_lossy(); let use_json_logs: bool = env::var(COMM_SERVICES_USE_JSON_LOGS) .unwrap_or("false".to_string()) .parse() .unwrap_or_default(); if use_json_logs { let subscriber = tracing_subscriber::fmt() .json() .with_env_filter(filter) .finish(); tracing::subscriber::set_global_default(subscriber)?; } else { let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish(); tracing::subscriber::set_global_default(subscriber)?; } match config::parse_cli_command() { Command::Keygen { dir } => { generate_and_persist_keypair(dir)?; } Command::Server => { config::load_server_config(); let addr = IDENTITY_SERVICE_SOCKET_ADDR.parse()?; let aws_config = aws::config::from_env().region("us-east-2").load().await; let database_client = DatabaseClient::new(&aws_config); let inner_client_service = ClientService::new(database_client.clone()); let client_service = IdentityClientServiceServer::with_interceptor( inner_client_service, grpc_services::shared::version_interceptor, ); let inner_auth_service = AuthenticatedService::new(database_client.clone()); + let db_client = database_client.clone(); let auth_service = AuthServer::with_interceptor(inner_auth_service, move |req| { - grpc_services::authenticated::auth_interceptor(req, &database_client) + grpc_services::authenticated::auth_interceptor(req, &db_client) .and_then(grpc_services::shared::version_interceptor) }); info!("Listening to gRPC traffic on {}", addr); let grpc_server = Server::builder() .accept_http1(true) .layer(cors_layer()) .layer(GrpcWebLayer::new()) .trace_fn(|_| { tracing::info_span!( "grpc_request", request_id = uuid::Uuid::new_v4().to_string() ) }) .add_service(client_service) .add_service(auth_service) .serve(addr); - let websocket_server = websockets::run_server(); + let websocket_server = websockets::run_server(database_client); return tokio::select! { websocket_result = websocket_server => websocket_result, grpc_result = grpc_server => { grpc_result.map_err(|e| e.into()) }, }; } Command::SyncIdentitySearch => { let aws_config = aws::config::from_env().region("us-east-2").load().await; let database_client = DatabaseClient::new(&aws_config); let sync_result = sync_index(&database_client).await; error::consume_error(sync_result); } } Ok(()) } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index 1041deda2..c6d576fa5 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,354 +1,354 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use futures::lock::Mutex; use futures_util::{SinkExt, StreamExt}; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{Body, Request, Response}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use identity_search_messages::{ ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat, IdentitySearchFailure, IdentitySearchMethod, IdentitySearchResponse, IdentitySearchResult, IdentitySearchUser, MessagesToServer, }; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info, warn}; mod auth; mod opensearch; mod send; use crate::config::CONFIG; use crate::constants::{ error_types, IDENTITY_SEARCH_INDEX, IDENTITY_SEARCH_RESULT_SIZE, IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT, }; use opensearch::OpenSearchResponse; use send::{send_message, WebsocketSink}; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { size: u32, query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } struct WebsocketService { addr: SocketAddr, + db_client: crate::DatabaseClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; + let db_client = self.db_client.clone(); let future = async move { debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } - let response = match req.uri().path() { - "/health" => Response::new(Body::from("OK")), - _ => Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from("Not found"))?, - }; - Ok(response) + // If not a websocker upgrade, treat it as regular HTTP request + crate::http::handle_http_request(req, db_client).await }; Box::pin(future) } } #[tracing::instrument(skip_all)] -pub async fn run_server() -> Result<(), errors::BoxedError> { +pub async fn run_server( + db_client: crate::DatabaseClient, +) -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening to WebSocket traffic on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { + let db_client = db_client.clone(); let connection = http - .serve_connection(stream, WebsocketService { addr }) + .serve_connection(stream, WebsocketService { addr, db_client }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!( errorType = error_types::SEARCH_LOG, "Error serving HTTP/WebSocket connection: {:?}", err ); } }); } Ok(()) } #[tracing::instrument(skip_all)] async fn send_search_request( url: &str, json_body: T, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .json(&json_body) .send() .await } #[tracing::instrument(skip_all)] async fn close_connection(outgoing: WebsocketSink) { if let Err(e) = outgoing.lock().await.close().await { error!( errorType = error_types::SEARCH_LOG, "Error closing connection: {}", e ); } } #[tracing::instrument(skip_all)] async fn handle_prefix_search( request_id: &str, prefix_request: identity_search_messages::IdentitySearchPrefix, ) -> Result { let prefix_query = Query { size: IDENTITY_SEARCH_RESULT_SIZE, query: Prefix { prefix: Username { username: prefix_request.prefix.trim().to_string(), }, }, }; let opensearch_url = format!( "https://{}/{}/_search/", &CONFIG.opensearch_endpoint, IDENTITY_SEARCH_INDEX ); let search_response = send_search_request(&opensearch_url, prefix_query) .await? .json::>() .await?; let usernames: Vec = search_response .hits .inner .into_iter() .filter_map(|hit| hit.source) .collect(); let search_result = IdentitySearchResult { id: request_id.to_string(), hits: usernames, }; Ok(search_result) } #[tracing::instrument(skip_all)] async fn handle_websocket_frame( text: String, outgoing: WebsocketSink, ) -> Result<(), errors::WebsocketError> { let Ok(serialized_message) = serde_json::from_str::(&text) else { return Err(errors::WebsocketError::SerializationError); }; match serialized_message { MessagesToServer::Heartbeat(Heartbeat {}) => { debug!("Received heartbeat"); Ok(()) } MessagesToServer::IdentitySearchQuery(search_query) => { let handler_result = match search_query.search_method { IdentitySearchMethod::IdentitySearchPrefix(prefix_query) => { handle_prefix_search(&search_query.id, prefix_query).await } }; let search_response = match handler_result { Ok(search_result) => IdentitySearchResponse::Success(search_result), Err(e) => IdentitySearchResponse::Error(IdentitySearchFailure { id: search_query.id, error: e.to_string(), }), }; let serialized_message = serde_json::to_string(&search_response)?; send_message(Message::Text(serialized_message), outgoing.clone()).await; Ok(()) } _ => Err(errors::WebsocketError::InvalidMessage), } } #[tracing::instrument(skip_all)] async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!( errorType = error_types::SEARCH_LOG, "WebSocket handshake error: {}", e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { if let Err(auth_error) = auth::handle_auth_message(&text).await { let error_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error( auth_error.to_string(), ), }; let serialized_response = serde_json::to_string(&error_response) .expect("Error serializing auth error response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; close_connection(outgoing).await; return; } else { let success_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&success_response) .expect("Error serializing auth success response"); send_message(Message::Text(serialized_response), outgoing.clone()) .await; } } _ => { warn!("Invalid authentication message from {}", addr); close_connection(outgoing).await; return; } } } else { error!( errorType = error_types::SEARCH_LOG, "No authentication message from {}", addr ); close_connection(outgoing).await; return; } let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; loop { tokio::select! { client_message = incoming.next() => { let message: Message = match client_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); send_message(Message::Pong(msg), outgoing.clone()).await; } Message::Text(text) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await { error!(errorType = error_types::SEARCH_LOG, "Error handling WebSocket frame: {}", e); continue; }; } _ => { error!(errorType = error_types::SEARCH_LOG, "Client sent invalid message type"); break; } } } _ = &mut ping_timeout => { if !got_heartbeat_response { debug!("Connection to {} died.", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); send_message(Message::text(serialized), outgoing.clone()).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; } } } info!("unregistering connection to: {}", addr); close_connection(outgoing).await; }