Page MenuHomePhabricator

D10986.diff
No OneTemporary

D10986.diff

diff --git a/native/native_rust_library/Cargo.lock b/native/native_rust_library/Cargo.lock
--- a/native/native_rust_library/Cargo.lock
+++ b/native/native_rust_library/Cargo.lock
@@ -182,6 +182,7 @@
"reqwest",
"serde_json",
"sha2",
+ "tokio",
"tokio-tungstenite",
"url",
]
diff --git a/native/native_rust_library/src/backup.rs b/native/native_rust_library/src/backup.rs
--- a/native/native_rust_library/src/backup.rs
+++ b/native/native_rust_library/src/backup.rs
@@ -12,8 +12,8 @@
use crate::BACKUP_SOCKET_ADDR;
use crate::RUNTIME;
use backup_client::{
- BackupClient, BackupDescriptor, DownloadLogsRequest, LatestBackupIDResponse,
- LogWSResponse, RequestedData, SinkExt, StreamExt, UserIdentity,
+ BackupClient, BackupDescriptor, LatestBackupIDResponse, RequestedData,
+ UserIdentity,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -138,36 +138,6 @@
let user_data: serde_json::Value = serde_json::from_slice(&user_data)?;
- let (tx, rx) = backup_client.download_logs(&user_identity).await?;
-
- tokio::pin!(tx);
- tokio::pin!(rx);
-
- tx.send(DownloadLogsRequest {
- backup_id: backup_id.clone(),
- from_id: None,
- })
- .await?;
-
- match rx.next().await {
- Some(Ok(LogWSResponse::LogDownload {
- log_id: 1,
- content,
- attachments: None,
- }))
- if content == (1..100).collect::<Vec<u8>>() => {}
- response => {
- return Err(Box::new(InvalidWSLogResponse(format!("{response:?}"))))
- }
- };
-
- match rx.next().await {
- Some(Ok(LogWSResponse::LogDownloadFinished { last_log_id: None })) => {}
- response => {
- return Err(Box::new(InvalidWSLogResponse(format!("{response:?}"))))
- }
- };
-
Ok(
json!({
"userData": user_data,
@@ -226,7 +196,3 @@
Ok(decrypted)
}
-
-#[derive(Debug, derive_more::Display)]
-struct InvalidWSLogResponse(String);
-impl Error for InvalidWSLogResponse {}
diff --git a/services/commtest/Cargo.lock b/services/commtest/Cargo.lock
--- a/services/commtest/Cargo.lock
+++ b/services/commtest/Cargo.lock
@@ -170,6 +170,7 @@
"reqwest",
"serde_json",
"sha2",
+ "tokio",
"tokio-tungstenite",
"url",
]
diff --git a/services/commtest/tests/backup_integration_test.rs b/services/commtest/tests/backup_integration_test.rs
--- a/services/commtest/tests/backup_integration_test.rs
+++ b/services/commtest/tests/backup_integration_test.rs
@@ -1,22 +1,19 @@
-use std::collections::{HashMap, HashSet};
-
use backup_client::{
- BackupClient, BackupData, BackupDescriptor, Error as BackupClientError,
- LogUploadConfirmation, RequestedData, SinkExt, StreamExt, TryStreamExt,
+ BackupClient, BackupData, BackupDescriptor, DownloadedLog,
+ Error as BackupClientError, LogUploadConfirmation, RequestedData, SinkExt,
+ StreamExt, TryStreamExt,
};
use bytesize::ByteSize;
use comm_lib::{
auth::UserIdentity,
- backup::{
- DownloadLogsRequest, LatestBackupIDResponse, LogWSResponse,
- UploadLogRequest,
- },
+ backup::{LatestBackupIDResponse, UploadLogRequest},
};
use commtest::{
service_addr,
tools::{generate_stable_nbytes, Error},
};
use reqwest::StatusCode;
+use std::collections::HashSet;
use uuid::Uuid;
#[tokio::test]
@@ -37,17 +34,17 @@
.upload_backup(&user_identity, backup_data.clone())
.await?;
- let (tx, rx) = backup_client.upload_logs(&user_identity).await.unwrap();
+ let (tx, rx) = backup_client.upload_logs(&user_identity).await?;
tokio::pin!(tx);
tokio::pin!(rx);
for log_data in log_datas {
- tx.send(log_data.clone()).await.unwrap();
+ tx.send(log_data.clone()).await?;
}
let result: HashSet<LogUploadConfirmation> =
- rx.take(log_datas.len()).try_collect().await.unwrap();
+ rx.take(log_datas.len()).try_collect().await?;
let expected = log_datas
.iter()
.map(|data| LogUploadConfirmation {
@@ -96,49 +93,18 @@
assert_eq!(user_keys, backup_data.user_keys);
// Test log download
- let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap();
+ let log_stream = backup_client
+ .download_logs(&user_identity, &backup_data.backup_id)
+ .await;
- tokio::pin!(tx);
- tokio::pin!(rx);
+ let downloaded_logs: Vec<DownloadedLog> = log_stream.try_collect().await?;
- tx.send(DownloadLogsRequest {
- backup_id: backup_data.backup_id.clone(),
- from_id: None,
- })
- .await
- .unwrap();
-
- let mut downloaded_logs = HashMap::new();
- 'download: loop {
- loop {
- match rx.next().await.unwrap().unwrap() {
- LogWSResponse::LogDownload {
- log_id,
- content,
- attachments,
- } => {
- downloaded_logs.insert(log_id, (content, attachments));
- }
- LogWSResponse::LogDownloadFinished { last_log_id } => {
- if let Some(last_log_id) = last_log_id {
- tx.send(DownloadLogsRequest {
- backup_id: backup_data.backup_id.clone(),
- from_id: Some(last_log_id),
- })
- .await
- .unwrap();
- } else {
- break 'download;
- }
- }
- msg => panic!("Got response: {msg:?}"),
- };
- }
- }
- let expected_logs = log_datas
+ let expected_logs: Vec<DownloadedLog> = log_datas
.iter()
- .cloned()
- .map(|data| (data.log_id, (data.content, data.attachments)))
+ .map(|data| DownloadedLog {
+ content: data.content.clone(),
+ attachments: data.attachments.clone(),
+ })
.collect();
assert_eq!(downloaded_logs, expected_logs);
@@ -164,27 +130,18 @@
);
// Test log cleanup
- let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap();
+ let log_stream = backup_client
+ .download_logs(&user_identity, &removed_backup.backup_id)
+ .await;
- tokio::pin!(tx);
- tokio::pin!(rx);
+ let downloaded_logs: Vec<DownloadedLog> = log_stream.try_collect().await?;
- tx.send(DownloadLogsRequest {
- backup_id: removed_backup.backup_id.clone(),
- from_id: None,
- })
- .await
- .unwrap();
-
- match rx.next().await.unwrap().unwrap() {
- LogWSResponse::LogDownloadFinished { last_log_id: None } => (),
- msg => {
- panic!(
- "Logs for first backup should have been removed, \
- instead got response: {msg:?}"
- )
- }
- };
+ if !downloaded_logs.is_empty() {
+ panic!(
+ "Logs for first backup should have been removed, \
+ instead got: {downloaded_logs:?}"
+ )
+ }
Ok(())
}
diff --git a/shared/backup_client/Cargo.toml b/shared/backup_client/Cargo.toml
--- a/shared/backup_client/Cargo.toml
+++ b/shared/backup_client/Cargo.toml
@@ -18,6 +18,7 @@
tokio-tungstenite = "0.18.0"
futures-util = "0.3"
bincode = "1.3.3"
+tokio = "1.24"
[features]
default = ["native-tls"]
diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs
--- a/shared/backup_client/src/lib.rs
+++ b/shared/backup_client/src/lib.rs
@@ -1,3 +1,6 @@
+use std::time::Duration;
+
+use async_stream::{stream, try_stream};
pub use comm_lib::auth::UserIdentity;
pub use comm_lib::backup::{
DownloadLogsRequest, LatestBackupIDResponse, LogWSRequest, LogWSResponse,
@@ -22,6 +25,9 @@
},
};
+const LOG_DOWNLOAD_RETRY_DELAY: Duration = Duration::from_secs(5);
+const LOG_DOWNLOAD_MAX_RETRY: usize = 3;
+
#[derive(Debug, Clone)]
pub struct BackupClient {
url: reqwest::Url,
@@ -130,36 +136,112 @@
LogWSResponse::LogUploaded { backup_id, log_id } => {
Ok(LogUploadConfirmation { backup_id, log_id })
}
- LogWSResponse::LogDownload { .. }
- | LogWSResponse::LogDownloadFinished { .. } => {
- Err(Error::InvalidBackupMessage)
- }
LogWSResponse::ServerError => Err(Error::ServerError),
+ msg => Err(Error::InvalidBackupMessage(msg)),
});
Ok((tx, rx))
}
- pub async fn download_logs(
- &self,
- user_identity: &UserIdentity,
- ) -> Result<
- (
- impl Sink<DownloadLogsRequest, Error = Error>,
- impl Stream<Item = Result<LogWSResponse, Error>>,
- ),
- Error,
- > {
- let (tx, rx) = self.create_log_ws_connection(user_identity).await?;
+ /// Handles complete log download.
+ /// It will try and retry download a few times, but if the issues persist
+ /// the next item returned will be the last received error and the stream
+ /// will be closed.
+ pub async fn download_logs<'this>(
+ &'this self,
+ user_identity: &'this UserIdentity,
+ backup_id: &'this str,
+ ) -> impl Stream<Item = Result<DownloadedLog, Error>> + 'this {
+ stream! {
+ let mut last_downloaded_log = None;
+ let mut fail_count = 0;
+
+ 'retry: loop {
+ let stream = self.log_download_stream(user_identity, backup_id, &mut last_downloaded_log).await;
+ let mut stream = Box::pin(stream);
+
+ while let Some(item) = stream.next().await {
+ match item {
+ Ok(log) => yield Ok(log),
+ Err(err) => {
+ println!("Error when downloading logs: {err:?}");
+
+ fail_count += 1;
+ if fail_count >= LOG_DOWNLOAD_MAX_RETRY {
+ yield Err(err);
+ break 'retry;
+ }
+
+ tokio::time::sleep(LOG_DOWNLOAD_RETRY_DELAY).await;
+ continue 'retry;
+ }
+ }
+ }
+
+ // Everything downloaded
+ return;
+ }
- let rx = rx.map(|response| match response? {
- msg @ (LogWSResponse::LogDownloadFinished { .. }
- | LogWSResponse::LogDownload { .. }) => Ok(msg),
- LogWSResponse::LogUploaded { .. } => Err(Error::InvalidBackupMessage),
- LogWSResponse::ServerError => Err(Error::ServerError),
- });
+ println!("Log download failed!");
+ }
+ }
- Ok((tx, rx))
+ /// Handles singular connection websocket connection. Returns error in case
+ /// anything goes wrong e.g. missing log or connection error.
+ async fn log_download_stream<'stream>(
+ &'stream self,
+ user_identity: &'stream UserIdentity,
+ backup_id: &'stream str,
+ last_downloaded_log: &'stream mut Option<usize>,
+ ) -> impl Stream<Item = Result<DownloadedLog, Error>> + 'stream {
+ try_stream! {
+ let (tx, rx) = self.create_log_ws_connection(user_identity).await?;
+
+ let mut tx = Box::pin(tx);
+ let mut rx = Box::pin(rx);
+
+ tx.send(DownloadLogsRequest {
+ backup_id: backup_id.to_string(),
+ from_id: *last_downloaded_log,
+ })
+ .await?;
+
+ while let Some(response) = rx.try_next().await? {
+ let expected_log_id = last_downloaded_log.unwrap_or(0);
+ match response {
+ LogWSResponse::LogDownload {
+ content,
+ attachments,
+ log_id,
+ } if log_id == expected_log_id + 1 => {
+ *last_downloaded_log = Some(log_id);
+ yield DownloadedLog {
+ content,
+ attachments,
+ };
+ }
+ LogWSResponse::LogDownload { .. } => {
+ Err(Error::LogMissing)?;
+ }
+ LogWSResponse::LogDownloadFinished {
+ last_log_id: Some(log_id),
+ } if log_id == expected_log_id => {
+ tx.send(DownloadLogsRequest {
+ backup_id: backup_id.to_string(),
+ from_id: *last_downloaded_log,
+ })
+ .await?
+ }
+ LogWSResponse::LogDownloadFinished { last_log_id: None } => return,
+ LogWSResponse::LogDownloadFinished { .. } => {
+ Err(Error::LogMissing)?;
+ }
+ msg => Err(Error::InvalidBackupMessage(msg))?,
+ }
+ }
+
+ Err(Error::WSClosed)?;
+ }
}
async fn create_log_ws_connection<Request: Into<LogWSRequest>>(
@@ -261,9 +343,13 @@
pub log_id: usize,
}
-#[derive(
- Debug, derive_more::Display, derive_more::Error, derive_more::From,
-)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct DownloadedLog {
+ pub content: Vec<u8>,
+ pub attachments: Option<Vec<String>>,
+}
+
+#[derive(Debug, derive_more::Display, derive_more::From)]
pub enum Error {
InvalidAuthorizationHeader,
UrlSchemaError,
@@ -273,9 +359,13 @@
JsonError(serde_json::Error),
BincodeError(bincode::Error),
InvalidWSMessage,
- InvalidBackupMessage,
+ #[display(fmt = "Error::InvalidBackupMessage({:?})", _0)]
+ InvalidBackupMessage(LogWSResponse),
ServerError,
+ LogMissing,
+ WSClosed,
}
+impl std::error::Error for Error {}
impl From<InvalidHeaderValue> for Error {
fn from(_: InvalidHeaderValue) -> Self {

File Metadata

Mime Type
text/plain
Expires
Thu, Dec 19, 8:58 PM (17 h, 53 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2678854
Default Alt Text
D10986.diff (12 KB)

Event Timeline