diff --git a/services/backup/src/http/handlers/backup.rs b/services/backup/src/http/handlers/backup.rs --- a/services/backup/src/http/handlers/backup.rs +++ b/services/backup/src/http/handlers/backup.rs @@ -10,6 +10,7 @@ backup::LatestBackupIDResponse, blob::{client::BlobServiceClient, types::BlobInfo}, http::multipart::{get_named_text_field, get_text_field}, + tools::Defer, }; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tracing::{info, instrument, trace, warn}; @@ -32,7 +33,7 @@ tracing::Span::current().record("backup_id", &backup_id); - let user_keys_blob_info = forward_field_to_blob( + let (user_keys_blob_info, user_keys_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_keys_hash", @@ -40,7 +41,7 @@ ) .await?; - let user_data_blob_info = forward_field_to_blob( + let (user_data_blob_info, user_data_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_data_hash", @@ -76,6 +77,10 @@ .put_backup_item(item) .await .map_err(BackupError::from)?; + + user_keys_revoke.cancel(); + user_data_revoke.cancel(); + Ok(HttpResponse::Ok().finish()) } @@ -84,12 +89,12 @@ name = "forward_to_blob", fields(hash_field_name, data_field_name) )] -async fn forward_field_to_blob( +async fn forward_field_to_blob<'revoke, 'blob: 'revoke>( multipart: &mut actix_multipart::Multipart, - blob_client: &web::Data, + blob_client: &'blob web::Data, hash_field_name: &str, data_field_name: &str, -) -> actix_web::Result { +) -> actix_web::Result<(BlobInfo, Defer<'revoke>)> { trace!("Reading blob fields: {hash_field_name:?}, {data_field_name:?}"); let blob_hash = get_named_text_field(hash_field_name, multipart).await?; @@ -143,7 +148,13 @@ tokio::try_join!(receive_promise, send_promise)?; - Ok(blob_info) + let revoke_info = blob_info.clone(); + let revoke_holder = Defer::new(|| { + blob_client + .schedule_revoke_holder(revoke_info.blob_hash, revoke_info.holder) + }); + + Ok((blob_info, revoke_holder)) } #[instrument(name = "download_user_keys", skip_all, fields(backup_id = %path.as_str()))] diff --git a/services/comm-services-lib/src/tools.rs b/services/comm-services-lib/src/tools.rs --- a/services/comm-services-lib/src/tools.rs +++ b/services/comm-services-lib/src/tools.rs @@ -20,6 +20,58 @@ pub type BoxedError = Box; +/// Defers call of the provided function to when [Defer] goes out of scope. +/// This can be used for cleanup code that must be run when e.g. the enclosing +/// function exits either by return or try operator `?`. +/// +/// # Example +/// ```ignore +/// fn f(){ +/// let _ = Defer::new(|| println!("cleanup")) +/// +/// // Cleanup will run if function would exit here +/// operation_that_can_fail()?; +/// +/// if should_exit_early { +/// // Cleanup will run if function would exit here +/// return; +/// } +/// } +/// ``` +pub struct Defer<'s>(Option>); + +impl<'s> Defer<'s> { + pub fn new(f: impl FnOnce() + 's) -> Self { + Self(Some(Box::new(f))) + } + + /// Consumes the value, without calling the provided function + /// + /// # Example + /// ```ignore + /// // Start a "transaction" + /// operation_that_should_be_reverted(); + /// let revert = Defer::new(|| println!("revert")) + /// operation_that_can_fail()?; + /// operation_that_can_fail()?; + /// operation_that_can_fail()?; + /// // Now we can "commit" the changes + /// revert.cancel(); + /// ``` + pub fn cancel(mut self) { + self.0 = None; + // Implicit drop + } +} + +impl Drop for Defer<'_> { + fn drop(&mut self) { + if let Some(f) = self.0.take() { + f(); + } + } +} + #[cfg(test)] mod valid_identifier_tests { use super::*; @@ -67,4 +119,39 @@ fn empty_is_invalid() { assert!(!is_valid_identifier("")); } + + #[test] + fn defer_runs() { + fn f(a: &mut bool) { + let _ = Defer::new(|| *a = true); + } + + let mut v = false; + f(&mut v); + assert!(v) + } + + #[test] + fn consumed_defer_doesnt_run() { + fn f(a: &mut bool) { + let defer = Defer::new(|| *a = true); + defer.cancel(); + } + + let mut v = false; + f(&mut v); + assert!(!v) + } + + #[test] + fn defer_runs_on_try() { + fn f(a: &mut bool) -> Result<(), ()> { + let _ = Defer::new(|| *a = true); + Err(()) + } + + let mut v = false; + let _ = f(&mut v); + assert!(v) + } }