diff --git a/services/identity/src/client_service.rs b/services/identity/src/client_service.rs --- a/services/identity/src/client_service.rs +++ b/services/identity/src/client_service.rs @@ -28,7 +28,9 @@ VerifyUserAccessTokenResponse, WalletLoginRequest, WalletLoginResponse, }; use crate::config::CONFIG; -use crate::database::{DatabaseClient, Device, KeyPayload}; +use crate::database::{ + DBDeviceTypeInt, DatabaseClient, DeviceType, KeyPayload, +}; use crate::error::Error as DBError; use crate::grpc_utils::DeviceInfoWithAuth; use crate::id::generate_uuid; @@ -84,7 +86,7 @@ pub notif_prekey: String, pub notif_prekey_signature: String, pub notif_one_time_keys: Vec, - pub device_type: Device, + pub device_type: DeviceType, } #[derive(derive_more::Constructor)] @@ -168,7 +170,7 @@ notif_prekey, notif_prekey_signature, notif_one_time_keys: one_time_notif_prekeys, - device_type: Device::try_from(device_type) + device_type: DeviceType::try_from(DBDeviceTypeInt(device_type)) .map_err(handle_db_error)?, }, }; @@ -271,7 +273,7 @@ notif_prekey, notif_prekey_signature, notif_one_time_keys: one_time_notif_prekeys, - device_type: Device::try_from(device_type) + device_type: DeviceType::try_from(DBDeviceTypeInt(device_type)) .map_err(handle_db_error)?, }, }; @@ -508,7 +510,7 @@ notif_prekey, notif_prekey_signature, notif_one_time_keys: one_time_notif_prekeys, - device_type: Device::try_from(device_type) + device_type: DeviceType::try_from(DBDeviceTypeInt(device_type)) .map_err(handle_db_error)?, }, }; @@ -650,7 +652,7 @@ notif_prekey, notif_prekey_signature, notif_one_time_keys: one_time_notif_prekeys, - device_type: Device::try_from(device_type) + device_type: DeviceType::try_from(DBDeviceTypeInt(device_type)) .map_err(handle_db_error)?, }, social_proof, diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -1,6 +1,5 @@ use constant_time_eq::constant_time_eq; use std::collections::{HashMap, HashSet}; -use std::fmt::{Display, Formatter, Result as FmtResult}; use std::str::FromStr; use std::sync::Arc; @@ -48,6 +47,7 @@ use crate::id::generate_uuid; use crate::nonce::NonceData; use crate::token::{AccessTokenData, AuthType}; +pub use grpc_clients::identity::DeviceType; #[derive(Serialize, Deserialize)] pub struct OlmKeys { @@ -72,49 +72,22 @@ } } -#[derive(Clone, Copy)] -#[allow(non_camel_case_types)] -pub enum Device { - // Numeric values should match the protobuf definition - Keyserver = 0, - Web, - Ios, - Android, - Windows, - MacOS, -} +pub struct DBDeviceTypeInt(pub i32); -impl TryFrom for Device { +impl TryFrom for DeviceType { type Error = crate::error::Error; - fn try_from(value: i32) -> Result { - match value { - 0 => Ok(Device::Keyserver), - 1 => Ok(Device::Web), - 2 => Ok(Device::Ios), - 3 => Ok(Device::Android), - 4 => Ok(Device::Windows), - 5 => Ok(Device::MacOS), - _ => Err(Error::Attribute(DBItemError { + fn try_from(value: DBDeviceTypeInt) -> Result { + let device_result = DeviceType::try_from(value.0); + + device_result.map_err(|_| { + Error::Attribute(DBItemError { attribute_name: USERS_TABLE_DEVICES_MAP_DEVICE_TYPE_ATTRIBUTE_NAME .to_string(), - attribute_value: Some(AttributeValue::N(value.to_string())), + attribute_value: Some(AttributeValue::N(value.0.to_string())), attribute_error: DBItemAttributeError::InvalidValue, - })), - } - } -} - -impl Display for Device { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - match self { - Device::Keyserver => write!(f, "keyserver"), - Device::Web => write!(f, "web"), - Device::Ios => write!(f, "ios"), - Device::Android => write!(f, "android"), - Device::Windows => write!(f, "windows"), - Device::MacOS => write!(f, "macos"), - } + }) + }) } } @@ -1586,4 +1559,14 @@ "DYmV8VdkjwG/VtC8C53morogNJhpTPT/4jzW0/cxzQo" ); } + + #[test] + fn test_int_to_device_type() { + let valid_result = DeviceType::try_from(3); + assert!(valid_result.is_ok()); + assert_eq!(valid_result.unwrap(), DeviceType::Android); + + let invalid_result = DeviceType::try_from(6); + assert!(invalid_result.is_err()); + } }