diff --git a/services/identity/Cargo.toml b/services/identity/Cargo.toml index ca2295184..9045b71cf 100644 --- a/services/identity/Cargo.toml +++ b/services/identity/Cargo.toml @@ -1,49 +1,50 @@ [package] name = "identity" version = "0.1.0" edition.workspace = true license.workspace = true homepage.workspace = true [dependencies] tonic = "0.9.1" prost = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } ed25519-dalek = { workspace = true } clap = { workspace = true, features = ["derive", "env"] } derive_more = { workspace = true } comm-lib = { path = "../../shared/comm-lib", features = [ "aws", "grpc_clients", + "blob-client", ] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } chrono = { workspace = true } rand = "0.8" constant_time_eq = "0.2.2" siwe = { workspace = true } time = { workspace = true } comm-opaque2 = { path = "../../shared/comm-opaque2" } grpc_clients = { path = "../../shared/grpc_clients" } hyper = { workspace = true } hyper-tungstenite = { workspace = true } once_cell = { workspace = true } hex = { workspace = true } tonic-web = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tunnelbroker_messages = { path = "../../shared/tunnelbroker_messages" } identity_search_messages = { path = "../../shared/identity_search_messages" } uuid = { workspace = true, features = ["v4"] } base64 = { workspace = true } regex = { workspace = true } tower-http = { workspace = true, features = ["cors"] } http = { workspace = true } reqwest = { workspace = true, features = ["json", "rustls-tls"] } futures = { workspace = true } url = { workspace = true } tower = { workspace = true } [build-dependencies] tonic-build = "0.9.1" diff --git a/services/identity/src/comm_service/blob.rs b/services/identity/src/comm_service/blob.rs new file mode 100644 index 000000000..d910cd398 --- /dev/null +++ b/services/identity/src/comm_service/blob.rs @@ -0,0 +1,58 @@ +use comm_lib::{ + blob::{ + client::BlobServiceClient, + types::http::{RemoveHoldersRequest, RemoveHoldersResponse}, + }, + database::batch_operations::ExponentialBackoffConfig, + tools::base64_to_base64url, +}; +use tracing::{debug, warn}; + +#[tracing::instrument(skip_all)] +pub async fn remove_holders_for_devices( + blob_client: &BlobServiceClient, + device_ids: &[String], +) -> Result<(), crate::error::Error> { + if device_ids.is_empty() { + debug!("No holders to remove."); + return Ok(()); + } + + debug!( + "Attempting to remove holders for {} devices.", + device_ids.len() + ); + + let retry_config = ExponentialBackoffConfig::default(); + let mut retry_counter = retry_config.new_counter(); + + // holders are prefixed with deviceID in base64url format + // to escape forbidden characters + let holder_prefixes: Vec = device_ids + .iter() + .map(|device_id| base64_to_base64url(device_id)) + .collect(); + + let mut request = RemoveHoldersRequest::ByIndexedTags { + tags: holder_prefixes, + }; + loop { + request = match blob_client.remove_multiple_holders(request.clone()).await { + Ok(response) if response.failed_requests.is_empty() => break, + Ok(RemoveHoldersResponse { failed_requests }) => { + warn!( + "Remaining {} holders not removed. Retrying...", + failed_requests.len() + ); + RemoveHoldersRequest::from(failed_requests) + } + Err(err) => { + warn!(?err, "Removing holders failed due to error. Retrying..."); + request + } + }; + retry_counter.sleep_and_retry().await?; + } + debug!("Removed all holders"); + Ok(()) +} diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs index e5633fcce..9e18db6b0 100644 --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -1,66 +1,68 @@ use comm_lib::aws::DynamoDBError; use comm_lib::database::DBItemError; use tracing::error; #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum Error { #[display(...)] AwsSdk(DynamoDBError), #[display(...)] Attribute(DBItemError), #[display(...)] Transport(tonic::transport::Error), #[display(...)] Status(tonic::Status), #[display(...)] MissingItem, #[display(...)] DeviceList(DeviceListError), #[display(...)] MalformedItem, #[display(...)] Serde(serde_json::Error), #[display(...)] Reqwest(reqwest::Error), #[display(...)] + BlobService(comm_lib::blob::client::BlobServiceError), + #[display(...)] CannotOverwrite, #[display(...)] OneTimeKeyUploadLimitExceeded, #[display(...)] MaxRetriesExceeded, #[display(...)] IllegalState, #[display(...)] InvalidFormat, } #[derive(Debug, derive_more::Display, derive_more::Error)] pub enum DeviceListError { DeviceAlreadyExists, DeviceNotFound, ConcurrentUpdateError, InvalidDeviceListUpdate, InvalidSignature, } pub fn consume_error(result: Result) { match result { Ok(_) => (), Err(e) => { error!("{}", e); } } } impl From for Error { fn from(value: comm_lib::database::Error) -> Self { use comm_lib::database::Error as E; match value { E::AwsSdk(err) => Self::AwsSdk(err), E::Attribute(err) => Self::Attribute(err), E::MaxRetriesExceeded => Self::MaxRetriesExceeded, } } } diff --git a/services/identity/src/main.rs b/services/identity/src/main.rs index a0df620f1..b6aa3ba54 100644 --- a/services/identity/src/main.rs +++ b/services/identity/src/main.rs @@ -1,144 +1,145 @@ use comm_lib::auth::AuthService; use comm_lib::aws; use comm_lib::aws::config::timeout::TimeoutConfig; use comm_lib::aws::config::BehaviorVersion; use config::Command; use database::DatabaseClient; use tonic::transport::Server; use tonic_web::GrpcWebLayer; mod client_service; mod config; pub mod constants; mod cors; mod database; pub mod ddb_utils; mod device_list; pub mod error; mod grpc_services; mod grpc_utils; mod http; mod id; mod keygen; mod log; mod nonce; mod olm; mod regex; mod reserved_users; mod siwe; mod sync_identity_search; mod token; mod websockets; mod comm_service { pub mod backup; + pub mod blob; pub mod tunnelbroker; } use constants::{COMM_SERVICES_USE_JSON_LOGS, IDENTITY_SERVICE_SOCKET_ADDR}; use cors::cors_layer; use keygen::generate_and_persist_keypair; use std::env; use sync_identity_search::sync_index; use tokio::time::Duration; use tracing::{self, info, Level}; use tracing_subscriber::EnvFilter; use client_service::{ClientService, IdentityClientServiceServer}; use grpc_services::authenticated::AuthenticatedService; use grpc_services::protos::auth::identity_client_service_server::IdentityClientServiceServer as AuthServer; use websockets::errors::BoxedError; #[tokio::main] async fn main() -> Result<(), BoxedError> { let filter = EnvFilter::builder() .with_default_directive(Level::INFO.into()) .with_env_var(EnvFilter::DEFAULT_ENV) .from_env_lossy(); let use_json_logs: bool = env::var(COMM_SERVICES_USE_JSON_LOGS) .unwrap_or("false".to_string()) .parse() .unwrap_or_default(); if use_json_logs { let subscriber = tracing_subscriber::fmt() .json() .with_env_filter(filter) .finish(); tracing::subscriber::set_global_default(subscriber)?; } else { let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish(); tracing::subscriber::set_global_default(subscriber)?; } match config::parse_cli_command() { Command::Keygen { dir } => { generate_and_persist_keypair(dir)?; } Command::Server => { config::load_server_config(); let addr = IDENTITY_SERVICE_SOCKET_ADDR.parse()?; let aws_config = aws::config::defaults(BehaviorVersion::v2024_03_28()) .timeout_config( TimeoutConfig::builder() .connect_timeout(Duration::from_secs(60)) .build(), ) .region("us-east-2") .load() .await; let comm_auth_service = AuthService::new(&aws_config, "http://localhost:50054".to_string()); let database_client = DatabaseClient::new(&aws_config); let inner_client_service = ClientService::new(database_client.clone()); let client_service = IdentityClientServiceServer::with_interceptor( inner_client_service, grpc_services::shared::version_interceptor, ); let inner_auth_service = AuthenticatedService::new(database_client.clone(), comm_auth_service); let db_client = database_client.clone(); let auth_service = AuthServer::with_interceptor(inner_auth_service, move |req| { grpc_services::authenticated::auth_interceptor(req, &db_client) .and_then(grpc_services::shared::version_interceptor) }); info!("Listening to gRPC traffic on {}", addr); let grpc_server = Server::builder() .accept_http1(true) .layer(cors_layer()) .layer(GrpcWebLayer::new()) .trace_fn(|_| { tracing::info_span!( "grpc_request", request_id = uuid::Uuid::new_v4().to_string() ) }) .add_service(client_service) .add_service(auth_service) .serve(addr); let websocket_server = websockets::run_server(database_client); return tokio::select! { websocket_result = websocket_server => websocket_result, grpc_result = grpc_server => { grpc_result.map_err(|e| e.into()) }, }; } Command::SyncIdentitySearch => { let aws_config = aws::config::defaults(BehaviorVersion::v2024_03_28()) .region("us-east-2") .load() .await; let database_client = DatabaseClient::new(&aws_config); let sync_result = sync_index(&database_client).await; error::consume_error(sync_result); } } Ok(()) } diff --git a/shared/comm-lib/src/blob/types.rs b/shared/comm-lib/src/blob/types.rs index 692a4ba35..cbb59994f 100644 --- a/shared/comm-lib/src/blob/types.rs +++ b/shared/comm-lib/src/blob/types.rs @@ -1,227 +1,227 @@ use derive_more::Constructor; use hex::ToHex; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; /// This module defines structures for HTTP requests and responses /// for the Blob Service. The definitions in this file should remain in sync /// with the types and validators defined in the corresponding /// JavaScript file at `lib/types/blob-service-types.js`. /// /// If you edit the definitions in one file, /// please make sure to update the corresponding definitions in the other. pub mod http { use serde::{Deserialize, Serialize}; pub use super::BlobInfo; // Assign multiple holders #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct AssignHoldersRequest { pub requests: Vec, } #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct HolderAssignmentResult { #[serde(flatten)] pub request: BlobInfo, pub success: bool, pub data_exists: bool, pub holder_already_exists: bool, } #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct AssignHoldersResponse { pub results: Vec, } // Remove multiple holders - #[derive(Serialize, Deserialize, Debug)] + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(untagged)] pub enum RemoveHoldersRequest { // remove holders with given (hash, holder) pairs #[serde(rename_all = "camelCase")] Items { requests: Vec, /// If true, the blobs will be deleted instantly /// after their last holders are revoked. #[serde(default)] instant_delete: bool, }, // remove all holders that are indexed by any of given tags ByIndexedTags { tags: Vec, }, } #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct RemoveHoldersResponse { pub failed_requests: Vec, } // Single holder endpoint types #[derive(Serialize, Deserialize, Debug)] pub struct AssignHolderRequest { pub blob_hash: String, pub holder: String, } #[derive(Serialize, Deserialize, Debug)] pub struct AssignHolderResponse { pub data_exists: bool, } #[derive(Serialize, Deserialize, Debug)] pub struct RemoveHolderRequest { pub blob_hash: String, pub holder: String, /// If true, the blob will be deleted instantly /// after the last holder is revoked. #[serde(default)] pub instant_delete: bool, } // impls impl From> for RemoveHoldersRequest { fn from(requests: Vec) -> Self { Self::Items { requests, instant_delete: false, } } } } /// Blob owning information - stores both blob_hash and holder #[derive(Clone, Debug, Constructor, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct BlobInfo { pub blob_hash: String, pub holder: String, } impl BlobInfo { pub fn from_bytes(data: &[u8]) -> Self { Self { blob_hash: Sha256::digest(data).encode_hex(), holder: uuid::Uuid::new_v4().to_string(), } } } #[cfg(feature = "aws")] mod db_conversions { use super::*; use crate::database::{AttributeTryInto, DBItemError, TryFromAttribute}; use aws_sdk_dynamodb::types::AttributeValue; use std::collections::HashMap; const BLOB_HASH_DDB_MAP_KEY: &str = "blob_hash"; const HOLDER_DDB_MAP_KEY: &str = "holder"; impl From for AttributeValue { fn from(value: BlobInfo) -> Self { let map = HashMap::from([ ( BLOB_HASH_DDB_MAP_KEY.to_string(), AttributeValue::S(value.blob_hash), ), ( HOLDER_DDB_MAP_KEY.to_string(), AttributeValue::S(value.holder), ), ]); AttributeValue::M(map) } } impl From<&BlobInfo> for AttributeValue { fn from(value: &BlobInfo) -> Self { AttributeValue::from(value.to_owned()) } } impl TryFromAttribute for BlobInfo { fn try_from_attr( attribute_name: impl Into, attribute: Option, ) -> Result { let attr_name: String = attribute_name.into(); let mut inner_map: HashMap = attribute.attr_try_into(&attr_name)?; let blob_hash = inner_map .remove("blob_hash") .attr_try_into(format!("{attr_name}.blob_hash"))?; let holder = inner_map .remove("holder") .attr_try_into(format!("{attr_name}.holder"))?; Ok(BlobInfo { blob_hash, holder }) } } } #[cfg(test)] mod serialization_tests { use super::http::*; mod remove_holders_request { use super::*; #[test] fn serialize_items() { let req = RemoveHoldersRequest::Items { requests: vec![BlobInfo::new("a".into(), "b".into())], instant_delete: false, }; let expected = r#"{"requests":[{"blobHash":"a","holder":"b"}],"instantDelete":false}"#; assert_eq!(expected, serde_json::to_string(&req).unwrap()); } #[test] fn deserialize_items() { let json = r#"{"requests":[{"blobHash":"a","holder":"b"}],"instantDelete":false}"#; let deserialized: RemoveHoldersRequest = serde_json::from_str(json).expect("Request JSON payload invalid"); let expected_items = vec![BlobInfo::new("a".into(), "b".into())]; let is_matching = matches!( deserialized, RemoveHoldersRequest::Items { requests: items, instant_delete: false, } if items == expected_items ); assert!(is_matching, "Deserialized request is incorrect"); } #[test] fn serialize_tags() { let req = RemoveHoldersRequest::ByIndexedTags { tags: vec!["foo".into(), "bar".into()], }; let expected = r#"{"tags":["foo","bar"]}"#; assert_eq!(expected, serde_json::to_string(&req).unwrap()); } #[test] fn deserialize_tags() { let json = r#"{"tags":["foo","bar"]}"#; let deserialized: RemoveHoldersRequest = serde_json::from_str(json).expect("Request JSON payload invalid"); let expected_tags: Vec = vec!["foo".into(), "bar".into()]; let is_matching = matches!( deserialized, RemoveHoldersRequest::ByIndexedTags { tags: actual_tags } if actual_tags == expected_tags ); assert!(is_matching, "Deserialized request is incorrect"); } } } diff --git a/shared/comm-lib/src/tools.rs b/shared/comm-lib/src/tools.rs index dc493ab77..2945402a8 100644 --- a/shared/comm-lib/src/tools.rs +++ b/shared/comm-lib/src/tools.rs @@ -1,279 +1,285 @@ use rand::{distributions::DistString, CryptoRng, Rng}; // colon is valid because it is used as a separator // in some backup service identifiers const VALID_IDENTIFIER_CHARS: &[char] = &['_', '-', '=', ':']; /// Checks if the given string is a valid identifier for an entity /// (e.g. backup ID, blob hash, blob holder). /// /// Some popular identifier formats are considered valid, including UUID, /// nanoid, base64url. On the other hand, path or url-like identifiers /// are not supposed to be valid pub fn is_valid_identifier(identifier: &str) -> bool { if identifier.is_empty() { return false; } identifier .chars() .all(|c| c.is_ascii_alphanumeric() || VALID_IDENTIFIER_CHARS.contains(&c)) } +/// Converts base64 string to base64url format. See RFC 4648 ยง 5 for details. +#[inline] +pub fn base64_to_base64url(base64_string: &str) -> String { + base64_string.replace('/', "_").replace('+', "-") +} + pub type BoxedError = Box; /// Defers call of the provided function to when [Defer] goes out of scope. /// This can be used for cleanup code that must be run when e.g. the enclosing /// function exits either by return or try operator `?`. /// /// # Example /// ```ignore /// fn f(){ /// let _ = Defer::new(|| println!("cleanup")) /// /// // Cleanup will run if function would exit here /// operation_that_can_fail()?; /// /// if should_exit_early { /// // Cleanup will run if function would exit here /// return; /// } /// } /// ``` pub struct Defer<'s>(Option>); impl<'s> Defer<'s> { pub fn new(f: impl FnOnce() + 's) -> Self { Self(Some(Box::new(f))) } /// Consumes the value, without calling the provided function /// /// # Example /// ```ignore /// // Start a "transaction" /// operation_that_should_be_reverted(); /// let revert = Defer::new(|| println!("revert")) /// operation_that_can_fail()?; /// operation_that_can_fail()?; /// operation_that_can_fail()?; /// // Now we can "commit" the changes /// revert.cancel(); /// ``` pub fn cancel(mut self) { self.0 = None; // Implicit drop } } impl Drop for Defer<'_> { fn drop(&mut self) { if let Some(f) = self.0.take() { f(); } } } pub trait IntoChunks { /// Splits the vec into `num_chunks` chunks and returns an iterator /// over these chunks. The chunks do not overlap. /// /// Chunks size is given by `ceil(vector_length / num_chunks)`. /// If vector length is not divisible by `num_chunks`, /// the last chunk will have less elements. /// /// If you're looking for chunks of given size, use [`chunks`] instead. /// /// # Panics /// /// Panics if `num_chunks` is 0. /// /// # Examples /// /// ``` /// use comm_lib::tools::IntoChunks; /// /// let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; /// let mut iter = items.into_n_chunks(3); /// assert_eq!(&iter.next().unwrap(), &[1, 2, 3, 4]); /// assert_eq!(&iter.next().unwrap(), &[5, 6, 7, 8]); /// assert_eq!(&iter.next().unwrap(), &[9, 10]); /// assert!(iter.next().is_none()); /// ``` /// /// [`chunks`]: slice::chunks fn into_n_chunks(self, num_chunks: usize) -> impl Iterator>; } impl IntoChunks for Vec { fn into_n_chunks(self, num_chunks: usize) -> impl Iterator> { struct ChunksIterator { pub slice: Vec, pub chunk_size: usize, } impl Iterator for ChunksIterator { type Item = Vec; fn next(&mut self) -> Option> { let next_size = std::cmp::min(self.slice.len(), self.chunk_size); if next_size == 0 { None } else { let next_chunk = self.slice.drain(0..next_size).collect(); Some(next_chunk) } } } assert!(num_chunks > 0, "Number of chunks cannot be 0"); let len = self.len(); let rem = len % num_chunks; let chunk_size = len / num_chunks + if rem > 0 { 1 } else { 0 }; ChunksIterator { slice: self, chunk_size, } } } pub fn generate_random_string( length: usize, rng: &mut (impl Rng + CryptoRng), ) -> String { rand::distributions::Alphanumeric.sample_string(rng, length) } #[cfg(test)] mod valid_identifier_tests { use super::*; #[test] fn alphanumeric_identifier() { assert!(is_valid_identifier("some_identifier_v123")); } #[test] fn alphanumeric_with_colon() { assert!(is_valid_identifier("some_identifier:with_colon")); } #[test] fn uuid_is_valid() { let example_uuid = "a2b9e4d4-8d1f-4c7f-9c3d-5f4e4e6b1d1d"; assert!(is_valid_identifier(example_uuid)); } #[test] fn base64url_is_valid() { let example_base64url = "VGhlIP3-aWNrIGJyb3duIGZveCBqciAxMyBsYXp5IGRvZ_7_"; assert!(is_valid_identifier(example_base64url)) } #[test] fn standard_base64_is_invalid() { let example_base64 = "VGhlIP3+aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9n/v8="; assert!(!is_valid_identifier(example_base64)); } #[test] fn path_is_invalid() { assert!(!is_valid_identifier("some/path")); } #[test] fn url_is_invalid() { assert!(!is_valid_identifier("https://example.com")); } #[test] fn empty_is_invalid() { assert!(!is_valid_identifier("")); } } #[cfg(test)] mod defer_tests { use super::*; #[test] fn defer_runs() { fn f(a: &mut bool) { let _ = Defer::new(|| *a = true); } let mut v = false; f(&mut v); assert!(v) } #[test] fn consumed_defer_doesnt_run() { fn f(a: &mut bool) { let defer = Defer::new(|| *a = true); defer.cancel(); } let mut v = false; f(&mut v); assert!(!v) } #[test] fn defer_runs_on_try() { fn f(a: &mut bool) -> Result<(), ()> { let _ = Defer::new(|| *a = true); Err(()) } let mut v = false; let _ = f(&mut v); assert!(v) } } #[cfg(test)] mod vec_utils_tests { use super::*; #[test] fn test_chunks_without_remainder() { let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let mut iter = items.into_n_chunks(3); assert_eq!(&iter.next().unwrap(), &[1, 2, 3, 4]); assert_eq!(&iter.next().unwrap(), &[5, 6, 7, 8]); assert_eq!(&iter.next().unwrap(), &[9, 10, 11, 12]); assert!(iter.next().is_none()); } #[test] fn test_chunks_with_remainder() { let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let mut iter = items.into_n_chunks(3); assert_eq!(&iter.next().unwrap(), &[1, 2, 3, 4]); assert_eq!(&iter.next().unwrap(), &[5, 6, 7, 8]); assert_eq!(&iter.next().unwrap(), &[9, 10]); assert!(iter.next().is_none()); } #[test] fn test_one_chunk() { let items: Vec = vec![1, 2, 3]; let mut iter = items.into_n_chunks(1); assert_eq!(&iter.next().unwrap(), &[1, 2, 3]); assert!(iter.next().is_none()); } #[test] fn test_empty_vec() { let items: Vec = vec![]; let mut iter = items.into_n_chunks(2); assert!(iter.next().is_none()); } #[test] #[should_panic] fn into_n_chunks_panics_with_0_chunks() { let items = vec![1, 2, 3]; let _ = items.into_n_chunks(0); } }