diff --git a/services/identity/src/database/farcaster.rs b/services/identity/src/database/farcaster.rs index c66a1b4f4..0b4550264 100644 --- a/services/identity/src/database/farcaster.rs +++ b/services/identity/src/database/farcaster.rs @@ -1,110 +1,128 @@ use comm_lib::aws::ddb::types::AttributeValue; +use comm_lib::aws::ddb::types::ReturnValue; use comm_lib::database::AttributeExtractor; use comm_lib::database::AttributeMap; use comm_lib::database::DBItemAttributeError; use comm_lib::database::DBItemError; use comm_lib::database::Value; use tracing::error; use crate::constants::USERS_TABLE; use crate::constants::USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME; use crate::constants::USERS_TABLE_FARCASTER_ID_INDEX; use crate::constants::USERS_TABLE_PARTITION_KEY; use crate::constants::USERS_TABLE_USERNAME_ATTRIBUTE; use crate::constants::USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE; use crate::grpc_services::protos::unauth::FarcasterUser; use super::DatabaseClient; use super::Error; pub struct FarcasterUserData(pub FarcasterUser); impl DatabaseClient { pub async fn get_farcaster_users( &self, farcaster_ids: Vec, ) -> Result, Error> { let mut users: Vec = Vec::new(); for id in farcaster_ids { let query_response = self .client .query() .table_name(USERS_TABLE) .index_name(USERS_TABLE_FARCASTER_ID_INDEX) .key_condition_expression(format!( "{} = :val", USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME )) .expression_attribute_values(":val", AttributeValue::S(id)) .send() .await .map_err(|e| { error!("Failed to query users by farcasterID: {:?}", e); Error::AwsSdk(e.into()) })? .items .and_then(|mut items| items.pop()) .map(FarcasterUserData::try_from) .transpose() .map_err(Error::from)?; if let Some(data) = query_response { users.push(data); } } Ok(users) } pub async fn add_farcaster_id( &self, user_id: String, farcaster_id: String, ) -> Result<(), Error> { - let update_expression = - format!("SET {} = :val", USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME); + let update_expression = format!( + "SET {0} = if_not_exists({0}, :val)", + USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME, + ); - self + let response = self .client .update_item() .table_name(USERS_TABLE) .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id)) .update_expression(update_expression) - .expression_attribute_values(":val", AttributeValue::S(farcaster_id)) + .expression_attribute_values( + ":val", + AttributeValue::S(farcaster_id.clone()), + ) + .return_values(ReturnValue::UpdatedNew) .send() .await .map_err(|e| Error::AwsSdk(e.into()))?; + match response.attributes { + None => return Err(Error::MissingItem), + Some(mut attrs) => { + let farcaster_id_from_table: String = + attrs.take_attr(USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME)?; + if farcaster_id_from_table != farcaster_id { + return Err(Error::CannotOverwrite); + } + } + } + Ok(()) } } impl TryFrom for FarcasterUserData { type Error = DBItemError; fn try_from(mut attrs: AttributeMap) -> Result { let user_id = attrs.take_attr(USERS_TABLE_PARTITION_KEY)?; let maybe_username = attrs.take_attr(USERS_TABLE_USERNAME_ATTRIBUTE)?; let maybe_wallet_address = attrs.take_attr(USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE)?; let username = match (maybe_username, maybe_wallet_address) { (Some(u), _) => u, (_, Some(w)) => w, (_, _) => { return Err(DBItemError { attribute_name: USERS_TABLE_USERNAME_ATTRIBUTE.to_string(), attribute_value: Value::AttributeValue(None), attribute_error: DBItemAttributeError::Missing, }); } }; let farcaster_id = attrs.take_attr(USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME)?; Ok(Self(FarcasterUser { user_id, username, farcaster_id, })) } } diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs index 329fb8009..ab2318ee1 100644 --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -1,42 +1,44 @@ use comm_lib::aws::DynamoDBError; use comm_lib::database::DBItemError; use tracing::error; #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(...)] Transport(tonic::transport::Error), #[display(...)] Status(tonic::Status), #[display(...)] MissingItem, #[display(...)] DeviceList(DeviceListError), #[display(...)] MalformedItem, #[display(...)] Serde(serde_json::Error), + #[display(...)] + CannotOverwrite, } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DeviceListError { DeviceAlreadyExists, DeviceNotFound, ConcurrentUpdateError, InvalidDeviceListUpdate, } pub fn consume_error(result: Result) { match result { Ok(_) => (), Err(e) => { error!("{}", e); } } }