diff --git a/services/identity/src/database/one_time_keys.rs b/services/identity/src/database/one_time_keys.rs --- a/services/identity/src/database/one_time_keys.rs +++ b/services/identity/src/database/one_time_keys.rs @@ -2,10 +2,14 @@ use comm_lib::{ aws::{ - ddb::types::{AttributeValue, Delete, TransactWriteItem, Update}, + ddb::types::{ + AttributeValue, Delete, DeleteRequest, TransactWriteItem, Update, + WriteRequest, + }, DynamoDBError, }, database::{ + batch_operations::{batch_write, ExponentialBackoffConfig}, parse_int_attribute, AttributeExtractor, AttributeMap, DBItemAttributeError, DBItemError, }, @@ -83,7 +87,7 @@ } let Some(otk_row) = self - .get_one_time_keys(user_id, device_id, account_type, 1) + .get_one_time_keys(user_id, device_id, account_type, Some(1)) .await? .pop() else { @@ -155,26 +159,30 @@ user_id: &str, device_id: &str, account_type: OlmAccountType, - num_keys: usize, + num_keys: Option, ) -> Result, Error> { use crate::constants::one_time_keys_table::*; - // DynamoDB will reject the `query` request if `limit < 1` - if num_keys < 1 { - return Ok(Vec::new()); - } - let partition_key = create_one_time_key_partition_key(user_id, device_id, account_type); - let otk_rows = self + let mut query = self .client .query() .table_name(NAME) .key_condition_expression("#pk = :pk") .expression_attribute_names("#pk", PARTITION_KEY) - .expression_attribute_values(":pk", AttributeValue::S(partition_key)) - .limit(num_keys as i32) + .expression_attribute_values(":pk", AttributeValue::S(partition_key)); + + if let Some(limit) = num_keys { + // DynamoDB will reject the `query` request if `limit < 1` + if limit < 1 { + return Ok(Vec::new()); + } + query = query.limit(limit as i32); + } + + let otk_rows = query .send() .await .map_err(|e| Error::AwsSdk(e.into()))? @@ -185,9 +193,11 @@ .collect::, _>>() .map_err(Error::from)?; - if otk_rows.len() != num_keys { - error!("There are fewer one-time keys than the number requested"); - return Err(Error::NotEnoughOneTimeKeys); + if let Some(limit) = num_keys { + if otk_rows.len() != limit { + error!("There are fewer one-time keys than the number requested"); + return Err(Error::NotEnoughOneTimeKeys); + } } Ok(otk_rows) @@ -258,7 +268,7 @@ user_id, device_id, OlmAccountType::Content, - num_content_keys_to_delete, + Some(num_content_keys_to_delete), ) .await?; @@ -267,7 +277,7 @@ user_id, device_id, OlmAccountType::Notification, - num_notif_keys_to_delete, + Some(num_notif_keys_to_delete), ) .await?; @@ -367,6 +377,46 @@ Err(e) => Err(Error::Attribute(e)), } } + + /// Deletes all data for a user's device from one-time keys table + pub async fn delete_otks_table_rows_for_user_device( + &self, + user_id: &str, + device_id: &str, + ) -> Result<(), Error> { + use crate::constants::one_time_keys_table::*; + + let content_otk_primary_keys = self + .get_one_time_keys(user_id, device_id, OlmAccountType::Content, None) + .await?; + + let notif_otk_primary_keys = self + .get_one_time_keys(user_id, device_id, OlmAccountType::Notification, None) + .await?; + + let delete_requests = content_otk_primary_keys + .into_iter() + .chain(notif_otk_primary_keys) + .map(|otk_row| { + let request = DeleteRequest::builder() + .key(PARTITION_KEY, AttributeValue::S(otk_row.partition_key)) + .key(SORT_KEY, AttributeValue::S(otk_row.sort_key)) + .build(); + WriteRequest::builder().delete_request(request).build() + }) + .collect::>(); + + batch_write( + &self.client, + NAME, + delete_requests, + ExponentialBackoffConfig::default(), + ) + .await + .map_err(Error::from)?; + + Ok(()) + } } pub struct OTKRow { diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -50,3 +50,14 @@ } } } + +impl From for Error { + fn from(value: comm_lib::database::Error) -> Self { + use comm_lib::database::Error as E; + match value { + E::AwsSdk(err) => Self::AwsSdk(err), + E::Attribute(err) => Self::Attribute(err), + E::MaxRetriesExceeded => Self::MaxRetriesExceeded, + } + } +} diff --git a/services/identity/src/grpc_services/authenticated.rs b/services/identity/src/grpc_services/authenticated.rs --- a/services/identity/src/grpc_services/authenticated.rs +++ b/services/identity/src/grpc_services/authenticated.rs @@ -307,6 +307,12 @@ .await .map_err(handle_db_error)?; + self + .db_client + .delete_otks_table_rows_for_user_device(&user_id, &device_id) + .await + .map_err(handle_db_error)?; + self .db_client .delete_access_token_data(user_id, device_id)