diff --git a/services/identity/src/database/device_list.rs b/services/identity/src/database/device_list.rs --- a/services/identity/src/database/device_list.rs +++ b/services/identity/src/database/device_list.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use chrono::{DateTime, Utc}; use comm_lib::{ @@ -845,6 +845,9 @@ &self, user_id: &str, update: DeviceListUpdate, + // A function that receives previous and new device IDs and + // returns boolean determining if the new device list is valid. + validator_fn: impl Fn(&[&str], &[&str]) -> bool, ) -> Result { let DeviceListUpdate { devices: new_list, @@ -852,21 +855,18 @@ } = update; self .transact_update_devicelist(user_id, |current_list, _| { - // TODO: Add proper validation according to the whitepaper - // currently only adding new device is supported (new.len - old.len = 1) - - let new_set: HashSet<_> = new_list.iter().collect(); - let current_set: HashSet<_> = current_list.iter().collect(); - // difference is A - B (only new devices) - let difference: HashSet<_> = new_set.difference(¤t_set).collect(); - if difference.len() != 1 { + let previous_device_ids: Vec<&str> = + current_list.iter().map(AsRef::as_ref).collect(); + let new_device_ids: Vec<&str> = + new_list.iter().map(AsRef::as_ref).collect(); + if !validator_fn(&previous_device_ids, &new_device_ids) { warn!("Received invalid device list update"); return Err(Error::DeviceList( DeviceListError::InvalidDeviceListUpdate, )); } - debug!("Applying device list update. Difference: {:?}", difference); + debug!("Applying device list update"); *current_list = new_list; Ok((None, Some(timestamp))) diff --git a/services/identity/src/device_list.rs b/services/identity/src/device_list.rs --- a/services/identity/src/device_list.rs +++ b/services/identity/src/device_list.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, Duration, Utc}; +use std::collections::HashSet; use crate::{ constants::DEVICE_LIST_TIMESTAMP_VALID_FOR, error::DeviceListError, @@ -49,6 +50,18 @@ Ok(()) } +/// Returns `true` if `new_device_list` contains exactly one more new device +/// compared to `previous_device_list` +pub fn is_device_added( + previous_device_list: &[&str], + new_device_list: &[&str], +) -> bool { + let previous_set: HashSet<_> = previous_device_list.iter().collect(); + let new_set: HashSet<_> = new_device_list.iter().collect(); + + return new_set.difference(&previous_set).count() == 1; +} + #[cfg(test)] mod tests { use super::*; @@ -85,4 +98,15 @@ "No provided timestamp should pass" ); } + + #[test] + fn test_is_device_added_check() { + use std::ops::Not; + + let list1 = vec!["device1"]; + let list2 = vec!["device1", "device2"]; + + assert!(is_device_added(&list1, &list2)); + assert!(is_device_added(&list2, &list1).not()); + } } diff --git a/services/identity/src/grpc_services/authenticated.rs b/services/identity/src/grpc_services/authenticated.rs --- a/services/identity/src/grpc_services/authenticated.rs +++ b/services/identity/src/grpc_services/authenticated.rs @@ -514,6 +514,10 @@ &self, request: tonic::Request, ) -> Result, tonic::Status> { + // TODO: Add proper validation according to the whitepaper + // currently only adding new device is supported (new.len - old.len = 1) + use crate::device_list::is_device_added as validator; + let (user_id, _device_id) = get_user_and_device_id(&request)?; // TODO: when we stop doing "primary device rotation" (migration procedure) // we should verify if this RPC is called by primary device only @@ -522,7 +526,7 @@ let update = DeviceListUpdate::try_from(new_list)?; self .db_client - .apply_devicelist_update(&user_id, update) + .apply_devicelist_update(&user_id, update, validator) .await .map_err(handle_db_error)?;