diff --git a/services/commtest/tests/identity_one_time_key_tests.rs b/services/commtest/tests/identity_one_time_key_tests.rs --- a/services/commtest/tests/identity_one_time_key_tests.rs +++ b/services/commtest/tests/identity_one_time_key_tests.rs @@ -4,7 +4,8 @@ use commtest::identity::olm_account_infos::generate_random_olm_key; use commtest::service_addr; use grpc_clients::identity::{ - get_auth_client, protos::authenticated::UploadOneTimeKeysRequest, + get_auth_client, protos::authenticated::OutboundKeysForUserRequest, + protos::authenticated::UploadOneTimeKeysRequest, }; #[tokio::test] @@ -38,3 +39,151 @@ .await .unwrap(); } + +#[tokio::test] +async fn max_hundred_keys_in_ddb() { + let device_info = register_user_device(None, None).await; + + let mut identity_client = get_auth_client( + &service_addr::IDENTITY_GRPC.to_string(), + device_info.user_id.clone(), + device_info.device_id, + device_info.access_token, + PLACEHOLDER_CODE_VERSION, + DEVICE_TYPE.to_string(), + ) + .await + .expect("Couldn't connect to identity service"); + + // We expect these keys to be removed by the identity service before we + // retrieve any OTKs + let first_upload_request = UploadOneTimeKeysRequest { + content_one_time_prekeys: vec![generate_random_olm_key()], + notif_one_time_prekeys: vec![generate_random_olm_key()], + }; + + identity_client + .upload_one_time_keys(first_upload_request) + .await + .unwrap(); + + let mut expected_first_retrieved_content_key = None; + let mut expected_first_retrieved_notif_key = None; + + let mut expected_second_retrieved_content_key = None; + let mut expected_second_retrieved_notif_key = None; + + // Upload 100 content and notif one-time keys in batches of 20 keys + for request_num in 0..5 { + let content_keys: Vec<_> = + (0..20).map(|_| generate_random_olm_key()).collect(); + let notif_keys: Vec<_> = + (0..20).map(|_| generate_random_olm_key()).collect(); + + if request_num == 0 { + expected_first_retrieved_content_key = content_keys.get(0).cloned(); + expected_first_retrieved_notif_key = notif_keys.get(0).cloned(); + expected_second_retrieved_content_key = content_keys.get(5).cloned(); + expected_second_retrieved_notif_key = notif_keys.get(5).cloned(); + } + + let upload_request = UploadOneTimeKeysRequest { + content_one_time_prekeys: content_keys, + notif_one_time_prekeys: notif_keys, + }; + + identity_client + .upload_one_time_keys(upload_request) + .await + .unwrap(); + } + + let keyserver_request = OutboundKeysForUserRequest { + user_id: device_info.user_id, + }; + + let first_keyserver_response = identity_client + .get_keyserver_keys(keyserver_request.clone()) + .await + .unwrap() + .into_inner() + .keyserver_info + .unwrap(); + + assert!(first_keyserver_response.one_time_content_prekey.is_some()); + assert!(first_keyserver_response.one_time_notif_prekey.is_some()); + + assert_eq!( + expected_first_retrieved_content_key, + first_keyserver_response.one_time_content_prekey + ); + assert_eq!( + expected_first_retrieved_notif_key, + first_keyserver_response.one_time_notif_prekey + ); + + // Upload 5 more keys for each account + let content_keys: Vec<_> = + (0..5).map(|_| generate_random_olm_key()).collect(); + let notif_keys: Vec<_> = (0..5).map(|_| generate_random_olm_key()).collect(); + + let final_upload_request = UploadOneTimeKeysRequest { + content_one_time_prekeys: content_keys, + notif_one_time_prekeys: notif_keys, + }; + + identity_client + .upload_one_time_keys(final_upload_request) + .await + .unwrap(); + + let second_keyserver_response = identity_client + .get_keyserver_keys(keyserver_request) + .await + .unwrap() + .into_inner() + .keyserver_info + .unwrap(); + + assert!(second_keyserver_response.one_time_content_prekey.is_some()); + assert!(second_keyserver_response.one_time_notif_prekey.is_some()); + + assert_eq!( + expected_second_retrieved_content_key, + second_keyserver_response.one_time_content_prekey + ); + assert_eq!( + expected_second_retrieved_notif_key, + second_keyserver_response.one_time_notif_prekey + ); +} + +#[tokio::test] +async fn max_24_keys_per_account_per_upload() { + let device_info = register_user_device(None, None).await; + + let mut identity_client = get_auth_client( + &service_addr::IDENTITY_GRPC.to_string(), + device_info.user_id, + device_info.device_id, + device_info.access_token, + PLACEHOLDER_CODE_VERSION, + DEVICE_TYPE.to_string(), + ) + .await + .expect("Couldn't connect to identity service"); + + // The limit is 24 keys per account per upload, so this should fail + let content_keys = (0..26).map(|_| generate_random_olm_key()).collect(); + let notif_keys = (0..20).map(|_| generate_random_olm_key()).collect(); + + let upload_request = UploadOneTimeKeysRequest { + content_one_time_prekeys: content_keys, + notif_one_time_prekeys: notif_keys, + }; + + assert!(identity_client + .upload_one_time_keys(upload_request) + .await + .is_err()); +} diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -248,5 +248,6 @@ } // One-time keys -pub const ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT: usize = 49; +pub const ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT: usize = 24; pub const ONE_TIME_KEY_SIZE: usize = 43; // as defined in olm +pub const MAX_ONE_TIME_KEYS: usize = 100; // as defined in olm diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -17,6 +17,7 @@ use std::sync::Arc; pub use crate::database::device_list::DeviceIDAttribute; +pub use crate::database::one_time_keys::OTKRow; use crate::{ constants::USERS_TABLE_SOCIAL_PROOF_ATTRIBUTE_NAME, ddb_utils::EthereumIdentity, reserved_users::UserDetail, siwe::SocialProof, diff --git a/services/identity/src/database/one_time_keys.rs b/services/identity/src/database/one_time_keys.rs --- a/services/identity/src/database/one_time_keys.rs +++ b/services/identity/src/database/one_time_keys.rs @@ -2,24 +2,23 @@ use comm_lib::{ aws::{ - ddb::{ - operation::query::QueryOutput, - types::{AttributeValue, Delete, TransactWriteItem, Update}, - }, + ddb::types::{AttributeValue, Delete, TransactWriteItem, Update}, DynamoDBError, }, database::{ - parse_int_attribute, AttributeExtractor, DBItemAttributeError, DBItemError, + parse_int_attribute, AttributeExtractor, AttributeMap, + DBItemAttributeError, DBItemError, }, }; use tracing::{debug, error, info}; use crate::{ - constants::ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT, + constants::{MAX_ONE_TIME_KEYS, ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT}, database::DeviceIDAttribute, ddb_utils::{ create_one_time_key_partition_key, into_one_time_put_requests, - into_one_time_update_requests, is_transaction_retryable, OlmAccountType, + into_one_time_update_and_delete_requests, is_transaction_retryable, + OlmAccountType, }, error::{consume_error, Error}, olm::is_valid_olm_key, @@ -41,7 +40,6 @@ can_request_more_keys: bool, ) -> Result<(Option, bool), Error> { use crate::constants::devices_table; - use crate::constants::one_time_keys_table as otk_table; use crate::constants::retry; use crate::constants::ONE_TIME_KEY_MINIMUM_THRESHOLD; @@ -84,25 +82,15 @@ return Ok((None, requested_more_keys)); } - let query_result = self - .get_next_one_time_key(user_id, device_id, account_type) - .await?; - let mut items = query_result.items.unwrap_or_default(); - let mut item = items.pop().unwrap_or_default(); - let pk = item.take_attr(otk_table::PARTITION_KEY)?; - let sk = item.take_attr(otk_table::SORT_KEY)?; - let otk: String = item.take_attr(otk_table::ATTR_ONE_TIME_KEY)?; - - let delete_otk = Delete::builder() - .table_name(otk_table::NAME) - .key(otk_table::PARTITION_KEY, AttributeValue::S(pk)) - .key(otk_table::SORT_KEY, AttributeValue::S(sk)) - .condition_expression("attribute_exists(#otk)") - .expression_attribute_names("#otk", otk_table::ATTR_ONE_TIME_KEY) - .build(); + let Some(otk_row) = self + .get_one_time_keys(user_id, device_id, account_type, 1) + .await? + .pop() + else { + return Err(Error::NotEnoughOneTimeKeys); + }; - let delete_otk_operation = - TransactWriteItem::builder().delete(delete_otk).build(); + let delete_otk_operation = otk_row.as_delete_request(); let update_otk_count = Update::builder() .table_name(devices_table::NAME) @@ -141,7 +129,7 @@ .await; match transaction { - Ok(_) => return Ok((Some(otk), requested_more_keys)), + Ok(_) => return Ok((Some(otk_row.otk), requested_more_keys)), Err(e) => { let dynamo_db_error = DynamoDBError::from(e); let retryable_codes = HashSet::from([ @@ -162,28 +150,47 @@ } } - async fn get_next_one_time_key( + async fn get_one_time_keys( &self, user_id: &str, device_id: &str, account_type: OlmAccountType, - ) -> Result { + num_keys: usize, + ) -> Result, Error> { use crate::constants::one_time_keys_table::*; + // DynamoDB will reject the `query` request if `limit < 1` + if num_keys < 1 { + return Ok(Vec::new()); + } + let partition_key = create_one_time_key_partition_key(user_id, device_id, account_type); - self + let otk_rows = self .client .query() .table_name(NAME) .key_condition_expression("#pk = :pk") .expression_attribute_names("#pk", PARTITION_KEY) .expression_attribute_values(":pk", AttributeValue::S(partition_key)) - .limit(1) + .limit(num_keys as i32) .send() .await - .map_err(|e| Error::AwsSdk(e.into())) + .map_err(|e| Error::AwsSdk(e.into()))? + .items + .unwrap_or_default() + .into_iter() + .map(OTKRow::try_from) + .collect::, _>>() + .map_err(Error::from)?; + + if otk_rows.len() != num_keys { + error!("There are fewer one-time keys than the number requested"); + return Err(Error::NotEnoughOneTimeKeys); + } + + Ok(otk_rows) } pub async fn append_one_time_prekeys( @@ -195,11 +202,11 @@ ) -> Result<(), Error> { use crate::constants::retry; - let num_content_keys = content_one_time_keys.len(); - let num_notif_keys = notif_one_time_keys.len(); + let num_content_keys_to_append = content_one_time_keys.len(); + let num_notif_keys_to_append = notif_one_time_keys.len(); - if num_content_keys > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT - || num_notif_keys > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT + if num_content_keys_to_append > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT + || num_notif_keys_to_append > ONE_TIME_KEY_UPLOAD_LIMIT_PER_ACCOUNT { return Err(Error::OneTimeKeyUploadLimitExceeded); } @@ -230,17 +237,54 @@ current_time, ); - let update_otk_count_operation = into_one_time_update_requests( - user_id, - device_id, - num_content_keys, - num_notif_keys, - ); + let current_content_otk_count = self + .get_otk_count(user_id, device_id, OlmAccountType::Content) + .await?; + + let current_notif_otk_count = self + .get_otk_count(user_id, device_id, OlmAccountType::Notification) + .await?; + + let num_content_keys_to_delete = (num_content_keys_to_append + + current_content_otk_count) + .saturating_sub(MAX_ONE_TIME_KEYS); + + let num_notif_keys_to_delete = (num_notif_keys_to_append + + current_notif_otk_count) + .saturating_sub(MAX_ONE_TIME_KEYS); + + let content_keys_to_delete = self + .get_one_time_keys( + user_id, + device_id, + OlmAccountType::Content, + num_content_keys_to_delete, + ) + .await?; + + let notif_keys_to_delete = self + .get_one_time_keys( + user_id, + device_id, + OlmAccountType::Notification, + num_notif_keys_to_delete, + ) + .await?; + + let update_and_delete_otk_count_operation = + into_one_time_update_and_delete_requests( + user_id, + device_id, + num_content_keys_to_append, + num_notif_keys_to_append, + content_keys_to_delete, + notif_keys_to_delete, + ); let mut operations = Vec::new(); operations.extend_from_slice(&content_otk_requests); operations.extend_from_slice(¬if_otk_requests); - operations.push(update_otk_count_operation); + operations.extend_from_slice(&update_and_delete_otk_count_operation); // TODO: Introduce `transact_write_helper` similar to `batch_write_helper` // in `comm-lib` to handle transactions with retries @@ -324,3 +368,49 @@ } } } + +pub struct OTKRow { + pub partition_key: String, + pub sort_key: String, + pub otk: String, +} + +impl OTKRow { + pub fn as_delete_request(&self) -> TransactWriteItem { + use crate::constants::one_time_keys_table as otk_table; + + let delete_otk = Delete::builder() + .table_name(otk_table::NAME) + .key( + otk_table::PARTITION_KEY, + AttributeValue::S(self.partition_key.to_string()), + ) + .key( + otk_table::SORT_KEY, + AttributeValue::S(self.sort_key.to_string()), + ) + .condition_expression("attribute_exists(#otk)") + .expression_attribute_names("#otk", otk_table::ATTR_ONE_TIME_KEY) + .build(); + + TransactWriteItem::builder().delete(delete_otk).build() + } +} + +impl TryFrom for OTKRow { + type Error = DBItemError; + + fn try_from(mut attrs: AttributeMap) -> Result { + use crate::constants::one_time_keys_table as otk_table; + + let partition_key = attrs.take_attr(otk_table::PARTITION_KEY)?; + let sort_key = attrs.take_attr(otk_table::SORT_KEY)?; + let otk: String = attrs.take_attr(otk_table::ATTR_ONE_TIME_KEY)?; + + Ok(Self { + partition_key, + sort_key, + otk, + }) + } +} diff --git a/services/identity/src/ddb_utils.rs b/services/identity/src/ddb_utils.rs --- a/services/identity/src/ddb_utils.rs +++ b/services/identity/src/ddb_utils.rs @@ -17,7 +17,7 @@ USERS_TABLE_SOCIAL_PROOF_ATTRIBUTE_NAME, USERS_TABLE_USERNAME_ATTRIBUTE, USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE, }, - database::DeviceIDAttribute, + database::{DeviceIDAttribute, OTKRow}, siwe::SocialProof, }; @@ -102,14 +102,29 @@ .collect() } -pub fn into_one_time_update_requests( +pub fn into_one_time_update_and_delete_requests( user_id: &str, device_id: &str, - num_content_keys: usize, - num_notif_keys: usize, -) -> TransactWriteItem { + num_content_keys_to_append: usize, + num_notif_keys_to_append: usize, + content_keys_to_delete: Vec, + notif_keys_to_delete: Vec, +) -> Vec { use crate::constants::devices_table; + let mut transactions = Vec::new(); + + for otk_row in content_keys_to_delete.iter().chain(¬if_keys_to_delete) { + let delete_otk_operation = otk_row.as_delete_request(); + transactions.push(delete_otk_operation) + } + + let content_key_count_delta = + num_content_keys_to_append - content_keys_to_delete.len(); + + let notif_key_count_delta = + num_notif_keys_to_append - notif_keys_to_delete.len(); + let update_otk_count = Update::builder() .table_name(devices_table::NAME) .key( @@ -127,17 +142,21 @@ )) .expression_attribute_values( ":num_content", - AttributeValue::N(num_content_keys.to_string()), + AttributeValue::N(content_key_count_delta.to_string()), ) .expression_attribute_values( ":num_notif", - AttributeValue::N(num_notif_keys.to_string()), + AttributeValue::N(notif_key_count_delta.to_string()), ) .build(); - TransactWriteItem::builder() + let update_otk_count_operation = TransactWriteItem::builder() .update(update_otk_count) - .build() + .build(); + + transactions.push(update_otk_count_operation); + + transactions } pub trait DateTimeExt { diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -30,6 +30,8 @@ MaxRetriesExceeded, #[display(...)] InvalidFormat, + #[display(...)] + NotEnoughOneTimeKeys, } #[derive(Debug, derive_more::Display, derive_more::Error)]