diff --git a/native/native_rust_library/src/backup.rs b/native/native_rust_library/src/backup.rs index aa27c0ea2..25fc9981a 100644 --- a/native/native_rust_library/src/backup.rs +++ b/native/native_rust_library/src/backup.rs @@ -1,241 +1,249 @@ use crate::argon2_tools::{compute_backup_key, compute_backup_key_str}; use crate::constants::{aes, secure_store}; use crate::ffi::secure_store_get; use crate::handle_string_result_as_callback; use crate::BACKUP_SOCKET_ADDR; use crate::RUNTIME; use backup_client::{ BackupClient, BackupData, BackupDescriptor, DownloadLogsRequest, - LatestBackupIDResponse, LogWSResponse, RequestedData, SinkExt, StreamExt, - UploadLogRequest, UserIdentity, + LatestBackupIDResponse, LogUploadConfirmation, LogWSResponse, RequestedData, + SinkExt, StreamExt, UploadLogRequest, UserIdentity, }; use serde::{Deserialize, Serialize}; use serde_json::json; use std::error::Error; pub mod ffi { use crate::handle_void_result_as_callback; use super::*; pub fn create_backup_sync( backup_id: String, backup_secret: String, pickle_key: String, pickled_account: String, user_data: String, promise_id: u32, ) { RUNTIME.spawn(async move { let result = create_backup( backup_id, backup_secret, pickle_key, pickled_account, user_data, ) .await; handle_void_result_as_callback(result, promise_id); }); } pub fn restore_backup_sync(backup_secret: String, promise_id: u32) { RUNTIME.spawn(async move { let result = restore_backup(backup_secret).await; handle_string_result_as_callback(result, promise_id); }); } } pub async fn create_backup( backup_id: String, backup_secret: String, pickle_key: String, pickled_account: String, user_data: String, ) -> Result<(), Box> { let mut backup_key = compute_backup_key(backup_secret.as_bytes(), backup_id.as_bytes())?; let mut user_data = user_data.into_bytes(); let mut backup_data_key = [0; aes::KEY_SIZE]; crate::ffi::generate_key(&mut backup_data_key)?; let encrypted_user_data = encrypt(&mut backup_data_key, &mut user_data)?; let user_keys = UserKeys { backup_data_key, pickle_key, pickled_account, }; let encrypted_user_keys = user_keys.encrypt(&mut backup_key)?; let backup_client = BackupClient::new(BACKUP_SOCKET_ADDR)?; let user_identity = get_user_identity_from_secure_store()?; let backup_data = BackupData { backup_id: backup_id.clone(), user_data: encrypted_user_data, user_keys: encrypted_user_keys, attachments: Vec::new(), }; backup_client .upload_backup(&user_identity, backup_data) .await?; - let (tx, rx) = backup_client - .upload_logs(&user_identity, &backup_id) - .await?; + let (tx, rx) = backup_client.upload_logs(&user_identity).await?; tokio::pin!(tx); tokio::pin!(rx); let log_data = UploadLogRequest { + backup_id: backup_id.clone(), log_id: 1, content: (1..100).collect(), attachments: None, }; tx.send(log_data.clone()).await?; match rx.next().await { - Some(Ok(1)) => (), + Some(Ok(LogUploadConfirmation { + backup_id: response_backup_id, + log_id: 1, + })) + if backup_id == response_backup_id => + { + // Correctly uploaded + } response => { return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) } }; Ok(()) } pub async fn restore_backup( backup_secret: String, ) -> Result> { let backup_client = BackupClient::new(BACKUP_SOCKET_ADDR)?; let user_identity = get_user_identity_from_secure_store()?; let latest_backup_descriptor = BackupDescriptor::Latest { username: user_identity.user_id.clone(), }; let backup_id_response = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::BackupID) .await?; let LatestBackupIDResponse { backup_id } = serde_json::from_slice(&backup_id_response)?; let mut backup_key = compute_backup_key_str(&backup_secret, &backup_id)?; let mut encrypted_user_keys = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::UserKeys) .await?; let mut user_keys = UserKeys::from_encrypted(&mut encrypted_user_keys, &mut backup_key)?; let backup_data_descriptor = BackupDescriptor::BackupID { backup_id: backup_id.clone(), user_identity: user_identity.clone(), }; let mut encrypted_user_data = backup_client .download_backup_data(&backup_data_descriptor, RequestedData::UserData) .await?; let user_data = decrypt(&mut user_keys.backup_data_key, &mut encrypted_user_data)?; let user_data: serde_json::Value = serde_json::from_slice(&user_data)?; - let (tx, rx) = backup_client - .download_logs(&user_identity, &backup_id) - .await?; + let (tx, rx) = backup_client.download_logs(&user_identity).await?; tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }).await?; + tx.send(DownloadLogsRequest { + backup_id: backup_id.clone(), + from_id: None, + }) + .await?; match rx.next().await { Some(Ok(LogWSResponse::LogDownload { log_id: 1, content, attachments: None, })) if content == (1..100).collect::>() => {} response => { return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) } }; match rx.next().await { Some(Ok(LogWSResponse::LogDownloadFinished { last_log_id: None })) => {} response => { return Err(Box::new(InvalidWSLogResponse(format!("{response:?}")))) } }; Ok( json!({ "userData": user_data, "pickleKey": user_keys.pickle_key, "pickledAccount": user_keys.pickled_account, }) .to_string(), ) } fn get_user_identity_from_secure_store() -> Result { Ok(UserIdentity { user_id: secure_store_get(secure_store::USER_ID)?, access_token: secure_store_get(secure_store::COMM_SERVICES_ACCESS_TOKEN)?, device_id: secure_store_get(secure_store::DEVICE_ID)?, }) } #[derive(Debug, Serialize, Deserialize)] struct UserKeys { backup_data_key: [u8; 32], pickle_key: String, pickled_account: String, } impl UserKeys { fn encrypt(&self, backup_key: &mut [u8]) -> Result, Box> { let mut json = serde_json::to_vec(self)?; encrypt(backup_key, &mut json) } fn from_encrypted( data: &mut [u8], backup_key: &mut [u8], ) -> Result> { let decrypted = decrypt(backup_key, data)?; Ok(serde_json::from_slice(&decrypted)?) } } fn encrypt(key: &mut [u8], data: &mut [u8]) -> Result, Box> { let encrypted_len = data.len() + aes::IV_LENGTH + aes::TAG_LENGTH; let mut encrypted = vec![0; encrypted_len]; crate::ffi::encrypt(key, data, &mut encrypted)?; Ok(encrypted) } fn decrypt(key: &mut [u8], data: &mut [u8]) -> Result, Box> { let decrypted_len = data.len() - aes::IV_LENGTH - aes::TAG_LENGTH; let mut decrypted = vec![0; decrypted_len]; crate::ffi::decrypt(key, data, &mut decrypted)?; Ok(decrypted) } #[derive(Debug, derive_more::Display)] struct InvalidWSLogResponse(String); impl Error for InvalidWSLogResponse {} diff --git a/services/backup/src/http/handlers/log.rs b/services/backup/src/http/handlers/log.rs index dd9a3890e..ab3d3f527 100644 --- a/services/backup/src/http/handlers/log.rs +++ b/services/backup/src/http/handlers/log.rs @@ -1,307 +1,295 @@ use crate::constants::WS_FRAME_SIZE; use crate::database::{log_item::LogItem, DatabaseClient}; use actix::{Actor, ActorContext, ActorFutureExt, AsyncContext, StreamHandler}; use actix_http::ws::{CloseCode, Item}; use actix_web::{ web::{self, Bytes, BytesMut}, Error, HttpRequest, HttpResponse, }; use actix_web_actors::ws::{self, WebsocketContext}; use comm_lib::{ backup::{ DownloadLogsRequest, LogWSRequest, LogWSResponse, UploadLogRequest, }, blob::{ client::{BlobServiceClient, BlobServiceError}, types::BlobInfo, }, database::{self, blob::BlobOrDBContent}, }; -use std::{ - sync::Arc, - time::{Duration, Instant}, -}; +use std::time::{Duration, Instant}; use tracing::{error, info, instrument, warn}; pub async fn handle_ws( - path: web::Path, req: HttpRequest, stream: web::Payload, blob_client: web::Data, db_client: web::Data, ) -> Result { - let backup_id = path.into_inner(); ws::WsResponseBuilder::new( LogWSActor { - info: Arc::new(ConnectionInfo { backup_id }), blob_client: blob_client.as_ref().clone(), db_client: db_client.as_ref().clone(), last_msg_time: Instant::now(), buffer: BytesMut::new(), }, &req, stream, ) .frame_size(WS_FRAME_SIZE) .start() } -struct ConnectionInfo { - backup_id: String, -} - #[derive( Debug, derive_more::From, derive_more::Display, derive_more::Error, )] enum LogWSError { Bincode(bincode::Error), Blob(BlobServiceError), DB(database::Error), } struct LogWSActor { - info: Arc, blob_client: BlobServiceClient, db_client: DatabaseClient, last_msg_time: Instant, buffer: BytesMut, } impl LogWSActor { const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); fn handle_msg_sync( &self, ctx: &mut WebsocketContext, bytes: Bytes, ) { - let fut = Self::handle_msg( - self.info.clone(), - self.blob_client.clone(), - self.db_client.clone(), - bytes, - ); + let fut = + Self::handle_msg(self.blob_client.clone(), self.db_client.clone(), bytes); let fut = actix::fut::wrap_future(fut).map( |responses, _: &mut LogWSActor, ctx: &mut WebsocketContext| { let responses = match responses { Ok(responses) => responses, Err(err) => { error!("Error: {err:?}"); vec![LogWSResponse::ServerError] } }; for response in responses { match bincode::serialize(&response) { Ok(bytes) => ctx.binary(bytes), Err(error) => { error!( "Error serializing a response: {response:?}. Error: {error}" ); } }; } }, ); ctx.spawn(fut); } async fn handle_msg( - info: Arc, blob_client: BlobServiceClient, db_client: DatabaseClient, bytes: Bytes, ) -> Result, LogWSError> { let request = bincode::deserialize(&bytes)?; match request { LogWSRequest::UploadLog(UploadLogRequest { + backup_id, log_id, content, attachments, }) => { let mut attachment_blob_infos = Vec::new(); for attachment in attachments.unwrap_or_default() { let blob_info = Self::create_attachment(&blob_client, attachment).await?; attachment_blob_infos.push(blob_info); } let mut log_item = LogItem { - backup_id: info.backup_id.clone(), + backup_id: backup_id.clone(), log_id, content: BlobOrDBContent::new(content), attachments: attachment_blob_infos, }; log_item.ensure_size_constraints(&blob_client).await?; db_client.put_log_item(log_item).await?; - Ok(vec![LogWSResponse::LogUploaded { log_id }]) + Ok(vec![LogWSResponse::LogUploaded { backup_id, log_id }]) } - LogWSRequest::DownloadLogs(DownloadLogsRequest { from_id }) => { + LogWSRequest::DownloadLogs(DownloadLogsRequest { + backup_id, + from_id, + }) => { let (log_items, last_id) = - db_client.fetch_log_items(&info.backup_id, from_id).await?; + db_client.fetch_log_items(&backup_id, from_id).await?; let mut messages = vec![]; for LogItem { log_id, content, attachments, .. } in log_items { let content = content.fetch_bytes(&blob_client).await?; let attachments: Vec = attachments.into_iter().map(|att| att.blob_hash).collect(); let attachments = if attachments.is_empty() { None } else { Some(attachments) }; messages.push(LogWSResponse::LogDownload { log_id, content, attachments, }) } messages.push(LogWSResponse::LogDownloadFinished { last_log_id: last_id, }); Ok(messages) } } } async fn create_attachment( blob_client: &BlobServiceClient, attachment: String, ) -> Result { let blob_info = BlobInfo { blob_hash: attachment, holder: uuid::Uuid::new_v4().to_string(), }; if !blob_client .assign_holder(&blob_info.blob_hash, &blob_info.holder) .await? { warn!( "Blob attachment with hash {:?} doesn't exist", blob_info.blob_hash ); } Ok(blob_info) } } impl Actor for LogWSActor { type Context = ws::WebsocketContext; - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn started(&mut self, ctx: &mut Self::Context) { info!("Socket opened"); ctx.run_interval(Self::HEARTBEAT_INTERVAL, |actor, ctx| { if Instant::now().duration_since(actor.last_msg_time) > Self::CONNECTION_TIMEOUT { warn!("Socket timeout, closing connection"); ctx.stop(); return; } ctx.ping(&[]); }); } - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn stopped(&mut self, _: &mut Self::Context) { info!("Socket closed"); } } impl StreamHandler> for LogWSActor { - #[instrument(skip_all, fields(backup_id = self.info.backup_id))] + #[instrument(skip_all)] fn handle( &mut self, msg: Result, ctx: &mut Self::Context, ) { let msg = match msg { Ok(msg) => msg, Err(err) => { warn!("Error during socket message handling: {err}"); ctx.close(Some(CloseCode::Error.into())); ctx.stop(); return; } }; self.last_msg_time = Instant::now(); match msg { ws::Message::Binary(bytes) => self.handle_msg_sync(ctx, bytes), // Continuations - this is mostly boilerplate code. Some websocket // clients may split a message into these ones ws::Message::Continuation(Item::FirstBinary(bytes)) => { if !self.buffer.is_empty() { warn!("Socket received continuation before previous was completed"); ctx.close(Some(CloseCode::Error.into())); ctx.stop(); return; } self.buffer.extend_from_slice(&bytes); } ws::Message::Continuation(Item::Continue(bytes)) => { if self.buffer.is_empty() { warn!("Socket received continuation message before it was started"); ctx.close(Some(CloseCode::Error.into())); ctx.stop(); return; } self.buffer.extend_from_slice(&bytes); } ws::Message::Continuation(Item::Last(bytes)) => { if self.buffer.is_empty() { warn!( "Socket received last continuation message before it was started" ); ctx.close(Some(CloseCode::Error.into())); ctx.stop(); return; } self.buffer.extend_from_slice(&bytes); let bytes = self.buffer.split(); self.handle_msg_sync(ctx, bytes.into()); } // Heartbeat ws::Message::Ping(message) => ctx.pong(&message), ws::Message::Pong(_) => (), // Other ws::Message::Text(_) | ws::Message::Continuation(Item::FirstText(_)) => { warn!("Socket received unsupported message"); ctx.close(Some(CloseCode::Unsupported.into())); ctx.stop(); } ws::Message::Close(reason) => { info!("Socket was closed"); ctx.close(reason); ctx.stop(); } ws::Message::Nop => (), } } } diff --git a/services/backup/src/http/mod.rs b/services/backup/src/http/mod.rs index 4dc0f86b6..5422b5d2c 100644 --- a/services/backup/src/http/mod.rs +++ b/services/backup/src/http/mod.rs @@ -1,75 +1,75 @@ use actix_web::{web, App, HttpResponse, HttpServer}; use anyhow::Result; use comm_lib::{ blob::client::BlobServiceClient, http::auth::get_comm_authentication_middleware, }; use tracing::info; use crate::{database::DatabaseClient, http::handlers::log::handle_ws, CONFIG}; mod handlers { pub(super) mod backup; pub(super) mod log; } pub async fn run_http_server( db_client: DatabaseClient, blob_client: BlobServiceClient, ) -> Result<()> { info!( "Starting HTTP server listening at port {}", CONFIG.http_port ); let db = web::Data::new(db_client); let blob = web::Data::new(blob_client); HttpServer::new(move || { App::new() .wrap(tracing_actix_web::TracingLogger::default()) .wrap(comm_lib::http::cors_config( CONFIG.localstack_endpoint.is_some(), )) .app_data(db.clone()) .app_data(blob.clone()) .route("/health", web::get().to(HttpResponse::Ok)) .service( // Backup services that don't require authetication web::scope("/backups/latest") .service( web::resource("{username}/backup_id") .route(web::get().to(handlers::backup::get_latest_backup_id)), ) .service(web::resource("{username}/user_keys").route( web::get().to(handlers::backup::download_latest_backup_keys), )), ) .service( // Backup services requiring authetication web::scope("/backups") .wrap(get_comm_authentication_middleware()) .service( web::resource("").route(web::post().to(handlers::backup::upload)), ) .service( web::resource("{backup_id}/user_keys") .route(web::get().to(handlers::backup::download_user_keys)), ) .service( web::resource("{backup_id}/user_data") .route(web::get().to(handlers::backup::download_user_data)), ), ) .service( web::scope("/logs") .wrap(get_comm_authentication_middleware()) - .service(web::resource("{backup_d}").route(web::get().to(handle_ws))), + .service(web::resource("").route(web::get().to(handle_ws))), ) }) .bind(("0.0.0.0", CONFIG.http_port))? .run() .await?; Ok(()) } diff --git a/services/commtest/tests/backup_integration_test.rs b/services/commtest/tests/backup_integration_test.rs index c12ba9441..fe958a494 100644 --- a/services/commtest/tests/backup_integration_test.rs +++ b/services/commtest/tests/backup_integration_test.rs @@ -1,249 +1,255 @@ use std::collections::{HashMap, HashSet}; use backup_client::{ BackupClient, BackupData, BackupDescriptor, Error as BackupClientError, - RequestedData, SinkExt, StreamExt, TryStreamExt, + LogUploadConfirmation, RequestedData, SinkExt, StreamExt, TryStreamExt, }; use bytesize::ByteSize; use comm_lib::{ auth::UserIdentity, backup::{ DownloadLogsRequest, LatestBackupIDResponse, LogWSResponse, UploadLogRequest, }, }; use commtest::{ service_addr, tools::{generate_stable_nbytes, Error}, }; use reqwest::StatusCode; use uuid::Uuid; #[tokio::test] async fn backup_integration_test() -> Result<(), Error> { let backup_client = BackupClient::new(service_addr::BACKUP_SERVICE_HTTP)?; let user_identity = UserIdentity { user_id: "1".to_string(), access_token: "dummy access token".to_string(), device_id: "dummy device_id".to_string(), }; let backup_datas = generate_backup_data(); // Upload backups for (backup_data, log_datas) in &backup_datas { backup_client .upload_backup(&user_identity, backup_data.clone()) .await?; - let (tx, rx) = backup_client - .upload_logs(&user_identity, &backup_data.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.upload_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); for log_data in log_datas { tx.send(log_data.clone()).await.unwrap(); } - let result: HashSet = + let result: HashSet = rx.take(log_datas.len()).try_collect().await.unwrap(); - let expected = log_datas.iter().map(|data| data.log_id).collect(); + let expected = log_datas + .iter() + .map(|data| LogUploadConfirmation { + backup_id: data.backup_id.clone(), + log_id: data.log_id, + }) + .collect(); + assert_eq!(result, expected); } // Test direct lookup let (backup_data, log_datas) = &backup_datas[1]; let second_backup_descriptor = BackupDescriptor::BackupID { backup_id: backup_data.backup_id.clone(), user_identity: user_identity.clone(), }; let user_keys = backup_client .download_backup_data(&second_backup_descriptor, RequestedData::UserKeys) .await?; assert_eq!(user_keys, backup_data.user_keys); let user_data = backup_client .download_backup_data(&second_backup_descriptor, RequestedData::UserData) .await?; assert_eq!(user_data, backup_data.user_data); // Test latest backup lookup let latest_backup_descriptor = BackupDescriptor::Latest { // Initial version of the backup service uses `user_id` in place of a username username: "1".to_string(), }; let backup_id_response = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::BackupID) .await?; let response: LatestBackupIDResponse = serde_json::from_slice(&backup_id_response)?; assert_eq!(response.backup_id, backup_data.backup_id); let user_keys = backup_client .download_backup_data(&latest_backup_descriptor, RequestedData::UserKeys) .await?; assert_eq!(user_keys, backup_data.user_keys); // Test log download - let (tx, rx) = backup_client - .download_logs(&user_identity, &backup_data.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }) - .await - .unwrap(); + tx.send(DownloadLogsRequest { + backup_id: backup_data.backup_id.clone(), + from_id: None, + }) + .await + .unwrap(); let mut downloaded_logs = HashMap::new(); 'download: loop { loop { match rx.next().await.unwrap().unwrap() { LogWSResponse::LogDownload { log_id, content, attachments, } => { downloaded_logs.insert(log_id, (content, attachments)); } LogWSResponse::LogDownloadFinished { last_log_id } => { if let Some(last_log_id) = last_log_id { tx.send(DownloadLogsRequest { + backup_id: backup_data.backup_id.clone(), from_id: Some(last_log_id), }) .await .unwrap(); } else { break 'download; } } msg => panic!("Got response: {msg:?}"), }; } } let expected_logs = log_datas .iter() .cloned() .map(|data| (data.log_id, (data.content, data.attachments))) .collect(); assert_eq!(downloaded_logs, expected_logs); // Test backup cleanup let (removed_backup, _) = &backup_datas[0]; let removed_backup_descriptor = BackupDescriptor::BackupID { backup_id: removed_backup.backup_id.clone(), user_identity: user_identity.clone(), }; let response = backup_client .download_backup_data(&removed_backup_descriptor, RequestedData::UserKeys) .await; let Err(BackupClientError::ReqwestError(error)) = response else { panic!("First backup should have been removed, instead got response: {response:?}"); }; assert_eq!( error.status(), Some(StatusCode::NOT_FOUND), "Expected status 'not found'" ); // Test log cleanup - let (tx, rx) = backup_client - .download_logs(&user_identity, &removed_backup.backup_id) - .await - .unwrap(); + let (tx, rx) = backup_client.download_logs(&user_identity).await.unwrap(); tokio::pin!(tx); tokio::pin!(rx); - tx.send(DownloadLogsRequest { from_id: None }) - .await - .unwrap(); + tx.send(DownloadLogsRequest { + backup_id: removed_backup.backup_id.clone(), + from_id: None, + }) + .await + .unwrap(); match rx.next().await.unwrap().unwrap() { LogWSResponse::LogDownloadFinished { last_log_id: None } => (), msg => { panic!( "Logs for first backup should have been removed, \ instead got response: {msg:?}" ) } }; Ok(()) } fn generate_backup_data() -> [(BackupData, Vec); 2] { [ ( BackupData { backup_id: "b1".to_string(), user_keys: generate_stable_nbytes( ByteSize::kib(4).as_u64() as usize, Some(b'a'), ), user_data: generate_stable_nbytes( ByteSize::mib(4).as_u64() as usize, Some(b'A'), ), attachments: vec![], }, - generate_log_data(b'a'), + generate_log_data("b1", b'a'), ), ( BackupData { backup_id: "b2".to_string(), user_keys: generate_stable_nbytes( ByteSize::kib(4).as_u64() as usize, Some(b'b'), ), user_data: generate_stable_nbytes( ByteSize::mib(4).as_u64() as usize, Some(b'B'), ), attachments: vec![], }, - generate_log_data(b'b'), + generate_log_data("b2", b'b'), ), ] } -fn generate_log_data(value: u8) -> Vec { +fn generate_log_data(backup_id: &str, value: u8) -> Vec { const IN_DB_SIZE: usize = ByteSize::kib(4).as_u64() as usize; const IN_BLOB_SIZE: usize = ByteSize::kib(400).as_u64() as usize; (1..30) .map(|log_id| { let size = if log_id % 2 == 0 { IN_DB_SIZE } else { IN_BLOB_SIZE }; let attachments = if log_id % 10 == 0 { Some(vec![Uuid::new_v4().to_string()]) } else { None }; let mut content = generate_stable_nbytes(size, Some(value)); let unique_suffix = log_id.to_string(); content.extend(unique_suffix.as_bytes()); UploadLogRequest { + backup_id: backup_id.to_string(), log_id, content, attachments, } }) .collect() } diff --git a/shared/backup_client/src/lib.rs b/shared/backup_client/src/lib.rs index b9f499abf..a57b62511 100644 --- a/shared/backup_client/src/lib.rs +++ b/shared/backup_client/src/lib.rs @@ -1,305 +1,310 @@ pub use comm_lib::auth::UserIdentity; pub use comm_lib::backup::{ DownloadLogsRequest, LatestBackupIDResponse, LogWSRequest, LogWSResponse, UploadLogRequest, }; pub use futures_util::{SinkExt, StreamExt, TryStreamExt}; use futures_util::{Sink, Stream}; use hex::ToHex; use reqwest::{ header::InvalidHeaderValue, multipart::{Form, Part}, Body, }; use sha2::{Digest, Sha256}; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{ connect_async, tungstenite::{ client::IntoClientRequest, http::{header, Request}, Error as TungsteniteError, Message::{Binary, Ping}, }, }; #[derive(Debug, Clone)] pub struct BackupClient { url: reqwest::Url, } impl BackupClient { pub fn new>(url: T) -> Result { Ok(BackupClient { url: url.try_into()?, }) } } /// Backup functions impl BackupClient { pub async fn upload_backup( &self, user_identity: &UserIdentity, backup_data: BackupData, ) -> Result<(), Error> { let BackupData { backup_id, user_keys, user_data, attachments, } = backup_data; let client = reqwest::Client::new(); let form = Form::new() .text("backup_id", backup_id) .text( "user_keys_hash", Sha256::digest(&user_keys).encode_hex::(), ) .part("user_keys", Part::stream(Body::from(user_keys))) .text( "user_data_hash", Sha256::digest(&user_data).encode_hex::(), ) .part("user_data", Part::stream(Body::from(user_data))) .text("attachments", attachments.join("\n")); let response = client .post(self.url.join("backups")?) .bearer_auth(user_identity.as_authorization_token()?) .multipart(form) .send() .await?; response.error_for_status()?; Ok(()) } pub async fn download_backup_data( &self, backup_descriptor: &BackupDescriptor, requested_data: RequestedData, ) -> Result, Error> { let client = reqwest::Client::new(); let url = self.url.join("backups/")?; let url = match backup_descriptor { BackupDescriptor::BackupID { backup_id, .. } => { url.join(&format!("{backup_id}/"))? } BackupDescriptor::Latest { username } => { url.join(&format!("latest/{username}/"))? } }; let url = match &requested_data { RequestedData::BackupID => url.join("backup_id")?, RequestedData::UserKeys => url.join("user_keys")?, RequestedData::UserData => url.join("user_data")?, }; let mut request = client.get(url); if let BackupDescriptor::BackupID { user_identity, .. } = backup_descriptor { request = request.bearer_auth(user_identity.as_authorization_token()?) } let response = request.send().await?; let result = response.error_for_status()?.bytes().await?.to_vec(); Ok(result) } } /// Log functions impl BackupClient { pub async fn upload_logs( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result< ( impl Sink, - impl Stream>, + impl Stream>, ), Error, > { - let request = self.create_ws_request(user_identity, backup_id)?; + let request = self.create_ws_request(user_identity)?; let (stream, response) = connect_async(request).await?; if response.status().is_client_error() { return Err(Error::WSInitError(TungsteniteError::Http(response))); } let (tx, rx) = stream.split(); let tx = tx.with(|request: UploadLogRequest| async { let request = LogWSRequest::UploadLog(request); let request = bincode::serialize(&request)?; Ok(Binary(request)) }); let rx = rx.filter_map(|msg| async { let response = match get_log_ws_response(msg) { Some(Ok(response)) => response, Some(Err(err)) => return Some(Err(err)), None => return None, }; match response { - LogWSResponse::LogUploaded { log_id } => Some(Ok(log_id)), + LogWSResponse::LogUploaded { backup_id, log_id } => { + Some(Ok(LogUploadConfirmation { backup_id, log_id })) + } LogWSResponse::LogDownload { .. } | LogWSResponse::LogDownloadFinished { .. } => { Some(Err(WSError::InvalidBackupMessage)) } LogWSResponse::ServerError => Some(Err(WSError::ServerError)), } }); Ok((tx, rx)) } pub async fn download_logs( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result< ( impl Sink, impl Stream>, ), Error, > { - let request = self.create_ws_request(user_identity, backup_id)?; + let request = self.create_ws_request(user_identity)?; let (stream, response) = connect_async(request).await?; if response.status().is_client_error() { return Err(Error::WSInitError(TungsteniteError::Http(response))); } let (tx, rx) = stream.split(); let tx = tx.with(|request: DownloadLogsRequest| async { let request = LogWSRequest::DownloadLogs(request); let request = bincode::serialize(&request)?; Ok(Binary(request)) }); let rx = rx.filter_map(|msg| async { let response = match get_log_ws_response(msg) { Some(Ok(response)) => response, Some(Err(err)) => return Some(Err(err)), None => return None, }; match response { LogWSResponse::LogDownloadFinished { .. } | LogWSResponse::LogDownload { .. } => Some(Ok(response)), LogWSResponse::LogUploaded { .. } => { Some(Err(WSError::InvalidBackupMessage)) } LogWSResponse::ServerError => Some(Err(WSError::ServerError)), } }); Ok((tx, rx)) } fn create_ws_request( &self, user_identity: &UserIdentity, - backup_id: &str, ) -> Result, Error> { let mut url = self.url.clone(); match url.scheme() { "http" => url.set_scheme("ws")?, "https" => url.set_scheme("wss")?, _ => (), }; - let url = url.join("logs/")?.join(backup_id)?; + let url = url.join("logs")?; let mut request = url.into_client_request().unwrap(); let token = user_identity.as_authorization_token()?; request .headers_mut() .insert(header::AUTHORIZATION, format!("Bearer {token}").parse()?); Ok(request) } } #[derive(Debug, Clone)] pub struct BackupData { pub backup_id: String, pub user_keys: Vec, pub user_data: Vec, pub attachments: Vec, } #[derive(Debug, Clone)] pub enum BackupDescriptor { BackupID { backup_id: String, user_identity: UserIdentity, }, Latest { username: String, }, } #[derive(Debug, Clone, Copy)] pub enum RequestedData { BackupID, UserKeys, UserData, } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LogUploadConfirmation { + pub backup_id: String, + pub log_id: usize, +} + #[derive( Debug, derive_more::Display, derive_more::Error, derive_more::From, )] pub enum Error { InvalidAuthorizationHeader, UrlError(url::ParseError), ReqwestError(reqwest::Error), WSInitError(TungsteniteError), JsonError(serde_json::Error), } impl From for Error { fn from(_: InvalidHeaderValue) -> Self { Self::InvalidAuthorizationHeader } } #[derive( Debug, derive_more::Display, derive_more::Error, derive_more::From, )] pub enum WSError { BincodeError(bincode::Error), TungsteniteError(TungsteniteError), InvalidWSMessage, InvalidBackupMessage, ServerError, } fn get_log_ws_response( msg: Result, ) -> Option> { let bytes = match msg { Ok(Binary(bytes)) => bytes, // Handled by tungstenite Ok(Ping(_)) => return None, Ok(_) => return Some(Err(WSError::InvalidWSMessage)), Err(err) => return Some(Err(err.into())), }; match bincode::deserialize(&bytes) { Ok(response) => Some(Ok(response)), Err(err) => Some(Err(err.into())), } } diff --git a/shared/comm-lib/src/backup/mod.rs b/shared/comm-lib/src/backup/mod.rs index dc4775a2b..06d7bf245 100644 --- a/shared/comm-lib/src/backup/mod.rs +++ b/shared/comm-lib/src/backup/mod.rs @@ -1,41 +1,44 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LatestBackupIDResponse { #[serde(rename = "backupID")] pub backup_id: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UploadLogRequest { + pub backup_id: String, pub log_id: usize, pub content: Vec, pub attachments: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DownloadLogsRequest { + pub backup_id: String, pub from_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub enum LogWSRequest { UploadLog(UploadLogRequest), DownloadLogs(DownloadLogsRequest), } #[derive(Debug, Clone, Serialize, Deserialize)] pub enum LogWSResponse { LogUploaded { + backup_id: String, log_id: usize, }, LogDownload { log_id: usize, content: Vec, attachments: Option>, }, LogDownloadFinished { last_log_id: Option, }, ServerError, }