diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -11,6 +11,8 @@ pub const USERS_TABLE_REGISTRATION_ATTRIBUTE: &str = "pakeRegistrationData"; pub const USERS_TABLE_USERNAME_ATTRIBUTE: &str = "username"; pub const USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE: &str = "userPublicKey"; +pub const USERS_TABLE_DEVICES_ATTRIBUTE: &str = "devices"; +pub const USERS_TABLE_DEVICES_MAP_ATTRIBUTE_NAME: &str = "deviceID"; pub const USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE: &str = "walletAddress"; pub const USERS_TABLE_USERNAME_INDEX: &str = "username-index"; pub const USERS_TABLE_WALLET_ADDRESS_INDEX: &str = "walletAddress-index"; diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -17,10 +17,12 @@ ACCESS_TOKEN_SORT_KEY, ACCESS_TOKEN_TABLE, ACCESS_TOKEN_TABLE_AUTH_TYPE_ATTRIBUTE, ACCESS_TOKEN_TABLE_CREATED_ATTRIBUTE, ACCESS_TOKEN_TABLE_PARTITION_KEY, ACCESS_TOKEN_TABLE_TOKEN_ATTRIBUTE, - ACCESS_TOKEN_TABLE_VALID_ATTRIBUTE, USERS_TABLE, USERS_TABLE_PARTITION_KEY, - USERS_TABLE_REGISTRATION_ATTRIBUTE, USERS_TABLE_USERNAME_ATTRIBUTE, - USERS_TABLE_USERNAME_INDEX, USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE, - USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE, USERS_TABLE_WALLET_ADDRESS_INDEX, + ACCESS_TOKEN_TABLE_VALID_ATTRIBUTE, USERS_TABLE, + USERS_TABLE_DEVICES_ATTRIBUTE, USERS_TABLE_DEVICES_MAP_ATTRIBUTE_NAME, + USERS_TABLE_PARTITION_KEY, USERS_TABLE_REGISTRATION_ATTRIBUTE, + USERS_TABLE_USERNAME_ATTRIBUTE, USERS_TABLE_USERNAME_INDEX, + USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE, USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE, + USERS_TABLE_WALLET_ADDRESS_INDEX, }; use crate::opaque::Cipher; use crate::token::{AccessTokenData, AuthType}; @@ -82,11 +84,13 @@ pub async fn update_users_table( &self, user_id: String, + device_id: String, registration: Option>, username: Option, user_public_key: Option, ) -> Result { let mut update_expression_parts = Vec::new(); + let mut expression_attribute_names = HashMap::new(); let mut expression_attribute_values = HashMap::new(); if let Some(reg) = registration { update_expression_parts @@ -103,10 +107,21 @@ .insert(":u".to_string(), AttributeValue::S(username)); }; if let Some(public_key) = user_public_key { - update_expression_parts - .push(format!("{} = :k", USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE)); - expression_attribute_values - .insert(":k".to_string(), AttributeValue::S(public_key)); + update_expression_parts.push(format!( + "{}.#{} = :k", + USERS_TABLE_DEVICES_ATTRIBUTE, USERS_TABLE_DEVICES_MAP_ATTRIBUTE_NAME, + )); + expression_attribute_names.insert( + format!("#{}", USERS_TABLE_DEVICES_MAP_ATTRIBUTE_NAME), + device_id, + ); + expression_attribute_values.insert( + ":k".to_string(), + AttributeValue::M(HashMap::from([( + USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE.to_string(), + AttributeValue::S(public_key), + )])), + ); }; self @@ -115,7 +130,20 @@ .table_name(USERS_TABLE) .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id)) .update_expression(format!("SET {}", update_expression_parts.join(","))) - .set_expression_attribute_values(Some(expression_attribute_values)) + .set_expression_attribute_names( + if expression_attribute_names.is_empty() { + None + } else { + Some(expression_attribute_names) + }, + ) + .set_expression_attribute_values( + if expression_attribute_values.is_empty() { + None + } else { + Some(expression_attribute_values) + }, + ) .send() .await .map_err(|e| Error::AwsSdk(e.into())) diff --git a/services/identity/src/service.rs b/services/identity/src/service.rs --- a/services/identity/src/service.rs +++ b/services/identity/src/service.rs @@ -128,6 +128,7 @@ let registration_finish_and_login_start_result = match pake_registration_finish( &user_id, + &device_id, client.clone(), &pake_registration_upload_and_credential_request .pake_registration_upload, @@ -626,6 +627,7 @@ async fn pake_registration_finish( user_id: &str, + device_id: &str, client: DatabaseClient, registration_upload_bytes: &[u8], server_registration: Option>, @@ -662,6 +664,7 @@ match client .update_users_table( user_id.to_string(), + device_id.to_string(), Some(server_registration_finish_result), Some(username.to_string()), Some(user_public_key.to_string()),