diff --git a/services/commtest/tests/identity_one_time_key_tests.rs b/services/commtest/tests/identity_one_time_key_tests.rs new file mode 100644 --- /dev/null +++ b/services/commtest/tests/identity_one_time_key_tests.rs @@ -0,0 +1,35 @@ +mod client { + tonic::include_proto!("identity.client"); +} +mod auth_proto { + tonic::include_proto!("identity.authenticated"); +} +use client::identity_client_service_client::IdentityClientServiceClient; +use client::UploadOneTimeKeysRequest; +use commtest::identity::device::create_device; + +#[tokio::test] +async fn verify_access_token() { + let device_info = create_device().await; + + let mut identity_client = + IdentityClientServiceClient::connect("http://127.0.0.1:50054") + .await + .expect("Couldn't connect to identity service"); + + let upload_request = UploadOneTimeKeysRequest { + user_id: device_info.user_id, + device_id: device_info.device_id, + access_token: device_info.access_token, + content_one_time_pre_keys: vec![ + "content1".to_string(), + "content2".to_string(), + ], + notif_one_time_pre_keys: vec!["notif1".to_string(), "notif2".to_string()], + }; + + identity_client + .upload_one_time_keys(upload_request) + .await + .unwrap(); +} diff --git a/services/identity/src/client_service.rs b/services/identity/src/client_service.rs --- a/services/identity/src/client_service.rs +++ b/services/identity/src/client_service.rs @@ -791,7 +791,6 @@ self .client .append_one_time_prekeys( - message.user_id, message.device_id, message.content_one_time_pre_keys, message.notif_one_time_pre_keys, diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -85,6 +85,16 @@ pub const RESERVED_USERNAMES_TABLE: &str = "identity-reserved-usernames"; pub const RESERVED_USERNAMES_TABLE_PARTITION_KEY: &str = "username"; +// One time keys table, which need to exist in their own table to ensure +// atomicity of additions and removals +pub mod one_time_keys_table { + pub const NAME: &'static str = "identity-one-time-keys"; + pub const PARTITION_KEY: &'static str = "deviceID"; + pub const DEVICE_ID: &'static str = PARTITION_KEY; + pub const SORT_KEY: &'static str = "oneTimeKey"; + pub const ONE_TIME_KEY: &'static str = SORT_KEY; +} + // Tokio pub const MPSC_CHANNEL_BUFFER_CAPACITY: usize = 1; diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use std::sync::Arc; +use crate::ddb_utils::{into_one_time_put_requests, OlmAccountType}; use crate::error::{DBItemAttributeError, DBItemError, Error}; use aws_config::SdkConfig; use aws_sdk_dynamodb::model::{AttributeValue, PutRequest, WriteRequest}; @@ -25,13 +26,11 @@ NONCE_TABLE_CREATED_ATTRIBUTE, NONCE_TABLE_PARTITION_KEY, RESERVED_USERNAMES_TABLE, RESERVED_USERNAMES_TABLE_PARTITION_KEY, USERS_TABLE, USERS_TABLE_DEVICES_ATTRIBUTE, - USERS_TABLE_DEVICES_MAP_CONTENT_ONETIME_KEYS_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_CONTENT_PREKEY_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_CONTENT_PREKEY_SIGNATURE_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_DEVICE_TYPE_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_KEY_PAYLOAD_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_KEY_PAYLOAD_SIGNATURE_ATTRIBUTE_NAME, - USERS_TABLE_DEVICES_MAP_NOTIF_ONETIME_KEYS_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_SIGNATURE_ATTRIBUTE_NAME, USERS_TABLE_DEVICES_MAP_SOCIAL_PROOF_ATTRIBUTE_NAME, @@ -295,45 +294,34 @@ pub async fn append_one_time_prekeys( &self, - user_id: String, device_id: String, content_one_time_keys: Vec, notif_one_time_keys: Vec, ) -> Result<(), Error> { - let notif_keys_av: Vec = notif_one_time_keys - .into_iter() - .map(AttributeValue::S) - .collect(); - let content_keys_av: Vec = content_one_time_keys - .into_iter() - .map(AttributeValue::S) - .collect(); + use crate::constants::one_time_keys_table; - let update_expression = - format!("SET {0}.#{1}.{2} = list_append({0}.#{1}.{2}, :n), {0}.#{1}.{3} = list_append({0}.#{1}.{3}, :i)", - USERS_TABLE_DEVICES_ATTRIBUTE, - "deviceID", - USERS_TABLE_DEVICES_MAP_NOTIF_ONETIME_KEYS_ATTRIBUTE_NAME, - USERS_TABLE_DEVICES_MAP_CONTENT_ONETIME_KEYS_ATTRIBUTE_NAME - ); - let expression_attribute_names = - HashMap::from([(format!("#{}", "deviceID"), device_id)]); - let expression_attribute_values = HashMap::from([ - (":n".to_string(), AttributeValue::L(notif_keys_av)), - (":i".to_string(), AttributeValue::L(content_keys_av)), - ]); + let mut otk_requests = into_one_time_put_requests( + &device_id, + content_one_time_keys, + OlmAccountType::Content, + ); + let notif_otk_requests: Vec = into_one_time_put_requests( + &device_id, + notif_one_time_keys, + OlmAccountType::Notification, + ); + otk_requests.extend(notif_otk_requests); - self - .client - .update_item() - .table_name(USERS_TABLE) - .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id)) - .update_expression(update_expression) - .set_expression_attribute_names(Some(expression_attribute_names)) - .set_expression_attribute_values(Some(expression_attribute_values)) - .send() - .await - .map_err(|e| Error::AwsSdk(e.into()))?; + // BatchWriteItem has a hard limit of 25 writes per call + for requests in otk_requests.chunks(25) { + self + .client + .batch_write_item() + .request_items(one_time_keys_table::NAME, requests.to_vec()) + .send() + .await + .map_err(|e| Error::AwsSdk(e.into()))?; + } Ok(()) } @@ -344,14 +332,21 @@ flattened_device_key_upload: FlattenedDeviceKeyUpload, social_proof: Option, ) -> Result<(), Error> { + use crate::constants::one_time_keys_table; + + // Avoid borrowing from lifetime of flattened_device_key_upload + let device_id = flattened_device_key_upload.device_id_key.clone(); + let content_one_time_keys = + flattened_device_key_upload.content_onetime_keys.clone(); + let notif_one_time_keys = + flattened_device_key_upload.notif_onetime_keys.clone(); + let device_info = - create_device_info(flattened_device_key_upload.clone(), social_proof); + create_device_info(flattened_device_key_upload, social_proof); let update_expression = format!("SET {}.#{} = :v", USERS_TABLE_DEVICES_ATTRIBUTE, "deviceID",); - let expression_attribute_names = HashMap::from([( - format!("#{}", "deviceID"), - flattened_device_key_upload.device_id_key, - )]); + let expression_attribute_names = + HashMap::from([(format!("#{}", "deviceID"), device_id.clone())]); let expression_attribute_values = HashMap::from([(":v".to_string(), AttributeValue::M(device_info))]); @@ -367,6 +362,29 @@ .await .map_err(|e| Error::AwsSdk(e.into()))?; + let mut otk_requests = into_one_time_put_requests( + &device_id, + content_one_time_keys, + OlmAccountType::Content, + ); + let notif_otk_requests = into_one_time_put_requests( + &device_id, + notif_one_time_keys, + OlmAccountType::Notification, + ); + otk_requests.extend(notif_otk_requests); + + // BatchWriteItem only supports 25 writes in a call + for requests in otk_requests.chunks(25) { + self + .client + .batch_write_item() + .request_items(one_time_keys_table::NAME, requests.to_vec()) + .send() + .await + .map_err(|e| Error::AwsSdk(e.into()))?; + } + Ok(()) } @@ -1042,16 +1060,6 @@ .to_string(), AttributeValue::S(flattened_device_key_upload.content_prekey_signature), ), - ( - USERS_TABLE_DEVICES_MAP_CONTENT_ONETIME_KEYS_ATTRIBUTE_NAME.to_string(), - AttributeValue::L( - flattened_device_key_upload - .content_onetime_keys - .into_iter() - .map(AttributeValue::S) - .collect(), - ), - ), ( USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_ATTRIBUTE_NAME.to_string(), AttributeValue::S(flattened_device_key_upload.notif_prekey), @@ -1060,16 +1068,6 @@ USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_SIGNATURE_ATTRIBUTE_NAME.to_string(), AttributeValue::S(flattened_device_key_upload.notif_prekey_signature), ), - ( - USERS_TABLE_DEVICES_MAP_NOTIF_ONETIME_KEYS_ATTRIBUTE_NAME.to_string(), - AttributeValue::L( - flattened_device_key_upload - .notif_onetime_keys - .into_iter() - .map(AttributeValue::S) - .collect(), - ), - ), ]); if let Some(social_proof) = social_proof { diff --git a/services/identity/src/ddb_utils.rs b/services/identity/src/ddb_utils.rs new file mode 100644 --- /dev/null +++ b/services/identity/src/ddb_utils.rs @@ -0,0 +1,58 @@ +use aws_sdk_dynamodb::model::{AttributeValue, PutRequest, WriteRequest}; +use std::collections::HashMap; +use std::iter::IntoIterator; + +#[derive(Copy, Clone, Debug)] +pub enum OlmAccountType { + Content, + Notification, +} + +// Prefix the one time keys with the olm account variant. This allows for a single +// DDB table to contain both notification and content keys for a device. +fn create_one_time_key_partition_key( + device_id: &str, + account_type: OlmAccountType, +) -> String { + match account_type { + OlmAccountType::Content => format!("content_{device_id}"), + OlmAccountType::Notification => format!("notification_{device_id}"), + } +} + +fn create_one_time_key_put_request( + device_id: &str, + one_time_key: String, + account_type: OlmAccountType, +) -> WriteRequest { + use crate::constants::one_time_keys_table::*; + + let partition_key = + create_one_time_key_partition_key(device_id, account_type); + let builder = PutRequest::builder(); + let attrs = HashMap::from([ + (PARTITION_KEY.to_string(), AttributeValue::S(partition_key)), + (SORT_KEY.to_string(), AttributeValue::S(one_time_key)), + ]); + + let put_request = builder.set_item(Some(attrs)).build(); + + WriteRequest::builder().put_request(put_request).build() +} + +pub fn into_one_time_put_requests( + device_id: &str, + one_time_keys: T, + account_type: OlmAccountType, +) -> Vec +where + T: IntoIterator, + ::Item: ToString, +{ + one_time_keys + .into_iter() + .map(|otk| { + create_one_time_key_put_request(device_id, otk.to_string(), account_type) + }) + .collect() +} diff --git a/services/identity/src/main.rs b/services/identity/src/main.rs --- a/services/identity/src/main.rs +++ b/services/identity/src/main.rs @@ -9,6 +9,7 @@ mod config; pub mod constants; mod database; +pub mod ddb_utils; pub mod error; mod grpc_services; mod id; diff --git a/services/terraform/modules/shared/dynamodb.tf b/services/terraform/modules/shared/dynamodb.tf --- a/services/terraform/modules/shared/dynamodb.tf +++ b/services/terraform/modules/shared/dynamodb.tf @@ -254,6 +254,23 @@ } } +resource "aws_dynamodb_table" "identity-one-time-keys" { + name = "identity-one-time-keys" + hash_key = "deviceID" + range_key = "oneTimeKey" + billing_mode = "PAY_PER_REQUEST" + + attribute { + name = "deviceID" + type = "S" + } + + attribute { + name = "oneTimeKey" + type = "S" + } +} + resource "aws_dynamodb_table" "feature-flags" { name = "feature-flags" hash_key = "platform"