diff --git a/services/identity/src/ddb_utils.rs b/services/identity/src/ddb_utils.rs index 4a9b56ab2..126ecaab8 100644 --- a/services/identity/src/ddb_utils.rs +++ b/services/identity/src/ddb_utils.rs @@ -1,293 +1,282 @@ -use chrono::{DateTime, NaiveDateTime, Utc}; +use chrono::{DateTime, Utc}; use comm_lib::{ aws::{ ddb::types::{ error::TransactionCanceledException, AttributeValue, Put, TransactWriteItem, Update, }, DynamoDBError, }, database::{AttributeExtractor, AttributeMap}, }; use std::collections::{HashMap, HashSet}; use std::iter::IntoIterator; use crate::{ constants::{ USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME, USERS_TABLE_SOCIAL_PROOF_ATTRIBUTE_NAME, USERS_TABLE_USERNAME_ATTRIBUTE, USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE, }, database::{DeviceIDAttribute, OTKRow}, siwe::SocialProof, }; #[derive(Copy, Clone, Debug)] pub enum OlmAccountType { Content, Notification, } pub fn create_one_time_key_partition_key( user_id: &str, device_id: &str, account_type: OlmAccountType, ) -> String { let account_type = match account_type { OlmAccountType::Content => "content", OlmAccountType::Notification => "notif", }; format!("{user_id}#{device_id}#{account_type}") } fn create_one_time_key_sort_key( key_number: usize, current_time: DateTime, ) -> String { let timestamp = current_time.to_rfc3339(); format!("{timestamp}#{:02}", key_number) } fn create_one_time_key_put_request( user_id: &str, device_id: &str, one_time_key: String, key_number: usize, account_type: OlmAccountType, current_time: DateTime, ) -> Put { use crate::constants::one_time_keys_table::*; let partition_key = create_one_time_key_partition_key(user_id, device_id, account_type); let sort_key = create_one_time_key_sort_key(key_number, current_time); let builder = Put::builder(); let attrs = HashMap::from([ (PARTITION_KEY.to_string(), AttributeValue::S(partition_key)), (SORT_KEY.to_string(), AttributeValue::S(sort_key)), ( ATTR_ONE_TIME_KEY.to_string(), AttributeValue::S(one_time_key), ), ]); builder.table_name(NAME).set_item(Some(attrs)).build() } pub fn into_one_time_put_requests( user_id: &str, device_id: &str, one_time_keys: T, account_type: OlmAccountType, current_time: DateTime, ) -> Vec where T: IntoIterator, ::Item: ToString, { one_time_keys .into_iter() .enumerate() .map(|(index, otk)| { create_one_time_key_put_request( user_id, device_id, otk.to_string(), index, account_type, current_time, ) }) .map(|put_request| TransactWriteItem::builder().put(put_request).build()) .collect() } pub fn into_one_time_update_and_delete_requests( user_id: &str, device_id: &str, num_content_keys_to_append: usize, num_notif_keys_to_append: usize, content_keys_to_delete: Vec, notif_keys_to_delete: Vec, ) -> Vec { use crate::constants::devices_table; let mut transactions = Vec::new(); for otk_row in content_keys_to_delete.iter().chain(¬if_keys_to_delete) { let delete_otk_operation = otk_row.as_delete_request(); transactions.push(delete_otk_operation) } let content_key_count_delta = num_content_keys_to_append - content_keys_to_delete.len(); let notif_key_count_delta = num_notif_keys_to_append - notif_keys_to_delete.len(); let update_otk_count = Update::builder() .table_name(devices_table::NAME) .key( devices_table::ATTR_USER_ID, AttributeValue::S(user_id.to_string()), ) .key( devices_table::ATTR_ITEM_ID, DeviceIDAttribute(device_id.into()).into(), ) .update_expression(format!( "ADD {} :num_content, {} :num_notif", devices_table::ATTR_CONTENT_OTK_COUNT, devices_table::ATTR_NOTIF_OTK_COUNT )) .expression_attribute_values( ":num_content", AttributeValue::N(content_key_count_delta.to_string()), ) .expression_attribute_values( ":num_notif", AttributeValue::N(notif_key_count_delta.to_string()), ) .build(); let update_otk_count_operation = TransactWriteItem::builder() .update(update_otk_count) .build(); transactions.push(update_otk_count_operation); transactions } -pub trait DateTimeExt { - fn from_utc_timestamp_millis(timestamp: i64) -> Option>; -} - -impl DateTimeExt for DateTime { - fn from_utc_timestamp_millis(timestamp: i64) -> Option { - let naive = NaiveDateTime::from_timestamp_millis(timestamp)?; - Some(Self::from_naive_utc_and_offset(naive, Utc)) - } -} - pub struct DBIdentity { pub identifier: Identifier, pub farcaster_id: Option, } pub enum Identifier { Username(String), WalletAddress(EthereumIdentity), } impl Identifier { pub fn username(&self) -> &str { match self { Identifier::Username(username) => username, Identifier::WalletAddress(eth_identity) => ð_identity.wallet_address, } } } pub struct EthereumIdentity { pub wallet_address: String, pub social_proof: SocialProof, } impl TryFrom for DBIdentity { type Error = crate::error::Error; fn try_from(mut value: AttributeMap) -> Result { let farcaster_id = value.take_attr(USERS_TABLE_FARCASTER_ID_ATTRIBUTE_NAME)?; let username_result = value.take_attr(USERS_TABLE_USERNAME_ATTRIBUTE); if let Ok(username) = username_result { return Ok(DBIdentity { identifier: Identifier::Username(username), farcaster_id, }); } let wallet_address_result = value.take_attr(USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE); let social_proof_result = value.take_attr(USERS_TABLE_SOCIAL_PROOF_ATTRIBUTE_NAME); if let (Ok(wallet_address), Ok(social_proof)) = (wallet_address_result, social_proof_result) { Ok(DBIdentity { identifier: Identifier::WalletAddress(EthereumIdentity { wallet_address, social_proof, }), farcaster_id, }) } else { Err(Self::Error::MalformedItem) } } } pub fn is_transaction_retryable( err: &DynamoDBError, retryable_codes: &HashSet<&str>, ) -> bool { match err { DynamoDBError::TransactionCanceledException( TransactionCanceledException { cancellation_reasons: Some(reasons), .. }, ) => reasons.iter().any(|reason| { retryable_codes.contains(&reason.code().unwrap_or_default()) }), _ => false, } } #[cfg(test)] mod tests { use crate::constants::one_time_keys_table; use super::*; #[test] fn test_into_one_time_put_requests() { let otks = ["not", "real", "keys"]; let current_time = Utc::now(); let requests = into_one_time_put_requests( "abc", "123", otks, OlmAccountType::Content, current_time, ); assert_eq!(requests.len(), 3); for (index, request) in requests.into_iter().enumerate() { let mut item = request.put.unwrap().item.unwrap(); assert_eq!( item.remove(one_time_keys_table::PARTITION_KEY).unwrap(), AttributeValue::S("abc#123#content".to_string()) ); assert_eq!( item.remove(one_time_keys_table::SORT_KEY).unwrap(), AttributeValue::S(format!( "{}#{:02}", current_time.to_rfc3339(), index )) ); assert_eq!( item.remove(one_time_keys_table::ATTR_ONE_TIME_KEY).unwrap(), AttributeValue::S(otks[index].to_string()) ); } } } diff --git a/services/identity/src/device_list.rs b/services/identity/src/device_list.rs index 5263d96f9..340b9090b 100644 --- a/services/identity/src/device_list.rs +++ b/services/identity/src/device_list.rs @@ -1,555 +1,553 @@ use chrono::{DateTime, Duration, Utc}; use std::{collections::HashSet, str::FromStr}; use tracing::{debug, error, warn}; use crate::{ constants::{error_types, DEVICE_LIST_TIMESTAMP_VALID_FOR}, database::{DeviceListRow, DeviceListUpdate}, - ddb_utils::DateTimeExt, error::DeviceListError, grpc_services::protos::auth::UpdateDeviceListRequest, }; // serde helper for serializing/deserializing // device list JSON payload #[derive(serde::Serialize, serde::Deserialize)] struct RawDeviceList { devices: Vec, timestamp: i64, } /// Signed device list payload that is serializable to JSON. /// For the DDB payload, see [`DeviceListUpdate`] #[derive(Clone, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "camelCase")] pub struct SignedDeviceList { /// JSON-stringified [`RawDeviceList`] raw_device_list: String, /// Current primary device signature. /// NOTE: Present only when the payload is received from primary device. /// It's `None` for Identity-generated device-lists #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] cur_primary_signature: Option, /// Previous primary device signature. Present only /// if primary device has changed since last update. #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] last_primary_signature: Option, } impl SignedDeviceList { fn as_raw(&self) -> Result { // The device list payload is sent as an escaped JSON payload. // Escaped double quotes need to be trimmed before attempting to deserialize serde_json::from_str(&self.raw_device_list.replace(r#"\""#, r#"""#)) .map_err(|err| { warn!("Failed to deserialize raw device list: {}", err); tonic::Status::invalid_argument("invalid device list payload") }) } /// Serializes the signed device list to a JSON string pub fn as_json_string(&self) -> Result { serde_json::to_string(self).map_err(|err| { error!( errorType = error_types::GRPC_SERVICES_LOG, "Failed to serialize device list updates: {}", err ); tonic::Status::failed_precondition("unexpected error") }) } } impl TryFrom for SignedDeviceList { type Error = tonic::Status; fn try_from(row: DeviceListRow) -> Result { let raw_list = RawDeviceList { devices: row.device_ids, timestamp: row.timestamp.timestamp_millis(), }; let stringified_list = serde_json::to_string(&raw_list).map_err(|err| { error!( errorType = error_types::GRPC_SERVICES_LOG, "Failed to serialize raw device list: {}", err ); tonic::Status::failed_precondition("unexpected error") })?; Ok(Self { raw_device_list: stringified_list, cur_primary_signature: row.current_primary_signature, last_primary_signature: row.last_primary_signature, }) } } impl TryFrom for SignedDeviceList { type Error = tonic::Status; fn try_from(request: UpdateDeviceListRequest) -> Result { request.new_device_list.parse().map_err(|err| { warn!("Failed to deserialize device list update: {}", err); tonic::Status::invalid_argument("invalid device list payload") }) } } impl FromStr for SignedDeviceList { type Err = serde_json::Error; fn from_str(s: &str) -> Result { serde_json::from_str(s) } } impl TryFrom for DeviceListUpdate { type Error = tonic::Status; fn try_from(signed_list: SignedDeviceList) -> Result { let RawDeviceList { devices, timestamp: raw_timestamp, } = signed_list.as_raw()?; - let timestamp = DateTime::::from_utc_timestamp_millis(raw_timestamp) - .ok_or_else(|| { - error!( - errorType = error_types::GRPC_SERVICES_LOG, - "Failed to parse RawDeviceList timestamp!" - ); - tonic::Status::invalid_argument("invalid timestamp") - })?; + let timestamp = + DateTime::from_timestamp_millis(raw_timestamp).ok_or_else(|| { + error!( + errorType = error_types::GRPC_SERVICES_LOG, + "Failed to parse RawDeviceList timestamp!" + ); + tonic::Status::invalid_argument("invalid timestamp") + })?; Ok(DeviceListUpdate { devices, timestamp, current_primary_signature: signed_list.cur_primary_signature, last_primary_signature: signed_list.last_primary_signature, raw_payload: signed_list.raw_device_list, }) } } /// Returns `true` if given timestamp is valid. The timestamp is considered /// valid under the following condition: /// - `new_timestamp` is greater than `previous_timestamp` (if provided) /// - `new_timestamp` is not older than [`DEVICE_LIST_TIMESTAMP_VALID_FOR`] /// /// Note: For Identity-managed device lists, the timestamp can be `None`. /// Verification is then skipped fn is_new_timestamp_valid( previous_timestamp: Option<&DateTime>, new_timestamp: Option<&DateTime>, ) -> bool { let Some(new_timestamp) = new_timestamp else { return true; }; if let Some(previous_timestamp) = previous_timestamp { if new_timestamp < previous_timestamp { return false; } } let timestamp_valid_duration = Duration::from_std(DEVICE_LIST_TIMESTAMP_VALID_FOR) .expect("FATAL - Invalid duration constant provided"); Utc::now().signed_duration_since(new_timestamp) < timestamp_valid_duration } /// Returns error if new timestamp is invalid. The timestamp is considered /// valid under the following condition: /// - `new_timestamp` is greater than `previous_timestamp` (if provided) /// - `new_timestamp` is not older than [`DEVICE_LIST_TIMESTAMP_VALID_FOR`] /// /// Note: For Identity-managed device lists, the timestamp can be `None`. /// Verification is then skipped pub fn verify_device_list_timestamp( previous_timestamp: Option<&DateTime>, new_timestamp: Option<&DateTime>, ) -> Result<(), DeviceListError> { if !is_new_timestamp_valid(previous_timestamp, new_timestamp) { return Err(DeviceListError::InvalidDeviceListUpdate); } Ok(()) } pub fn verify_device_list_signatures( previous_primary_device_id: Option<&String>, new_device_list: &DeviceListUpdate, ) -> Result<(), DeviceListError> { let Some(primary_device_id) = new_device_list.devices.first() else { return Ok(()); }; // verify current signature if let Some(signature) = &new_device_list.current_primary_signature { crate::grpc_utils::ed25519_verify( primary_device_id, &new_device_list.raw_payload, signature, ) .map_err(|err| { debug!("curPrimarySignature verification failed: {err}"); DeviceListError::InvalidSignature })?; } // verify last signature if primary device changed if let (Some(previous_primary_id), Some(last_signature)) = ( previous_primary_device_id.filter(|prev| *prev != primary_device_id), &new_device_list.last_primary_signature, ) { crate::grpc_utils::ed25519_verify( previous_primary_id, &new_device_list.raw_payload, last_signature, ) .map_err(|err| { debug!("lastPrimarySignature verification failed: {err}"); DeviceListError::InvalidSignature })?; } Ok(()) } pub fn verify_initial_device_list( device_list: &DeviceListUpdate, expected_primary_device_id: &str, ) -> Result<(), tonic::Status> { use tonic::Status; if device_list.last_primary_signature.is_some() { debug!("Received lastPrimarySignature for initial device list"); return Err(Status::invalid_argument( "invalid device list: unexpected lastPrimarySignature", )); } let Some(signature) = &device_list.current_primary_signature else { debug!("Missing curPrimarySignature for initial device list"); return Err(Status::invalid_argument( "invalid device list: signature missing", )); }; crate::grpc_utils::ed25519_verify( expected_primary_device_id, &device_list.raw_payload, signature, )?; if device_list.devices.len() != 1 { debug!("Invalid device list length"); return Err(Status::invalid_argument( "invalid device list: invalid length", )); } if device_list .devices .first() .filter(|it| **it == expected_primary_device_id) .is_none() { debug!("Invalid primary device ID for initial device list"); return Err(Status::invalid_argument( "invalid device list: invalid primary device", )); } Ok(()) } pub mod validation { use super::*; /// Returns `true` if `new_device_list` contains exactly one more new device /// compared to `previous_device_list` fn is_device_added( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { let previous_set: HashSet<_> = previous_device_list.iter().collect(); let new_set: HashSet<_> = new_device_list.iter().collect(); return new_set.difference(&previous_set).count() == 1; } /// Returns `true` if `new_device_list` contains exactly one fewer device /// compared to `previous_device_list` fn is_device_removed( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { let previous_set: HashSet<_> = previous_device_list.iter().collect(); let new_set: HashSet<_> = new_device_list.iter().collect(); return previous_set.difference(&new_set).count() == 1; } fn primary_device_changed( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { let previous_primary = previous_device_list.first(); let new_primary = new_device_list.first(); new_primary != previous_primary } /// Verifies if exactly one device has been replaced. /// No reorders are permitted. Both lists have to have the same length. fn is_device_replaced( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { if previous_device_list.len() != new_device_list.len() { return false; } // exactly 1 different device ID std::iter::zip(previous_device_list, new_device_list) .filter(|(a, b)| a != b) .count() == 1 } // This is going to be used when doing primary devicd keys rotation #[allow(unused)] pub fn primary_device_rotation_validator( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { primary_device_changed(previous_device_list, new_device_list) && !is_device_replaced(&previous_device_list[1..], &new_device_list[1..]) } /// The `UpdateDeviceList` RPC should be able to either add or remove /// one device, and it cannot currently switch primary devices. /// The RPC is also able to replace a keyserver device pub fn update_device_list_rpc_validator( previous_device_list: &[&str], new_device_list: &[&str], ) -> bool { if primary_device_changed(previous_device_list, new_device_list) { return false; } // allow replacing a keyserver if is_device_replaced(previous_device_list, new_device_list) { return true; } let is_added = is_device_added(previous_device_list, new_device_list); let is_removed = is_device_removed(previous_device_list, new_device_list); is_added != is_removed } #[cfg(test)] mod tests { use super::*; #[test] fn test_device_added_or_removed() { use std::ops::Not; let list1 = vec!["device1"]; let list2 = vec!["device1", "device2"]; assert!(is_device_added(&list1, &list2)); assert!(is_device_removed(&list1, &list2).not()); assert!(is_device_added(&list2, &list1).not()); assert!(is_device_removed(&list2, &list1)); assert!(is_device_added(&list1, &list1).not()); assert!(is_device_removed(&list1, &list1).not()); } #[test] fn test_primary_device_changed() { use std::ops::Not; let list1 = vec!["device1"]; let list2 = vec!["device1", "device2"]; let list3 = vec!["device2"]; assert!(primary_device_changed(&list1, &list2).not()); assert!(primary_device_changed(&list1, &list3)); } #[test] fn test_device_replaced() { use std::ops::Not; let list1 = vec!["device1"]; let list2 = vec!["device2"]; let list3 = vec!["device1", "device2"]; let list4 = vec!["device2", "device1"]; let list5 = vec!["device2", "device3"]; assert!(is_device_replaced(&list1, &list2), "Singleton replacement"); assert!(is_device_replaced(&list4, &list5), "Standard replacement"); assert!(is_device_replaced(&list1, &list3).not(), "Length unequal"); assert!(is_device_replaced(&list3, &list3).not(), "Unchanged"); assert!(is_device_replaced(&list3, &list4).not(), "Reorder"); } } } #[cfg(test)] mod tests { use super::*; #[test] fn deserialize_device_list_signature() { let payload_with_signature = r#"{"rawDeviceList":"{\"devices\":[\"device1\"],\"timestamp\":111111111}","curPrimarySignature":"foo"}"#; let payload_without_signatures = r#"{"rawDeviceList":"{\"devices\":[\"device1\",\"device2\"],\"timestamp\":222222222}"}"#; let list_with_signature: SignedDeviceList = serde_json::from_str(payload_with_signature).unwrap(); let list_without_signatures: SignedDeviceList = serde_json::from_str(payload_without_signatures).unwrap(); assert_eq!( list_with_signature.cur_primary_signature, Some("foo".to_string()) ); assert!(list_with_signature.last_primary_signature.is_none()); assert!(list_without_signatures.cur_primary_signature.is_none()); assert!(list_without_signatures.last_primary_signature.is_none()); } #[test] fn serialize_device_list_signatures() { let raw_list = r#"{"devices":["device1"],"timestamp":111111111}"#; let expected_payload_without_signatures = r#"{"rawDeviceList":"{\"devices\":[\"device1\"],\"timestamp\":111111111}"}"#; let device_list_without_signature = SignedDeviceList { raw_device_list: raw_list.to_string(), cur_primary_signature: None, last_primary_signature: None, }; assert_eq!( device_list_without_signature.as_json_string().unwrap(), expected_payload_without_signatures ); let expected_payload_with_signature = r#"{"rawDeviceList":"{\"devices\":[\"device1\"],\"timestamp\":111111111}","curPrimarySignature":"foo"}"#; let device_list_with_cur_signature = SignedDeviceList { raw_device_list: raw_list.to_string(), cur_primary_signature: Some("foo".to_string()), last_primary_signature: None, }; assert_eq!( device_list_with_cur_signature.as_json_string().unwrap(), expected_payload_with_signature ); } #[test] fn serialize_device_list_updates() { let raw_updates = vec![ create_device_list_row(RawDeviceList { devices: vec!["device1".into()], timestamp: 111111111, }), create_device_list_row(RawDeviceList { devices: vec!["device1".into(), "device2".into()], timestamp: 222222222, }), ]; let expected_raw_list1 = r#"{"devices":["device1"],"timestamp":111111111}"#; let expected_raw_list2 = r#"{"devices":["device1","device2"],"timestamp":222222222}"#; let signed_updates = raw_updates .into_iter() .map(SignedDeviceList::try_from) .collect::, _>>() .expect("signing device list updates failed"); assert_eq!(signed_updates[0].raw_device_list, expected_raw_list1); assert_eq!(signed_updates[1].raw_device_list, expected_raw_list2); let stringified_updates = signed_updates .iter() .map(serde_json::to_string) .collect::, _>>() .expect("serialize signed device lists failed"); let expected_stringified_list1 = r#"{"rawDeviceList":"{\"devices\":[\"device1\"],\"timestamp\":111111111}"}"#; let expected_stringified_list2 = r#"{"rawDeviceList":"{\"devices\":[\"device1\",\"device2\"],\"timestamp\":222222222}"}"#; assert_eq!(stringified_updates[0], expected_stringified_list1); assert_eq!(stringified_updates[1], expected_stringified_list2); } #[test] fn deserialize_device_list_update() { let raw_payload = r#"{"rawDeviceList":"{\"devices\":[\"device1\",\"device2\"],\"timestamp\":123456789}"}"#; let request = UpdateDeviceListRequest { new_device_list: raw_payload.to_string(), }; let signed_list = SignedDeviceList::try_from(request) .expect("Failed to parse SignedDeviceList"); let update = DeviceListUpdate::try_from(signed_list) .expect("Failed to parse DeviceListUpdate from signed list"); let expected_timestamp = - DateTime::::from_utc_timestamp_millis(123456789).unwrap(); + DateTime::from_timestamp_millis(123456789).unwrap(); assert_eq!(update.timestamp, expected_timestamp); assert_eq!( update.devices, vec!["device1".to_string(), "device2".to_string()] ); } #[test] fn test_timestamp_validation() { let valid_timestamp = Utc::now() - Duration::milliseconds(100); let previous_timestamp = Utc::now() - Duration::seconds(10); let too_old_timestamp = previous_timestamp - Duration::seconds(1); let expired_timestamp = Utc::now() - Duration::minutes(20); assert!( verify_device_list_timestamp( Some(&previous_timestamp), Some(&valid_timestamp) ) .is_ok(), "Valid timestamp should pass verification" ); assert!( verify_device_list_timestamp( Some(&previous_timestamp), Some(&too_old_timestamp) ) .is_err(), "Timestamp older than previous, should fail verification" ); assert!( verify_device_list_timestamp(None, Some(&expired_timestamp)).is_err(), "Expired timestamp should fail verification" ); assert!( verify_device_list_timestamp(None, None).is_ok(), "No provided timestamp should pass" ); } /// helper for mocking DB rows from raw device list payloads fn create_device_list_row(raw_list: RawDeviceList) -> DeviceListRow { DeviceListRow { user_id: "".to_string(), device_ids: raw_list.devices, - timestamp: DateTime::::from_utc_timestamp_millis(raw_list.timestamp) - .unwrap(), + timestamp: DateTime::from_timestamp_millis(raw_list.timestamp).unwrap(), current_primary_signature: None, last_primary_signature: None, } } } diff --git a/services/identity/src/grpc_services/authenticated.rs b/services/identity/src/grpc_services/authenticated.rs index 66febd175..33dd565e5 100644 --- a/services/identity/src/grpc_services/authenticated.rs +++ b/services/identity/src/grpc_services/authenticated.rs @@ -1,746 +1,745 @@ use std::collections::HashMap; use crate::config::CONFIG; use crate::database::DeviceListUpdate; use crate::device_list::SignedDeviceList; use crate::{ client_service::{handle_db_error, UpdateState, WorkflowInProgress}, constants::{error_types, request_metadata}, database::DatabaseClient, - ddb_utils::DateTimeExt, grpc_services::shared::get_value, }; -use chrono::{DateTime, Utc}; +use chrono::DateTime; use comm_opaque2::grpc::protocol_error_to_grpc_status; use tonic::{Request, Response, Status}; use tracing::{debug, error, trace, warn}; use super::protos::auth::{ identity_client_service_server::IdentityClientService, DeletePasswordUserFinishRequest, DeletePasswordUserStartRequest, DeletePasswordUserStartResponse, GetDeviceListRequest, GetDeviceListResponse, InboundKeyInfo, InboundKeysForUserRequest, InboundKeysForUserResponse, KeyserverKeysResponse, LinkFarcasterAccountRequest, OutboundKeyInfo, OutboundKeysForUserRequest, OutboundKeysForUserResponse, PeersDeviceListsRequest, PeersDeviceListsResponse, RefreshUserPrekeysRequest, UpdateDeviceListRequest, UpdateUserPasswordFinishRequest, UpdateUserPasswordStartRequest, UpdateUserPasswordStartResponse, UploadOneTimeKeysRequest, UserIdentitiesRequest, UserIdentitiesResponse, }; use super::protos::unauth::Empty; #[derive(derive_more::Constructor)] pub struct AuthenticatedService { db_client: DatabaseClient, } fn get_auth_info(req: &Request<()>) -> Option<(String, String, String)> { trace!("Retrieving auth info for request: {:?}", req); let user_id = get_value(req, request_metadata::USER_ID)?; let device_id = get_value(req, request_metadata::DEVICE_ID)?; let access_token = get_value(req, request_metadata::ACCESS_TOKEN)?; Some((user_id, device_id, access_token)) } pub fn auth_interceptor( req: Request<()>, db_client: &DatabaseClient, ) -> Result, Status> { trace!("Intercepting request to check auth info: {:?}", req); let (user_id, device_id, access_token) = get_auth_info(&req) .ok_or_else(|| Status::unauthenticated("Missing credentials"))?; let handle = tokio::runtime::Handle::current(); let new_db_client = db_client.clone(); // This function cannot be `async`, yet must call the async db call // Force tokio to resolve future in current thread without an explicit .await let valid_token = tokio::task::block_in_place(move || { handle .block_on(new_db_client.verify_access_token( user_id, device_id, access_token, )) .map_err(handle_db_error) })?; if !valid_token { return Err(Status::aborted("Bad Credentials")); } Ok(req) } pub fn get_user_and_device_id( request: &Request, ) -> Result<(String, String), Status> { let user_id = get_value(request, request_metadata::USER_ID) .ok_or_else(|| Status::unauthenticated("Missing user_id field"))?; let device_id = get_value(request, request_metadata::DEVICE_ID) .ok_or_else(|| Status::unauthenticated("Missing device_id field"))?; Ok((user_id, device_id)) } #[tonic::async_trait] impl IdentityClientService for AuthenticatedService { #[tracing::instrument(skip_all)] async fn refresh_user_prekeys( &self, request: Request, ) -> Result, Status> { let (user_id, device_id) = get_user_and_device_id(&request)?; let message = request.into_inner(); debug!("Refreshing prekeys for user: {}", user_id); let content_keys = message .new_content_prekeys .ok_or_else(|| Status::invalid_argument("Missing content keys"))?; let notif_keys = message .new_notif_prekeys .ok_or_else(|| Status::invalid_argument("Missing notification keys"))?; self .db_client .update_device_prekeys( user_id, device_id, content_keys.into(), notif_keys.into(), ) .await .map_err(handle_db_error)?; let response = Response::new(Empty {}); Ok(response) } #[tracing::instrument(skip_all)] async fn get_outbound_keys_for_user( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); let user_id = &message.user_id; let devices_map = self .db_client .get_keys_for_user(user_id, true) .await .map_err(handle_db_error)? .ok_or_else(|| tonic::Status::not_found("user not found"))?; let transformed_devices = devices_map .into_iter() .map(|(key, device_info)| (key, OutboundKeyInfo::from(device_info))) .collect::>(); Ok(tonic::Response::new(OutboundKeysForUserResponse { devices: transformed_devices, })) } #[tracing::instrument(skip_all)] async fn get_inbound_keys_for_user( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); let user_id = &message.user_id; let devices_map = self .db_client .get_keys_for_user(user_id, false) .await .map_err(handle_db_error)? .ok_or_else(|| tonic::Status::not_found("user not found"))?; let transformed_devices = devices_map .into_iter() .map(|(key, device_info)| (key, InboundKeyInfo::from(device_info))) .collect::>(); let identifier = self .db_client .get_user_identity(user_id) .await .map_err(handle_db_error)? .ok_or_else(|| tonic::Status::not_found("user not found"))?; Ok(tonic::Response::new(InboundKeysForUserResponse { devices: transformed_devices, identity: Some(identifier.into()), })) } #[tracing::instrument(skip_all)] async fn get_keyserver_keys( &self, request: Request, ) -> Result, Status> { let message = request.into_inner(); let identifier = self .db_client .get_user_identity(&message.user_id) .await .map_err(handle_db_error)? .ok_or_else(|| tonic::Status::not_found("user not found"))?; let Some(keyserver_info) = self .db_client .get_keyserver_keys_for_user(&message.user_id) .await .map_err(handle_db_error)? else { return Err(Status::not_found("keyserver not found")); }; let primary_device_data = self .db_client .get_primary_device_data(&message.user_id) .await .map_err(handle_db_error)?; let primary_device_keys = primary_device_data.device_key_info; let response = Response::new(KeyserverKeysResponse { keyserver_info: Some(keyserver_info.into()), identity: Some(identifier.into()), primary_device_identity_info: Some(primary_device_keys.into()), }); return Ok(response); } #[tracing::instrument(skip_all)] async fn upload_one_time_keys( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, device_id) = get_user_and_device_id(&request)?; let message = request.into_inner(); debug!("Attempting to update one time keys for user: {}", user_id); self .db_client .append_one_time_prekeys( &user_id, &device_id, &message.content_one_time_prekeys, &message.notif_one_time_prekeys, ) .await .map_err(handle_db_error)?; Ok(tonic::Response::new(Empty {})) } #[tracing::instrument(skip_all)] async fn update_user_password_start( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; let message = request.into_inner(); let server_registration = comm_opaque2::server::Registration::new(); let server_message = server_registration .start( &CONFIG.server_setup, &message.opaque_registration_request, user_id.as_bytes(), ) .map_err(protocol_error_to_grpc_status)?; let update_state = UpdateState { user_id }; let session_id = self .db_client .insert_workflow(WorkflowInProgress::Update(update_state)) .await .map_err(handle_db_error)?; let response = UpdateUserPasswordStartResponse { session_id, opaque_registration_response: server_message, }; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn update_user_password_finish( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); let Some(WorkflowInProgress::Update(state)) = self .db_client .get_workflow(message.session_id) .await .map_err(handle_db_error)? else { return Err(tonic::Status::not_found("session not found")); }; let server_registration = comm_opaque2::server::Registration::new(); let password_file = server_registration .finish(&message.opaque_registration_upload) .map_err(protocol_error_to_grpc_status)?; self .db_client .update_user_password(state.user_id, password_file) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn log_out_user( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, device_id) = get_user_and_device_id(&request)?; self .db_client .remove_device(&user_id, &device_id) .await .map_err(handle_db_error)?; self .db_client .delete_otks_table_rows_for_user_device(&user_id, &device_id) .await .map_err(handle_db_error)?; self .db_client .delete_access_token_data(user_id, device_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn log_out_secondary_device( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, device_id) = get_user_and_device_id(&request)?; debug!( "Secondary device logout request for user_id={}, device_id={}", user_id, device_id ); self .verify_device_on_device_list( &user_id, &device_id, DeviceListItemKind::Secondary, ) .await?; self .db_client .delete_access_token_data(&user_id, &device_id) .await .map_err(handle_db_error)?; self .db_client .remove_device_data(&user_id, &device_id) .await .map_err(handle_db_error)?; self .db_client .delete_otks_table_rows_for_user_device(&user_id, &device_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn delete_wallet_user( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; debug!("Attempting to delete wallet user: {}", user_id); self .db_client .delete_user(user_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn delete_password_user_start( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; let message = request.into_inner(); debug!("Attempting to start deleting password user: {}", user_id); let maybe_username_and_password_file = self .db_client .get_username_and_password_file(&user_id) .await .map_err(handle_db_error)?; let Some((username, password_file_bytes)) = maybe_username_and_password_file else { return Err(tonic::Status::not_found("user not found")); }; let mut server_login = comm_opaque2::server::Login::new(); let server_response = server_login .start( &CONFIG.server_setup, &password_file_bytes, &message.opaque_login_request, username.as_bytes(), ) .map_err(protocol_error_to_grpc_status)?; let delete_state = construct_delete_password_user_info(server_login); let session_id = self .db_client .insert_workflow(WorkflowInProgress::PasswordUserDeletion(Box::new( delete_state, ))) .await .map_err(handle_db_error)?; let response = Response::new(DeletePasswordUserStartResponse { session_id, opaque_login_response: server_response, }); Ok(response) } #[tracing::instrument(skip_all)] async fn delete_password_user_finish( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; let message = request.into_inner(); debug!("Attempting to finish deleting password user: {}", user_id); let Some(WorkflowInProgress::PasswordUserDeletion(state)) = self .db_client .get_workflow(message.session_id) .await .map_err(handle_db_error)? else { return Err(tonic::Status::not_found("session not found")); }; let mut server_login = state.opaque_server_login; server_login .finish(&message.opaque_login_upload) .map_err(protocol_error_to_grpc_status)?; self .db_client .delete_user(user_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn get_device_list_for_user( &self, request: tonic::Request, ) -> Result, tonic::Status> { let GetDeviceListRequest { user_id, since_timestamp, } = request.into_inner(); let since = since_timestamp .map(|timestamp| { - DateTime::::from_utc_timestamp_millis(timestamp) + DateTime::from_timestamp_millis(timestamp) .ok_or_else(|| tonic::Status::invalid_argument("Invalid timestamp")) }) .transpose()?; let mut db_result = self .db_client .get_device_list_history(user_id, since) .await .map_err(handle_db_error)?; // these should be sorted already, but just in case db_result.sort_by_key(|list| list.timestamp); let device_list_updates: Vec = db_result .into_iter() .map(SignedDeviceList::try_from) .collect::, _>>()?; let stringified_updates = device_list_updates .iter() .map(SignedDeviceList::as_json_string) .collect::, _>>()?; Ok(Response::new(GetDeviceListResponse { device_list_updates: stringified_updates, })) } #[tracing::instrument(skip_all)] async fn get_device_lists_for_users( &self, request: tonic::Request, ) -> Result, tonic::Status> { let PeersDeviceListsRequest { user_ids } = request.into_inner(); // do all fetches concurrently let mut fetch_tasks = tokio::task::JoinSet::new(); let mut device_lists = HashMap::with_capacity(user_ids.len()); for user_id in user_ids { let db_client = self.db_client.clone(); fetch_tasks.spawn(async move { let result = db_client.get_current_device_list(&user_id).await; (user_id, result) }); } while let Some(task_result) = fetch_tasks.join_next().await { match task_result { Ok((user_id, Ok(Some(device_list_row)))) => { let signed_list = SignedDeviceList::try_from(device_list_row)?; let serialized_list = signed_list.as_json_string()?; device_lists.insert(user_id, serialized_list); } Ok((user_id, Ok(None))) => { warn!(user_id, "User has no device list, skipping!"); } Ok((user_id, Err(err))) => { error!( user_id, errorType = error_types::GRPC_SERVICES_LOG, "Failed fetching device list: {err}" ); // abort fetching other users fetch_tasks.abort_all(); return Err(handle_db_error(err)); } Err(join_error) => { error!( errorType = error_types::GRPC_SERVICES_LOG, "Failed to join device list task: {join_error}" ); fetch_tasks.abort_all(); return Err(Status::aborted("unexpected error")); } } } let response = PeersDeviceListsResponse { users_device_lists: device_lists, }; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn update_device_list( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _device_id) = get_user_and_device_id(&request)?; // TODO: when we stop doing "primary device rotation" (migration procedure) // we should verify if this RPC is called by primary device only let new_list = SignedDeviceList::try_from(request.into_inner())?; let update = DeviceListUpdate::try_from(new_list)?; self .db_client .apply_devicelist_update( &user_id, update, crate::device_list::validation::update_device_list_rpc_validator, ) .await .map_err(handle_db_error)?; Ok(Response::new(Empty {})) } #[tracing::instrument(skip_all)] async fn link_farcaster_account( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; let message = request.into_inner(); let mut get_farcaster_users_response = self .db_client .get_farcaster_users(vec![message.farcaster_id.clone()]) .await .map_err(handle_db_error)?; if get_farcaster_users_response.len() > 1 { error!( errorType = error_types::GRPC_SERVICES_LOG, "multiple users associated with the same Farcaster ID" ); return Err(Status::failed_precondition("cannot link Farcaster ID")); } if let Some(u) = get_farcaster_users_response.pop() { if u.0.user_id == user_id { return Ok(Response::new(Empty {})); } else { return Err(Status::already_exists( "farcaster ID already associated with different user", )); } } self .db_client .add_farcaster_id(user_id, message.farcaster_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn unlink_farcaster_account( &self, request: tonic::Request, ) -> Result, tonic::Status> { let (user_id, _) = get_user_and_device_id(&request)?; self .db_client .remove_farcaster_id(user_id) .await .map_err(handle_db_error)?; let response = Empty {}; Ok(Response::new(response)) } #[tracing::instrument(skip_all)] async fn find_user_identities( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); let results = self .db_client .find_db_user_identities(message.user_ids) .await .map_err(handle_db_error)?; let mapped_results = results .into_iter() .map(|(user_id, identifier)| (user_id, identifier.into())) .collect(); let response = UserIdentitiesResponse { identities: mapped_results, }; return Ok(Response::new(response)); } } enum DeviceListItemKind { Any, Primary, Secondary, } impl AuthenticatedService { async fn verify_device_on_device_list( &self, user_id: &String, device_id: &String, device_kind: DeviceListItemKind, ) -> Result<(), tonic::Status> { let device_list = self .db_client .get_current_device_list(user_id) .await .map_err(|err| { error!( user_id, errorType = error_types::GRPC_SERVICES_LOG, "Failed fetching device list: {err}" ); handle_db_error(err) })?; let Some(device_list) = device_list else { error!( user_id, errorType = error_types::GRPC_SERVICES_LOG, "User has no device list!" ); return Err(Status::failed_precondition("no device list")); }; use DeviceListItemKind as DeviceKind; let device_on_list = match device_kind { DeviceKind::Any => device_list.has_device(device_id), DeviceKind::Primary => device_list.is_primary_device(device_id), DeviceKind::Secondary => device_list.has_secondary_device(device_id), }; if !device_on_list { debug!( "Device {} not on device list for user {}", device_id, user_id ); return Err(Status::permission_denied("device not on device list")); } Ok(()) } } #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct DeletePasswordUserInfo { pub opaque_server_login: comm_opaque2::server::Login, } fn construct_delete_password_user_info( opaque_server_login: comm_opaque2::server::Login, ) -> DeletePasswordUserInfo { DeletePasswordUserInfo { opaque_server_login, } } diff --git a/shared/comm-lib/src/auth/service.rs b/shared/comm-lib/src/auth/service.rs index 18127e370..7d0c79942 100644 --- a/shared/comm-lib/src/auth/service.rs +++ b/shared/comm-lib/src/auth/service.rs @@ -1,165 +1,164 @@ use aws_sdk_secretsmanager::Client as SecretsManagerClient; -use chrono::{DateTime, Duration, NaiveDateTime, Utc}; +use chrono::{DateTime, Duration, Utc}; use grpc_clients::identity::unauthenticated::client as identity_client; use super::{AuthorizationCredential, ServicesAuthToken, UserIdentity}; const SECRET_NAME: &str = "servicesToken"; /// duration for which we consider previous token valid /// after rotation const ROTATION_PROTECTION_PERIOD: i64 = 3; // seconds // AWS managed version tags for secrets const AWSCURRENT: &str = "AWSCURRENT"; const AWSPREVIOUS: &str = "AWSPREVIOUS"; // Identity service gRPC clients require a code version and device type. // We can supply some placeholder values for services for the time being, since // this metadata is only relevant for devices. const PLACEHOLDER_CODE_VERSION: u64 = 0; const DEVICE_TYPE: &str = "service"; #[derive( Debug, derive_more::Display, derive_more::Error, derive_more::From, )] pub enum AuthServiceError { SecretManagerError(aws_sdk_secretsmanager::Error), GrpcError(grpc_clients::error::Error), Unexpected, } type AuthServiceResult = Result; /// This service is responsible for handling request authentication. /// For HTTP services, it should be added as app data to the server: /// ```ignore /// let auth_service = AuthService::new(&aws_config, &config.identity_endpoint); /// let auth_middleware = get_comm_authentication_middleware(); /// App::new() /// .app_data(auth_service.clone()) /// .wrap(auth_middleware) /// // ... /// ``` #[derive(Clone)] pub struct AuthService { secrets_manager: SecretsManagerClient, identity_service_url: String, } impl AuthService { pub fn new( aws_cfg: &aws_config::SdkConfig, identity_service_url: impl Into, ) -> Self { let secrets_client = SecretsManagerClient::new(aws_cfg); AuthService { secrets_manager: secrets_client, identity_service_url: identity_service_url.into(), } } /// Obtains a service-to-service token which can be used to authenticate /// when calling other services endpoints. It should be only used when /// no [`UserIdentity`] is provided from client pub async fn get_services_token( &self, ) -> AuthServiceResult { get_services_token_version(&self.secrets_manager, AWSCURRENT) .await .map_err(AuthServiceError::from) } /// Verifies the provided [`AuthorizationCredential`]. Returns `true` if /// authentication was successful. pub async fn verify_auth_credential( &self, credential: &AuthorizationCredential, ) -> AuthServiceResult { match credential { AuthorizationCredential::UserToken(user) => { let UserIdentity { user_id, device_id, access_token, } = user; identity_client::verify_user_access_token( &self.identity_service_url, user_id, device_id, access_token, PLACEHOLDER_CODE_VERSION, DEVICE_TYPE.to_string(), ) .await .map_err(AuthServiceError::from) } AuthorizationCredential::ServicesToken(token) => { verify_services_token(&self.secrets_manager, token) .await .map_err(AuthServiceError::from) } } } } async fn get_services_token_version( client: &SecretsManagerClient, version: impl Into, ) -> Result { let result = client .get_secret_value() .secret_id(SECRET_NAME) .version_stage(version) .send() .await?; let token = result .secret_string() .expect("Services token secret is not a string. This should not happen"); Ok(ServicesAuthToken::new(token.to_string())) } async fn time_since_rotation( client: &SecretsManagerClient, ) -> Result, aws_sdk_secretsmanager::Error> { let result = client .describe_secret() .secret_id(SECRET_NAME) .send() .await?; let duration = result .last_rotated_date() .and_then(|date| date.to_millis().ok()) - .and_then(NaiveDateTime::from_timestamp_millis) - .map(|naive| DateTime::::from_naive_utc_and_offset(naive, Utc)) + .and_then(DateTime::from_timestamp_millis) .map(|last_rotated| Utc::now().signed_duration_since(last_rotated)); Ok(duration) } async fn verify_services_token( client: &SecretsManagerClient, token_to_verify: &ServicesAuthToken, ) -> Result { let actual_token = get_services_token_version(client, AWSCURRENT).await?; // we need to always get it to achieve constant time eq let last_rotated = time_since_rotation(client).await?; let was_recently_rotated = last_rotated .filter(|rotation_time| { *rotation_time < Duration::seconds(ROTATION_PROTECTION_PERIOD) }) .is_some(); let is_valid = *token_to_verify == actual_token; // token might have just been rotated. In this case check the previous token // this case makes the function non-constant time, but it happens very rarely if !is_valid && was_recently_rotated { let previous_token = get_services_token_version(client, AWSPREVIOUS).await?; let previous_valid = *token_to_verify == previous_token; return Ok(previous_valid); } Ok(is_valid) } diff --git a/shared/comm-lib/src/database.rs b/shared/comm-lib/src/database.rs index 76e28dd30..36dfbc4b3 100644 --- a/shared/comm-lib/src/database.rs +++ b/shared/comm-lib/src/database.rs @@ -1,860 +1,858 @@ use aws_sdk_dynamodb::types::AttributeValue; pub use aws_sdk_dynamodb::Error as DynamoDBError; use chrono::{DateTime, Utc}; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::num::ParseIntError; use std::str::FromStr; #[cfg(feature = "blob-client")] pub mod blob; // # Useful type aliases // Rust exports `pub type` only into the so-called "type namespace", but in // order to use them e.g. with the `TryFromAttribute` trait, they also need // to be exported into the "value namespace" which is what `pub use` does. // // To overcome that, a dummy module is created and aliases are re-exported // with `pub use` construct mod aliases { use aws_sdk_dynamodb::types::AttributeValue; use std::collections::HashMap; pub type AttributeMap = HashMap; } pub use self::aliases::AttributeMap; // # Error handling #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(fmt = "Maximum retries exceeded")] MaxRetriesExceeded, } #[derive(Debug, derive_more::From)] pub enum Value { AttributeValue(Option), String(String), } #[derive(Debug, derive_more::Error, derive_more::Constructor)] pub struct DBItemError { pub attribute_name: String, pub attribute_value: Value, pub attribute_error: DBItemAttributeError, } impl Display for DBItemError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match &self.attribute_error { DBItemAttributeError::Missing => { write!(f, "Attribute {} is missing", self.attribute_name) } DBItemAttributeError::IncorrectType => write!( f, "Value for attribute {} has incorrect type: {:?}", self.attribute_name, self.attribute_value ), error => write!( f, "Error regarding attribute {} with value {:?}: {}", self.attribute_name, self.attribute_value, error ), } } } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DBItemAttributeError { #[display(...)] Missing, #[display(...)] IncorrectType, #[display(...)] TimestampOutOfRange, #[display(...)] InvalidTimestamp(chrono::ParseError), #[display(...)] InvalidNumberFormat(ParseIntError), #[display(...)] ExpiredTimestamp, #[display(...)] InvalidValue, } /// Conversion trait for [`AttributeValue`] /// /// Types implementing this trait are able to do the following: /// ```ignore /// use comm_lib::database::{TryFromAttribute, AttributeTryInto}; /// /// let foo = SomeType::try_from_attr("MyAttribute", Some(attribute))?; /// /// // if `AttributeTryInto` is imported, also: /// let bar = Some(attribute).attr_try_into("MyAttribute")?; /// ``` pub trait TryFromAttribute: Sized { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result; } /// Do NOT implement this trait directly. Implement [`TryFromAttribute`] instead pub trait AttributeTryInto { fn attr_try_into( self, attribute_name: impl Into, ) -> Result; } // Automatic attr_try_into() for all attribute values // that have TryFromAttribute implemented impl AttributeTryInto for Option { fn attr_try_into( self, attribute_name: impl Into, ) -> Result { T::try_from_attr(attribute_name, self) } } /// Helper trait for extracting attributes from a collection pub trait AttributeExtractor { /// Gets an attribute from the map and tries to convert it to the given type /// This method does not consume the raw attribute - it gets cloned /// See [`AttributeExtractor::take_attr`] for a non-cloning method fn get_attr( &self, attribute_name: &str, ) -> Result; /// Takes an attribute from the map and tries to convert it to the given type /// This method consumes the raw attribute - it gets removed from the map /// See [`AttributeExtractor::get_attr`] for a non-mutating method fn take_attr( &mut self, attribute_name: &str, ) -> Result; } impl AttributeExtractor for AttributeMap { fn get_attr( &self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.get(attribute_name).cloned()) } fn take_attr( &mut self, attribute_name: &str, ) -> Result { T::try_from_attr(attribute_name, self.remove(attribute_name)) } } // this allows us to get optional attributes impl TryFromAttribute for Option where T: TryFromAttribute, { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { if attribute.is_none() { return Ok(None); } match T::try_from_attr(attribute_name, attribute) { Ok(value) => Ok(Some(value)), Err(DBItemError { attribute_error: DBItemAttributeError::Missing, .. }) => Ok(None), Err(error) => Err(error), } } } impl TryFromAttribute for String { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::S(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for bool { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Bool(value)) => Ok(value), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for DateTime { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { match &attribute { Some(AttributeValue::S(datetime)) => datetime.parse().map_err(|e| { DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::InvalidTimestamp(e), ) }), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for AttributeMap { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::M(map)) => Ok(map), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for Vec { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::B(data)) => Ok(data.into_inner()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for HashSet { fn try_from_attr( attribute_name: impl Into, attribute_value: Option, ) -> Result { match attribute_value { Some(AttributeValue::Ss(set)) => Ok(set.into_iter().collect()), Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } } impl TryFromAttribute for Vec { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { let attribute_name = attribute_name.into(); match attribute { Some(AttributeValue::L(list)) => Ok( list .into_iter() .map(|attribute| { T::try_from_attr(format!("{attribute_name}[i]"), Some(attribute)) }) .collect::, _>>()?, ), Some(_) => Err(DBItemError::new( attribute_name, Value::AttributeValue(attribute), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name, Value::AttributeValue(attribute), DBItemAttributeError::Missing, )), } } } #[deprecated = "Use `String::try_from_attr()` instead"] pub fn parse_string_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { String::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `bool::try_from_attr()` instead"] pub fn parse_bool_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { bool::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `DateTime::::try_from_attr()` instead"] pub fn parse_datetime_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { DateTime::::try_from_attr(attribute_name, attribute_value) } #[deprecated = "Use `AttributeMap::try_from_attr()` instead"] pub fn parse_map_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result { attribute_value.attr_try_into(attribute_name) } pub fn parse_int_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result where T: FromStr, { match &attribute_value { Some(AttributeValue::N(numeric_str)) => { parse_integer(attribute_name, numeric_str) } Some(_) => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::IncorrectType, )), None => Err(DBItemError::new( attribute_name.into(), Value::AttributeValue(attribute_value), DBItemAttributeError::Missing, )), } } /// Parses the UTC timestamp in milliseconds from a DynamoDB numeric attribute pub fn parse_timestamp_attribute( attribute_name: impl Into, attribute_value: Option, ) -> Result, DBItemError> { let attribute_name: String = attribute_name.into(); let timestamp = parse_int_attribute::( attribute_name.clone(), attribute_value.clone(), )?; - let naive_datetime = chrono::NaiveDateTime::from_timestamp_millis(timestamp) - .ok_or_else(|| { - DBItemError::new( - attribute_name, - Value::AttributeValue(attribute_value), - DBItemAttributeError::TimestampOutOfRange, - ) - })?; - Ok(DateTime::from_naive_utc_and_offset(naive_datetime, Utc)) + chrono::DateTime::from_timestamp_millis(timestamp).ok_or_else(|| { + DBItemError::new( + attribute_name, + Value::AttributeValue(attribute_value), + DBItemAttributeError::TimestampOutOfRange, + ) + }) } pub fn parse_integer( attribute_name: impl Into, attribute_value: &str, ) -> Result where T: FromStr, { attribute_value.parse::().map_err(|e| { DBItemError::new( attribute_name.into(), Value::String(attribute_value.into()), DBItemAttributeError::InvalidNumberFormat(e), ) }) } pub mod batch_operations { use aws_sdk_dynamodb::{ error::SdkError, operation::batch_write_item::BatchWriteItemError, types::{KeysAndAttributes, WriteRequest}, Error as DynamoDBError, }; use rand::Rng; use std::time::Duration; use tracing::{debug, trace}; use super::AttributeMap; /// DynamoDB hard limit for single BatchWriteItem request const SINGLE_BATCH_WRITE_ITEM_LIMIT: usize = 25; const SINGLE_BATCH_GET_ITEM_LIMIT: usize = 100; /// Exponential backoff configuration for batch write operation #[derive(derive_more::Constructor, Debug)] pub struct ExponentialBackoffConfig { /// Maximum retry attempts before the function fails. /// Set this to 0 to disable exponential backoff. /// Defaults to **8**. pub max_attempts: u32, /// Base wait duration before retry. Defaults to **25ms**. /// It is doubled with each attempt: 25ms, 50, 100, 200... pub base_duration: Duration, /// Jitter factor for retry delay. Factor 0.5 for 100ms delay /// means that wait time will be between 50ms and 150ms. /// The value must be in range 0.0 - 1.0. It will be clamped /// if out of these bounds. Defaults to **0.3** pub jitter_factor: f32, /// Retry on [`ProvisionedThroughputExceededException`]. /// Defaults to **true**. /// /// [`ProvisionedThroughputExceededException`]: aws_sdk_dynamodb::Error::ProvisionedThroughputExceededException pub retry_on_provisioned_capacity_exceeded: bool, } impl Default for ExponentialBackoffConfig { fn default() -> Self { ExponentialBackoffConfig { max_attempts: 8, base_duration: Duration::from_millis(25), jitter_factor: 0.3, retry_on_provisioned_capacity_exceeded: true, } } } impl ExponentialBackoffConfig { fn new_counter(&self) -> ExponentialBackoffHelper { ExponentialBackoffHelper::new(self) } fn backoff_enabled(&self) -> bool { self.max_attempts > 0 } fn should_retry_on_capacity_exceeded(&self) -> bool { self.backoff_enabled() && self.retry_on_provisioned_capacity_exceeded } } #[tracing::instrument(name = "batch_get", skip(ddb, primary_keys, config))] pub async fn batch_get( ddb: &aws_sdk_dynamodb::Client, table_name: &str, primary_keys: K, projection_expression: Option, config: ExponentialBackoffConfig, ) -> Result, super::Error> where K: IntoIterator, K::Item: Into, { let mut primary_keys: Vec<_> = primary_keys.into_iter().map(Into::into).collect(); let mut results = Vec::with_capacity(primary_keys.len()); tracing::debug!( ?config, "Starting batch read operation of {} items...", primary_keys.len() ); let mut exponential_backoff = config.new_counter(); let mut backup = Vec::with_capacity(SINGLE_BATCH_GET_ITEM_LIMIT); loop { let items_to_drain = std::cmp::min(primary_keys.len(), SINGLE_BATCH_GET_ITEM_LIMIT); let chunk = primary_keys.drain(..items_to_drain).collect::>(); if chunk.is_empty() { // No more items tracing::trace!("No more items to process. Exiting"); break; } // we don't need the backup when we don't retry if config.should_retry_on_capacity_exceeded() { chunk.clone_into(&mut backup); } tracing::trace!("Attempting to get chunk of {} items...", chunk.len()); let result = ddb .batch_get_item() .request_items( table_name, KeysAndAttributes::builder() .set_keys(Some(chunk)) .consistent_read(true) .set_projection_expression(projection_expression.clone()) .build(), ) .send() .await; match result { Ok(output) => { if let Some(mut responses) = output.responses { if let Some(items) = responses.remove(table_name) { tracing::trace!("Successfully read {} items", items.len()); results.extend(items); } } else { tracing::warn!("Responses was None"); } if let Some(mut unprocessed) = output.unprocessed_keys { let keys_to_retry = match unprocessed.remove(table_name) { Some(KeysAndAttributes { keys: Some(keys), .. }) if !keys.is_empty() => keys, _ => { tracing::trace!("Chunk read successfully. Continuing."); exponential_backoff.reset(); continue; } }; exponential_backoff.sleep_and_retry().await?; tracing::debug!( "Some items failed. Retrying {} requests", keys_to_retry.len() ); primary_keys.extend(keys_to_retry); } else { tracing::trace!("Unprocessed items was None"); } } Err(error) => { let error: DynamoDBError = error.into(); if !matches!( error, DynamoDBError::ProvisionedThroughputExceededException(_) ) { tracing::error!("BatchGetItem failed: {0:?} - {0}", error); return Err(error.into()); } tracing::warn!("Provisioned capacity exceeded!"); if !config.retry_on_provisioned_capacity_exceeded { return Err(error.into()); } exponential_backoff.sleep_and_retry().await?; primary_keys.append(&mut backup); trace!("Retrying now..."); } }; } debug!("Batch read completed."); Ok(results) } /// Performs a single DynamoDB table batch write operation. If the batch /// contains more than 25 items, it is split into chunks. /// /// The function uses exponential backoff retries when AWS throttles /// the request or maximum provisioned capacity is exceeded #[tracing::instrument(name = "batch_write", skip(ddb, requests, config))] pub async fn batch_write( ddb: &aws_sdk_dynamodb::Client, table_name: &str, mut requests: Vec, config: ExponentialBackoffConfig, ) -> Result<(), super::Error> { tracing::debug!( ?config, "Starting batch write operation of {} items...", requests.len() ); let mut exponential_backoff = config.new_counter(); let mut backup = Vec::with_capacity(SINGLE_BATCH_WRITE_ITEM_LIMIT); loop { let items_to_drain = std::cmp::min(requests.len(), SINGLE_BATCH_WRITE_ITEM_LIMIT); let chunk = requests.drain(..items_to_drain).collect::>(); if chunk.is_empty() { // No more items tracing::trace!("No more items to process. Exiting"); break; } // we don't need the backup when we don't retry if config.should_retry_on_capacity_exceeded() { chunk.clone_into(&mut backup); } tracing::trace!("Attempting to write chunk of {} items...", chunk.len()); let result = ddb .batch_write_item() .request_items(table_name, chunk) .send() .await; match result { Ok(output) => { if let Some(mut items) = output.unprocessed_items { let requests_to_retry = items.remove(table_name).unwrap_or_default(); if requests_to_retry.is_empty() { tracing::trace!("Chunk written successfully. Continuing."); exponential_backoff.reset(); continue; } exponential_backoff.sleep_and_retry().await?; tracing::debug!( "Some items failed. Retrying {} requests", requests_to_retry.len() ); requests.extend(requests_to_retry); } else { tracing::trace!("Unprocessed items was None"); } } Err(error) => { if !is_provisioned_capacity_exceeded(&error) { tracing::error!("BatchWriteItem failed: {0:?} - {0}", error); return Err(super::Error::AwsSdk(error.into())); } tracing::warn!("Provisioned capacity exceeded!"); if !config.retry_on_provisioned_capacity_exceeded { return Err(super::Error::AwsSdk(error.into())); } exponential_backoff.sleep_and_retry().await?; requests.append(&mut backup); trace!("Retrying now..."); } }; } debug!("Batch write completed."); Ok(()) } /// internal helper struct struct ExponentialBackoffHelper<'cfg> { config: &'cfg ExponentialBackoffConfig, attempt: u32, } impl<'cfg> ExponentialBackoffHelper<'cfg> { fn new(config: &'cfg ExponentialBackoffConfig) -> Self { ExponentialBackoffHelper { config, attempt: 0 } } /// reset counter after successfull operation fn reset(&mut self) { self.attempt = 0; } /// increase counter and sleep in case of failure async fn sleep_and_retry(&mut self) -> Result<(), super::Error> { let jitter_factor = 1f32.min(0f32.max(self.config.jitter_factor)); let random_multiplier = 1.0 + rand::thread_rng().gen_range(-jitter_factor..=jitter_factor); let backoff_multiplier = 2u32.pow(self.attempt); let base_duration = self.config.base_duration * backoff_multiplier; let sleep_duration = base_duration.mul_f32(random_multiplier); self.attempt += 1; if self.attempt > self.config.max_attempts { tracing::warn!("Retry limit exceeded!"); return Err(super::Error::MaxRetriesExceeded); } tracing::debug!( attempt = self.attempt, "Batch failed. Sleeping for {}ms before retrying...", sleep_duration.as_millis() ); tokio::time::sleep(sleep_duration).await; Ok(()) } } /// Check if transaction failed due to /// `ProvisionedThroughputExceededException` exception fn is_provisioned_capacity_exceeded( err: &SdkError, ) -> bool { let SdkError::ServiceError(service_error) = err else { return false; }; matches!( service_error.err(), BatchWriteItemError::ProvisionedThroughputExceededException(_) ) } } #[derive(Debug, Clone, Copy, derive_more::Display, derive_more::Error)] pub struct UnknownAttributeTypeError; fn calculate_attr_value_size_in_db( value: &AttributeValue, ) -> Result { const ELEMENT_BYTE_OVERHEAD: usize = 1; const CONTAINER_BYTE_OVERHEAD: usize = 3; /// AWS doesn't provide an exact algorithm for calculating number size in bytes /// in case they change the internal representation. We know that number can use /// between 2 and 21 bytes so we use the maximum value as the byte size. const NUMBER_BYTE_SIZE: usize = 21; let result = match value { AttributeValue::B(blob) => blob.as_ref().len(), AttributeValue::L(list) => { CONTAINER_BYTE_OVERHEAD + list.len() * ELEMENT_BYTE_OVERHEAD + list .iter() .try_fold(0, |a, v| Ok(a + calculate_attr_value_size_in_db(v)?))? } AttributeValue::M(map) => { CONTAINER_BYTE_OVERHEAD + map.len() * ELEMENT_BYTE_OVERHEAD + calculate_size_in_db(map)? } AttributeValue::Bool(_) | AttributeValue::Null(_) => 1, AttributeValue::Bs(set) => set.len(), AttributeValue::N(_) => NUMBER_BYTE_SIZE, AttributeValue::Ns(set) => set.len() * NUMBER_BYTE_SIZE, AttributeValue::S(string) => string.as_bytes().len(), AttributeValue::Ss(set) => { set.iter().map(|string| string.as_bytes().len()).sum() } _ => return Err(UnknownAttributeTypeError), }; Ok(result) } pub fn calculate_size_in_db( value: &AttributeMap, ) -> Result { value.iter().try_fold(0, |a, (attr, value)| { Ok(a + attr.as_bytes().len() + calculate_attr_value_size_in_db(value)?) }) } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_integer() { assert!(parse_integer::("some_attr", "123").is_ok()); assert!(parse_integer::("negative", "-123").is_ok()); assert!(parse_integer::("float", "3.14").is_err()); assert!(parse_integer::("NaN", "foo").is_err()); assert!(parse_integer::("negative_uint", "-123").is_err()); assert!(parse_integer::("too_large", "65536").is_err()); } #[test] fn test_parse_timestamp() { let timestamp = Utc::now().timestamp_millis(); let attr = AttributeValue::N(timestamp.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_ok()); assert_eq!(parsed_timestamp.unwrap().timestamp_millis(), timestamp); } #[test] fn test_parse_invalid_timestamp() { let attr = AttributeValue::N("foo".to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); } #[test] fn test_parse_timestamp_out_of_range() { let attr = AttributeValue::N(i64::MAX.to_string()); let parsed_timestamp = parse_timestamp_attribute("some_attr", Some(attr)); assert!(parsed_timestamp.is_err()); assert!(matches!( parsed_timestamp.unwrap_err().attribute_error, DBItemAttributeError::TimestampOutOfRange )); } #[test] fn test_optional_attribute() { let mut attrs = AttributeMap::from([( "foo".to_string(), AttributeValue::S("bar".to_string()), )]); let foo: Option = attrs.take_attr("foo").expect("failed to parse arg 'foo'"); let bar: Option = attrs.take_attr("bar").expect("failed to parse arg 'bar'"); assert!(foo.is_some()); assert!(bar.is_none()); } }