diff --git a/services/identity/src/database/device_list.rs b/services/identity/src/database/device_list.rs --- a/services/identity/src/database/device_list.rs +++ b/services/identity/src/database/device_list.rs @@ -4,7 +4,10 @@ use std::collections::HashMap; use aws_sdk_dynamodb::{ - client::fluent_builders::Query, model::AttributeValue, output::GetItemOutput, + client::fluent_builders::Query, + error::TransactionCanceledException, + model::{AttributeValue, Put, TransactWriteItem, Update}, + output::GetItemOutput, }; use chrono::{DateTime, Utc}; use tracing::{error, warn}; @@ -17,7 +20,10 @@ }, database::parse_string_attribute, ddb_utils::AttributesOptionExt, - error::{DBItemAttributeError, DBItemError, Error, FromAttributeValue}, + error::{ + DBItemAttributeError, DBItemError, DeviceListError, Error, + FromAttributeValue, + }, grpc_services::protos::unauth::DeviceType, }; @@ -61,6 +67,17 @@ pub pre_key_signature: String, } +impl DeviceListRow { + /// Generates new device list row from given devices + fn new(user_id: impl Into, device_ids: Vec) -> Self { + Self { + user_id: user_id.into(), + device_ids, + timestamp: Utc::now(), + } + } +} + // helper structs for converting to/from attribute values for sort key (a.k.a itemID) struct DeviceIDAttribute(String); struct DeviceListKeyAttribute(DateTime); @@ -398,6 +415,111 @@ Ok(item.is_some()) } + + pub async fn get_current_device_list( + &self, + user_id: impl Into, + ) -> Result, Error> { + self + .client + .query() + .table_name(devices_table::NAME) + .index_name(devices_table::TIMESTAMP_INDEX_NAME) + .consistent_read(true) + .key_condition_expression("#user_id = :user_id") + // sort descending + .scan_index_forward(false) + .expression_attribute_names("#user_id", ATTR_USER_ID) + .expression_attribute_values( + ":user_id", + AttributeValue::S(user_id.into()), + ) + .limit(1) + .send() + .await + .map_err(|e| { + error!("Failed to query device list updates by index: {:?}", e); + Error::AwsSdk(e.into()) + })? + .items + .and_then(|mut items| items.pop()) + .map(DeviceListRow::try_from) + .transpose() + .map_err(Error::from) + } + + /// Performs a transactional update of the device list for the user. Afterwards + /// generates a new device list and updates the timestamp in the users table. + /// This is done in a transaction. Operation fails if the device list has been + /// updated concurrently (timestamp mismatch). + async fn transact_update_devicelist( + &self, + user_id: &str, + // The closure performing a transactional update of the device list. It receives a mutable + // reference to the current device list. The closure should return a transactional DDB + // operation to be performed when updating the device list. + action: impl FnOnce(&mut Vec) -> Result, + ) -> Result<(), Error> { + let previous_timestamp = + get_current_devicelist_timestamp(self, user_id).await?; + let mut device_ids = self + .get_current_device_list(user_id) + .await? + .map(|device_list| device_list.device_ids) + .unwrap_or_default(); + + // Perform the update action, then generate new device list + let operation = action(&mut device_ids)?; + let new_device_list = DeviceListRow::new(user_id, device_ids); + + // Update timestamp in users table + let timestamp_update_operation = device_list_timestamp_update_operation( + user_id, + previous_timestamp, + new_device_list.timestamp, + ); + + // Put updated device list (a new version) + let put_device_list = Put::builder() + .table_name(devices_table::NAME) + .set_item(Some(new_device_list.into())) + .condition_expression( + "attribute_not_exists(#user_id) AND attribute_not_exists(#item_id)", + ) + .expression_attribute_names("#user_id", ATTR_USER_ID) + .expression_attribute_names("#item_id", ATTR_ITEM_ID) + .build(); + let put_device_list_operation = + TransactWriteItem::builder().put(put_device_list).build(); + + self + .client + .transact_write_items() + .transact_items(operation) + .transact_items(put_device_list_operation) + .transact_items(timestamp_update_operation) + .send() + .await + .map_err(|e| match aws_sdk_dynamodb::Error::from(e) { + aws_sdk_dynamodb::Error::TransactionCanceledException( + TransactionCanceledException { + cancellation_reasons: Some(reasons), + .. + }, + ) if reasons + .iter() + .any(|reason| reason.code() == Some("ConditionalCheckFailed")) => + { + Error::DeviceList(DeviceListError::ConcurrentUpdateError) + } + other => { + error!("Device list update transaction failed: {:?}", other); + Error::AwsSdk(other) + } + })?; + + Ok(()) + } } /// Gets timestamp of user's current device list. Returns None if the user @@ -438,6 +560,43 @@ Ok(Some(timestamp)) } +/// Generates update expression for current device list timestamp in users table. +/// The previous timestamp is used as a condition to ensure that the value hasn't changed +/// since we got it. This avoids race conditions when updating the device list. +fn device_list_timestamp_update_operation( + user_id: impl Into, + previous_timestamp: Option>, + new_timestamp: DateTime, +) -> TransactWriteItem { + let update_builder = match previous_timestamp { + Some(previous_timestamp) => Update::builder() + .condition_expression("#device_list_timestamp = :previous_timestamp") + .expression_attribute_values( + ":previous_timestamp", + AttributeValue::S(previous_timestamp.to_rfc3339()), + ), + // If there's no previous timestamp, the attribute shouldn't exist yet + None => Update::builder() + .condition_expression("attribute_not_exists(#device_list_timestamp)"), + }; + + let update = update_builder + .table_name(USERS_TABLE) + .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id.into())) + .update_expression("SET #device_list_timestamp = :new_timestamp") + .expression_attribute_names( + "#device_list_timestamp", + USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME, + ) + .expression_attribute_values( + ":new_timestamp", + AttributeValue::S(new_timestamp.to_rfc3339()), + ) + .build(); + + TransactWriteItem::builder().update(update).build() +} + /// Helper function to query rows by given sort key prefix fn query_rows_with_prefix( db: &crate::database::DatabaseClient, diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -17,6 +17,13 @@ Status(tonic::Status), #[display(...)] MissingItem, + #[display(...)] + DeviceList(DeviceListError), +} + +#[derive(Debug, derive_more::Display, derive_more::Error)] +pub enum DeviceListError { + ConcurrentUpdateError, } #[derive(Debug, derive_more::Error, derive_more::Constructor)]