diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -164,6 +164,10 @@ pub const NONCE_LENGTH: usize = 17; pub const NONCE_TTL_DURATION: i64 = 30; +// Identity + +pub const DEFAULT_IDENTITY_ENDPOINT: &str = "http://localhost:50054"; + // LocalStack pub const LOCALSTACK_ENDPOINT: &str = "LOCALSTACK_ENDPOINT"; diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs new file mode 100644 --- /dev/null +++ b/services/identity/src/websockets/auth.rs @@ -0,0 +1,90 @@ +use client_proto::VerifyUserAccessTokenRequest; +use grpc_clients::identity; +use grpc_clients::tonic::Request; +use identity::get_unauthenticated_client; +use identity::protos::unauthenticated as client_proto; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error}; + +use crate::constants::DEFAULT_IDENTITY_ENDPOINT; +use crate::websockets::errors::WebsocketError; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type", rename_all = "camelCase")] +pub struct AuthMessage { + #[serde(rename = "userID")] + pub user_id: String, + #[serde(rename = "deviceID")] + pub device_id: String, + pub access_token: String, +} + +const PLACEHOLDER_CODE_VERSION: u64 = 0; +const DEVICE_TYPE: &str = "service"; + +async fn verify_user_access_token( + user_id: &str, + device_id: &str, + access_token: &str, +) -> Result { + let grpc_client = get_unauthenticated_client( + DEFAULT_IDENTITY_ENDPOINT, + PLACEHOLDER_CODE_VERSION, + DEVICE_TYPE.to_string(), + ) + .await; + + let mut grpc_client = match grpc_client { + Ok(grpc_client) => grpc_client, + Err(e) => { + error!("Failed to get unauthenticated client: {}", e); + return Err(WebsocketError::AuthError); + } + }; + + let message = VerifyUserAccessTokenRequest { + user_id: user_id.to_string(), + device_id: device_id.to_string(), + access_token: access_token.to_string(), + }; + + let request = Request::new(message); + let response = match grpc_client.verify_user_access_token(request).await { + Ok(response) => response, + Err(e) => { + error!("Failed to verify user access token: {}", e); + return Err(WebsocketError::AuthError); + } + }; + + Ok(response.into_inner().token_valid) +} + +pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { + error!("Handling auth message: {}", message); + let auth_message = serde_json::from_str(message.trim()); + + let auth_message: AuthMessage = match auth_message { + Ok(auth_message) => auth_message, + Err(e) => { + error!("Failed to parse auth message: {}", e); + return Err(WebsocketError::InvalidMessage); + } + }; + + let user_id = auth_message.user_id; + let device_id = auth_message.device_id; + let access_token = auth_message.access_token; + + let is_valid_token = + verify_user_access_token(&user_id, &device_id, &access_token).await?; + + if is_valid_token { + debug!("User {} authenticated", user_id); + } else { + debug!("User {} not authenticated", user_id); + return Err(WebsocketError::UnauthorizedDevice); + } + + Ok(()) +} diff --git a/services/identity/src/websockets/errors.rs b/services/identity/src/websockets/errors.rs --- a/services/identity/src/websockets/errors.rs +++ b/services/identity/src/websockets/errors.rs @@ -5,7 +5,9 @@ )] pub enum WebsocketError { InvalidMessage, + UnauthorizedDevice, SendError, SearchError, + AuthError, SerializationError, } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -16,6 +16,8 @@ use tokio::net::TcpListener; use tracing::{debug, error, info}; +mod auth; + use crate::config::CONFIG; use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; @@ -166,6 +168,14 @@ } } +async fn close_connection( + outgoing: Arc, Message>>>, +) { + if let Err(e) = outgoing.lock().await.close().await { + error!("Error closing connection: {}", e); + } +} + async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); @@ -184,6 +194,27 @@ let opensearch_url = format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); + if let Some(Ok(auth_message)) = incoming.next().await { + match auth_message { + Message::Text(text) => { + if let Err(auth_error) = auth::handle_auth_message(&text).await { + send_error_response(auth_error, outgoing.clone()).await; + close_connection(outgoing).await; + return; + } + } + _ => { + error!("Invalid authentication message from {}", addr); + close_connection(outgoing).await; + return; + } + } + } else { + error!("No authentication message from {}", addr); + close_connection(outgoing).await; + return; + } + while let Some(message) = incoming.next().await { match message { Ok(Message::Close(_)) => { @@ -296,7 +327,5 @@ } } - if let Err(e) = outgoing.lock().await.close().await { - error!("Failed to close WebSocket connection: {}", e); - }; + close_connection(outgoing).await; }