diff --git a/native/native_rust_library/src/backup.rs b/native/native_rust_library/src/backup.rs --- a/native/native_rust_library/src/backup.rs +++ b/native/native_rust_library/src/backup.rs @@ -6,8 +6,8 @@ use crate::RUNTIME; use backup_client::{ BackupClient, BackupData, BackupDescriptor, DownloadLogsRequest, - LatestBackupIDResponse, LogWSResponse, RequestedData, SinkExt, StreamExt, - UploadLogRequest, UserIdentity, + LatestBackupIDResponse, LogUploadConfirmation, LogWSResponse, RequestedData, + SinkExt, StreamExt, UploadLogRequest, UserIdentity, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -85,21 +85,27 @@ .upload_backup(&user_identity, backup_data) .await?; - let (tx, rx) = backup_client - .upload_logs(&user_identity, &backup_id) - .await?; + let (tx, rx) = backup_client.upload_logs(&user_identity).await?; tokio::pin!(tx); tokio::pin!(rx); let log_data = UploadLogRequest { + backup_id: backup_id.clone(), log_id: 1, content: (1..100).collect(), attachments: None, }; tx.send(log_data.clone()).await?; match rx.next().await { - Some(Ok(1)) => (), + Some(Ok(LogUploadConfirmation { + backup_id: response_backup_id, + log_id: 1, + })) + if backup_id == response_backup_id => + { + // Correctly uploaded + } response => { return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) } @@ -149,14 +155,16 @@ let user_data: serde_json::Value = serde_json::from_slice(&user_data)?; - let (tx, rx) = backup_client - .download_logs(&user_identity, &backup_id) - .await?; + let (tx, rx) = backup_client.download_logs(&user_identity).await?; tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }).await?; + tx.send(DownloadLogsRequest { + backup_id: backup_id.clone(), + from_id: None, + }) + .await?; match rx.next().await { Some(Ok(LogWSResponse::LogDownload { diff --git a/services/backup/src/http/handlers/log.rs b/services/backup/src/http/handlers/log.rs --- a/services/backup/src/http/handlers/log.rs +++ b/services/backup/src/http/handlers/log.rs @@ -17,23 +17,17 @@ }, database::{self, blob::BlobOrDBContent}, }; -use std::{ - sync::Arc, - time::{Duration, Instant}, -}; +use std::time::{Duration, Instant}; use tracing::{error, info, instrument, warn}; pub async fn handle_ws( - path: web::Path, req: HttpRequest, stream: web::Payload, blob_client: web::Data, db_client: web::Data, ) -> Result { - let backup_id = path.into_inner(); ws::WsResponseBuilder::new( LogWSActor { - info: Arc::new(ConnectionInfo { backup_id }), blob_client: blob_client.as_ref().clone(), db_client: db_client.as_ref().clone(), last_msg_time: Instant::now(), @@ -46,10 +40,6 @@ .start() } -struct ConnectionInfo { - backup_id: String, -} - #[derive( Debug, derive_more::From, derive_more::Display, derive_more::Error, )] @@ -60,7 +50,6 @@ } struct LogWSActor { - info: Arc, blob_client: BlobServiceClient, db_client: DatabaseClient, @@ -77,12 +66,8 @@ ctx: &mut WebsocketContext, bytes: Bytes, ) { - let fut = Self::handle_msg( - self.info.clone(), - self.blob_client.clone(), - self.db_client.clone(), - bytes, - ); + let fut = + Self::handle_msg(self.blob_client.clone(), self.db_client.clone(), bytes); let fut = actix::fut::wrap_future(fut).map( |responses, @@ -113,7 +98,6 @@ } async fn handle_msg( - info: Arc, blob_client: BlobServiceClient, db_client: DatabaseClient, bytes: Bytes, @@ -122,6 +106,7 @@ match request { LogWSRequest::UploadLog(UploadLogRequest { + backup_id, log_id, content, attachments, @@ -136,7 +121,7 @@ } let mut log_item = LogItem { - backup_id: info.backup_id.clone(), + backup_id: backup_id.clone(), log_id, content: BlobOrDBContent::new(content), attachments: attachment_blob_infos, @@ -145,11 +130,14 @@ log_item.ensure_size_constraints(&blob_client).await?; db_client.put_log_item(log_item).await?; - Ok(vec![LogWSResponse::LogUploaded { log_id }]) + Ok(vec![LogWSResponse::LogUploaded { backup_id, log_id }]) } - LogWSRequest::DownloadLogs(DownloadLogsRequest { from_id }) => { + LogWSRequest::DownloadLogs(DownloadLogsRequest { + backup_id, + from_id, + }) => { let (log_items, last_id) = - db_client.fetch_log_items(&info.backup_id, from_id).await?; + db_client.fetch_log_items(&backup_id, from_id).await?; let mut messages = vec![]; @@ -210,7 +198,7 @@ impl Actor for LogWSActor { type Context = ws::WebsocketContext; - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn started(&mut self, ctx: &mut Self::Context) { info!("Socket opened"); ctx.run_interval(Self::HEARTBEAT_INTERVAL, |actor, ctx| { @@ -226,14 +214,14 @@ }); } - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn stopped(&mut self, _: &mut Self::Context) { info!("Socket closed"); } } impl StreamHandler> for LogWSActor { - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn handle( &mut self, msg: Result, diff --git a/services/backup/src/http/mod.rs b/services/backup/src/http/mod.rs --- a/services/backup/src/http/mod.rs +++ b/services/backup/src/http/mod.rs @@ -64,7 +64,7 @@ .service( web::scope("/logs") .wrap(get_comm_authentication_middleware()) - .service(web::resource("{backup_d}").route(web::get().to(handle_ws))), + .service(web::resource("").route(web::get().to(handle_ws))), ) }) .bind(("0.0.0.0", CONFIG.http_port))? diff --git a/services/commtest/tests/backup_integration_test.rs b/services/commtest/tests/backup_integration_test.rs --- a/services/commtest/tests/backup_integration_test.rs +++ b/services/commtest/tests/backup_integration_test.rs @@ -2,7 +2,7 @@ use backup_client::{ BackupClient, BackupData, BackupDescriptor, Error as BackupClientError, - RequestedData, SinkExt, StreamExt, TryStreamExt, + LogUploadConfirmation, RequestedData, SinkExt, StreamExt, TryStreamExt, }; use bytesize::ByteSize; use comm_lib::{ @@ -37,10 +37,7 @@ .upload_backup(&user_identity, backup_data.clone()) .await?; - let (tx, rx) = backup_client - .upload_logs(&user_identity, &backup_data.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.upload_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); @@ -49,9 +46,16 @@ tx.send(log_data.clone()).await.unwrap(); } - let result: HashSet = + let result: HashSet = rx.take(log_datas.len()).try_collect().await.unwrap(); - let expected = log_datas.iter().map(|data| data.log_id).collect(); + let expected = log_datas + .iter() + .map(|data| LogUploadConfirmation { + backup_id: data.backup_id.clone(), + log_id: data.log_id, + }) + .collect(); + assert_eq!(result, expected); } @@ -92,17 +96,17 @@ assert_eq!(user_keys, backup_data.user_keys); // Test log download - let (tx, rx) = backup_client - .download_logs(&user_identity, &backup_data.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }) - .await - .unwrap(); + tx.send(DownloadLogsRequest { + backup_id: backup_data.backup_id.clone(), + from_id: None, + }) + .await + .unwrap(); let mut downloaded_logs = HashMap::new(); 'download: loop { @@ -118,6 +122,7 @@ LogWSResponse::LogDownloadFinished { last_log_id } => { if let Some(last_log_id) = last_log_id { tx.send(DownloadLogsRequest { + backup_id: backup_data.backup_id.clone(), from_id: Some(last_log_id), }) .await @@ -159,17 +164,17 @@ ); // Test log cleanup - let (tx, rx) = backup_client - .download_logs(&user_identity, &removed_backup.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }) - .await - .unwrap(); + tx.send(DownloadLogsRequest { + backup_id: removed_backup.backup_id.clone(), + from_id: None, + }) + .await + .unwrap(); match rx.next().await.unwrap().unwrap() { LogWSResponse::LogDownloadFinished { last_log_id: None } => (), @@ -199,7 +204,7 @@ ), attachments: vec![], }, - generate_log_data(b'a'), + generate_log_data("b1", b'a'), ), ( BackupData { @@ -214,12 +219,12 @@ ), attachments: vec![], }, - generate_log_data(b'b'), + generate_log_data("b2", b'b'), ), ] } -fn generate_log_data(value: u8) -> Vec { +fn generate_log_data(backup_id: &str, value: u8) -> Vec { const IN_DB_SIZE: usize = ByteSize::kib(4).as_u64() as usize; const IN_BLOB_SIZE: usize = ByteSize::kib(400).as_u64() as usize; @@ -240,6 +245,7 @@ content.extend(unique_suffix.as_bytes()); UploadLogRequest { + backup_id: backup_id.to_string(), log_id, content, attachments, diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -119,15 +119,14 @@ pub async fn upload_logs( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result< ( impl Sink, - impl Stream>, + impl Stream>, ), Error, > { - let request = self.create_ws_request(user_identity, backup_id)?; + let request = self.create_ws_request(user_identity)?; let (stream, response) = connect_async(request).await?; if response.status().is_client_error() { @@ -150,7 +149,9 @@ }; match response { - LogWSResponse::LogUploaded { log_id } => Some(Ok(log_id)), + LogWSResponse::LogUploaded { backup_id, log_id } => { + Some(Ok(LogUploadConfirmation { backup_id, log_id })) + } LogWSResponse::LogDownload { .. } | LogWSResponse::LogDownloadFinished { .. } => { Some(Err(WSError::InvalidBackupMessage)) @@ -165,7 +166,6 @@ pub async fn download_logs( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result< ( impl Sink, @@ -173,7 +173,7 @@ ), Error, > { - let request = self.create_ws_request(user_identity, backup_id)?; + let request = self.create_ws_request(user_identity)?; let (stream, response) = connect_async(request).await?; if response.status().is_client_error() { @@ -211,7 +211,6 @@ fn create_ws_request( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result, Error> { let mut url = self.url.clone(); @@ -220,7 +219,7 @@ "https" => url.set_scheme("wss")?, _ => (), }; - let url = url.join("logs/")?.join(backup_id)?; + let url = url.join("logs")?; let mut request = url.into_client_request().unwrap(); @@ -259,6 +258,12 @@ UserData, } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LogUploadConfirmation { + pub backup_id: String, + pub log_id: usize, +} + #[derive( Debug, derive_more::Display, derive_more::Error, derive_more::From, )] diff --git a/shared/comm-lib/src/backup/mod.rs b/shared/comm-lib/src/backup/mod.rs --- a/shared/comm-lib/src/backup/mod.rs +++ b/shared/comm-lib/src/backup/mod.rs @@ -8,6 +8,7 @@ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UploadLogRequest { + pub backup_id: String, pub log_id: usize, pub content: Vec, pub attachments: Option>, @@ -15,6 +16,7 @@ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DownloadLogsRequest { + pub backup_id: String, pub from_id: Option, } @@ -27,6 +29,7 @@ #[derive(Debug, Clone, Serialize, Deserialize)] pub enum LogWSResponse { LogUploaded { + backup_id: String, log_id: usize, }, LogDownload {