diff --git a/services/identity/Cargo.lock b/services/identity/Cargo.lock --- a/services/identity/Cargo.lock +++ b/services/identity/Cargo.lock @@ -1490,6 +1490,7 @@ "ed25519-dalek", "grpc_clients", "hex", + "http", "moka", "once_cell", "prost", @@ -1502,6 +1503,7 @@ "tonic", "tonic-build", "tonic-web", + "tower-http", "tracing", "tracing-subscriber", "tunnelbroker_messages", diff --git a/services/identity/Cargo.toml b/services/identity/Cargo.toml --- a/services/identity/Cargo.toml +++ b/services/identity/Cargo.toml @@ -31,6 +31,8 @@ uuid = { version = "1.3", features = [ "v4" ] } base64 = "0.21.2" regex = "1" +tower-http = { version = "0.4", features = ["cors"] } +http = "0.2" [build-dependencies] tonic-build = "0.9.1" diff --git a/services/identity/src/config.rs b/services/identity/src/config.rs --- a/services/identity/src/config.rs +++ b/services/identity/src/config.rs @@ -67,6 +67,10 @@ tunnelbroker_endpoint, }) } + + pub fn is_dev(&self) -> bool { + self.localstack_endpoint.is_some() + } } impl fmt::Debug for Config { diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -145,3 +145,36 @@ // Minimum supported code versions pub const MIN_SUPPORTED_NATIVE_VERSION: u64 = 270; + +// Request metadata + +pub mod request_metadata { + pub const CODE_VERSION: &str = "code_version"; + pub const DEVICE_TYPE: &str = "device_type"; + pub const USER_ID: &str = "user_id"; + pub const DEVICE_ID: &str = "device_id"; + pub const ACCESS_TOKEN: &str = "access_token"; +} + +// CORS + +pub mod cors { + use std::time::Duration; + + pub const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); + pub const DEFAULT_EXPOSED_HEADERS: [&str; 3] = + ["grpc-status", "grpc-message", "grpc-status-details-bin"]; + pub const DEFAULT_ALLOW_HEADERS: [&str; 9] = [ + "x-grpc-web", + "content-type", + "x-user-agent", + "grpc-timeout", + super::request_metadata::CODE_VERSION, + super::request_metadata::DEVICE_TYPE, + super::request_metadata::USER_ID, + super::request_metadata::DEVICE_ID, + super::request_metadata::ACCESS_TOKEN, + ]; + pub const DEFAULT_ALLOW_ORIGIN: [&str; 2] = + ["https://web.comm.app", "http://localhost:3000"]; +} diff --git a/services/identity/src/cors.rs b/services/identity/src/cors.rs new file mode 100644 --- /dev/null +++ b/services/identity/src/cors.rs @@ -0,0 +1,35 @@ +use http::{HeaderName, HeaderValue}; +use tower_http::cors::{AllowOrigin, CorsLayer}; + +use crate::{config::CONFIG, constants::cors}; + +pub fn cors_layer() -> CorsLayer { + let allow_origin = if CONFIG.is_dev() { + AllowOrigin::mirror_request() + } else { + AllowOrigin::list( + cors::DEFAULT_ALLOW_ORIGIN + .iter() + .cloned() + .map(HeaderValue::from_static), + ) + }; + CorsLayer::new() + .allow_origin(allow_origin) + .allow_credentials(true) + .max_age(cors::DEFAULT_MAX_AGE) + .expose_headers( + cors::DEFAULT_EXPOSED_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ) + .allow_headers( + cors::DEFAULT_ALLOW_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ) +} diff --git a/services/identity/src/grpc_services/authenticated.rs b/services/identity/src/grpc_services/authenticated.rs --- a/services/identity/src/grpc_services/authenticated.rs +++ b/services/identity/src/grpc_services/authenticated.rs @@ -1,6 +1,6 @@ use crate::{ - client_service::handle_db_error, database::DatabaseClient, - grpc_services::shared::get_value, + client_service::handle_db_error, constants::request_metadata, + database::DatabaseClient, grpc_services::shared::get_value, }; use tonic::{Request, Response, Status}; @@ -29,9 +29,9 @@ fn get_auth_info(req: &Request<()>) -> Option<(String, String, String)> { debug!("Retrieving auth info for request: {:?}", req); - let user_id = get_value(req, "user_id")?; - let device_id = get_value(req, "device_id")?; - let access_token = get_value(req, "access_token")?; + let user_id = get_value(req, request_metadata::USER_ID)?; + let device_id = get_value(req, request_metadata::DEVICE_ID)?; + let access_token = get_value(req, request_metadata::ACCESS_TOKEN)?; Some((user_id, device_id, access_token)) } @@ -70,9 +70,9 @@ pub fn get_user_and_device_id( request: &Request, ) -> Result<(String, String), Status> { - let user_id = get_value(request, "user_id") + let user_id = get_value(request, request_metadata::USER_ID) .ok_or_else(|| Status::unauthenticated("Missing user_id field"))?; - let device_id = get_value(request, "device_id") + let device_id = get_value(request, request_metadata::DEVICE_ID) .ok_or_else(|| Status::unauthenticated("Missing device_id field"))?; Ok((user_id, device_id)) diff --git a/services/identity/src/grpc_services/shared.rs b/services/identity/src/grpc_services/shared.rs --- a/services/identity/src/grpc_services/shared.rs +++ b/services/identity/src/grpc_services/shared.rs @@ -2,7 +2,7 @@ use tonic::{Request, Status}; use tracing::debug; -use crate::constants::MIN_SUPPORTED_NATIVE_VERSION; +use crate::constants::{request_metadata, MIN_SUPPORTED_NATIVE_VERSION}; pub fn version_interceptor(req: Request<()>) -> Result, Status> { debug!("Intercepting request to check version: {:?}", req); @@ -21,8 +21,10 @@ fn get_version_info(req: &Request<()>) -> Option<(u64, String)> { debug!("Retrieving version info for request: {:?}", req); - let code_version: u64 = get_value(req, "code_version")?.parse().ok()?; - let device_type = get_value(req, "device_type")?; + let code_version: u64 = get_value(req, request_metadata::CODE_VERSION)? + .parse() + .ok()?; + let device_type = get_value(req, request_metadata::DEVICE_TYPE)?; Some((code_version, device_type)) } diff --git a/services/identity/src/main.rs b/services/identity/src/main.rs --- a/services/identity/src/main.rs +++ b/services/identity/src/main.rs @@ -4,10 +4,12 @@ use database::DatabaseClient; use moka::future::Cache; use tonic::transport::Server; +use tonic_web::GrpcWebLayer; mod client_service; mod config; pub mod constants; +mod cors; mod database; pub mod ddb_utils; pub mod error; @@ -23,6 +25,7 @@ use config::load_config; use constants::{IDENTITY_SERVICE_SOCKET_ADDR, SECRETS_DIRECTORY}; +use cors::cors_layer; use keygen::generate_and_persist_keypair; use tracing::{self, info, Level}; use tracing_subscriber::EnvFilter; @@ -93,7 +96,9 @@ info!("Listening to gRPC traffic on {}", addr); Server::builder() .accept_http1(true) - .add_service(tonic_web::enable(client_service)) + .layer(cors_layer()) + .layer(GrpcWebLayer::new()) + .add_service(client_service) .add_service(auth_service) .serve(addr) .await?;