diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -12,7 +12,6 @@ Body, }; use sha2::{Digest, Sha256}; -use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{ connect_async, tungstenite::{ @@ -125,38 +124,17 @@ ), Error, > { - let request = self.create_ws_request(user_identity)?; - let (stream, response) = connect_async(request).await?; + let (tx, rx) = self.create_log_ws_connection(user_identity).await?; - if response.status().is_client_error() { - return Err(Error::TungsteniteError(TungsteniteError::Http(response))); - } - - let (tx, rx) = stream.split(); - - let tx = tx.with(|request: UploadLogRequest| async { - let request = LogWSRequest::UploadLog(request); - let request = bincode::serialize(&request)?; - Ok(Binary(request)) - }); - - let rx = rx.filter_map(|msg| async { - let response = match get_log_ws_response(msg) { - Some(Ok(response)) => response, - Some(Err(err)) => return Some(Err(err)), - None => return None, - }; - - match response { - LogWSResponse::LogUploaded { backup_id, log_id } => { - Some(Ok(LogUploadConfirmation { backup_id, log_id })) - } - LogWSResponse::LogDownload { .. } - | LogWSResponse::LogDownloadFinished { .. } => { - Some(Err(Error::InvalidBackupMessage)) - } - LogWSResponse::ServerError => Some(Err(Error::ServerError)), + let rx = rx.map(|response| match response? { + LogWSResponse::LogUploaded { backup_id, log_id } => { + Ok(LogUploadConfirmation { backup_id, log_id }) } + LogWSResponse::LogDownload { .. } + | LogWSResponse::LogDownloadFinished { .. } => { + Err(Error::InvalidBackupMessage) + } + LogWSResponse::ServerError => Err(Error::ServerError), }); Ok((tx, rx)) @@ -171,6 +149,28 @@ impl Stream>, ), Error, + > { + let (tx, rx) = self.create_log_ws_connection(user_identity).await?; + + let rx = rx.map(|response| match response? { + msg @ (LogWSResponse::LogDownloadFinished { .. } + | LogWSResponse::LogDownload { .. }) => Ok(msg), + LogWSResponse::LogUploaded { .. } => Err(Error::InvalidBackupMessage), + LogWSResponse::ServerError => Err(Error::ServerError), + }); + + Ok((tx, rx)) + } + + async fn create_log_ws_connection>( + &self, + user_identity: &UserIdentity, + ) -> Result< + ( + impl Sink, + impl Stream>, + ), + Error, > { let request = self.create_ws_request(user_identity)?; let (stream, response) = connect_async(request).await?; @@ -181,26 +181,24 @@ let (tx, rx) = stream.split(); - let tx = tx.with(|request: DownloadLogsRequest| async { - let request = LogWSRequest::DownloadLogs(request); + let tx = tx.with(|request: Request| async { + let request: LogWSRequest = request.into(); let request = bincode::serialize(&request)?; Ok(Binary(request)) }); let rx = rx.filter_map(|msg| async { - let response = match get_log_ws_response(msg) { - Some(Ok(response)) => response, - Some(Err(err)) => return Some(Err(err)), - None => return None, + let bytes = match msg { + Ok(Binary(bytes)) => bytes, + // Handled by tungstenite + Ok(Ping(_)) => return None, + Ok(_) => return Some(Err(Error::InvalidWSMessage)), + Err(err) => return Some(Err(err.into())), }; - match response { - LogWSResponse::LogDownloadFinished { .. } - | LogWSResponse::LogDownload { .. } => Some(Ok(response)), - LogWSResponse::LogUploaded { .. } => { - Some(Err(Error::InvalidBackupMessage)) - } - LogWSResponse::ServerError => Some(Err(Error::ServerError)), + match bincode::deserialize(&bytes) { + Ok(response) => Some(Ok(response)), + Err(err) => Some(Err(err.into())), } }); @@ -284,20 +282,3 @@ Self::InvalidAuthorizationHeader } } - -fn get_log_ws_response( - msg: Result, -) -> Option> { - let bytes = match msg { - Ok(Binary(bytes)) => bytes, - // Handled by tungstenite - Ok(Ping(_)) => return None, - Ok(_) => return Some(Err(Error::InvalidWSMessage)), - Err(err) => return Some(Err(err.into())), - }; - - match bincode::deserialize(&bytes) { - Ok(response) => Some(Ok(response)), - Err(err) => Some(Err(err.into())), - } -} diff --git a/shared/comm-lib/src/backup/mod.rs b/shared/comm-lib/src/backup/mod.rs --- a/shared/comm-lib/src/backup/mod.rs +++ b/shared/comm-lib/src/backup/mod.rs @@ -20,7 +20,7 @@ pub from_id: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, derive_more::From)] pub enum LogWSRequest { UploadLog(UploadLogRequest), DownloadLogs(DownloadLogsRequest),