diff --git a/services/backup/src/http/handlers/log.rs b/services/backup/src/http/handlers/log.rs --- a/services/backup/src/http/handlers/log.rs +++ b/services/backup/src/http/handlers/log.rs @@ -8,7 +8,7 @@ Error, HttpRequest, HttpResponse, }; use actix_web_actors::ws::{self, WebsocketContext}; -use comm_lib::auth::UserIdentity; +use comm_lib::auth::{AuthService, AuthServiceError, UserIdentity}; use comm_lib::{ backup::{ DownloadLogsRequest, LogWSRequest, LogWSResponse, UploadLogRequest, @@ -20,6 +20,7 @@ database::{self, blob::BlobOrDBContent}, }; use std::future::Future; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tracing::{error, info, instrument, warn}; @@ -28,12 +29,14 @@ stream: web::Payload, blob_client: web::Data, db_client: web::Data, + auth_service: AuthService, ) -> Result { ws::WsResponseBuilder::new( LogWSActor { - user: None, + user: Arc::new(Mutex::new(None)), blob_client: blob_client.as_ref().clone(), db_client: db_client.as_ref().clone(), + auth_service, last_msg_time: Instant::now(), buffer: BytesMut::new(), }, @@ -51,13 +54,14 @@ Bincode(bincode::Error), Blob(BlobServiceError), DB(database::Error), + Auth(AuthServiceError), } struct LogWSActor { - user: Option, + user: Arc>>, blob_client: BlobServiceClient, db_client: DatabaseClient, - + auth_service: AuthService, last_msg_time: Instant, buffer: BytesMut, } @@ -74,11 +78,19 @@ match bincode::deserialize(&bytes) { Ok(request) => { if let LogWSRequest::Authenticate(user) = request { - self.user.replace(user); + Self::spawn_response_future( + ctx, + Self::handle_auth_msg( + self.auth_service.clone(), + Arc::clone(&self.user), + user, + ), + ); return; } - let Some(user) = &self.user else { + let user_guard = self.user.lock().expect("user mutex poisoned"); + let Some(user) = user_guard.as_ref() else { Self::spawn_response_future( ctx, ready(Ok(vec![LogWSResponse::Unauthenticated])), @@ -139,6 +151,24 @@ ctx.spawn(fut); } + async fn handle_auth_msg( + auth_service: AuthService, + current_user: Arc>>, + user_to_verify: UserIdentity, + ) -> Result, LogWSError> { + use comm_lib::auth::AuthorizationCredential; + let credential = AuthorizationCredential::UserToken(user_to_verify.clone()); + let user_valid = auth_service.verify_auth_credential(&credential).await?; + + if user_valid { + *current_user.lock().expect("mutex poisoned") = Some(user_to_verify); + Ok(vec![LogWSResponse::AuthSuccess]) + } else { + tracing::debug!("Invalid credentials"); + Ok(vec![LogWSResponse::Unauthenticated]) + } + } + async fn handle_msg( user_id: String, blob_client: BlobServiceClient, diff --git a/services/commtest/tests/backup_integration_test.rs b/services/commtest/tests/backup_integration_test.rs --- a/services/commtest/tests/backup_integration_test.rs +++ b/services/commtest/tests/backup_integration_test.rs @@ -34,10 +34,7 @@ .upload_backup(&user_identity, backup_data.clone()) .await?; - let (tx, rx) = backup_client.upload_logs(&user_identity).await?; - - tokio::pin!(tx); - tokio::pin!(rx); + let (mut tx, rx) = backup_client.upload_logs(&user_identity).await?; for log_data in log_datas { tx.send(log_data.clone()).await?; diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -203,10 +203,7 @@ last_downloaded_log: &'stream mut Option, ) -> impl Stream> + 'stream { try_stream! { - let (tx, rx) = self.create_log_ws_connection(user_identity).await?; - - let mut tx = Box::pin(tx); - let mut rx = Box::pin(rx); + let (mut tx, mut rx) = self.create_log_ws_connection(user_identity).await?; tx.send(DownloadLogsRequest { backup_id: backup_id.to_string(), @@ -291,6 +288,17 @@ } }); + let tx = Box::pin(tx); + let mut rx = Box::pin(rx); + + if let Some(response) = rx.try_next().await? { + match response { + LogWSResponse::AuthSuccess => {} + LogWSResponse::Unauthenticated => Err(Error::Unauthenticated)?, + msg => Err(Error::InvalidBackupMessage(msg))?, + } + } + Ok((tx, rx)) } @@ -366,6 +374,7 @@ ServerError, LogMissing, WSClosed, + Unauthenticated, } impl std::error::Error for Error {} diff --git a/shared/comm-lib/src/backup/mod.rs b/shared/comm-lib/src/backup/mod.rs --- a/shared/comm-lib/src/backup/mod.rs +++ b/shared/comm-lib/src/backup/mod.rs @@ -46,5 +46,6 @@ last_log_id: Option, }, ServerError, + AuthSuccess, Unauthenticated, }