diff --git a/services/identity/src/client_service.rs b/services/identity/src/client_service.rs --- a/services/identity/src/client_service.rs +++ b/services/identity/src/client_service.rs @@ -830,18 +830,11 @@ &message.signature, )?; - let mut filtered_usernames = Vec::new(); - - for username in usernames { - if !self - .client - .username_taken(username.clone()) - .await - .map_err(handle_db_error)? - { - filtered_usernames.push(username); - } - } + let filtered_usernames = self + .client + .filter_out_taken_usernames(usernames) + .await + .map_err(handle_db_error)?; self .client diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -1,5 +1,5 @@ use constant_time_eq::constant_time_eq; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::str::FromStr; use std::sync::Arc; @@ -485,6 +485,23 @@ Ok(result.is_some()) } + pub async fn filter_out_taken_usernames( + &self, + usernames: Vec, + ) -> Result, Error> { + let db_usernames = self.get_all_usernames().await?; + + let db_usernames_set: HashSet = db_usernames.into_iter().collect(); + let usernames_set: HashSet = usernames.into_iter().collect(); + + let available_usernames: Vec = usernames_set + .difference(&db_usernames_set) + .cloned() + .collect(); + + Ok(available_usernames) + } + async fn get_user_from_user_info( &self, user_info: String, @@ -631,12 +648,12 @@ .map_err(|e| Error::AwsSdk(e.into())) } - pub async fn get_users(&self) -> Result, Error> { + async fn get_all_usernames(&self) -> Result, Error> { let scan_output = self .client .scan() .table_name(USERS_TABLE) - .projection_expression(USERS_TABLE_PARTITION_KEY) + .projection_expression(USERS_TABLE_USERNAME_ATTRIBUTE) .send() .await .map_err(|e| Error::AwsSdk(e.into()))?; @@ -644,12 +661,12 @@ let mut result = Vec::new(); if let Some(attributes) = scan_output.items { for mut attribute in attributes { - let id = parse_string_attribute( - USERS_TABLE_PARTITION_KEY, - attribute.remove(USERS_TABLE_PARTITION_KEY), - ) - .map_err(Error::Attribute)?; - result.push(id); + if let Ok(username) = parse_string_attribute( + USERS_TABLE_USERNAME_ATTRIBUTE, + attribute.remove(USERS_TABLE_USERNAME_ATTRIBUTE), + ) { + result.push(username); + } } } Ok(result) @@ -683,49 +700,29 @@ &self, usernames: Vec, ) -> Result<(), Error> { - let mut write_requests = vec![]; - - for username in usernames { - let item: HashMap = vec![( - RESERVED_USERNAMES_TABLE_PARTITION_KEY.to_string(), - AttributeValue::S(username), - )] - .into_iter() - .collect(); - - let write_request = WriteRequest::builder() - .put_request(PutRequest::builder().set_item(Some(item)).build()) - .build(); - - write_requests.push(write_request); - } - - loop { - let output = self + // A single call to BatchWriteItem can consist of up to 25 operations + for usernames_chunk in usernames.chunks(25) { + let write_requests = usernames_chunk + .iter() + .map(|username| { + let put_request = PutRequest::builder() + .item( + RESERVED_USERNAMES_TABLE_PARTITION_KEY, + AttributeValue::S(username.to_string()), + ) + .build(); + + WriteRequest::builder().put_request(put_request).build() + }) + .collect(); + + self .client .batch_write_item() .request_items(RESERVED_USERNAMES_TABLE, write_requests) .send() .await .map_err(|e| Error::AwsSdk(e.into()))?; - - let unprocessed_items_map = match output.unprocessed_items() { - Some(map) => map, - None => break, - }; - - let unprocessed_requests = - match unprocessed_items_map.get(RESERVED_USERNAMES_TABLE) { - Some(requests) => requests, - None => break, - }; - - info!( - "{} unprocessed items, retrying...", - unprocessed_requests.len() - ); - - write_requests = unprocessed_requests.to_vec(); } info!("Batch write item to reserved usernames table succeeded");