diff --git a/services/identity/src/http/mod.rs b/services/identity/src/http/mod.rs
new file mode 100644
index 000000000..4efa0b144
--- /dev/null
+++ b/services/identity/src/http/mod.rs
@@ -0,0 +1,23 @@
+use http::StatusCode;
+use hyper::{Body, Request, Response};
+
+type HttpRequest = Request
;
+type HttpResponse = Response;
+
+/// Main router for HTTP requests
+#[tracing::instrument(skip_all, name = "http_request", fields(request_id))]
+pub(super) async fn handle_http_request(
+ req: HttpRequest,
+ _db_client: crate::DatabaseClient,
+) -> Result {
+ tracing::Span::current()
+ .record("request_id", uuid::Uuid::new_v4().to_string());
+
+ let response = match req.uri().path() {
+ "/health" => Response::new(Body::from("OK")),
+ _ => Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body(Body::from("Not found"))?,
+ };
+ Ok(response)
+}
diff --git a/services/identity/src/main.rs b/services/identity/src/main.rs
index 79a8c6cdf..19c1cec36 100644
--- a/services/identity/src/main.rs
+++ b/services/identity/src/main.rs
@@ -1,120 +1,122 @@
use comm_lib::aws;
use config::Command;
use database::DatabaseClient;
use tonic::transport::Server;
use tonic_web::GrpcWebLayer;
mod client_service;
mod config;
pub mod constants;
mod cors;
mod database;
pub mod ddb_utils;
mod device_list;
pub mod error;
mod grpc_services;
mod grpc_utils;
+mod http;
mod id;
mod keygen;
mod nonce;
mod olm;
mod regex;
mod reserved_users;
mod siwe;
mod sync_identity_search;
mod token;
mod tunnelbroker;
mod websockets;
use constants::{COMM_SERVICES_USE_JSON_LOGS, IDENTITY_SERVICE_SOCKET_ADDR};
use cors::cors_layer;
use keygen::generate_and_persist_keypair;
use std::env;
use sync_identity_search::sync_index;
use tracing::{self, info, Level};
use tracing_subscriber::EnvFilter;
use client_service::{ClientService, IdentityClientServiceServer};
use grpc_services::authenticated::AuthenticatedService;
use grpc_services::protos::auth::identity_client_service_server::IdentityClientServiceServer as AuthServer;
use websockets::errors::BoxedError;
#[tokio::main]
async fn main() -> Result<(), BoxedError> {
let filter = EnvFilter::builder()
.with_default_directive(Level::INFO.into())
.with_env_var(EnvFilter::DEFAULT_ENV)
.from_env_lossy();
let use_json_logs: bool = env::var(COMM_SERVICES_USE_JSON_LOGS)
.unwrap_or("false".to_string())
.parse()
.unwrap_or_default();
if use_json_logs {
let subscriber = tracing_subscriber::fmt()
.json()
.with_env_filter(filter)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
} else {
let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish();
tracing::subscriber::set_global_default(subscriber)?;
}
match config::parse_cli_command() {
Command::Keygen { dir } => {
generate_and_persist_keypair(dir)?;
}
Command::Server => {
config::load_server_config();
let addr = IDENTITY_SERVICE_SOCKET_ADDR.parse()?;
let aws_config = aws::config::from_env().region("us-east-2").load().await;
let database_client = DatabaseClient::new(&aws_config);
let inner_client_service = ClientService::new(database_client.clone());
let client_service = IdentityClientServiceServer::with_interceptor(
inner_client_service,
grpc_services::shared::version_interceptor,
);
let inner_auth_service =
AuthenticatedService::new(database_client.clone());
+ let db_client = database_client.clone();
let auth_service =
AuthServer::with_interceptor(inner_auth_service, move |req| {
- grpc_services::authenticated::auth_interceptor(req, &database_client)
+ grpc_services::authenticated::auth_interceptor(req, &db_client)
.and_then(grpc_services::shared::version_interceptor)
});
info!("Listening to gRPC traffic on {}", addr);
let grpc_server = Server::builder()
.accept_http1(true)
.layer(cors_layer())
.layer(GrpcWebLayer::new())
.trace_fn(|_| {
tracing::info_span!(
"grpc_request",
request_id = uuid::Uuid::new_v4().to_string()
)
})
.add_service(client_service)
.add_service(auth_service)
.serve(addr);
- let websocket_server = websockets::run_server();
+ let websocket_server = websockets::run_server(database_client);
return tokio::select! {
websocket_result = websocket_server => websocket_result,
grpc_result = grpc_server => { grpc_result.map_err(|e| e.into()) },
};
}
Command::SyncIdentitySearch => {
let aws_config = aws::config::from_env().region("us-east-2").load().await;
let database_client = DatabaseClient::new(&aws_config);
let sync_result = sync_index(&database_client).await;
error::consume_error(sync_result);
}
}
Ok(())
}
diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs
index 1041deda2..c6d576fa5 100644
--- a/services/identity/src/websockets/mod.rs
+++ b/services/identity/src/websockets/mod.rs
@@ -1,354 +1,354 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use futures::lock::Mutex;
use futures_util::{SinkExt, StreamExt};
-use hyper::{Body, Request, Response, StatusCode};
+use hyper::{Body, Request, Response};
use hyper_tungstenite::tungstenite::Message;
use hyper_tungstenite::HyperWebsocket;
use identity_search_messages::{
ConnectionInitializationResponse, ConnectionInitializationStatus, Heartbeat,
IdentitySearchFailure, IdentitySearchMethod, IdentitySearchResponse,
IdentitySearchResult, IdentitySearchUser, MessagesToServer,
};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tracing::{debug, error, info, warn};
mod auth;
mod opensearch;
mod send;
use crate::config::CONFIG;
use crate::constants::{
error_types, IDENTITY_SEARCH_INDEX, IDENTITY_SEARCH_RESULT_SIZE,
IDENTITY_SERVICE_WEBSOCKET_ADDR, SOCKET_HEARTBEAT_TIMEOUT,
};
use opensearch::OpenSearchResponse;
use send::{send_message, WebsocketSink};
pub mod errors;
#[derive(Serialize, Deserialize)]
struct Query {
size: u32,
query: Prefix,
}
#[derive(Serialize, Deserialize)]
struct Prefix {
prefix: Username,
}
#[derive(Serialize, Deserialize)]
struct Username {
username: String,
}
struct WebsocketService {
addr: SocketAddr,
+ db_client: crate::DatabaseClient,
}
impl hyper::service::Service> for WebsocketService {
type Response = Response;
type Error = errors::BoxedError;
type Future =
Pin> + Send>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request) -> Self::Future {
let addr = self.addr;
+ let db_client = self.db_client.clone();
let future = async move {
debug!(
"Incoming HTTP request on WebSocket port: {} {}",
req.method(),
req.uri().path()
);
if hyper_tungstenite::is_upgrade_request(&req) {
let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?;
tokio::spawn(async move {
accept_connection(websocket, addr).await;
});
return Ok(response);
}
- let response = match req.uri().path() {
- "/health" => Response::new(Body::from("OK")),
- _ => Response::builder()
- .status(StatusCode::NOT_FOUND)
- .body(Body::from("Not found"))?,
- };
- Ok(response)
+ // If not a websocker upgrade, treat it as regular HTTP request
+ crate::http::handle_http_request(req, db_client).await
};
Box::pin(future)
}
}
#[tracing::instrument(skip_all)]
-pub async fn run_server() -> Result<(), errors::BoxedError> {
+pub async fn run_server(
+ db_client: crate::DatabaseClient,
+) -> Result<(), errors::BoxedError> {
let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?;
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
info!("Listening to WebSocket traffic on {}", addr);
let mut http = hyper::server::conn::Http::new();
http.http1_only(true);
http.http1_keep_alive(true);
while let Ok((stream, addr)) = listener.accept().await {
+ let db_client = db_client.clone();
let connection = http
- .serve_connection(stream, WebsocketService { addr })
+ .serve_connection(stream, WebsocketService { addr, db_client })
.with_upgrades();
tokio::spawn(async move {
if let Err(err) = connection.await {
error!(
errorType = error_types::SEARCH_LOG,
"Error serving HTTP/WebSocket connection: {:?}", err
);
}
});
}
Ok(())
}
#[tracing::instrument(skip_all)]
async fn send_search_request(
url: &str,
json_body: T,
) -> Result {
let client = reqwest::Client::new();
client
.post(url)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.json(&json_body)
.send()
.await
}
#[tracing::instrument(skip_all)]
async fn close_connection(outgoing: WebsocketSink) {
if let Err(e) = outgoing.lock().await.close().await {
error!(
errorType = error_types::SEARCH_LOG,
"Error closing connection: {}", e
);
}
}
#[tracing::instrument(skip_all)]
async fn handle_prefix_search(
request_id: &str,
prefix_request: identity_search_messages::IdentitySearchPrefix,
) -> Result {
let prefix_query = Query {
size: IDENTITY_SEARCH_RESULT_SIZE,
query: Prefix {
prefix: Username {
username: prefix_request.prefix.trim().to_string(),
},
},
};
let opensearch_url = format!(
"https://{}/{}/_search/",
&CONFIG.opensearch_endpoint, IDENTITY_SEARCH_INDEX
);
let search_response = send_search_request(&opensearch_url, prefix_query)
.await?
.json::>()
.await?;
let usernames: Vec = search_response
.hits
.inner
.into_iter()
.filter_map(|hit| hit.source)
.collect();
let search_result = IdentitySearchResult {
id: request_id.to_string(),
hits: usernames,
};
Ok(search_result)
}
#[tracing::instrument(skip_all)]
async fn handle_websocket_frame(
text: String,
outgoing: WebsocketSink,
) -> Result<(), errors::WebsocketError> {
let Ok(serialized_message) = serde_json::from_str::(&text)
else {
return Err(errors::WebsocketError::SerializationError);
};
match serialized_message {
MessagesToServer::Heartbeat(Heartbeat {}) => {
debug!("Received heartbeat");
Ok(())
}
MessagesToServer::IdentitySearchQuery(search_query) => {
let handler_result = match search_query.search_method {
IdentitySearchMethod::IdentitySearchPrefix(prefix_query) => {
handle_prefix_search(&search_query.id, prefix_query).await
}
};
let search_response = match handler_result {
Ok(search_result) => IdentitySearchResponse::Success(search_result),
Err(e) => IdentitySearchResponse::Error(IdentitySearchFailure {
id: search_query.id,
error: e.to_string(),
}),
};
let serialized_message = serde_json::to_string(&search_response)?;
send_message(Message::Text(serialized_message), outgoing.clone()).await;
Ok(())
}
_ => Err(errors::WebsocketError::InvalidMessage),
}
}
#[tracing::instrument(skip_all)]
async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) {
debug!("Incoming WebSocket connection from {}", addr);
let ws_stream = match hyper_ws.await {
Ok(stream) => stream,
Err(e) => {
error!(
errorType = error_types::SEARCH_LOG,
"WebSocket handshake error: {}", e
);
return;
}
};
let (outgoing, mut incoming) = ws_stream.split();
let outgoing = Arc::new(Mutex::new(outgoing));
if let Some(Ok(auth_message)) = incoming.next().await {
match auth_message {
Message::Text(text) => {
if let Err(auth_error) = auth::handle_auth_message(&text).await {
let error_response = ConnectionInitializationResponse {
status: ConnectionInitializationStatus::Error(
auth_error.to_string(),
),
};
let serialized_response = serde_json::to_string(&error_response)
.expect("Error serializing auth error response");
send_message(Message::Text(serialized_response), outgoing.clone())
.await;
close_connection(outgoing).await;
return;
} else {
let success_response = ConnectionInitializationResponse {
status: ConnectionInitializationStatus::Success,
};
let serialized_response = serde_json::to_string(&success_response)
.expect("Error serializing auth success response");
send_message(Message::Text(serialized_response), outgoing.clone())
.await;
}
}
_ => {
warn!("Invalid authentication message from {}", addr);
close_connection(outgoing).await;
return;
}
}
} else {
error!(
errorType = error_types::SEARCH_LOG,
"No authentication message from {}", addr
);
close_connection(outgoing).await;
return;
}
let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT));
let mut got_heartbeat_response = true;
loop {
tokio::select! {
client_message = incoming.next() => {
let message: Message = match client_message {
Some(Ok(msg)) => msg,
_ => {
debug!("Connection to {} closed remotely.", addr);
break;
}
};
match message {
Message::Close(_) => {
debug!("Connection to {} closed.", addr);
break;
}
Message::Pong(_) => {
debug!("Received Pong message from {}", addr);
}
Message::Ping(msg) => {
debug!("Received Ping message from {}", addr);
send_message(Message::Pong(msg), outgoing.clone()).await;
}
Message::Text(text) => {
got_heartbeat_response = true;
ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT));
if let Err(e) = handle_websocket_frame(text, outgoing.clone()).await {
error!(errorType = error_types::SEARCH_LOG, "Error handling WebSocket frame: {}", e);
continue;
};
}
_ => {
error!(errorType = error_types::SEARCH_LOG, "Client sent invalid message type");
break;
}
}
}
_ = &mut ping_timeout => {
if !got_heartbeat_response {
debug!("Connection to {} died.", addr);
break;
}
let serialized = serde_json::to_string(&Heartbeat {}).unwrap();
send_message(Message::text(serialized), outgoing.clone()).await;
got_heartbeat_response = false;
ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT));
}
else => {
debug!("Unhealthy connection for: {}", addr);
break;
}
}
}
info!("unregistering connection to: {}", addr);
close_connection(outgoing).await;
}