diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs --- a/services/identity/src/websockets/auth.rs +++ b/services/identity/src/websockets/auth.rs @@ -1,89 +1,44 @@ -use client_proto::VerifyUserAccessTokenRequest; -use grpc_clients::identity::{self, PlatformMetadata}; -use grpc_clients::tonic::Request; -use identity::get_unauthenticated_client; -use identity::protos::unauthenticated as client_proto; use identity_search_messages::IdentitySearchAuthMessage; use tracing::{debug, error}; -use crate::constants::{error_types, DEFAULT_IDENTITY_ENDPOINT}; +use crate::constants::error_types; use crate::websockets::errors::WebsocketError; -const PLACEHOLDER_CODE_VERSION: u64 = 0; -const DEVICE_TYPE: &str = "service"; - #[tracing::instrument(skip_all)] -async fn verify_user_access_token( - user_id: &str, - device_id: &str, - access_token: &str, -) -> Result { - let grpc_client = get_unauthenticated_client( - DEFAULT_IDENTITY_ENDPOINT, - PlatformMetadata::new(PLACEHOLDER_CODE_VERSION, DEVICE_TYPE), - ) - .await; - - let mut grpc_client = match grpc_client { - Ok(grpc_client) => grpc_client, - Err(e) => { - error!( - errorType = error_types::SEARCH_LOG, - "Failed to get unauthenticated client: {}", e - ); - return Err(WebsocketError::AuthError); - } +pub async fn handle_auth_message( + db_client: &crate::DatabaseClient, + message: &str, +) -> Result<(), WebsocketError> { + let Ok(auth_message) = serde_json::from_str(message.trim()) else { + error!( + errorType = error_types::SEARCH_LOG, + "Failed to parse auth message" + ); + return Err(WebsocketError::InvalidMessage); }; - let message = VerifyUserAccessTokenRequest { - user_id: user_id.to_string(), - device_id: device_id.to_string(), - access_token: access_token.to_string(), - }; + let IdentitySearchAuthMessage { + user_id, + device_id, + access_token, + } = auth_message; - let request = Request::new(message); - let response = match grpc_client.verify_user_access_token(request).await { - Ok(response) => response, - Err(_) => { + let is_valid_token = db_client + .verify_access_token(user_id.clone(), device_id, access_token) + .await + .map_err(|err| { error!( errorType = error_types::SEARCH_LOG, - "Failed to verify user access token" + "Failed to verify user access token: {:?}", err ); - return Err(WebsocketError::AuthError); - } - }; - - Ok(response.into_inner().token_valid) -} - -#[tracing::instrument(skip_all)] -pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { - let auth_message = serde_json::from_str(message.trim()); - - let auth_message: IdentitySearchAuthMessage = match auth_message { - Ok(auth_message) => auth_message, - Err(_) => { - error!( - errorType = error_types::SEARCH_LOG, - "Failed to parse auth message" - ); - return Err(WebsocketError::InvalidMessage); - } - }; - - let user_id = auth_message.user_id; - let device_id = auth_message.device_id; - let access_token = auth_message.access_token; - - let is_valid_token = - verify_user_access_token(&user_id, &device_id, &access_token).await?; + WebsocketError::AuthError + })?; if is_valid_token { debug!("User {} authenticated", user_id); + Ok(()) } else { debug!("User {} not authenticated", user_id); - return Err(WebsocketError::UnauthorizedDevice); + Err(WebsocketError::UnauthorizedDevice) } - - Ok(()) } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -79,7 +79,7 @@ let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; tokio::spawn(async move { - accept_connection(websocket, addr).await; + accept_connection(websocket, addr, db_client).await; }); return Ok(response); @@ -229,7 +229,11 @@ } #[tracing::instrument(skip_all)] -async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { +async fn accept_connection( + hyper_ws: HyperWebsocket, + addr: SocketAddr, + db_client: crate::DatabaseClient, +) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { @@ -250,7 +254,9 @@ if let Some(Ok(auth_message)) = incoming.next().await { match auth_message { Message::Text(text) => { - if let Err(auth_error) = auth::handle_auth_message(&text).await { + if let Err(auth_error) = + auth::handle_auth_message(&db_client, &text).await + { let error_response = ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error( auth_error.to_string(),