diff --git a/services/backup/src/http/handlers/log.rs b/services/backup/src/http/handlers/log.rs --- a/services/backup/src/http/handlers/log.rs +++ b/services/backup/src/http/handlers/log.rs @@ -1,5 +1,6 @@ use crate::constants::WS_FRAME_SIZE; use crate::database::{log_item::LogItem, DatabaseClient}; +use actix::fut::ready; use actix::{Actor, ActorContext, ActorFutureExt, AsyncContext, StreamHandler}; use actix_http::ws::{CloseCode, Item}; use actix_web::{ @@ -18,19 +19,19 @@ }, database::{self, blob::BlobOrDBContent}, }; +use std::future::Future; use std::time::{Duration, Instant}; use tracing::{error, info, instrument, warn}; pub async fn handle_ws( req: HttpRequest, - user: UserIdentity, stream: web::Payload, blob_client: web::Data, db_client: web::Data, ) -> Result { ws::WsResponseBuilder::new( LogWSActor { - user, + user: None, blob_client: blob_client.as_ref().clone(), db_client: db_client.as_ref().clone(), last_msg_time: Instant::now(), @@ -53,7 +54,7 @@ } struct LogWSActor { - user: UserIdentity, + user: Option, blob_client: BlobServiceClient, db_client: DatabaseClient, @@ -66,18 +67,51 @@ const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); fn handle_msg_sync( - &self, + &mut self, ctx: &mut WebsocketContext, bytes: Bytes, ) { - let fut = Self::handle_msg( - self.user.user_id.clone(), - self.blob_client.clone(), - self.db_client.clone(), - bytes, - ); + match bincode::deserialize(&bytes) { + Ok(request) => { + if let LogWSRequest::Authenticate(user) = request { + self.user.replace(user); + return; + } + + let Some(user) = &self.user else { + Self::spawn_response_future( + ctx, + ready(Ok(vec![LogWSResponse::Unauthenticated])), + ); + return; + }; + + Self::spawn_response_future( + ctx, + Self::handle_msg( + user.user_id.clone(), + self.blob_client.clone(), + self.db_client.clone(), + request, + ), + ); + } + Err(err) => { + error!("Error: {err:?}"); + + Self::spawn_response_future( + ctx, + ready(Ok(vec![LogWSResponse::ServerError])), + ); + } + }; + } - let fut = actix::fut::wrap_future(fut).map( + fn spawn_response_future( + ctx: &mut WebsocketContext, + future: impl Future, LogWSError>> + 'static, + ) { + let fut = actix::fut::wrap_future(future).map( |responses, _: &mut LogWSActor, ctx: &mut WebsocketContext| { @@ -109,10 +143,8 @@ user_id: String, blob_client: BlobServiceClient, db_client: DatabaseClient, - bytes: Bytes, + request: LogWSRequest, ) -> Result, LogWSError> { - let request = bincode::deserialize(&bytes)?; - match request { LogWSRequest::UploadLog(UploadLogRequest { backup_id, @@ -180,6 +212,10 @@ Ok(messages) } + LogWSRequest::Authenticate(_) => { + warn!("LogWSRequest::Authenticate should have been handled earlier."); + Ok(Vec::new()) + } } } diff --git a/services/backup/src/http/mod.rs b/services/backup/src/http/mod.rs --- a/services/backup/src/http/mod.rs +++ b/services/backup/src/http/mod.rs @@ -63,7 +63,6 @@ ) .service( web::scope("/logs") - .wrap(get_comm_authentication_middleware()) .service(web::resource("").route(web::get().to(handle_ws))), ) }) diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -18,8 +18,6 @@ use tokio_tungstenite::{ connect_async, tungstenite::{ - client::IntoClientRequest, - http::{header, Request}, Error as TungsteniteError, Message::{Binary, Ping}, }, @@ -254,14 +252,19 @@ ), Error, > { - let request = self.create_ws_request(user_identity)?; - let (stream, response) = connect_async(request).await?; + let url = self.create_ws_url()?; + let (stream, response) = connect_async(url).await?; if response.status().is_client_error() { return Err(Error::TungsteniteError(TungsteniteError::Http(response))); } - let (tx, rx) = stream.split(); + let (mut tx, rx) = stream.split(); + + tx.send(Binary(bincode::serialize(&LogWSRequest::Authenticate( + user_identity.clone(), + ))?)) + .await?; let tx = tx.with(|request: Request| async { let request: LogWSRequest = request.into(); @@ -287,10 +290,7 @@ Ok((tx, rx)) } - fn create_ws_request( - &self, - user_identity: &UserIdentity, - ) -> Result, Error> { + fn create_ws_url(&self) -> Result { let mut url = self.url.clone(); match url.scheme() { @@ -300,14 +300,7 @@ }; let url = url.join("logs")?; - let mut request = url.into_client_request().unwrap(); - - let token = user_identity.as_authorization_token()?; - request - .headers_mut() - .insert(header::AUTHORIZATION, format!("Bearer {token}").parse()?); - - Ok(request) + Ok(url) } } diff --git a/shared/comm-lib/src/backup/mod.rs b/shared/comm-lib/src/backup/mod.rs --- a/shared/comm-lib/src/backup/mod.rs +++ b/shared/comm-lib/src/backup/mod.rs @@ -1,3 +1,4 @@ +use crate::auth::UserIdentity; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -22,6 +23,7 @@ #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From)] pub enum LogWSRequest { + Authenticate(UserIdentity), UploadLog(UploadLogRequest), DownloadLogs(DownloadLogsRequest), } @@ -41,4 +43,5 @@ last_log_id: Option, }, ServerError, + Unauthenticated, }