diff --git a/native/native_rust_library/Cargo.lock b/native/native_rust_library/Cargo.lock --- a/native/native_rust_library/Cargo.lock +++ b/native/native_rust_library/Cargo.lock @@ -182,6 +182,7 @@ "reqwest", "serde_json", "sha2", + "tokio", "tokio-tungstenite", "url", ] diff --git a/native/native_rust_library/src/backup.rs b/native/native_rust_library/src/backup.rs --- a/native/native_rust_library/src/backup.rs +++ b/native/native_rust_library/src/backup.rs @@ -12,8 +12,8 @@ use crate::BACKUP_SOCKET_ADDR; use crate::RUNTIME; use backup_client::{ - BackupClient, BackupDescriptor, DownloadLogsRequest, LatestBackupIDResponse, - LogWSResponse, RequestedData, SinkExt, StreamExt, UserIdentity, + BackupClient, BackupDescriptor, LatestBackupIDResponse, RequestedData, + UserIdentity, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -138,36 +138,6 @@ let user_data: serde_json::Value = serde_json::from_slice(&user_data)?; - let (tx, rx) = backup_client.download_logs(&user_identity).await?; - - tokio::pin!(tx); - tokio::pin!(rx); - - tx.send(DownloadLogsRequest { - backup_id: backup_id.clone(), - from_id: None, - }) - .await?; - - match rx.next().await { - Some(Ok(LogWSResponse::LogDownload { - log_id: 1, - content, - attachments: None, - })) - if content == (1..100).collect::>() => {} - response => { - return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) - } - }; - - match rx.next().await { - Some(Ok(LogWSResponse::LogDownloadFinished { last_log_id: None })) => {} - response => { - return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) - } - }; - Ok( json!({ "userData": user_data, @@ -226,7 +196,3 @@ Ok(decrypted) } - -#[derive(Debug, derive_more::Display)] -struct InvalidWSLogResponse(String); -impl Error for InvalidWSLogResponse {} diff --git a/services/commtest/Cargo.lock b/services/commtest/Cargo.lock --- a/services/commtest/Cargo.lock +++ b/services/commtest/Cargo.lock @@ -170,6 +170,7 @@ "reqwest", "serde_json", "sha2", + "tokio", "tokio-tungstenite", "url", ] diff --git a/services/commtest/tests/backup_integration_test.rs b/services/commtest/tests/backup_integration_test.rs --- a/services/commtest/tests/backup_integration_test.rs +++ b/services/commtest/tests/backup_integration_test.rs @@ -1,22 +1,19 @@ -use std::collections::{HashMap, HashSet}; - use backup_client::{ - BackupClient, BackupData, BackupDescriptor, Error as BackupClientError, - LogUploadConfirmation, RequestedData, SinkExt, StreamExt, TryStreamExt, + BackupClient, BackupData, BackupDescriptor, DownloadedLog, + Error as BackupClientError, LogUploadConfirmation, RequestedData, SinkExt, + StreamExt, TryStreamExt, }; use bytesize::ByteSize; use comm_lib::{ auth::UserIdentity, - backup::{ - DownloadLogsRequest, LatestBackupIDResponse, LogWSResponse, - UploadLogRequest, - }, + backup::{LatestBackupIDResponse, UploadLogRequest}, }; use commtest::{ service_addr, tools::{generate_stable_nbytes, Error}, }; use reqwest::StatusCode; +use std::collections::HashSet; use uuid::Uuid; #[tokio::test] @@ -37,17 +34,17 @@ .upload_backup(&user_identity, backup_data.clone()) .await?; - let (tx, rx) = backup_client.upload_logs(&user_identity).await.unwrap(); + let (tx, rx) = backup_client.upload_logs(&user_identity).await?; tokio::pin!(tx); tokio::pin!(rx); for log_data in log_datas { - tx.send(log_data.clone()).await.unwrap(); + tx.send(log_data.clone()).await?; } let result: HashSet = - rx.take(log_datas.len()).try_collect().await.unwrap(); + rx.take(log_datas.len()).try_collect().await?; let expected = log_datas .iter() .map(|data| LogUploadConfirmation { @@ -96,49 +93,18 @@ assert_eq!(user_keys, backup_data.user_keys); // Test log download - let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); + let log_stream = backup_client + .download_logs(&user_identity, &backup_data.backup_id) + .await; - tokio::pin!(tx); - tokio::pin!(rx); + let downloaded_logs: Vec = log_stream.try_collect().await?; - tx.send(DownloadLogsRequest { - backup_id: backup_data.backup_id.clone(), - from_id: None, - }) - .await - .unwrap(); - - let mut downloaded_logs = HashMap::new(); - 'download: loop { - loop { - match rx.next().await.unwrap().unwrap() { - LogWSResponse::LogDownload { - log_id, - content, - attachments, - } => { - downloaded_logs.insert(log_id, (content, attachments)); - } - LogWSResponse::LogDownloadFinished { last_log_id } => { - if let Some(last_log_id) = last_log_id { - tx.send(DownloadLogsRequest { - backup_id: backup_data.backup_id.clone(), - from_id: Some(last_log_id), - }) - .await - .unwrap(); - } else { - break 'download; - } - } - msg => panic!("Got response: {msg:?}"), - }; - } - } - let expected_logs = log_datas + let expected_logs: Vec = log_datas .iter() - .cloned() - .map(|data| (data.log_id, (data.content, data.attachments))) + .map(|data| DownloadedLog { + content: data.content.clone(), + attachments: data.attachments.clone(), + }) .collect(); assert_eq!(downloaded_logs, expected_logs); @@ -164,27 +130,18 @@ ); // Test log cleanup - let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); + let log_stream = backup_client + .download_logs(&user_identity, &removed_backup.backup_id) + .await; - tokio::pin!(tx); - tokio::pin!(rx); + let downloaded_logs: Vec = log_stream.try_collect().await?; - tx.send(DownloadLogsRequest { - backup_id: removed_backup.backup_id.clone(), - from_id: None, - }) - .await - .unwrap(); - - match rx.next().await.unwrap().unwrap() { - LogWSResponse::LogDownloadFinished { last_log_id: None } => (), - msg => { - panic!( - "Logs for first backup should have been removed, \ - instead got response: {msg:?}" - ) - } - }; + if !downloaded_logs.is_empty() { + panic!( + "Logs for first backup should have been removed, \ + instead got: {downloaded_logs:?}" + ) + } Ok(()) } diff --git a/shared/backup_client/Cargo.toml b/shared/backup_client/Cargo.toml --- a/shared/backup_client/Cargo.toml +++ b/shared/backup_client/Cargo.toml @@ -18,6 +18,7 @@ tokio-tungstenite = "0.18.0" futures-util = "0.3" bincode = "1.3.3" +tokio = "1.24" [features] default = ["native-tls"] diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -1,3 +1,6 @@ +use std::time::Duration; + +use async_stream::{stream, try_stream}; pub use comm_lib::auth::UserIdentity; pub use comm_lib::backup::{ DownloadLogsRequest, LatestBackupIDResponse, LogWSRequest, LogWSResponse, @@ -22,6 +25,9 @@ }, }; +const LOG_DOWNLOAD_RETRY_DELAY: Duration = Duration::from_secs(5); +const LOG_DOWNLOAD_MAX_RETRY: usize = 3; + #[derive(Debug, Clone)] pub struct BackupClient { url: reqwest::Url, @@ -130,36 +136,112 @@ LogWSResponse::LogUploaded { backup_id, log_id } => { Ok(LogUploadConfirmation { backup_id, log_id }) } - LogWSResponse::LogDownload { .. } - | LogWSResponse::LogDownloadFinished { .. } => { - Err(Error::InvalidBackupMessage) - } LogWSResponse::ServerError => Err(Error::ServerError), + msg => Err(Error::InvalidBackupMessage(msg)), }); Ok((tx, rx)) } - pub async fn download_logs( - &self, - user_identity: &UserIdentity, - ) -> Result< - ( - impl Sink, - impl Stream>, - ), - Error, - > { - let (tx, rx) = self.create_log_ws_connection(user_identity).await?; + /// Handles complete log download. + /// It will try and retry download a few times, but if the issues persist + /// the next item returned will be the last received error and the stream + /// will be closed. + pub async fn download_logs<'this>( + &'this self, + user_identity: &'this UserIdentity, + backup_id: &'this str, + ) -> impl Stream> + 'this { + stream! { + let mut last_downloaded_log = None; + let mut fail_count = 0; + + 'retry: loop { + let stream = self.log_download_stream(user_identity, backup_id, &mut last_downloaded_log).await; + let mut stream = Box::pin(stream); + + while let Some(item) = stream.next().await { + match item { + Ok(log) => yield Ok(log), + Err(err) => { + println!("Error when downloading logs: {err:?}"); + + fail_count += 1; + if fail_count >= LOG_DOWNLOAD_MAX_RETRY { + yield Err(err); + break 'retry; + } + + tokio::time::sleep(LOG_DOWNLOAD_RETRY_DELAY).await; + continue 'retry; + } + } + } + + // Everything downloaded + return; + } - let rx = rx.map(|response| match response? { - msg @ (LogWSResponse::LogDownloadFinished { .. } - | LogWSResponse::LogDownload { .. }) => Ok(msg), - LogWSResponse::LogUploaded { .. } => Err(Error::InvalidBackupMessage), - LogWSResponse::ServerError => Err(Error::ServerError), - }); + println!("Log download failed!"); + } + } - Ok((tx, rx)) + /// Handles singular connection websocket connection. Returns error in case + /// anything goes wrong e.g. missing log or connection error. + async fn log_download_stream<'stream>( + &'stream self, + user_identity: &'stream UserIdentity, + backup_id: &'stream str, + last_downloaded_log: &'stream mut Option, + ) -> impl Stream> + 'stream { + try_stream! { + let (tx, rx) = self.create_log_ws_connection(user_identity).await?; + + let mut tx = Box::pin(tx); + let mut rx = Box::pin(rx); + + tx.send(DownloadLogsRequest { + backup_id: backup_id.to_string(), + from_id: *last_downloaded_log, + }) + .await?; + + while let Some(response) = rx.try_next().await? { + let expected_log_id = last_downloaded_log.unwrap_or(0); + match response { + LogWSResponse::LogDownload { + content, + attachments, + log_id, + } if log_id == expected_log_id + 1 => { + *last_downloaded_log = Some(log_id); + yield DownloadedLog { + content, + attachments, + }; + } + LogWSResponse::LogDownload { .. } => { + Err(Error::LogMissing)?; + } + LogWSResponse::LogDownloadFinished { + last_log_id: Some(log_id), + } if log_id == expected_log_id => { + tx.send(DownloadLogsRequest { + backup_id: backup_id.to_string(), + from_id: *last_downloaded_log, + }) + .await? + } + LogWSResponse::LogDownloadFinished { last_log_id: None } => return, + LogWSResponse::LogDownloadFinished { .. } => { + Err(Error::LogMissing)?; + } + msg => Err(Error::InvalidBackupMessage(msg))?, + } + } + + Err(Error::WSClosed)?; + } } async fn create_log_ws_connection>( @@ -261,9 +343,13 @@ pub log_id: usize, } -#[derive( - Debug, derive_more::Display, derive_more::Error, derive_more::From, -)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DownloadedLog { + pub content: Vec, + pub attachments: Option>, +} + +#[derive(Debug, derive_more::Display, derive_more::From)] pub enum Error { InvalidAuthorizationHeader, UrlSchemaError, @@ -273,9 +359,13 @@ JsonError(serde_json::Error), BincodeError(bincode::Error), InvalidWSMessage, - InvalidBackupMessage, + #[display(fmt = "Error::InvalidBackupMessage({:?})", _0)] + InvalidBackupMessage(LogWSResponse), ServerError, + LogMissing, + WSClosed, } +impl std::error::Error for Error {} impl From for Error { fn from(_: InvalidHeaderValue) -> Self {