diff --git a/services/identity/src/config.rs b/services/identity/src/config.rs --- a/services/identity/src/config.rs +++ b/services/identity/src/config.rs @@ -1,14 +1,17 @@ +use std::{env, fmt, fs, io, path}; + use base64::{engine::general_purpose, DecodeError, Engine as _}; use clap::{Parser, Subcommand}; +use http::HeaderValue; use once_cell::sync::Lazy; -use std::{env, fmt, fs, io, path}; +use tower_http::cors::AllowOrigin; use tracing::{error, info}; use crate::constants::{ - DEFAULT_OPENSEARCH_ENDPOINT, DEFAULT_TUNNELBROKER_ENDPOINT, - KEYSERVER_PUBLIC_KEY, LOCALSTACK_ENDPOINT, OPAQUE_SERVER_SETUP, - OPENSEARCH_ENDPOINT, SECRETS_DIRECTORY, SECRETS_SETUP_FILE, - TUNNELBROKER_GRPC_ENDPOINT, + cors::ALLOW_ORIGIN_LIST, DEFAULT_OPENSEARCH_ENDPOINT, + DEFAULT_TUNNELBROKER_ENDPOINT, KEYSERVER_PUBLIC_KEY, LOCALSTACK_ENDPOINT, + OPAQUE_SERVER_SETUP, OPENSEARCH_ENDPOINT, SECRETS_DIRECTORY, + SECRETS_SETUP_FILE, TUNNELBROKER_GRPC_ENDPOINT, }; /// Raw CLI arguments, should be only used internally to create ServerConfig @@ -49,6 +52,11 @@ #[arg(env = OPENSEARCH_ENDPOINT)] #[arg(default_value = DEFAULT_OPENSEARCH_ENDPOINT)] opensearch_endpoint: String, + + /// Allowed origins + #[arg(long, global = true)] + #[arg(env = ALLOW_ORIGIN_LIST)] + allow_origin_list: Option, } #[derive(Subcommand)] @@ -73,6 +81,7 @@ pub keyserver_public_key: Option, pub tunnelbroker_endpoint: String, pub opensearch_endpoint: String, + pub allow_origin: Option, } impl ServerConfig { @@ -85,36 +94,41 @@ if let Some(endpoint) = &cli.localstack_endpoint { info!("Using Localstack endpoint: {}", endpoint); } - info!("Using OpenSearch endpoint: {}", cli.opensearch_endpoint); let mut path_buf = path::PathBuf::new(); path_buf.push(SECRETS_DIRECTORY); path_buf.push(SECRETS_SETUP_FILE); - let server_setup = get_server_setup(path_buf.as_path())?; + let keyserver_public_key = env::var(KEYSERVER_PUBLIC_KEY).ok(); + let allow_origin = cli + .allow_origin_list + .clone() + .map(|s| slice_to_allow_origin(&s)) + .transpose()?; + Ok(Self { localstack_endpoint: cli.localstack_endpoint.clone(), tunnelbroker_endpoint: cli.tunnelbroker_endpoint.clone(), opensearch_endpoint: cli.opensearch_endpoint.clone(), server_setup, keyserver_public_key, + allow_origin, }) } - - pub fn is_dev(&self) -> bool { - self.localstack_endpoint.is_some() - } } impl fmt::Debug for ServerConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ServerConfig") - .field("server_keypair", &"** redacted **") - .field("keyserver_auth_token", &"** redacted **") .field("localstack_endpoint", &self.localstack_endpoint) + .field("server_setup", &"** redacted **") + .field("keyserver_public_key", &self.keyserver_public_key) + .field("tunnelbroker_endpoint", &self.tunnelbroker_endpoint) + .field("opensearch_endpoint", &self.opensearch_endpoint) + .field("allow_origin_list", &"** redacted **") .finish() } } @@ -131,6 +145,8 @@ Json(serde_json::Error), #[display(...)] Decode(DecodeError), + #[display(...)] + InvalidHeaderValue(http::header::InvalidHeaderValue), } fn get_server_setup( @@ -160,3 +176,12 @@ comm_opaque2::ServerSetup::deserialize(&decoded_server_setup) .map_err(Error::Opaque) } + +fn slice_to_allow_origin(origins: &str) -> Result { + let allow_origin_result: Result, Error> = origins + .split(',') + .map(|s| HeaderValue::from_str(s.trim()).map_err(Error::InvalidHeaderValue)) + .collect(); + let allow_origin_list = allow_origin_result?; + Ok(AllowOrigin::list(allow_origin_list)) +} diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -216,6 +216,5 @@ super::request_metadata::DEVICE_ID, super::request_metadata::ACCESS_TOKEN, ]; - pub const DEFAULT_ALLOW_ORIGIN: [&str; 2] = - ["https://web.comm.app", "http://localhost:3000"]; + pub const ALLOW_ORIGIN_LIST: &str = "ALLOW_ORIGIN_LIST"; } diff --git a/services/identity/src/cors.rs b/services/identity/src/cors.rs --- a/services/identity/src/cors.rs +++ b/services/identity/src/cors.rs @@ -1,19 +1,14 @@ -use http::{HeaderName, HeaderValue}; +use http::HeaderName; use tower_http::cors::{AllowOrigin, CorsLayer}; use crate::{config::CONFIG, constants::cors}; pub fn cors_layer() -> CorsLayer { - let allow_origin = if CONFIG.is_dev() { - AllowOrigin::mirror_request() - } else { - AllowOrigin::list( - cors::DEFAULT_ALLOW_ORIGIN - .iter() - .cloned() - .map(HeaderValue::from_static), - ) - }; + let allow_origin = CONFIG + .allow_origin + .clone() + .unwrap_or_else(AllowOrigin::mirror_request); + CorsLayer::new() .allow_origin(allow_origin) .allow_credentials(true)