diff --git a/services/identity/src/database/one_time_keys.rs b/services/identity/src/database/one_time_keys.rs index 5ec10d60d..1228d4a2a 100644 --- a/services/identity/src/database/one_time_keys.rs +++ b/services/identity/src/database/one_time_keys.rs @@ -1,507 +1,507 @@ use std::collections::HashSet; use comm_lib::{ aws::{ ddb::types::{ AttributeValue, Delete, DeleteRequest, TransactWriteItem, Update, WriteRequest, }, DynamoDBError, }, database::{ batch_operations::{batch_write, ExponentialBackoffConfig}, parse_int_attribute, AttributeExtractor, AttributeMap, DBItemAttributeError, DBItemError, }, }; use tracing::{debug, error, info, warn}; use crate::{ constants::{ error_types, MAX_ONE_TIME_KEYS, ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT, }, database::DeviceIDAttribute, ddb_utils::{ create_one_time_key_partition_key, into_one_time_put_requests, into_one_time_update_and_delete_requests, is_transaction_retryable, OlmAccountType, }, error::{consume_error, Error}, olm::is_valid_olm_key, }; use super::DatabaseClient; impl DatabaseClient { /// Gets the next one-time key for the account and then, in a transaction, /// deletes the key and updates the key count /// /// Returns the retrieved one-time key if it exists and a boolean indicating /// whether the `spawn_refresh_keys_task`` was called #[tracing::instrument(skip_all)] pub(super) async fn get_one_time_key( &self, user_id: &str, device_id: &str, account_type: OlmAccountType, can_request_more_keys: bool, ) -> Result<(Option, bool), Error> { use crate::constants::devices_table; use crate::constants::retry; use crate::constants::ONE_TIME_KEY_MINIMUM_THRESHOLD; let attr_otk_count = match account_type { OlmAccountType::Content => devices_table::ATTR_CONTENT_OTK_COUNT, OlmAccountType::Notification => devices_table::ATTR_NOTIF_OTK_COUNT, }; fn spawn_refresh_keys_task(device_id: &str) { // Clone the string slice to move into the async block let device_id = device_id.to_string(); tokio::spawn(async move { debug!("Attempting to request more keys for device: {}", &device_id); let result = crate::tunnelbroker::send_refresh_keys_request(&device_id).await; consume_error(result); }); } // TODO: Introduce `transact_write_helper` similar to `batch_write_helper` // in `comm-lib` to handle transactions with retries - let mut attempt = 0; + let retry_config = ExponentialBackoffConfig { + max_attempts: retry::MAX_ATTEMPTS as u32, + ..Default::default() + }; // TODO: Introduce nanny task that handles calling `spawn_refresh_keys_task` let mut requested_more_keys = false; + let mut exponential_backoff = retry_config.new_counter(); loop { - attempt += 1; - if attempt > retry::MAX_ATTEMPTS { - return Err(Error::MaxRetriesExceeded); - } - let otk_count = self.get_otk_count(user_id, device_id, account_type).await?; if otk_count < ONE_TIME_KEY_MINIMUM_THRESHOLD && can_request_more_keys { spawn_refresh_keys_task(device_id); requested_more_keys = true; } if otk_count < 1 { return Ok((None, requested_more_keys)); } let Some(otk_row) = self .get_one_time_keys(user_id, device_id, account_type, Some(1)) .await? .pop() else { return Ok((None, requested_more_keys)); }; let delete_otk_operation = otk_row.as_delete_request(); let update_otk_count = Update::builder() .table_name(devices_table::NAME) .key( devices_table::ATTR_USER_ID, AttributeValue::S(user_id.to_string()), ) .key( devices_table::ATTR_ITEM_ID, DeviceIDAttribute(device_id.into()).into(), ) .update_expression(format!("ADD {} :decrement_val", attr_otk_count)) .expression_attribute_values( ":decrement_val", AttributeValue::N("-1".to_string()), ) .condition_expression(format!("{} = :old_val", attr_otk_count)) .expression_attribute_values( ":old_val", AttributeValue::N(otk_count.to_string()), ) .build() .expect( "table_name, key or update_expression not set in Update builder", ); let update_otk_count_operation = TransactWriteItem::builder() .update(update_otk_count) .build(); let transaction = self .client .transact_write_items() .set_transact_items(Some(vec![ delete_otk_operation, update_otk_count_operation, ])) .send() .await; match transaction { Ok(_) => return Ok((Some(otk_row.otk), requested_more_keys)), Err(e) => { info!("Error retrieving one-time key: {:?}", e); let dynamo_db_error = DynamoDBError::from(e); let retryable_codes = HashSet::from([ retry::CONDITIONAL_CHECK_FAILED, retry::TRANSACTION_CONFLICT, ]); if is_transaction_retryable(&dynamo_db_error, &retryable_codes) { info!("Encountered transaction conflict while retrieving one-time key - retrying"); + exponential_backoff.sleep_and_retry().await?; } else { error!( errorType = error_types::OTK_DB_LOG, "One-time key retrieval transaction failed: {:?}", dynamo_db_error ); return Err(Error::AwsSdk(dynamo_db_error)); } } } } } #[tracing::instrument(skip_all)] async fn get_one_time_keys( &self, user_id: &str, device_id: &str, account_type: OlmAccountType, num_keys: Option, ) -> Result, Error> { use crate::constants::one_time_keys_table::*; let partition_key = create_one_time_key_partition_key(user_id, device_id, account_type); let mut query = self .client .query() .table_name(NAME) .key_condition_expression("#pk = :pk") .expression_attribute_names("#pk", PARTITION_KEY) .expression_attribute_values(":pk", AttributeValue::S(partition_key)); if let Some(limit) = num_keys { // DynamoDB will reject the `query` request if `limit < 1` if limit < 1 { return Ok(Vec::new()); } query = query.limit(limit as i32); } let otk_rows = query .send() .await .map_err(|e| { error!( errorType = error_types::OTK_DB_LOG, "DDB client failed to query OTK rows: {:?}", e ); Error::AwsSdk(e.into()) })? .items .unwrap_or_default() .into_iter() .map(OTKRow::try_from) .collect::, _>>() .map_err(Error::from)?; if let Some(limit) = num_keys { if otk_rows.len() != limit { warn!("There are fewer one-time keys than the number requested"); } } Ok(otk_rows) } #[tracing::instrument(skip_all)] pub async fn append_one_time_prekeys( &self, user_id: &str, device_id: &str, content_one_time_keys: &Vec, notif_one_time_keys: &Vec, ) -> Result<(), Error> { use crate::constants::retry; let num_content_keys_to_append = content_one_time_keys.len(); let num_notif_keys_to_append = notif_one_time_keys.len(); if num_content_keys_to_append > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT || num_notif_keys_to_append > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT { return Err(Error::OneTimeKeyUploadLimitExceeded); } if content_one_time_keys .iter() .any(|otk| !is_valid_olm_key(otk)) || notif_one_time_keys.iter().any(|otk| !is_valid_olm_key(otk)) { debug!("Invalid one-time key format"); return Err(Error::InvalidFormat); } let current_time = chrono::Utc::now(); let content_otk_requests = into_one_time_put_requests( user_id, device_id, content_one_time_keys, OlmAccountType::Content, current_time, ); let notif_otk_requests = into_one_time_put_requests( user_id, device_id, notif_one_time_keys, OlmAccountType::Notification, current_time, ); let current_content_otk_count = self .get_otk_count(user_id, device_id, OlmAccountType::Content) .await?; let current_notif_otk_count = self .get_otk_count(user_id, device_id, OlmAccountType::Notification) .await?; let num_content_keys_to_delete = (num_content_keys_to_append + current_content_otk_count) .saturating_sub(MAX_ONE_TIME_KEYS); let num_notif_keys_to_delete = (num_notif_keys_to_append + current_notif_otk_count) .saturating_sub(MAX_ONE_TIME_KEYS); let content_keys_to_delete = self .get_one_time_keys( user_id, device_id, OlmAccountType::Content, Some(num_content_keys_to_delete), ) .await?; let notif_keys_to_delete = self .get_one_time_keys( user_id, device_id, OlmAccountType::Notification, Some(num_notif_keys_to_delete), ) .await?; let update_and_delete_otk_count_operation = into_one_time_update_and_delete_requests( user_id, device_id, num_content_keys_to_append, num_notif_keys_to_append, content_keys_to_delete, notif_keys_to_delete, ); let mut operations = Vec::new(); operations.extend_from_slice(&content_otk_requests); operations.extend_from_slice(¬if_otk_requests); operations.extend_from_slice(&update_and_delete_otk_count_operation); // TODO: Introduce `transact_write_helper` similar to `batch_write_helper` // in `comm-lib` to handle transactions with retries let mut attempt = 0; loop { attempt += 1; if attempt > retry::MAX_ATTEMPTS { return Err(Error::MaxRetriesExceeded); } let transaction = self .client .transact_write_items() .set_transact_items(Some(operations.clone())) .send() .await; match transaction { Ok(_) => break, Err(e) => { let dynamo_db_error = DynamoDBError::from(e); let retryable_codes = HashSet::from([retry::TRANSACTION_CONFLICT]); if is_transaction_retryable(&dynamo_db_error, &retryable_codes) { info!("Encountered transaction conflict while uploading one-time keys - retrying"); } else { error!( errorType = error_types::OTK_DB_LOG, "One-time key upload transaction failed: {:?}", dynamo_db_error ); return Err(Error::AwsSdk(dynamo_db_error)); } } } } Ok(()) } #[tracing::instrument(skip_all)] async fn get_otk_count( &self, user_id: &str, device_id: &str, account_type: OlmAccountType, ) -> Result { use crate::constants::devices_table; let attr_name = match account_type { OlmAccountType::Content => devices_table::ATTR_CONTENT_OTK_COUNT, OlmAccountType::Notification => devices_table::ATTR_NOTIF_OTK_COUNT, }; let response = self .client .get_item() .table_name(devices_table::NAME) .projection_expression(attr_name) .key( devices_table::ATTR_USER_ID, AttributeValue::S(user_id.to_string()), ) .key( devices_table::ATTR_ITEM_ID, DeviceIDAttribute(device_id.into()).into(), ) .send() .await .map_err(|e| { error!( errorType = error_types::OTK_DB_LOG, "Failed to get user's OTK count: {:?}", e ); Error::AwsSdk(e.into()) })?; let mut user_item = response.item.unwrap_or_default(); match parse_int_attribute(attr_name, user_item.remove(attr_name)) { Ok(num) => Ok(num), Err(DBItemError { attribute_error: DBItemAttributeError::Missing, .. }) => Ok(0), Err(e) => Err(Error::Attribute(e)), } } /// Deletes all data for a user's device from one-time keys table pub async fn delete_otks_table_rows_for_user_device( &self, user_id: &str, device_id: &str, ) -> Result<(), Error> { use crate::constants::one_time_keys_table::*; let content_otk_primary_keys = self .get_one_time_keys(user_id, device_id, OlmAccountType::Content, None) .await?; let notif_otk_primary_keys = self .get_one_time_keys(user_id, device_id, OlmAccountType::Notification, None) .await?; let delete_requests = content_otk_primary_keys .into_iter() .chain(notif_otk_primary_keys) .map(|otk_row| { let request = DeleteRequest::builder() .key(PARTITION_KEY, AttributeValue::S(otk_row.partition_key)) .key(SORT_KEY, AttributeValue::S(otk_row.sort_key)) .build() .expect("no keys set in DeleteRequest builder"); WriteRequest::builder().delete_request(request).build() }) .collect::>(); batch_write( &self.client, NAME, delete_requests, ExponentialBackoffConfig::default(), ) .await .map_err(Error::from)?; Ok(()) } /// Deletes all data for a user from one-time keys table pub async fn delete_otks_table_rows_for_user( &self, user_id: &str, ) -> Result<(), Error> { let maybe_device_list_row = self.get_current_device_list(user_id).await?; let Some(device_list_row) = maybe_device_list_row else { info!("No devices associated with user. Skipping one-time key removal."); return Ok(()); }; for device_id in device_list_row.device_ids { self .delete_otks_table_rows_for_user_device(user_id, &device_id) .await?; } Ok(()) } } pub struct OTKRow { pub partition_key: String, pub sort_key: String, pub otk: String, } impl OTKRow { pub fn as_delete_request(&self) -> TransactWriteItem { use crate::constants::one_time_keys_table as otk_table; let delete_otk = Delete::builder() .table_name(otk_table::NAME) .key( otk_table::PARTITION_KEY, AttributeValue::S(self.partition_key.to_string()), ) .key( otk_table::SORT_KEY, AttributeValue::S(self.sort_key.to_string()), ) .condition_expression("attribute_exists(#otk)") .expression_attribute_names("#otk", otk_table::ATTR_ONE_TIME_KEY) .build() .expect("table_name or key not set in Delete builder"); TransactWriteItem::builder().delete(delete_otk).build() } } impl TryFrom for OTKRow { type Error = DBItemError; fn try_from(mut attrs: AttributeMap) -> Result { use crate::constants::one_time_keys_table as otk_table; let partition_key = attrs.take_attr(otk_table::PARTITION_KEY)?; let sort_key = attrs.take_attr(otk_table::SORT_KEY)?; let otk: String = attrs.take_attr(otk_table::ATTR_ONE_TIME_KEY)?; Ok(Self { partition_key, sort_key, otk, }) } } diff --git a/shared/comm-lib/src/database.rs b/shared/comm-lib/src/database.rs index 265db1c58..294af1011 100644 --- a/shared/comm-lib/src/database.rs +++ b/shared/comm-lib/src/database.rs @@ -1,857 +1,857 @@ use aws_sdk_dynamodb::types::AttributeValue; pub use aws_sdk_dynamodb::Error as DynamoDBError; use chrono::{DateTime, Utc}; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::num::ParseIntError; use std::str::FromStr; #[cfg(feature = "blob-client")] pub mod blob; // # Useful type aliases // Rust exports `pub type` only into the so-called "type namespace", but in // order to use them e.g. with the `TryFromAttribute` trait, they also need // to be exported into the "value namespace" which is what `pub use` does. // // To overcome that, a dummy module is created and aliases are re-exported // with `pub use` construct mod aliases { use aws_sdk_dynamodb::types::AttributeValue; use std::collections::HashMap; pub type AttributeMap = HashMap; } pub use self::aliases::AttributeMap; // # Error handling #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(fmt = "Maximum retries exceeded")] MaxRetriesExceeded, } #[derive(Debug, derive_more::From)] pub enum Value { AttributeValue(Option), String(String), } #[derive(Debug, derive_more::Error, derive_more::Constructor)] pub struct DBItemError { pub attribute_name: String, pub attribute_value: Value, pub attribute_error: DBItemAttributeError, } impl Display for DBItemError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match &self.attribute_error { DBItemAttributeError::Missing => { write!(f, "Attribute {} is missing", self.attribute_name) } DBItemAttributeError::IncorrectType => write!( f, "Value for attribute {} has incorrect type: {:?}", self.attribute_name, self.attribute_value ), error => write!( f, "Error regarding attribute {} with value {:?}: {}", self.attribute_name, self.attribute_value, error ), } } } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DBItemAttributeError { #[display(...)] Missing, #[display(...)] IncorrectType, #[display(...)] TimestampOutOfRange, #[display(...)] InvalidTimestamp(chrono::ParseError), #[display(...)] InvalidNumberFormat(ParseIntError), #[display(...)] ExpiredTimestamp, #[display(...)] InvalidValue, } /// Conversion trait for [`AttributeValue`] /// /// Types implementing this trait are able to do the following: /// ```ignore /// use comm_lib::database::{TryFromAttribute, AttributeTryInto}; /// /// let foo = SomeType::try_from_attr("MyAttribute", Some(attribute))?; /// /// // if `AttributeTryInto` is imported, also: /// let bar = Some(attribute).attr_try_into("MyAttribute")?; /// ``` pub trait TryFromAttribute: Sized { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result; } /// Do NOT implement this trait directly. Implement [`TryFromAttribute`] instead pub trait AttributeTryInto { fn attr_try_into( self, attribute_name: impl Into, ) -> Result; } // Automatic attr_try_into() for all attribute values // that have TryFromAttribute implemented impl AttributeTryInto for Option { fn attr_try_into( self, attribute_name: impl Into, ) -> Result { T::try_from_attr(attribute_name, self) } } /// Helper trait for extracting attributes from a collection pub trait AttributeExtractor { /// Gets an attribute from the map and tries to convert it to the given type /// This method does not consume the raw attribute - it gets cloned /// See [`AttributeExtractor::take_attr`] for a non-cloning method fn get_attr( &self, attribute_name: &str, ) -> Result; /// Takes an attribute from the map and tries to convert it to the given type /// This method consumes the raw attribute - it gets removed from the map /// See [`AttributeExtractor::get_attr`] for a non-mutating method fn take_attr( &mut self, attribute_name: &str, ) -> Result; } impl AttributeExtractor for AttributeMap { fn get_attr( &self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.get(attribute_name).cloned()) } fn take_attr( &mut self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.remove(attribute_name)) } } // this allows us to get optional attributes impl TryFromAttribute for Option where T: TryFromAttribute, { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { if attribute.is_none() { return Ok(None); } match T::try_from_attr(attribute_name, attribute) { Ok(value) => Ok(Some(value)), Err(DBItemError { attribute_error: DBItemAttributeError::Missing, .. }) => Ok(None), Err(error) => Err(error), } } } impl TryFromAttribute for String { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::S(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for bool { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Bool(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for DateTime { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { match &attribute { Some(AttributeValue::S(datetime)) => datetime.parse().map_err(|e| { DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::InvalidTimestamp(e), ) }), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for AttributeMap { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::M(map)) => Ok(map), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for Vec { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::B(data)) => Ok(data.into_inner()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for HashSet { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Ss(set)) => Ok(set.into_iter().collect()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for Vec { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { let attribute_name = attribute_name.into(); match attribute { Some(AttributeValue::L(list)) => Ok( list .into_iter() .map(|attribute| { T::try_from_attr(format!("{attribute_name}[i]"), Some(attribute)) }) .collect::, _>>()?, ), Some(_) => Err(DBItemError::new( attribute_name, Value::AttributeValue(attribute), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name, Value::AttributeValue(attribute), DBItemAttributeError::Missing, )), } } } #[deprecated = "Use `String::try_from_attr()` instead"] pub fn parse_string_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { String::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `bool::try_from_attr()` instead"] pub fn parse_bool_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { bool::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `DateTime::::try_from_attr()` instead"] pub fn parse_datetime_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { DateTime::::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `AttributeMap::try_from_attr()` instead"] pub fn parse_map_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { attribute_value.attr_try_into(attribute_name) } pub fn parse_int_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result where T: FromStr, { match &attribute_value { Some(AttributeValue::N(numeric_str)) => { parse_integer(attribute_name, numeric_str) } Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } /// Parses the UTC timestamp in milliseconds from a DynamoDB numeric attribute pub fn parse_timestamp_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { let attribute_name: String = attribute_name.into(); let timestamp = parse_int_attribute::( attribute_name.clone(), attribute_value.clone(), )?; chrono::DateTime::from_timestamp_millis(timestamp).ok_or_else(|| { DBItemError::new( attribute_name, Value::AttributeValue(attribute_value), DBItemAttributeError::TimestampOutOfRange, ) }) } pub fn parse_integer( attribute_name: impl Into, attribute_value: &str, ) -> Result where T: FromStr, { attribute_value.parse::().map_err(|e| { DBItemError::new( attribute_name.into(), Value::String(attribute_value.into()), DBItemAttributeError::InvalidNumberFormat(e), ) }) } pub mod batch_operations { use aws_sdk_dynamodb::{ error::SdkError, operation::batch_write_item::BatchWriteItemError, types::{KeysAndAttributes, WriteRequest}, Error as DynamoDBError, }; use rand::Rng; use std::time::Duration; use tracing::{debug, trace}; use super::AttributeMap; /// DynamoDB hard limit for single BatchWriteItem request const SINGLE_BATCH_WRITE_ITEM_LIMIT: usize = 25; const SINGLE_BATCH_GET_ITEM_LIMIT: usize = 100; /// Exponential backoff configuration for batch write operation #[derive(derive_more::Constructor, Debug)] pub struct ExponentialBackoffConfig { /// Maximum retry attempts before the function fails. /// Set this to 0 to disable exponential backoff. /// Defaults to **8**. pub max_attempts: u32, /// Base wait duration before retry. Defaults to **25ms**. /// It is doubled with each attempt: 25ms, 50, 100, 200... pub base_duration: Duration, /// Jitter factor for retry delay. Factor 0.5 for 100ms delay /// means that wait time will be between 50ms and 150ms. /// The value must be in range 0.0 - 1.0. It will be clamped /// if out of these bounds. Defaults to **0.3** pub jitter_factor: f32, /// Retry on [`ProvisionedThroughputExceededException`]. /// Defaults to **true**. /// /// [`ProvisionedThroughputExceededException`]: aws_sdk_dynamodb::Error::ProvisionedThroughputExceededException pub retry_on_provisioned_capacity_exceeded: bool, } impl Default for ExponentialBackoffConfig { fn default() -> Self { ExponentialBackoffConfig { max_attempts: 8, base_duration: Duration::from_millis(25), jitter_factor: 0.3, retry_on_provisioned_capacity_exceeded: true, } } } impl ExponentialBackoffConfig { - fn new_counter(&self) -> ExponentialBackoffHelper { + pub fn new_counter(&self) -> ExponentialBackoffHelper { ExponentialBackoffHelper::new(self) } fn backoff_enabled(&self) -> bool { self.max_attempts > 0 } fn should_retry_on_capacity_exceeded(&self) -> bool { self.backoff_enabled() && self.retry_on_provisioned_capacity_exceeded } } #[tracing::instrument(name = "batch_get", skip(ddb, primary_keys, config))] pub async fn batch_get( ddb: &aws_sdk_dynamodb::Client, table_name: &str, primary_keys: K, projection_expression: Option, config: ExponentialBackoffConfig, ) -> Result, super::Error> where K: IntoIterator, K::Item: Into, { let mut primary_keys: Vec<_> = primary_keys.into_iter().map(Into::into).collect(); let mut results = Vec::with_capacity(primary_keys.len()); tracing::debug!( ?config, "Starting batch read operation of {} items...", primary_keys.len() ); let mut exponential_backoff = config.new_counter(); let mut backup = Vec::with_capacity(SINGLE_BATCH_GET_ITEM_LIMIT); loop { let items_to_drain = std::cmp::min(primary_keys.len(), SINGLE_BATCH_GET_ITEM_LIMIT); let chunk = primary_keys.drain(..items_to_drain).collect::>(); if chunk.is_empty() { // No more items tracing::trace!("No more items to process. Exiting"); break; } // we don't need the backup when we don't retry if config.should_retry_on_capacity_exceeded() { chunk.clone_into(&mut backup); } tracing::trace!("Attempting to get chunk of {} items...", chunk.len()); let result = ddb .batch_get_item() .request_items( table_name, KeysAndAttributes::builder() .set_keys(Some(chunk)) .consistent_read(true) .set_projection_expression(projection_expression.clone()) .build() .expect("set_keys() was not called on KeysAndAttributes builder."), ) .send() .await; match result { Ok(output) => { if let Some(mut responses) = output.responses { if let Some(items) = responses.remove(table_name) { tracing::trace!("Successfully read {} items", items.len()); results.extend(items); } } else { tracing::warn!("Responses was None"); } if let Some(mut unprocessed) = output.unprocessed_keys { let keys_to_retry = match unprocessed.remove(table_name) { Some(KeysAndAttributes { keys, .. }) if !keys.is_empty() => keys, _ => { tracing::trace!("Chunk read successfully. Continuing."); exponential_backoff.reset(); continue; } }; exponential_backoff.sleep_and_retry().await?; tracing::debug!( "Some items failed. Retrying {} requests", keys_to_retry.len() ); primary_keys.extend(keys_to_retry); } else { tracing::trace!("Unprocessed items was None"); } } Err(error) => { let error: DynamoDBError = error.into(); if !matches!( error, DynamoDBError::ProvisionedThroughputExceededException(_) ) { tracing::error!("BatchGetItem failed: {0:?} - {0}", error); return Err(error.into()); } tracing::warn!("Provisioned capacity exceeded!"); if !config.retry_on_provisioned_capacity_exceeded { return Err(error.into()); } exponential_backoff.sleep_and_retry().await?; primary_keys.append(&mut backup); trace!("Retrying now..."); } }; } debug!("Batch read completed."); Ok(results) } /// Performs a single DynamoDB table batch write operation. If the batch /// contains more than 25 items, it is split into chunks. /// /// The function uses exponential backoff retries when AWS throttles /// the request or maximum provisioned capacity is exceeded #[tracing::instrument(name = "batch_write", skip(ddb, requests, config))] pub async fn batch_write( ddb: &aws_sdk_dynamodb::Client, table_name: &str, mut requests: Vec, config: ExponentialBackoffConfig, ) -> Result<(), super::Error> { tracing::debug!( ?config, "Starting batch write operation of {} items...", requests.len() ); let mut exponential_backoff = config.new_counter(); let mut backup = Vec::with_capacity(SINGLE_BATCH_WRITE_ITEM_LIMIT); loop { let items_to_drain = std::cmp::min(requests.len(), SINGLE_BATCH_WRITE_ITEM_LIMIT); let chunk = requests.drain(..items_to_drain).collect::>(); if chunk.is_empty() { // No more items tracing::trace!("No more items to process. Exiting"); break; } // we don't need the backup when we don't retry if config.should_retry_on_capacity_exceeded() { chunk.clone_into(&mut backup); } tracing::trace!("Attempting to write chunk of {} items...", chunk.len()); let result = ddb .batch_write_item() .request_items(table_name, chunk) .send() .await; match result { Ok(output) => { if let Some(mut items) = output.unprocessed_items { let requests_to_retry = items.remove(table_name).unwrap_or_default(); if requests_to_retry.is_empty() { tracing::trace!("Chunk written successfully. Continuing."); exponential_backoff.reset(); continue; } exponential_backoff.sleep_and_retry().await?; tracing::debug!( "Some items failed. Retrying {} requests", requests_to_retry.len() ); requests.extend(requests_to_retry); } else { tracing::trace!("Unprocessed items was None"); } } Err(error) => { if !is_provisioned_capacity_exceeded(&error) { tracing::error!("BatchWriteItem failed: {0:?} - {0}", error); return Err(super::Error::AwsSdk(error.into())); } tracing::warn!("Provisioned capacity exceeded!"); if !config.retry_on_provisioned_capacity_exceeded { return Err(super::Error::AwsSdk(error.into())); } exponential_backoff.sleep_and_retry().await?; requests.append(&mut backup); trace!("Retrying now..."); } }; } debug!("Batch write completed."); Ok(()) } - /// internal helper struct - struct ExponentialBackoffHelper<'cfg> { + /// Utility for managing retries with exponential backoff + pub struct ExponentialBackoffHelper<'cfg> { config: &'cfg ExponentialBackoffConfig, attempt: u32, } impl<'cfg> ExponentialBackoffHelper<'cfg> { fn new(config: &'cfg ExponentialBackoffConfig) -> Self { ExponentialBackoffHelper { config, attempt: 0 } } /// reset counter after successfull operation - fn reset(&mut self) { + pub fn reset(&mut self) { self.attempt = 0; } /// increase counter and sleep in case of failure - async fn sleep_and_retry(&mut self) -> Result<(), super::Error> { + pub async fn sleep_and_retry(&mut self) -> Result<(), super::Error> { let jitter_factor = 1f32.min(0f32.max(self.config.jitter_factor)); let random_multiplier = 1.0 + rand::thread_rng().gen_range(-jitter_factor..=jitter_factor); let backoff_multiplier = 2u32.pow(self.attempt); let base_duration = self.config.base_duration * backoff_multiplier; let sleep_duration = base_duration.mul_f32(random_multiplier); self.attempt += 1; if self.attempt > self.config.max_attempts { tracing::warn!("Retry limit exceeded!"); return Err(super::Error::MaxRetriesExceeded); } tracing::debug!( attempt = self.attempt, "Batch failed. Sleeping for {}ms before retrying...", sleep_duration.as_millis() ); tokio::time::sleep(sleep_duration).await; Ok(()) } } /// Check if transaction failed due to /// `ProvisionedThroughputExceededException` exception fn is_provisioned_capacity_exceeded( err: &SdkError, ) -> bool { let SdkError::ServiceError(service_error) = err else { return false; }; matches!( service_error.err(), BatchWriteItemError::ProvisionedThroughputExceededException(_) ) } } #[derive(Debug, Clone, Copy, derive_more::Display, derive_more::Error)] pub struct UnknownAttributeTypeError; fn calculate_attr_value_size_in_db( value: &AttributeValue, ) -> Result { const ELEMENT_BYTE_OVERHEAD: usize = 1; const CONTAINER_BYTE_OVERHEAD: usize = 3; /// AWS doesn't provide an exact algorithm for calculating number size in bytes /// in case they change the internal representation. We know that number can use /// between 2 and 21 bytes so we use the maximum value as the byte size. const NUMBER_BYTE_SIZE: usize = 21; let result = match value { AttributeValue::B(blob) => blob.as_ref().len(), AttributeValue::L(list) => { CONTAINER_BYTE_OVERHEAD + list.len() * ELEMENT_BYTE_OVERHEAD + list .iter() .try_fold(0, |a, v| Ok(a + calculate_attr_value_size_in_db(v)?))? } AttributeValue::M(map) => { CONTAINER_BYTE_OVERHEAD + map.len() * ELEMENT_BYTE_OVERHEAD + calculate_size_in_db(map)? } AttributeValue::Bool(_) | AttributeValue::Null(_) => 1, AttributeValue::Bs(set) => set.len(), AttributeValue::N(_) => NUMBER_BYTE_SIZE, AttributeValue::Ns(set) => set.len() * NUMBER_BYTE_SIZE, AttributeValue::S(string) => string.as_bytes().len(), AttributeValue::Ss(set) => { set.iter().map(|string| string.as_bytes().len()).sum() } _ => return Err(UnknownAttributeTypeError), }; Ok(result) } pub fn calculate_size_in_db( value: &AttributeMap, ) -> Result { value.iter().try_fold(0, |a, (attr, value)| { Ok(a + attr.as_bytes().len() + calculate_attr_value_size_in_db(value)?) }) } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_integer() { assert!(parse_integer::("some_attr", "123").is_ok()); assert!(parse_integer::("negative", "-123").is_ok()); assert!(parse_integer::("float", "3.14").is_err()); assert!(parse_integer::("NaN", "foo").is_err()); assert!(parse_integer::("negative_uint", "-123").is_err()); assert!(parse_integer::("too_large", "65536").is_err()); } #[test] fn test_parse_timestamp() { let timestamp = Utc::now().timestamp_millis(); let attr = AttributeValue::N(timestamp.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_ok()); assert_eq!(parsed_timestamp.unwrap().timestamp_millis(), timestamp); } #[test] fn test_parse_invalid_timestamp() { let attr = AttributeValue::N("foo".to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); } #[test] fn test_parse_timestamp_out_of_range() { let attr = AttributeValue::N(i64::MAX.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); assert!(matches!( parsed_timestamp.unwrap_err().attribute_error, DBItemAttributeError::TimestampOutOfRange )); } #[test] fn test_optional_attribute() { let mut attrs = AttributeMap::from([( "foo".to_string(), AttributeValue::S("bar".to_string()), )]); let foo: Option = attrs.take_attr("foo").expect("failed to parse arg 'foo'"); let bar: Option = attrs.take_attr("bar").expect("failed to parse arg 'bar'"); assert!(foo.is_some()); assert!(bar.is_none()); } }