diff --git a/services/identity/src/database/device_list.rs b/services/identity/src/database/device_list.rs index d3644b671..04fa117cc 100644 --- a/services/identity/src/database/device_list.rs +++ b/services/identity/src/database/device_list.rs @@ -1,461 +1,620 @@ // TODO: get rid of this #![allow(dead_code)] use std::collections::HashMap; use aws_sdk_dynamodb::{ - client::fluent_builders::Query, model::AttributeValue, output::GetItemOutput, + client::fluent_builders::Query, + error::TransactionCanceledException, + model::{AttributeValue, Put, TransactWriteItem, Update}, + output::GetItemOutput, }; use chrono::{DateTime, Utc}; use tracing::{error, warn}; use crate::{ constants::{ devices_table::{self, *}, USERS_TABLE, USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME, USERS_TABLE_PARTITION_KEY, }, database::parse_string_attribute, ddb_utils::AttributesOptionExt, - error::{DBItemAttributeError, DBItemError, Error, FromAttributeValue}, + error::{ + DBItemAttributeError, DBItemError, DeviceListError, Error, + FromAttributeValue, + }, grpc_services::protos::unauth::DeviceType, }; use super::{parse_date_time_attribute, DatabaseClient}; type RawAttributes = HashMap; #[derive(Clone, Debug)] pub enum DevicesTableRow { Device(DeviceRow), DeviceList(DeviceListRow), } #[derive(Clone, Debug)] pub struct DeviceRow { pub user_id: String, pub device_id: String, pub device_type: DeviceType, pub device_key_info: IdentityKeyInfo, pub content_prekey: PreKey, pub notif_prekey: PreKey, } #[derive(Clone, Debug)] pub struct DeviceListRow { pub user_id: String, pub timestamp: DateTime, pub device_ids: Vec, } #[derive(Clone, Debug)] pub struct IdentityKeyInfo { pub key_payload: String, pub key_payload_signature: String, pub social_proof: Option, } #[derive(Clone, Debug)] pub struct PreKey { pub pre_key: String, pub pre_key_signature: String, } +impl DeviceListRow { + /// Generates new device list row from given devices + fn new(user_id: impl Into, device_ids: Vec) -> Self { + Self { + user_id: user_id.into(), + device_ids, + timestamp: Utc::now(), + } + } +} + // helper structs for converting to/from attribute values for sort key (a.k.a itemID) struct DeviceIDAttribute(String); struct DeviceListKeyAttribute(DateTime); impl From for AttributeValue { fn from(value: DeviceIDAttribute) -> Self { AttributeValue::S(format!("{DEVICE_ITEM_KEY_PREFIX}{}", value.0)) } } impl From for AttributeValue { fn from(value: DeviceListKeyAttribute) -> Self { AttributeValue::S(format!( "{DEVICE_LIST_KEY_PREFIX}{}", value.0.to_rfc3339() )) } } impl TryFrom> for DeviceIDAttribute { type Error = DBItemError; fn try_from(value: Option) -> Result { let item_id = parse_string_attribute(ATTR_ITEM_ID, value)?; // remove the device- prefix let device_id = item_id .strip_prefix(DEVICE_ITEM_KEY_PREFIX) .ok_or_else(|| DBItemError { attribute_name: ATTR_ITEM_ID.to_string(), attribute_value: Some(AttributeValue::S(item_id.clone())), attribute_error: DBItemAttributeError::InvalidValue, })? .to_string(); Ok(Self(device_id)) } } impl TryFrom> for DeviceListKeyAttribute { type Error = DBItemError; fn try_from(value: Option) -> Result { let item_id = parse_string_attribute(ATTR_ITEM_ID, value)?; // remove the device-list- prefix, then parse the timestamp let timestamp: DateTime = item_id .strip_prefix(DEVICE_LIST_KEY_PREFIX) .ok_or_else(|| DBItemError { attribute_name: ATTR_ITEM_ID.to_string(), attribute_value: Some(AttributeValue::S(item_id.clone())), attribute_error: DBItemAttributeError::InvalidValue, }) .and_then(|s| { s.parse().map_err(|e| { DBItemError::new( ATTR_ITEM_ID.to_string(), Some(AttributeValue::S(item_id.clone())), DBItemAttributeError::InvalidTimestamp(e), ) }) })?; Ok(Self(timestamp)) } } impl TryFrom for DeviceRow { type Error = DBItemError; fn try_from(mut attrs: RawAttributes) -> Result { let user_id = parse_string_attribute(ATTR_USER_ID, attrs.remove(ATTR_USER_ID))?; let DeviceIDAttribute(device_id) = attrs.remove(ATTR_ITEM_ID).try_into()?; let raw_device_type = parse_string_attribute(ATTR_DEVICE_TYPE, attrs.remove(ATTR_DEVICE_TYPE))?; let device_type = DeviceType::from_str_name(&raw_device_type).ok_or_else(|| { DBItemError::new( ATTR_DEVICE_TYPE.to_string(), Some(AttributeValue::S(raw_device_type)), DBItemAttributeError::InvalidValue, ) })?; let device_key_info = attrs .remove(ATTR_DEVICE_KEY_INFO) .ok_or_missing(ATTR_DEVICE_KEY_INFO)? .to_hashmap(ATTR_DEVICE_KEY_INFO) .cloned() .and_then(IdentityKeyInfo::try_from)?; let content_prekey = attrs .remove(ATTR_CONTENT_PREKEY) .ok_or_missing(ATTR_CONTENT_PREKEY)? .to_hashmap(ATTR_CONTENT_PREKEY) .cloned() .and_then(PreKey::try_from)?; let notif_prekey = attrs .remove(ATTR_NOTIF_PREKEY) .ok_or_missing(ATTR_NOTIF_PREKEY)? .to_hashmap(ATTR_NOTIF_PREKEY) .cloned() .and_then(PreKey::try_from)?; Ok(Self { user_id, device_id, device_type, device_key_info, content_prekey, notif_prekey, }) } } impl From for RawAttributes { fn from(value: DeviceRow) -> Self { HashMap::from([ (ATTR_USER_ID.to_string(), AttributeValue::S(value.user_id)), ( ATTR_ITEM_ID.to_string(), DeviceIDAttribute(value.device_id).into(), ), ( ATTR_DEVICE_TYPE.to_string(), AttributeValue::S(value.device_type.as_str_name().to_string()), ), ( ATTR_DEVICE_KEY_INFO.to_string(), value.device_key_info.into(), ), (ATTR_CONTENT_PREKEY.to_string(), value.content_prekey.into()), (ATTR_NOTIF_PREKEY.to_string(), value.notif_prekey.into()), ]) } } impl From for AttributeValue { fn from(value: IdentityKeyInfo) -> Self { let mut attrs = HashMap::from([ ( ATTR_KEY_PAYLOAD.to_string(), AttributeValue::S(value.key_payload), ), ( ATTR_KEY_PAYLOAD_SIGNATURE.to_string(), AttributeValue::S(value.key_payload_signature), ), ]); if let Some(social_proof) = value.social_proof { attrs.insert( ATTR_SOCIAL_PROOF.to_string(), AttributeValue::S(social_proof), ); } AttributeValue::M(attrs) } } impl TryFrom for IdentityKeyInfo { type Error = DBItemError; fn try_from(mut attrs: RawAttributes) -> Result { let key_payload = parse_string_attribute(ATTR_KEY_PAYLOAD, attrs.remove(ATTR_KEY_PAYLOAD))?; let key_payload_signature = parse_string_attribute( ATTR_KEY_PAYLOAD_SIGNATURE, attrs.remove(ATTR_KEY_PAYLOAD_SIGNATURE), )?; // social proof is optional let social_proof = attrs .remove(ATTR_SOCIAL_PROOF) .map(|attr| attr.to_string(ATTR_SOCIAL_PROOF).cloned()) .transpose()?; Ok(Self { key_payload, key_payload_signature, social_proof, }) } } impl From for AttributeValue { fn from(value: PreKey) -> Self { let attrs = HashMap::from([ (ATTR_PREKEY.to_string(), AttributeValue::S(value.pre_key)), ( ATTR_PREKEY_SIGNATURE.to_string(), AttributeValue::S(value.pre_key_signature), ), ]); AttributeValue::M(attrs) } } impl TryFrom for PreKey { type Error = DBItemError; fn try_from(mut attrs: RawAttributes) -> Result { let pre_key = parse_string_attribute(ATTR_PREKEY, attrs.remove(ATTR_PREKEY))?; let pre_key_signature = parse_string_attribute( ATTR_PREKEY_SIGNATURE, attrs.remove(ATTR_PREKEY_SIGNATURE), )?; Ok(Self { pre_key, pre_key_signature, }) } } impl TryFrom for DeviceListRow { type Error = DBItemError; fn try_from(mut attrs: RawAttributes) -> Result { let user_id = parse_string_attribute(ATTR_USER_ID, attrs.remove(ATTR_USER_ID))?; let DeviceListKeyAttribute(timestamp) = attrs.remove(ATTR_ITEM_ID).try_into()?; // validate timestamps are in sync let timestamps_match = attrs .remove(ATTR_TIMESTAMP) .and_then(|attr| attr.as_n().ok().cloned()) .and_then(|val| val.parse::().ok()) .filter(|val| *val == timestamp.timestamp_millis()) .is_some(); if !timestamps_match { warn!( "DeviceList timestamp mismatch for (userID={}, itemID={})", &user_id, timestamp.to_rfc3339() ); } // this should be a list of strings let device_ids = attrs .remove(ATTR_DEVICE_IDS) .ok_or_else(|| { DBItemError::new( ATTR_DEVICE_IDS.to_string(), None, DBItemAttributeError::Missing, ) })? .to_vec(ATTR_DEVICE_IDS)? .iter() .map(|v| v.to_string("device_ids[?]").cloned()) .collect::, DBItemError>>()?; Ok(Self { user_id, timestamp, device_ids, }) } } impl From for RawAttributes { fn from(device_list: DeviceListRow) -> Self { let mut attrs = HashMap::new(); attrs.insert( ATTR_USER_ID.to_string(), AttributeValue::S(device_list.user_id.clone()), ); attrs.insert( ATTR_ITEM_ID.to_string(), DeviceListKeyAttribute(device_list.timestamp).into(), ); attrs.insert( ATTR_TIMESTAMP.to_string(), AttributeValue::N(device_list.timestamp.timestamp_millis().to_string()), ); attrs.insert( ATTR_DEVICE_IDS.to_string(), AttributeValue::L( device_list .device_ids .into_iter() .map(AttributeValue::S) .collect(), ), ); attrs } } impl DatabaseClient { /// Retrieves user's current devices and their full data pub async fn get_current_devices( &self, user_id: impl Into, ) -> Result, Error> { let response = query_rows_with_prefix(self, user_id, DEVICE_ITEM_KEY_PREFIX) .send() .await .map_err(|e| { error!("Failed to get current devices: {:?}", e); Error::AwsSdk(e.into()) })?; let Some(rows) = response.items else { return Ok(Vec::new()); }; rows .into_iter() .map(DeviceRow::try_from) .collect::, DBItemError>>() .map_err(Error::from) } /// Checks if given device exists on user's current device list pub async fn device_exists( &self, user_id: impl Into, device_id: impl Into, ) -> Result { let GetItemOutput { item, .. } = self .client .get_item() .table_name(devices_table::NAME) .key(ATTR_USER_ID, AttributeValue::S(user_id.into())) .key(ATTR_ITEM_ID, DeviceIDAttribute(device_id.into()).into()) // only fetch the primary key, we don't need the rest .projection_expression(format!("{ATTR_USER_ID}, {ATTR_ITEM_ID}")) .send() .await .map_err(|e| { error!("Failed to check if device exists: {:?}", e); Error::AwsSdk(e.into()) })?; Ok(item.is_some()) } + + pub async fn get_current_device_list( + &self, + user_id: impl Into, + ) -> Result, Error> { + self + .client + .query() + .table_name(devices_table::NAME) + .index_name(devices_table::TIMESTAMP_INDEX_NAME) + .consistent_read(true) + .key_condition_expression("#user_id = :user_id") + // sort descending + .scan_index_forward(false) + .expression_attribute_names("#user_id", ATTR_USER_ID) + .expression_attribute_values( + ":user_id", + AttributeValue::S(user_id.into()), + ) + .limit(1) + .send() + .await + .map_err(|e| { + error!("Failed to query device list updates by index: {:?}", e); + Error::AwsSdk(e.into()) + })? + .items + .and_then(|mut items| items.pop()) + .map(DeviceListRow::try_from) + .transpose() + .map_err(Error::from) + } + + /// Performs a transactional update of the device list for the user. Afterwards + /// generates a new device list and updates the timestamp in the users table. + /// This is done in a transaction. Operation fails if the device list has been + /// updated concurrently (timestamp mismatch). + async fn transact_update_devicelist( + &self, + user_id: &str, + // The closure performing a transactional update of the device list. It receives a mutable + // reference to the current device list. The closure should return a transactional DDB + // operation to be performed when updating the device list. + action: impl FnOnce(&mut Vec) -> Result, + ) -> Result<(), Error> { + let previous_timestamp = + get_current_devicelist_timestamp(self, user_id).await?; + let mut device_ids = self + .get_current_device_list(user_id) + .await? + .map(|device_list| device_list.device_ids) + .unwrap_or_default(); + + // Perform the update action, then generate new device list + let operation = action(&mut device_ids)?; + let new_device_list = DeviceListRow::new(user_id, device_ids); + + // Update timestamp in users table + let timestamp_update_operation = device_list_timestamp_update_operation( + user_id, + previous_timestamp, + new_device_list.timestamp, + ); + + // Put updated device list (a new version) + let put_device_list = Put::builder() + .table_name(devices_table::NAME) + .set_item(Some(new_device_list.into())) + .condition_expression( + "attribute_not_exists(#user_id) AND attribute_not_exists(#item_id)", + ) + .expression_attribute_names("#user_id", ATTR_USER_ID) + .expression_attribute_names("#item_id", ATTR_ITEM_ID) + .build(); + let put_device_list_operation = + TransactWriteItem::builder().put(put_device_list).build(); + + self + .client + .transact_write_items() + .transact_items(operation) + .transact_items(put_device_list_operation) + .transact_items(timestamp_update_operation) + .send() + .await + .map_err(|e| match aws_sdk_dynamodb::Error::from(e) { + aws_sdk_dynamodb::Error::TransactionCanceledException( + TransactionCanceledException { + cancellation_reasons: Some(reasons), + .. + }, + ) if reasons + .iter() + .any(|reason| reason.code() == Some("ConditionalCheckFailed")) => + { + Error::DeviceList(DeviceListError::ConcurrentUpdateError) + } + other => { + error!("Device list update transaction failed: {:?}", other); + Error::AwsSdk(other) + } + })?; + + Ok(()) + } } /// Gets timestamp of user's current device list. Returns None if the user /// doesn't have a device list yet. Storing the timestamp in the users table is /// required for consistency. It's used as a condition when updating the device /// list. async fn get_current_devicelist_timestamp( db: &crate::database::DatabaseClient, user_id: impl Into, ) -> Result>, Error> { let response = db .client .get_item() .table_name(USERS_TABLE) .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id.into())) .projection_expression(USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME) .send() .await .map_err(|e| { error!("Failed to get user's device list timestamp: {:?}", e); Error::AwsSdk(e.into()) })?; let mut user_item = response.item.unwrap_or_default(); let raw_datetime = user_item.remove(USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME); // existing records will not have this field when // updating device list for the first time if raw_datetime.is_none() { return Ok(None); } let timestamp = parse_date_time_attribute( USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME, raw_datetime, )?; Ok(Some(timestamp)) } +/// Generates update expression for current device list timestamp in users table. +/// The previous timestamp is used as a condition to ensure that the value hasn't changed +/// since we got it. This avoids race conditions when updating the device list. +fn device_list_timestamp_update_operation( + user_id: impl Into, + previous_timestamp: Option>, + new_timestamp: DateTime, +) -> TransactWriteItem { + let update_builder = match previous_timestamp { + Some(previous_timestamp) => Update::builder() + .condition_expression("#device_list_timestamp = :previous_timestamp") + .expression_attribute_values( + ":previous_timestamp", + AttributeValue::S(previous_timestamp.to_rfc3339()), + ), + // If there's no previous timestamp, the attribute shouldn't exist yet + None => Update::builder() + .condition_expression("attribute_not_exists(#device_list_timestamp)"), + }; + + let update = update_builder + .table_name(USERS_TABLE) + .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id.into())) + .update_expression("SET #device_list_timestamp = :new_timestamp") + .expression_attribute_names( + "#device_list_timestamp", + USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME, + ) + .expression_attribute_values( + ":new_timestamp", + AttributeValue::S(new_timestamp.to_rfc3339()), + ) + .build(); + + TransactWriteItem::builder().update(update).build() +} + /// Helper function to query rows by given sort key prefix fn query_rows_with_prefix( db: &crate::database::DatabaseClient, user_id: impl Into, prefix: &'static str, ) -> Query { db.client .query() .table_name(devices_table::NAME) .key_condition_expression( "#user_id = :user_id AND begins_with(#item_id, :device_prefix)", ) .expression_attribute_names("#user_id", ATTR_USER_ID) .expression_attribute_names("#item_id", ATTR_ITEM_ID) .expression_attribute_values(":user_id", AttributeValue::S(user_id.into())) .expression_attribute_values( ":device_prefix", AttributeValue::S(prefix.to_string()), ) .consistent_read(true) } diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs index c41844788..1dcbb3cdb 100644 --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -1,157 +1,164 @@ use aws_sdk_dynamodb::{model::AttributeValue, Error as DynamoDBError}; use std::collections::hash_map::HashMap; use std::fmt::{Display, Formatter, Result as FmtResult}; use tracing::error; #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(...)] Transport(tonic::transport::Error), #[display(...)] Status(tonic::Status), #[display(...)] MissingItem, + #[display(...)] + DeviceList(DeviceListError), +} + +#[derive(Debug, derive_more::Display, derive_more::Error)] +pub enum DeviceListError { + ConcurrentUpdateError, } #[derive(Debug, derive_more::Error, derive_more::Constructor)] pub struct DBItemError { pub attribute_name: String, pub attribute_value: Option, pub attribute_error: DBItemAttributeError, } impl Display for DBItemError { fn fmt(&self, f: &mut Formatter) -> FmtResult { match &self.attribute_error { DBItemAttributeError::Missing => { write!(f, "Attribute {} is missing", self.attribute_name) } DBItemAttributeError::IncorrectType => write!( f, "Value for attribute {} has incorrect type: {:?}", self.attribute_name, self.attribute_value ), error => write!( f, "Error regarding attribute {} with value {:?}: {}", self.attribute_name, self.attribute_value, error ), } } } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DBItemAttributeError { #[display(...)] Missing, #[display(...)] IncorrectType, #[display(...)] InvalidTimestamp(chrono::ParseError), #[display(...)] ExpiredTimestamp, #[display(...)] InvalidValue, } pub trait FromAttributeValue { fn to_vec( &self, attr_name: &str, ) -> Result<&Vec, DBItemError>; fn to_string(&self, attr_name: &str) -> Result<&String, DBItemError>; fn to_hashmap( &self, attr_name: &str, ) -> Result<&HashMap, DBItemError>; } fn handle_attr_failure(value: &AttributeValue, attr_name: &str) -> DBItemError { DBItemError { attribute_name: attr_name.to_string(), attribute_value: Some(value.clone()), attribute_error: DBItemAttributeError::IncorrectType, } } impl FromAttributeValue for AttributeValue { fn to_vec( &self, attr_name: &str, ) -> Result<&Vec, DBItemError> { self.as_l().map_err(|e| handle_attr_failure(e, attr_name)) } fn to_string(&self, attr_name: &str) -> Result<&String, DBItemError> { self.as_s().map_err(|e| handle_attr_failure(e, attr_name)) } fn to_hashmap( &self, attr_name: &str, ) -> Result<&HashMap, DBItemError> { self.as_m().map_err(|e| handle_attr_failure(e, attr_name)) } } pub trait AttributeValueFromHashMap { fn get_string(&self, key: &str) -> Result<&String, DBItemError>; fn get_map( &self, key: &str, ) -> Result<&HashMap, DBItemError>; fn get_vec(&self, key: &str) -> Result<&Vec, DBItemError>; } impl AttributeValueFromHashMap for HashMap { fn get_string(&self, key: &str) -> Result<&String, DBItemError> { self .get(key) .ok_or(DBItemError { attribute_name: key.to_string(), attribute_value: None, attribute_error: DBItemAttributeError::Missing, })? .to_string(key) } fn get_map( &self, key: &str, ) -> Result<&HashMap, DBItemError> { self .get(key) .ok_or(DBItemError { attribute_name: key.to_string(), attribute_value: None, attribute_error: DBItemAttributeError::Missing, })? .to_hashmap(key) } fn get_vec(&self, key: &str) -> Result<&Vec, DBItemError> { self .get(key) .ok_or(DBItemError { attribute_name: key.to_string(), attribute_value: None, attribute_error: DBItemAttributeError::Missing, })? .to_vec(key) } } pub fn consume_error(result: Result) { match result { Ok(_) => (), Err(e) => { error!("{}", e); } } }