diff --git a/services/identity/src/database/device_list.rs b/services/identity/src/database/device_list.rs --- a/services/identity/src/database/device_list.rs +++ b/services/identity/src/database/device_list.rs @@ -4,7 +4,10 @@ use std::collections::HashMap; use aws_sdk_dynamodb::{ - client::fluent_builders::Query, model::AttributeValue, output::GetItemOutput, + client::fluent_builders::Query, + error::TransactionCanceledException, + model::{AttributeValue, Put, TransactWriteItem, Update}, + output::GetItemOutput, }; use chrono::{DateTime, Utc}; use tracing::{error, warn}; @@ -16,7 +19,10 @@ USERS_TABLE_PARTITION_KEY, }, database::parse_string_attribute, - error::{DBItemAttributeError, DBItemError, Error, FromAttributeValue}, + error::{ + DBItemAttributeError, DBItemError, DeviceListError, Error, + FromAttributeValue, + }, grpc_services::protos::unauth::DeviceType, }; @@ -54,6 +60,21 @@ pub device_ids: Vec, } +impl DeviceListRow { + /// Generates new device list row from given devices + fn new( + user_id: impl Into, + devices: impl IntoIterator, + ) -> Self { + let device_ids = devices.into_iter().map(|d| d.device_id).collect(); + Self { + user_id: user_id.into(), + timestamp: Utc::now(), + device_ids, + } + } +} + // helper structs for converting to/from attribute values for sort key (a.k.a itemID) struct DeviceIDAttribute(String); struct DeviceListKeyAttribute(DateTime); @@ -393,6 +414,111 @@ Ok(Some(timestamp)) } +/// Performs a transactional update of the device list for the user. Afterwards +/// generates a new device list and updates the timestamp in the users table. +/// This is done in a transaction. Operation fails if the device list has been +/// updated concurrently (timestamp mismatch). +async fn transact_update_devicelist( + db: &crate::database::DatabaseClient, + user_id: &str, + // The closure performing a transactional update of the device list. It receives a mutable + // reference to the current device list. The closure should return a transactional DDB + // operation to be performed when updating the device list. + action: impl FnOnce(&mut Vec) -> Result, +) -> Result<(), Error> { + let previous_timestamp = + get_current_devicelist_timestamp(db, user_id).await?; + let mut user_devices = get_current_devices(db, user_id).await?; + + // Perform the update action, then generate new device list + let operation = action(&mut user_devices)?; + let new_device_list = DeviceListRow::new(user_id, user_devices); + + // Update timestamp in users table + let timestamp_update_operation = device_list_timestamp_update_operation( + user_id, + previous_timestamp, + new_device_list.timestamp, + ); + + // Put updated device list (a new version) + let put_device_list = Put::builder() + .table_name(devices_table::NAME) + .set_item(Some(new_device_list.into())) + .condition_expression( + "attribute_not_exists(#user_id) AND attribute_not_exists(#item_id)", + ) + .expression_attribute_names("#user_id", ATTR_USER_ID) + .expression_attribute_names("#item_id", ATTR_ITEM_ID) + .build(); + let put_device_list_operation = + TransactWriteItem::builder().put(put_device_list).build(); + + db.client + .transact_write_items() + .transact_items(operation) + .transact_items(put_device_list_operation) + .transact_items(timestamp_update_operation) + .send() + .await + .map_err(|e| match aws_sdk_dynamodb::Error::from(e) { + aws_sdk_dynamodb::Error::TransactionCanceledException( + TransactionCanceledException { + cancellation_reasons: Some(reasons), + .. + }, + ) if reasons + .iter() + .any(|reason| reason.code() == Some("ConditionalCheckFailed")) => + { + Error::DeviceList(DeviceListError::ConcurrentUpdateError) + } + other => { + error!("Device list update transaction failed: {:?}", other); + Error::AwsSdk(other) + } + })?; + + Ok(()) +} + +/// Generates update expression for current device list timestamp in users table. +/// The previous timestamp is used as a condition to ensure that the value hasn't changed +/// since we got it. This avoids race conditions when updating the device list. +fn device_list_timestamp_update_operation( + user_id: impl Into, + previous_timestamp: Option>, + new_timestamp: DateTime, +) -> TransactWriteItem { + let update_builder = match previous_timestamp { + Some(previous_timestamp) => Update::builder() + .condition_expression("#device_list_timestamp = :previous_timestamp") + .expression_attribute_values( + ":previous_timestamp", + AttributeValue::S(previous_timestamp.to_rfc3339()), + ), + // If there's no previous timestamp, the attribute shouldn't exist yet + None => Update::builder() + .condition_expression("attribute_not_exists(#device_list_timestamp)"), + }; + + let update = update_builder + .table_name(USERS_TABLE) + .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id.into())) + .update_expression("SET #device_list_timestamp = :new_timestamp") + .expression_attribute_names( + "#device_list_timestamp", + USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME, + ) + .expression_attribute_values( + ":new_timestamp", + AttributeValue::S(new_timestamp.to_rfc3339()), + ) + .build(); + + TransactWriteItem::builder().update(update).build() +} + /// Helper function to query rows by given sort key prefix fn query_rows_with_prefix( db: &crate::database::DatabaseClient, diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -17,6 +17,13 @@ Status(tonic::Status), #[display(...)] MissingItem, + #[display(...)] + DeviceList(DeviceListError), +} + +#[derive(Debug, derive_more::Display, derive_more::Error)] +pub enum DeviceListError { + ConcurrentUpdateError, } #[derive(Debug, derive_more::Error, derive_more::Constructor)]