diff --git a/services/identity/src/database/farcaster.rs b/services/identity/src/database/farcaster.rs --- a/services/identity/src/database/farcaster.rs +++ b/services/identity/src/database/farcaster.rs @@ -1,4 +1,5 @@ use comm_lib::aws::ddb::types::AttributeValue; +use comm_lib::aws::ddb::types::ReturnValue; use comm_lib::database::AttributeExtractor; use comm_lib::database::AttributeMap; use comm_lib::database::DBItemAttributeError; @@ -61,20 +62,37 @@ user_id: String, farcaster_id: String, ) -> Result<(), Error> { - let update_expression = - format!("SET {} = :val", USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME); + let update_expression = format!( + "SET {0} = if_not_exists({0}, :val)", + USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME, + ); - self + let response = self .client .update_item() .table_name(USERS_TABLE) .key(USERS_TABLE_PARTITION_KEY, AttributeValue::S(user_id)) .update_expression(update_expression) - .expression_attribute_values(":val", AttributeValue::S(farcaster_id)) + .expression_attribute_values( + ":val", + AttributeValue::S(farcaster_id.clone()), + ) + .return_values(ReturnValue::UpdatedNew) .send() .await .map_err(|e| Error::AwsSdk(e.into()))?; + match response.attributes { + None => return Err(Error::MissingItem), + Some(mut attrs) => { + let farcaster_id_from_table: String = + attrs.take_attr(USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME)?; + if farcaster_id_from_table != farcaster_id { + return Err(Error::CannotOverwrite); + } + } + } + Ok(()) } } diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -22,6 +22,8 @@ MalformedItem, #[display(...)] Serde(serde_json::Error), + #[display(...)] + CannotOverwrite, } #[derive(Debug, derive_more::Display, derive_more::Error)]