diff --git a/services/backup/src/constants.rs b/services/backup/src/constants.rs index d4d58aca0..0fbb3bf60 100644 --- a/services/backup/src/constants.rs +++ b/services/backup/src/constants.rs @@ -1,32 +1,32 @@ // Assorted constants pub const MPSC_CHANNEL_BUFFER_CAPACITY: usize = 1; pub const ID_SEPARATOR: &str = ":"; pub const ATTACHMENT_HOLDER_SEPARATOR: &str = ";"; // Configuration defaults pub const DEFAULT_HTTP_PORT: u16 = 50052; pub const DEFAULT_BLOB_SERVICE_URL: &str = "http://localhost:50053"; // Environment variable names pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; // DynamoDB constants pub const BACKUP_TABLE_NAME: &str = "backup-service-backup"; pub const BACKUP_TABLE_FIELD_USER_ID: &str = "userID"; pub const BACKUP_TABLE_FIELD_BACKUP_ID: &str = "backupID"; pub const BACKUP_TABLE_FIELD_CREATED: &str = "created"; pub const BACKUP_TABLE_FIELD_USER_DATA: &str = "userData"; pub const BACKUP_TABLE_FIELD_USER_KEYS: &str = "userKeys"; -pub const BACKUP_TABLE_FIELD_ATTACHMENT_HOLDERS: &str = "attachmentHolders"; +pub const BACKUP_TABLE_FIELD_ATTACHMENTS: &str = "attachments"; pub const BACKUP_TABLE_INDEX_USERID_CREATED: &str = "userID-created-index"; pub const LOG_TABLE_NAME: &str = "backup-service-log"; pub const LOG_TABLE_FIELD_BACKUP_ID: &str = "backupID"; pub const LOG_TABLE_FIELD_LOG_ID: &str = "logID"; pub const LOG_TABLE_FIELD_PERSISTED_IN_BLOB: &str = "persistedInBlob"; pub const LOG_TABLE_FIELD_VALUE: &str = "value"; pub const LOG_TABLE_FIELD_ATTACHMENT_HOLDERS: &str = "attachmentHolders"; pub const LOG_TABLE_FIELD_DATA_HASH: &str = "dataHash"; diff --git a/services/backup/src/database/backup_item.rs b/services/backup/src/database/backup_item.rs index 746691558..d4ea6da7a 100644 --- a/services/backup/src/database/backup_item.rs +++ b/services/backup/src/database/backup_item.rs @@ -1,170 +1,188 @@ use aws_sdk_dynamodb::types::AttributeValue; use chrono::{DateTime, Utc}; use comm_services_lib::{ - blob::types::BlobInfo, - database::{DBItemError, TryFromAttribute}, + blob::{client::BlobServiceClient, types::BlobInfo}, + database::{AttributeTryInto, DBItemError, TryFromAttribute}, }; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use crate::constants::{ - BACKUP_TABLE_FIELD_ATTACHMENT_HOLDERS, BACKUP_TABLE_FIELD_BACKUP_ID, + BACKUP_TABLE_FIELD_ATTACHMENTS, BACKUP_TABLE_FIELD_BACKUP_ID, BACKUP_TABLE_FIELD_CREATED, BACKUP_TABLE_FIELD_USER_DATA, BACKUP_TABLE_FIELD_USER_ID, BACKUP_TABLE_FIELD_USER_KEYS, }; #[derive(Clone, Debug)] pub struct BackupItem { pub user_id: String, pub backup_id: String, pub created: DateTime, pub user_keys: BlobInfo, pub user_data: BlobInfo, - pub attachment_holders: HashSet, + pub attachments: Vec, } impl BackupItem { pub fn new( user_id: String, backup_id: String, user_keys: BlobInfo, user_data: BlobInfo, - attachment_holders: HashSet, + attachments: Vec, ) -> Self { BackupItem { user_id, backup_id, created: chrono::Utc::now(), user_keys, user_data, - attachment_holders, + attachments, + } + } + + pub async fn revoke_holders(self, blob_client: &BlobServiceClient) { + blob_client + .schedule_revoke_holder(self.user_keys.blob_hash, self.user_keys.holder); + + blob_client + .schedule_revoke_holder(self.user_data.blob_hash, self.user_data.holder); + + for attachment_info in self.attachments { + blob_client.schedule_revoke_holder( + attachment_info.blob_hash, + attachment_info.holder, + ); } } } impl From for HashMap { fn from(value: BackupItem) -> Self { let mut attrs = HashMap::from([ ( BACKUP_TABLE_FIELD_USER_ID.to_string(), AttributeValue::S(value.user_id), ), ( BACKUP_TABLE_FIELD_BACKUP_ID.to_string(), AttributeValue::S(value.backup_id), ), ( BACKUP_TABLE_FIELD_CREATED.to_string(), AttributeValue::S(value.created.to_rfc3339()), ), ( BACKUP_TABLE_FIELD_USER_KEYS.to_string(), value.user_keys.into(), ), ( BACKUP_TABLE_FIELD_USER_DATA.to_string(), value.user_data.into(), ), ]); - if !value.attachment_holders.is_empty() { + if !value.attachments.is_empty() { attrs.insert( - BACKUP_TABLE_FIELD_ATTACHMENT_HOLDERS.to_string(), - AttributeValue::Ss(value.attachment_holders.into_iter().collect()), + BACKUP_TABLE_FIELD_ATTACHMENTS.to_string(), + AttributeValue::L( + value + .attachments + .into_iter() + .map(AttributeValue::from) + .collect(), + ), ); } attrs } } impl TryFrom> for BackupItem { type Error = DBItemError; fn try_from( mut value: HashMap, ) -> Result { let user_id = String::try_from_attr( BACKUP_TABLE_FIELD_USER_ID, value.remove(BACKUP_TABLE_FIELD_USER_ID), )?; let backup_id = String::try_from_attr( BACKUP_TABLE_FIELD_BACKUP_ID, value.remove(BACKUP_TABLE_FIELD_BACKUP_ID), )?; let created = DateTime::::try_from_attr( BACKUP_TABLE_FIELD_CREATED, value.remove(BACKUP_TABLE_FIELD_CREATED), )?; let user_keys = BlobInfo::try_from_attr( BACKUP_TABLE_FIELD_USER_KEYS, value.remove(BACKUP_TABLE_FIELD_USER_KEYS), )?; let user_data = BlobInfo::try_from_attr( BACKUP_TABLE_FIELD_USER_DATA, value.remove(BACKUP_TABLE_FIELD_USER_DATA), )?; - let attachments = value.remove(BACKUP_TABLE_FIELD_ATTACHMENT_HOLDERS); - let attachment_holders = if attachments.is_some() { - HashSet::::try_from_attr( - BACKUP_TABLE_FIELD_ATTACHMENT_HOLDERS, - attachments, - )? + let attachments = value.remove(BACKUP_TABLE_FIELD_ATTACHMENTS); + let attachments = if attachments.is_some() { + attachments.attr_try_into(BACKUP_TABLE_FIELD_ATTACHMENTS)? } else { - HashSet::new() + Vec::new() }; Ok(BackupItem { user_id, backup_id, created, user_keys, user_data, - attachment_holders, + attachments, }) } } /// Corresponds to the items in the [`crate::constants::BACKUP_TABLE_INDEX_USERID_CREATED`] /// global index #[derive(Clone, Debug)] pub struct OrderedBackupItem { pub user_id: String, pub created: DateTime, pub backup_id: String, pub user_keys: BlobInfo, } impl TryFrom> for OrderedBackupItem { type Error = DBItemError; fn try_from( mut value: HashMap, ) -> Result { let user_id = String::try_from_attr( BACKUP_TABLE_FIELD_USER_ID, value.remove(BACKUP_TABLE_FIELD_USER_ID), )?; let created = DateTime::::try_from_attr( BACKUP_TABLE_FIELD_CREATED, value.remove(BACKUP_TABLE_FIELD_CREATED), )?; let backup_id = String::try_from_attr( BACKUP_TABLE_FIELD_BACKUP_ID, value.remove(BACKUP_TABLE_FIELD_BACKUP_ID), )?; let user_keys = BlobInfo::try_from_attr( BACKUP_TABLE_FIELD_USER_KEYS, value.remove(BACKUP_TABLE_FIELD_USER_KEYS), )?; Ok(OrderedBackupItem { user_id, created, backup_id, user_keys, }) } } diff --git a/services/backup/src/http/handlers/backup.rs b/services/backup/src/http/handlers/backup.rs index b331b243b..00f0ac548 100644 --- a/services/backup/src/http/handlers/backup.rs +++ b/services/backup/src/http/handlers/backup.rs @@ -1,291 +1,322 @@ -use std::{collections::HashSet, convert::Infallible}; - use actix_web::{ error::ErrorBadRequest, web::{self, Bytes}, HttpResponse, Responder, }; use comm_services_lib::{ auth::UserIdentity, backup::LatestBackupIDResponse, blob::{client::BlobServiceClient, types::BlobInfo}, http::multipart::{get_named_text_field, get_text_field}, tools::Defer, }; +use std::convert::Infallible; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tracing::{info, instrument, trace, warn}; use crate::{ database::{backup_item::BackupItem, DatabaseClient}, error::BackupError, }; #[instrument(name = "upload_backup", skip_all, fields(backup_id))] pub async fn upload( user: UserIdentity, blob_client: web::Data, db_client: web::Data, mut multipart: actix_multipart::Multipart, ) -> actix_web::Result { info!("Upload backup request"); let backup_id = get_named_text_field("backup_id", &mut multipart).await?; tracing::Span::current().record("backup_id", &backup_id); let (user_keys_blob_info, user_keys_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_keys_hash", "user_keys", ) .await?; let (user_data_blob_info, user_data_revoke) = forward_field_to_blob( &mut multipart, &blob_client, "user_data_hash", "user_data", ) .await?; - let attachments_holders: HashSet = + let attachments_hashes: Vec = match get_text_field(&mut multipart).await? { Some((name, attachments)) => { if name != "attachments" { warn!( name, "Malformed request: 'attachments' text field expected." ); return Err(ErrorBadRequest("Bad request")); } attachments.lines().map(ToString::to_string).collect() } - None => HashSet::new(), + None => Vec::new(), }; + let mut attachments = Vec::new(); + let mut attachments_revokes = Vec::new(); + for attachment_hash in attachments_hashes { + let (holder, revoke) = + create_attachment_holder(&attachment_hash, &blob_client).await?; + + attachments.push(BlobInfo { + blob_hash: attachment_hash, + holder, + }); + attachments_revokes.push(revoke); + } + let item = BackupItem::new( user.user_id.clone(), backup_id, user_keys_blob_info, user_data_blob_info, - attachments_holders, + attachments, ); db_client .put_backup_item(item) .await .map_err(BackupError::from)?; user_keys_revoke.cancel(); user_data_revoke.cancel(); + for attachment_revoke in attachments_revokes { + attachment_revoke.cancel(); + } for backup in db_client .remove_old_backups(&user.user_id) .await .map_err(BackupError::from)? { - blob_client.schedule_revoke_holder( - backup.user_keys.blob_hash, - backup.user_keys.holder, - ); - - blob_client.schedule_revoke_holder( - backup.user_data.blob_hash, - backup.user_data.holder, - ); + backup.revoke_holders(&blob_client).await; } Ok(HttpResponse::Ok().finish()) } #[instrument( skip_all, name = "forward_to_blob", fields(hash_field_name, data_field_name) )] async fn forward_field_to_blob<'revoke, 'blob: 'revoke>( multipart: &mut actix_multipart::Multipart, blob_client: &'blob web::Data, hash_field_name: &str, data_field_name: &str, ) -> actix_web::Result<(BlobInfo, Defer<'revoke>)> { trace!("Reading blob fields: {hash_field_name:?}, {data_field_name:?}"); let blob_hash = get_named_text_field(hash_field_name, multipart).await?; let Some(mut field) = multipart.try_next().await? else { warn!("Malformed request: expected a field."); return Err(ErrorBadRequest("Bad request"))?; }; if field.name() != data_field_name { warn!( hash_field_name, "Malformed request: '{data_field_name}' data field expected." ); return Err(ErrorBadRequest("Bad request"))?; } let blob_info = BlobInfo { blob_hash, holder: uuid::Uuid::new_v4().to_string(), }; // [`actix_multipart::Multipart`] isn't [`std::marker::Send`], and so we cannot pass it to the blob client directly. // Instead we have to forward it to a channel and create stream from the receiver. let (tx, rx) = tokio::sync::mpsc::channel(1); let receive_promise = async move { trace!("Receiving blob data"); // [`actix_multipart::MultipartError`] isn't [`std::marker::Send`] so we return it here, and pass [`Infallible`] // as the error to the channel while let Some(chunk) = field.try_next().await? { if let Err(err) = tx.send(Result::::Ok(chunk)).await { warn!("Error when sending data through a channel: '{err}'"); // Error here means that the channel has been closed from the blob client side. We don't want to return an error // here, because `tokio::try_join!` only returns the first error it receives and we want to prioritize the backup // client error. break; } } trace!("Finished receiving blob data"); Result::<(), actix_web::Error>::Ok(()) }; let data_stream = ReceiverStream::new(rx); let send_promise = async { blob_client .simple_put(&blob_info.blob_hash, &blob_info.holder, data_stream) .await .map_err(BackupError::from)?; Ok(()) }; tokio::try_join!(receive_promise, send_promise)?; let revoke_info = blob_info.clone(); let revoke_holder = Defer::new(|| { blob_client .schedule_revoke_holder(revoke_info.blob_hash, revoke_info.holder) }); Ok((blob_info, revoke_holder)) } +#[instrument(skip_all, name = "create_attachment_holder")] +async fn create_attachment_holder<'revoke, 'blob: 'revoke>( + attachment: &str, + blob_client: &'blob web::Data, +) -> Result<(String, Defer<'revoke>), BackupError> { + let holder = uuid::Uuid::new_v4().to_string(); + + if !blob_client + .assign_holder(attachment, &holder) + .await + .map_err(BackupError::from)? + { + warn!("Blob attachment with hash {attachment:?} doesn't exist"); + } + + let revoke_hash = attachment.to_string(); + let revoke_holder = holder.clone(); + let revoke_holder = Defer::new(|| { + blob_client.schedule_revoke_holder(revoke_hash, revoke_holder) + }); + + Ok((holder, revoke_holder)) +} + #[instrument(name = "download_user_keys", skip_all, fields(backup_id = %path.as_str()))] pub async fn download_user_keys( user: UserIdentity, path: web::Path, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { info!("Download user keys request"); let backup_id = path.into_inner(); download_user_blob( |item| &item.user_keys, &user.user_id, &backup_id, blob_client, db_client, ) .await } #[instrument(name = "download_user_data", skip_all, fields(backup_id = %path.as_str()))] pub async fn download_user_data( user: UserIdentity, path: web::Path, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { info!("Download user data request"); let backup_id = path.into_inner(); download_user_blob( |item| &item.user_data, &user.user_id, &backup_id, blob_client, db_client, ) .await } pub async fn download_user_blob( data_extractor: impl FnOnce(&BackupItem) -> &BlobInfo, user_id: &str, backup_id: &str, blob_client: web::Data, db_client: web::Data, ) -> actix_web::Result { let backup_item = db_client .find_backup_item(user_id, backup_id) .await .map_err(BackupError::from)? .ok_or(BackupError::NoBackup)?; let stream = blob_client .get(&data_extractor(&backup_item).blob_hash) .await .map_err(BackupError::from)?; Ok( HttpResponse::Ok() .content_type("application/octet-stream") .streaming(stream), ) } #[instrument(name = "get_latest_backup_id", skip_all, fields(username = %path.as_str()))] pub async fn get_latest_backup_id( path: web::Path, db_client: web::Data, ) -> actix_web::Result { let username = path.into_inner(); // Treat username as user_id in the initial version let user_id = username; let Some(backup_item) = db_client .find_last_backup_item(&user_id) .await .map_err(BackupError::from)? else { return Err(BackupError::NoBackup.into()); }; let response = LatestBackupIDResponse { backup_id: backup_item.backup_id, }; Ok(web::Json(response)) } #[instrument(name = "download_latest_backup_keys", skip_all, fields(username = %path.as_str()))] pub async fn download_latest_backup_keys( path: web::Path, db_client: web::Data, blob_client: web::Data, ) -> actix_web::Result { let username = path.into_inner(); // Treat username as user_id in the initial version let user_id = username; let Some(backup_item) = db_client .find_last_backup_item(&user_id) .await .map_err(BackupError::from)? else { return Err(BackupError::NoBackup.into()); }; let stream = blob_client .get(&backup_item.user_keys.blob_hash) .await .map_err(BackupError::from)?; Ok( HttpResponse::Ok() .content_type("application/octet-stream") .streaming(stream), ) } diff --git a/services/comm-services-lib/src/database.rs b/services/comm-services-lib/src/database.rs index b9d1b738a..60c8d6333 100644 --- a/services/comm-services-lib/src/database.rs +++ b/services/comm-services-lib/src/database.rs @@ -1,619 +1,648 @@ use aws_sdk_dynamodb::types::AttributeValue; pub use aws_sdk_dynamodb::Error as DynamoDBError; use chrono::{DateTime, Utc}; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::num::ParseIntError; use std::str::FromStr; // # Useful type aliases // Rust exports `pub type` only into the so-called "type namespace", but in // order to use them e.g. with the `TryFromAttribute` trait, they also need // to be exported into the "value namespace" which is what `pub use` does. // // To overcome that, a dummy module is created and aliases are re-exported // with `pub use` construct mod aliases { use aws_sdk_dynamodb::types::AttributeValue; use std::collections::HashMap; pub type AttributeMap = HashMap; } pub use self::aliases::AttributeMap; // # Error handling #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(fmt = "Maximum retries exceeded")] MaxRetriesExceeded, } #[derive(Debug)] pub enum Value { AttributeValue(Option), String(String), } #[derive(Debug, derive_more::Error, derive_more::Constructor)] pub struct DBItemError { attribute_name: String, attribute_value: Value, attribute_error: DBItemAttributeError, } impl Display for DBItemError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match &self.attribute_error { DBItemAttributeError::Missing => { write!(f, "Attribute {} is missing", self.attribute_name) } DBItemAttributeError::IncorrectType => write!( f, "Value for attribute {} has incorrect type: {:?}", self.attribute_name, self.attribute_value ), error => write!( f, "Error regarding attribute {} with value {:?}: {}", self.attribute_name, self.attribute_value, error ), } } } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DBItemAttributeError { #[display(...)] Missing, #[display(...)] IncorrectType, #[display(...)] TimestampOutOfRange, #[display(...)] InvalidTimestamp(chrono::ParseError), #[display(...)] InvalidNumberFormat(ParseIntError), } /// Conversion trait for [`AttributeValue`] /// /// Types implementing this trait are able to do the following: /// ```ignore /// use comm_services_lib::database::{TryFromAttribute, AttributeTryInto}; /// /// let foo = SomeType::try_from_attr("MyAttribute", Some(attribute))?; /// /// // if `AttributeTryInto` is imported, also: /// let bar = Some(attribute).attr_try_into("MyAttribute")?; /// ``` pub trait TryFromAttribute: Sized { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result; } /// Do NOT implement this trait directly. Implement [`TryFromAttribute`] instead pub trait AttributeTryInto { fn attr_try_into( self, attribute_name: impl Into, ) -> Result; } // Automatic attr_try_into() for all attribute values // that have TryFromAttribute implemented impl AttributeTryInto for Option { fn attr_try_into( self, attribute_name: impl Into, ) -> Result { T::try_from_attr(attribute_name, self) } } /// Helper trait for extracting attributes from a collection pub trait AttributeExtractor { /// Gets an attribute from the map and tries to convert it to the given type /// This method does not consume the raw attribute - it gets cloned /// See [`AttributeExtractor::take_attr`] for a non-cloning method fn get_attr( &self, attribute_name: &str, ) -> Result; /// Takes an attribute from the map and tries to convert it to the given type /// This method consumes the raw attribute - it gets removed from the map /// See [`AttributeExtractor::get_attr`] for a non-mutating method fn take_attr( &mut self, attribute_name: &str, ) -> Result; } impl AttributeExtractor for AttributeMap { fn get_attr( &self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.get(attribute_name).cloned()) } fn take_attr( &mut self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.remove(attribute_name)) } } impl TryFromAttribute for String { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::S(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for bool { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Bool(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for DateTime { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { match &attribute { Some(AttributeValue::S(datetime)) => datetime.parse().map_err(|e| { DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::InvalidTimestamp(e), ) }), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for AttributeMap { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::M(map)) => Ok(map), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for Vec { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::B(data)) => Ok(data.into_inner()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for HashSet { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Ss(set)) => Ok(set.into_iter().collect()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } +impl TryFromAttribute for Vec { + fn try_from_attr( + attribute_name: impl Into, + attribute: Option, + ) -> Result { + let attribute_name = attribute_name.into(); + match attribute { + Some(AttributeValue::L(list)) => Ok( + list + .into_iter() + .map(|attribute| { + T::try_from_attr(format!("{attribute_name}[i]"), Some(attribute)) + }) + .collect::, _>>()?, + ), + Some(_) => Err(DBItemError::new( + attribute_name.into(), + Value::AttributeValue(attribute), + DBItemAttributeError::IncorrectType, + )), + None => Err(DBItemError::new( + attribute_name.into(), + Value::AttributeValue(attribute), + DBItemAttributeError::Missing, + )), + } + } +} + #[deprecated = "Use `String::try_from_attr()` instead"] pub fn parse_string_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { String::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `bool::try_from_attr()` instead"] pub fn parse_bool_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { bool::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `DateTime::::try_from_attr()` instead"] pub fn parse_datetime_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { DateTime::::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `AttributeMap::try_from_attr()` instead"] pub fn parse_map_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { attribute_value.attr_try_into(attribute_name) } pub fn parse_int_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result where T: FromStr, { match &attribute_value { Some(AttributeValue::N(numeric_str)) => { parse_integer(attribute_name, numeric_str) } Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } /// Parses the UTC timestamp in milliseconds from a DynamoDB numeric attribute pub fn parse_timestamp_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { let attribute_name: String = attribute_name.into(); let timestamp = parse_int_attribute::( attribute_name.clone(), attribute_value.clone(), )?; let naive_datetime = chrono::NaiveDateTime::from_timestamp_millis(timestamp) .ok_or_else(|| { DBItemError::new( attribute_name, Value::AttributeValue(attribute_value), DBItemAttributeError::TimestampOutOfRange, ) })?; Ok(DateTime::from_utc(naive_datetime, Utc)) } pub fn parse_integer( attribute_name: impl Into, attribute_value: &str, ) -> Result where T: FromStr, { attribute_value.parse::().map_err(|e| { DBItemError::new( attribute_name.into(), Value::String(attribute_value.into()), DBItemAttributeError::InvalidNumberFormat(e), ) }) } pub mod batch_operations { use aws_sdk_dynamodb::{ error::SdkError, operation::batch_write_item::BatchWriteItemError, types::WriteRequest, }; use rand::Rng; use std::time::Duration; use tracing::{debug, trace}; /// DynamoDB hard limit for single BatchWriteItem request const SINGLE_BATCH_ITEM_LIMIT: usize = 25; /// Exponential backoff configuration for batch write operation #[derive(derive_more::Constructor, Debug)] pub struct ExponentialBackoffConfig { /// Maximum retry attempts before the function fails. /// Set this to 0 to disable exponential backoff. /// Defaults to **8**. pub max_attempts: u32, /// Base wait duration before retry. Defaults to **25ms**. /// It is doubled with each attempt: 25ms, 50, 100, 200... pub base_duration: Duration, /// Jitter factor for retry delay. Factor 0.5 for 100ms delay /// means that wait time will be between 50ms and 150ms. /// The value must be in range 0.0 - 1.0. It will be clamped /// if out of these bounds. Defaults to **0.3** pub jitter_factor: f32, /// Retry on [`ProvisionedThroughputExceededException`]. /// Defaults to **true**. /// /// [`ProvisionedThroughputExceededException`]: aws_sdk_dynamodb::Error::ProvisionedThroughputExceededException pub retry_on_provisioned_capacity_exceeded: bool, } impl Default for ExponentialBackoffConfig { fn default() -> Self { ExponentialBackoffConfig { max_attempts: 8, base_duration: Duration::from_millis(25), jitter_factor: 0.3, retry_on_provisioned_capacity_exceeded: true, } } } impl ExponentialBackoffConfig { fn new_counter(&self) -> ExponentialBackoffHelper { ExponentialBackoffHelper::new(self) } fn backoff_enabled(&self) -> bool { self.max_attempts > 0 } fn should_retry_on_capacity_exceeded(&self) -> bool { self.backoff_enabled() && self.retry_on_provisioned_capacity_exceeded } } /// Performs a single DynamoDB table batch write operation. If the batch /// contains more than 25 items, it is split into chunks. /// /// The function uses exponential backoff retries when AWS throttles /// the request or maximum provisioned capacity is exceeded #[tracing::instrument(name = "batch_write", skip(ddb, requests, config))] pub async fn batch_write( ddb: &aws_sdk_dynamodb::Client, table_name: &str, mut requests: Vec, config: ExponentialBackoffConfig, ) -> Result<(), super::Error> { tracing::debug!( ?config, "Starting batch write operation of {} items...", requests.len() ); let mut exponential_backoff = config.new_counter(); let mut backup = Vec::with_capacity(SINGLE_BATCH_ITEM_LIMIT); loop { let items_to_drain = std::cmp::min(requests.len(), SINGLE_BATCH_ITEM_LIMIT); let chunk = requests.drain(..items_to_drain).collect::>(); if chunk.is_empty() { // No more items tracing::trace!("No more items to process. Exiting"); break; } // we don't need the backup when we don't retry if config.should_retry_on_capacity_exceeded() { chunk.clone_into(&mut backup); } tracing::trace!("Attempting to write chunk of {} items...", chunk.len()); let result = ddb .batch_write_item() .request_items(table_name, chunk) .send() .await; match result { Ok(output) => { if let Some(mut items) = output.unprocessed_items { let requests_to_retry = items.remove(table_name).unwrap_or_default(); if requests_to_retry.is_empty() { tracing::trace!("Chunk written successfully. Continuing."); exponential_backoff.reset(); continue; } exponential_backoff.sleep_and_retry().await?; tracing::debug!( "Some items failed. Retrying {} requests", requests_to_retry.len() ); requests.extend(requests_to_retry); } else { tracing::trace!("Unprocessed items was None"); } } Err(error) => { if !is_provisioned_capacity_exceeded(&error) { tracing::error!("BatchWriteItem failed: {0:?} - {0}", error); return Err(super::Error::AwsSdk(error.into())); } tracing::warn!("Provisioned capacity exceeded!"); if !config.retry_on_provisioned_capacity_exceeded { return Err(super::Error::AwsSdk(error.into())); } exponential_backoff.sleep_and_retry().await?; requests.append(&mut backup); trace!("Retrying now..."); } }; } debug!("Batch write completed."); Ok(()) } /// internal helper struct struct ExponentialBackoffHelper<'cfg> { config: &'cfg ExponentialBackoffConfig, attempt: u32, } impl<'cfg> ExponentialBackoffHelper<'cfg> { fn new(config: &'cfg ExponentialBackoffConfig) -> Self { ExponentialBackoffHelper { config, attempt: 0 } } /// reset counter after successfull operation fn reset(&mut self) { self.attempt = 0; } /// increase counter and sleep in case of failure async fn sleep_and_retry(&mut self) -> Result<(), super::Error> { let jitter_factor = 1f32.min(0f32.max(self.config.jitter_factor)); let random_multiplier = 1.0 + rand::thread_rng().gen_range(-jitter_factor..=jitter_factor); let backoff_multiplier = 2u32.pow(self.attempt); let base_duration = self.config.base_duration * backoff_multiplier; let sleep_duration = base_duration.mul_f32(random_multiplier); self.attempt += 1; if self.attempt > self.config.max_attempts { tracing::warn!("Retry limit exceeded!"); return Err(super::Error::MaxRetriesExceeded); } tracing::debug!( attempt = self.attempt, "Batch failed. Sleeping for {}ms before retrying...", sleep_duration.as_millis() ); tokio::time::sleep(sleep_duration).await; Ok(()) } } /// Check if transaction failed due to /// `ProvisionedThroughputExceededException` exception fn is_provisioned_capacity_exceeded( err: &SdkError, ) -> bool { let SdkError::ServiceError(service_error) = err else { return false; }; matches!( service_error.err(), BatchWriteItemError::ProvisionedThroughputExceededException(_) ) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_integer() { assert!(parse_integer::("some_attr", "123").is_ok()); assert!(parse_integer::("negative", "-123").is_ok()); assert!(parse_integer::("float", "3.14").is_err()); assert!(parse_integer::("NaN", "foo").is_err()); assert!(parse_integer::("negative_uint", "-123").is_err()); assert!(parse_integer::("too_large", "65536").is_err()); } #[test] fn test_parse_timestamp() { let timestamp = Utc::now().timestamp_millis(); let attr = AttributeValue::N(timestamp.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_ok()); assert_eq!(parsed_timestamp.unwrap().timestamp_millis(), timestamp); } #[test] fn test_parse_invalid_timestamp() { let attr = AttributeValue::N("foo".to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); } #[test] fn test_parse_timestamp_out_of_range() { let attr = AttributeValue::N(i64::MAX.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); assert!(matches!( parsed_timestamp.unwrap_err().attribute_error, DBItemAttributeError::TimestampOutOfRange )); } }