diff --git a/services/identity/src/device_list.rs b/services/identity/src/device_list.rs --- a/services/identity/src/device_list.rs +++ b/services/identity/src/device_list.rs @@ -50,16 +50,91 @@ Ok(()) } -/// Returns `true` if `new_device_list` contains exactly one more new device -/// compared to `previous_device_list` -pub fn is_device_added( - previous_device_list: &[&str], - new_device_list: &[&str], -) -> bool { - let previous_set: HashSet<_> = previous_device_list.iter().collect(); - let new_set: HashSet<_> = new_device_list.iter().collect(); +pub mod validation { + use super::*; + /// Returns `true` if `new_device_list` contains exactly one more new device + /// compared to `previous_device_list` + fn is_device_added( + previous_device_list: &[&str], + new_device_list: &[&str], + ) -> bool { + let previous_set: HashSet<_> = previous_device_list.iter().collect(); + let new_set: HashSet<_> = new_device_list.iter().collect(); + + return new_set.difference(&previous_set).count() == 1; + } + + /// Returns `true` if `new_device_list` contains exactly one fewer device + /// compared to `previous_device_list` + fn is_device_removed( + previous_device_list: &[&str], + new_device_list: &[&str], + ) -> bool { + let previous_set: HashSet<_> = previous_device_list.iter().collect(); + let new_set: HashSet<_> = new_device_list.iter().collect(); + + return previous_set.difference(&new_set).count() == 1; + } + + fn primary_device_changed( + previous_device_list: &[&str], + new_device_list: &[&str], + ) -> bool { + let previous_primary = previous_device_list.first(); + let new_primary = new_device_list.first(); + + new_primary != previous_primary + } + + /// The `UpdateDeviceList` RPC should be able to either add or remove + /// one device, and it cannot currently switch primary devices + pub fn update_device_list_rpc_validator( + previous_device_list: &[&str], + new_device_list: &[&str], + ) -> bool { + if primary_device_changed(previous_device_list, new_device_list) { + return false; + } + + let is_added = is_device_added(previous_device_list, new_device_list); + let is_removed = is_device_removed(previous_device_list, new_device_list); + + is_added != is_removed + } - return new_set.difference(&previous_set).count() == 1; + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_device_added_or_removed() { + use std::ops::Not; + + let list1 = vec!["device1"]; + let list2 = vec!["device1", "device2"]; + + assert!(is_device_added(&list1, &list2)); + assert!(is_device_removed(&list1, &list2).not()); + + assert!(is_device_added(&list2, &list1).not()); + assert!(is_device_removed(&list2, &list1)); + + assert!(is_device_added(&list1, &list1).not()); + assert!(is_device_removed(&list1, &list1).not()); + } + + #[test] + fn test_primary_device_changed() { + use std::ops::Not; + + let list1 = vec!["device1"]; + let list2 = vec!["device1", "device2"]; + let list3 = vec!["device2"]; + + assert!(primary_device_changed(&list1, &list2).not()); + assert!(primary_device_changed(&list1, &list3)); + } + } } #[cfg(test)] @@ -98,15 +173,4 @@ "No provided timestamp should pass" ); } - - #[test] - fn test_is_device_added_check() { - use std::ops::Not; - - let list1 = vec!["device1"]; - let list2 = vec!["device1", "device2"]; - - assert!(is_device_added(&list1, &list2)); - assert!(is_device_added(&list2, &list1).not()); - } } diff --git a/services/identity/src/grpc_services/authenticated.rs b/services/identity/src/grpc_services/authenticated.rs --- a/services/identity/src/grpc_services/authenticated.rs +++ b/services/identity/src/grpc_services/authenticated.rs @@ -508,10 +508,6 @@ &self, request: tonic::Request, ) -> Result, tonic::Status> { - // TODO: Add proper validation according to the whitepaper - // currently only adding new device is supported (new.len - old.len = 1) - use crate::device_list::is_device_added as validator; - let (user_id, _device_id) = get_user_and_device_id(&request)?; // TODO: when we stop doing "primary device rotation" (migration procedure) // we should verify if this RPC is called by primary device only @@ -520,7 +516,11 @@ let update = DeviceListUpdate::try_from(new_list)?; self .db_client - .apply_devicelist_update(&user_id, update, validator) + .apply_devicelist_update( + &user_id, + update, + crate::device_list::validation::update_device_list_rpc_validator, + ) .await .map_err(handle_db_error)?;