diff --git a/services/identity/Cargo.lock b/services/identity/Cargo.lock --- a/services/identity/Cargo.lock +++ b/services/identity/Cargo.lock @@ -1697,6 +1697,7 @@ "tracing", "tracing-subscriber", "tunnelbroker_messages", + "url", "uuid", ] diff --git a/services/identity/Cargo.toml b/services/identity/Cargo.toml --- a/services/identity/Cargo.toml +++ b/services/identity/Cargo.toml @@ -41,6 +41,7 @@ http = "0.2" reqwest = { version = "0.11", features = ["json"] } futures = "0.3.30" +url = "2.5" [build-dependencies] tonic-build = "0.9.1" diff --git a/services/identity/src/config.rs b/services/identity/src/config.rs --- a/services/identity/src/config.rs +++ b/services/identity/src/config.rs @@ -6,6 +6,7 @@ use once_cell::sync::Lazy; use tower_http::cors::AllowOrigin; use tracing::{error, info}; +use url::Url; use crate::constants::{ cors::ALLOW_ORIGIN_LIST, DEFAULT_OPENSEARCH_ENDPOINT, @@ -147,6 +148,16 @@ Decode(DecodeError), #[display(...)] InvalidHeaderValue(http::header::InvalidHeaderValue), + #[display(...)] + InvalidOrigin(InvalidOriginError), +} + +#[derive(Debug, derive_more::Display)] +pub enum InvalidOriginError { + InvalidScheme, + MissingHost, + MissingPort, + ParseError, } fn get_server_setup( @@ -180,8 +191,79 @@ fn slice_to_allow_origin(origins: &str) -> Result { let allow_origin_result: Result, Error> = origins .split(',') - .map(|s| HeaderValue::from_str(s.trim()).map_err(Error::InvalidHeaderValue)) + .map(|s| { + validate_origin(s)?; + HeaderValue::from_str(s.trim()).map_err(Error::InvalidHeaderValue) + }) .collect(); let allow_origin_list = allow_origin_result?; Ok(AllowOrigin::list(allow_origin_list)) } + +fn validate_origin(origin_str: &str) -> Result<(), Error> { + let Ok(url) = Url::parse(origin_str) else { + return Err(Error::InvalidOrigin(InvalidOriginError::ParseError)); + }; + if !matches!(url.scheme(), "http" | "https") { + return Err(Error::InvalidOrigin(InvalidOriginError::InvalidScheme)); + }; + if url.host_str().is_none() { + return Err(Error::InvalidOrigin(InvalidOriginError::MissingHost)); + }; + if url.port().is_none() { + return Err(Error::InvalidOrigin(InvalidOriginError::MissingPort)); + }; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::validate_origin; + + #[test] + fn test_valid_origin() { + let valid_origin = "http://localhost:3000"; + assert!( + validate_origin(valid_origin).is_ok(), + "Expected a valid origin, but got an invalid one" + ); + } + + #[test] + fn test_invalid_origin_missing_scheme() { + let invalid_origin = "localhost:3000"; + assert!( + validate_origin(invalid_origin).is_err(), + "Expected an invalid origin (missing scheme), but got a valid one" + ); + } + + #[test] + fn test_invalid_origin_missing_host() { + let invalid_origin = "http://:3000"; + assert!( + validate_origin(invalid_origin).is_err(), + "Expected an invalid origin (missing host), but got a valid one" + ); + } + + #[test] + fn test_invalid_origin_missing_port() { + // We require that the port always be specified in origins + let invalid_origin = "http://localhost"; + assert!( + validate_origin(invalid_origin).is_err(), + "Expected an invalid origin (missing port), but got a valid one" + ); + } + + #[test] + fn test_invalid_origin_invalid_scheme() { + // We only allow http and https origins + let invalid_origin = "ftp://example.com"; + assert!( + validate_origin(invalid_origin).is_err(), + "Expected an invalid origin (invalid scheme), but got a valid one" + ); + } +} diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -167,8 +167,7 @@ pub const OPENSEARCH_ENDPOINT: &str = "OPENSEARCH_ENDPOINT"; pub const DEFAULT_OPENSEARCH_ENDPOINT: &str = - "identity-search-domain.us-east-2.opensearch.localhost.local -stack.cloud:4566"; + "identity-search-domain.us-east-2.opensearch.localhost.localstack.cloud:4566"; pub const IDENTITY_SEARCH_INDEX: &str = "users"; pub const IDENTITY_SEARCH_RESULT_SIZE: u32 = 20;