diff --git a/services/backup/src/http/handlers/backup.rs b/services/backup/src/http/handlers/backup.rs index cd1228f44..3b551f6df 100644 --- a/services/backup/src/http/handlers/backup.rs +++ b/services/backup/src/http/handlers/backup.rs @@ -1,264 +1,275 @@ use std::{collections::HashSet, convert::Infallible}; use actix_web::{ error::ErrorBadRequest, web::{self, Bytes}, HttpResponse, Responder, }; use comm_services_lib::{ auth::UserIdentity, backup::LatestBackupIDResponse, blob::{client::BlobServiceClient, types::BlobInfo}, http::multipart::{get_named_text_field, get_text_field}, + tools::Defer, }; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tracing::{info, instrument, trace, warn}; use crate::{ database::{backup_item::BackupItem, DatabaseClient}, error::BackupError, }; #[instrument(name = "upload_backup", skip_all, fields(backup_id))] pub async fn upload( user: UserIdentity, blob_client: web::Data, db_client: web::Data, mut multipart: actix_multipart::Multipart, ) -> actix_web::Result { info!("Upload backup request"); let backup_id = get_named_text_field("backup_id", &mut multipart).await?; tracing::Span::current().record("backup_id", &backup_id); - let user_keys_blob_info = forward_field_to_blob( + let (user_keys_blob_info, user_keys_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_keys_hash", "user_keys", ) .await?; - let user_data_blob_info = forward_field_to_blob( + let (user_data_blob_info, user_data_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_data_hash", "user_data", ) .await?; let attachments_holders: HashSet = match get_text_field(&mut multipart).await? { Some((name, attachments)) => { if name != "attachments" { warn!( name, "Malformed request: 'attachments' text field expected." ); return Err(ErrorBadRequest("Bad request")); } attachments.lines().map(ToString::to_string).collect() } None => HashSet::new(), }; let item = BackupItem::new( user.user_id, backup_id, user_keys_blob_info, user_data_blob_info, attachments_holders, ); db_client .put_backup_item(item) .await .map_err(BackupError::from)?; + + user_keys_revoke.cancel(); + user_data_revoke.cancel(); + Ok(HttpResponse::Ok().finish()) } #[instrument( skip_all, name = "forward_to_blob", fields(hash_field_name, data_field_name) )] -async fn forward_field_to_blob( +async fn forward_field_to_blob<'revoke, 'blob: 'revoke>( multipart: &mut actix_multipart::Multipart, - blob_client: &web::Data, + blob_client: &'blob web::Data, hash_field_name: &str, data_field_name: &str, -) -> actix_web::Result { +) -> actix_web::Result<(BlobInfo, Defer<'revoke>)> { trace!("Reading blob fields: {hash_field_name:?}, {data_field_name:?}"); let blob_hash = get_named_text_field(hash_field_name, multipart).await?; let Some(mut field) = multipart.try_next().await? else { warn!("Malformed request: expected a field."); return Err(ErrorBadRequest("Bad request"))?; }; if field.name() != data_field_name { warn!( hash_field_name, "Malformed request: '{data_field_name}' data field expected." ); return Err(ErrorBadRequest("Bad request"))?; } let blob_info = BlobInfo { blob_hash, holder: uuid::Uuid::new_v4().to_string(), }; // [`actix_multipart::Multipart`] isn't [`std::marker::Send`], and so we cannot pass it to the blob client directly. // Instead we have to forward it to a channel and create stream from the receiver. let (tx, rx) = tokio::sync::mpsc::channel(1); let receive_promise = async move { trace!("Receiving blob data"); // [`actix_multipart::MultipartError`] isn't [`std::marker::Send`] so we return it here, and pass [`Infallible`] // as the error to the channel while let Some(chunk) = field.try_next().await? { if let Err(err) = tx.send(Result::::Ok(chunk)).await { warn!("Error when sending data through a channel: '{err}'"); // Error here means that the channel has been closed from the blob client side. We don't want to return an error // here, because `tokio::try_join!` only returns the first error it receives and we want to prioritize the backup // client error. break; } } trace!("Finished receiving blob data"); Result::<(), actix_web::Error>::Ok(()) }; let data_stream = ReceiverStream::new(rx); let send_promise = async { blob_client .simple_put(&blob_info.blob_hash, &blob_info.holder, data_stream) .await .map_err(BackupError::from)?; Ok(()) }; tokio::try_join!(receive_promise, send_promise)?; - Ok(blob_info) + let revoke_info = blob_info.clone(); + let revoke_holder = Defer::new(|| { + blob_client + .schedule_revoke_holder(revoke_info.blob_hash, revoke_info.holder) + }); + + Ok((blob_info, revoke_holder)) } #[instrument(name = "download_user_keys", skip_all, fields(backup_id = %path.as_str()))] pub async fn download_user_keys( user: UserIdentity, path: web::Path, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { info!("Download user keys request"); let backup_id = path.into_inner(); download_user_blob( |item| &item.user_keys, &user.user_id, &backup_id, blob_client, db_client, ) .await } #[instrument(name = "download_user_data", skip_all, fields(backup_id = %path.as_str()))] pub async fn download_user_data( user: UserIdentity, path: web::Path, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { info!("Download user data request"); let backup_id = path.into_inner(); download_user_blob( |item| &item.user_data, &user.user_id, &backup_id, blob_client, db_client, ) .await } pub async fn download_user_blob( data_extractor: impl FnOnce(&BackupItem) -> &BlobInfo, user_id: &str, backup_id: &str, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { let backup_item = db_client .find_backup_item(user_id, backup_id) .await .map_err(BackupError::from)? .ok_or(BackupError::NoBackup)?; let stream = blob_client .get(&data_extractor(&backup_item).blob_hash) .await .map_err(BackupError::from)?; Ok( HttpResponse::Ok() .content_type("application/octet-stream") .streaming(stream), ) } #[instrument(name = "get_latest_backup_id", skip_all, fields(username = %path.as_str()))] pub async fn get_latest_backup_id( path: web::Path, db_client: web::Data, ) -> actix_web::Result { let username = path.into_inner(); // Treat username as user_id in the initial version let user_id = username; let Some(backup_item) = db_client .find_last_backup_item(&user_id) .await .map_err(BackupError::from)? else { return Err(BackupError::NoBackup.into()); }; let response = LatestBackupIDResponse { backup_id: backup_item.backup_id, }; Ok(web::Json(response)) } #[instrument(name = "download_latest_backup_keys", skip_all, fields(username = %path.as_str()))] pub async fn download_latest_backup_keys( path: web::Path, db_client: web::Data, blob_client: web::Data, ) -> actix_web::Result { let username = path.into_inner(); // Treat username as user_id in the initial version let user_id = username; let Some(backup_item) = db_client .find_last_backup_item(&user_id) .await .map_err(BackupError::from)? else { return Err(BackupError::NoBackup.into()); }; let stream = blob_client .get(&backup_item.user_keys.blob_hash) .await .map_err(BackupError::from)?; Ok( HttpResponse::Ok() .content_type("application/octet-stream") .streaming(stream), ) } diff --git a/services/comm-services-lib/src/tools.rs b/services/comm-services-lib/src/tools.rs index 9c5c73e11..2b4e41d22 100644 --- a/services/comm-services-lib/src/tools.rs +++ b/services/comm-services-lib/src/tools.rs @@ -1,70 +1,157 @@ // colon is valid because it is used as a separator // in some backup service identifiers const VALID_IDENTIFIER_CHARS: &'static [char] = &['_', '-', '=', ':']; /// Checks if the given string is a valid identifier for an entity /// (e.g. backup ID, blob hash, blob holder). /// /// Some popular identifier formats are considered valid, including UUID, /// nanoid, base64url. On the other hand, path or url-like identifiers /// are not supposed to be valid pub fn is_valid_identifier(identifier: &str) -> bool { if identifier.is_empty() { return false; } identifier .chars() .all(|c| c.is_ascii_alphanumeric() || VALID_IDENTIFIER_CHARS.contains(&c)) } pub type BoxedError = Box; +/// Defers call of the provided function to when [Defer] goes out of scope. +/// This can be used for cleanup code that must be run when e.g. the enclosing +/// function exits either by return or try operator `?`. +/// +/// # Example +/// ```ignore +/// fn f(){ +/// let _ = Defer::new(|| println!("cleanup")) +/// +/// // Cleanup will run if function would exit here +/// operation_that_can_fail()?; +/// +/// if should_exit_early { +/// // Cleanup will run if function would exit here +/// return; +/// } +/// } +/// ``` +pub struct Defer<'s>(Option>); + +impl<'s> Defer<'s> { + pub fn new(f: impl FnOnce() + 's) -> Self { + Self(Some(Box::new(f))) + } + + /// Consumes the value, without calling the provided function + /// + /// # Example + /// ```ignore + /// // Start a "transaction" + /// operation_that_should_be_reverted(); + /// let revert = Defer::new(|| println!("revert")) + /// operation_that_can_fail()?; + /// operation_that_can_fail()?; + /// operation_that_can_fail()?; + /// // Now we can "commit" the changes + /// revert.cancel(); + /// ``` + pub fn cancel(mut self) { + self.0 = None; + // Implicit drop + } +} + +impl Drop for Defer<'_> { + fn drop(&mut self) { + if let Some(f) = self.0.take() { + f(); + } + } +} + #[cfg(test)] mod valid_identifier_tests { use super::*; #[test] fn alphanumeric_identifier() { assert!(is_valid_identifier("some_identifier_v123")); } #[test] fn alphanumeric_with_colon() { assert!(is_valid_identifier("some_identifier:with_colon")); } #[test] fn uuid_is_valid() { let example_uuid = "a2b9e4d4-8d1f-4c7f-9c3d-5f4e4e6b1d1d"; assert!(is_valid_identifier(example_uuid)); } #[test] fn base64url_is_valid() { let example_base64url = "VGhlIP3-aWNrIGJyb3duIGZveCBqciAxMyBsYXp5IGRvZ_7_"; assert!(is_valid_identifier(example_base64url)) } #[test] fn standard_base64_is_invalid() { let example_base64 = "VGhlIP3+aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9n/v8="; assert!(!is_valid_identifier(example_base64)); } #[test] fn path_is_invalid() { assert!(!is_valid_identifier("some/path")); } #[test] fn url_is_invalid() { assert!(!is_valid_identifier("https://example.com")); } #[test] fn empty_is_invalid() { assert!(!is_valid_identifier("")); } + + #[test] + fn defer_runs() { + fn f(a: &mut bool) { + let _ = Defer::new(|| *a = true); + } + + let mut v = false; + f(&mut v); + assert!(v) + } + + #[test] + fn consumed_defer_doesnt_run() { + fn f(a: &mut bool) { + let defer = Defer::new(|| *a = true); + defer.cancel(); + } + + let mut v = false; + f(&mut v); + assert!(!v) + } + + #[test] + fn defer_runs_on_try() { + fn f(a: &mut bool) -> Result<(), ()> { + let _ = Defer::new(|| *a = true); + Err(()) + } + + let mut v = false; + let _ = f(&mut v); + assert!(v) + } }