diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock --- a/services/tunnelbroker/Cargo.lock +++ b/services/tunnelbroker/Cargo.lock @@ -1071,6 +1071,18 @@ "wasi", ] +[[package]] +name = "grpc_clients" +version = "0.1.0" +dependencies = [ + "derive_more", + "prost", + "tonic 0.9.2", + "tonic-build 0.9.2", + "tracing", + "tracing-subscriber", +] + [[package]] name = "h2" version = "0.3.18" @@ -1202,7 +1214,7 @@ "rustls 0.20.8", "rustls-native-certs", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.4", ] [[package]] @@ -2254,6 +2266,16 @@ "webpki", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.1", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -2323,6 +2345,37 @@ "tracing-futures", ] +[[package]] +name = "tonic" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.0", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "rustls-pemfile", + "tokio", + "tokio-rustls 0.24.1", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tonic-build" version = "0.8.4" @@ -2336,6 +2389,19 @@ "syn 1.0.109", ] +[[package]] +name = "tonic-build" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "quote", + "syn 1.0.109", +] + [[package]] name = "tower" version = "0.4.13" @@ -2476,14 +2542,15 @@ "clap", "derive_more", "futures-util", + "grpc_clients", "lapin", "once_cell", "prost", "serde_json", "tokio", "tokio-tungstenite", - "tonic", - "tonic-build", + "tonic 0.8.3", + "tonic-build 0.8.4", "tracing", "tracing-subscriber", "tunnelbroker_messages", diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml --- a/services/tunnelbroker/Cargo.toml +++ b/services/tunnelbroker/Cargo.toml @@ -13,6 +13,7 @@ aws-sdk-dynamodb = "0.27" clap = { version = "4.2", features = ["derive", "env"] } futures-util = "0.3" +grpc_clients = { path = "../../shared/grpc_clients" } once_cell = "1.17" prost = "0.11" serde_json = "1.0" diff --git a/services/tunnelbroker/src/config.rs b/services/tunnelbroker/src/config.rs --- a/services/tunnelbroker/src/config.rs +++ b/services/tunnelbroker/src/config.rs @@ -20,6 +20,10 @@ #[arg(env = "LOCALSTACK_ENDPOINT")] #[arg(long)] pub localstack_endpoint: Option, + /// Comm Identity service URL + #[arg(env = "COMM_TUNNELBROKER_IDENTITY_ENDPOINT")] + #[arg(long, default_value_t = String::from("http://localhost:50054"))] + pub identity_endpoint: String, } /// Stores configuration parsed from command-line arguments diff --git a/services/tunnelbroker/src/database/message.rs b/services/tunnelbroker/src/database/message.rs --- a/services/tunnelbroker/src/database/message.rs +++ b/services/tunnelbroker/src/database/message.rs @@ -13,7 +13,7 @@ pub payload: String, } -#[derive(Debug, derive_more::Display)] +#[derive(Debug, derive_more::Display, derive_more::Error)] pub enum MessageErrors { SerializationError, } diff --git a/services/tunnelbroker/src/error.rs b/services/tunnelbroker/src/error.rs new file mode 100644 --- /dev/null +++ b/services/tunnelbroker/src/error.rs @@ -0,0 +1,19 @@ +#[derive( + Debug, derive_more::Display, derive_more::From, derive_more::Error, +)] +pub enum Error { + #[display(...)] + TonicError(tonic::transport::Error), + #[display(...)] + ClientError(grpc_clients::tonic::Status), + #[display(...)] + ServerError(tonic::Status), + #[display(...)] + GrpcClient(grpc_clients::error::Error), + #[display(...)] + SessionError(crate::websockets::session::SessionError), + #[display(...)] + LapinError(lapin::Error), + #[display(...)] + SerdeError(serde_json::Error), +} diff --git a/services/tunnelbroker/src/identity/mod.rs b/services/tunnelbroker/src/identity/mod.rs new file mode 100644 --- /dev/null +++ b/services/tunnelbroker/src/identity/mod.rs @@ -0,0 +1,27 @@ +use client_proto::VerifyUserAccessTokenRequest; +use grpc_clients::identity; +use grpc_clients::tonic::Request; +use identity::get_unauthenticated_client; +use identity::protos::unauthenticated as client_proto; + +use crate::config::CONFIG; +use crate::error::Error; + +/// Returns true if access token is valid +pub async fn verify_user_access_token( + user_id: &str, + device_id: &str, + access_token: &str, +) -> Result { + let mut grpc_client = + get_unauthenticated_client(&CONFIG.identity_endpoint).await?; + let message = VerifyUserAccessTokenRequest { + user_id: user_id.to_string(), + signing_public_key: device_id.to_string(), + access_token: access_token.to_string(), + }; + + let request = Request::new(message); + let response = grpc_client.verify_user_access_token(request).await?; + return Ok(response.into_inner().token_valid); +} diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -2,7 +2,9 @@ pub mod config; pub mod constants; pub mod database; +pub mod error; pub mod grpc; +pub mod identity; pub mod websockets; use anyhow::{anyhow, Result}; diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,4 +1,4 @@ -mod session; +pub mod session; use crate::database::DatabaseClient; use crate::websockets::session::SessionError; diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -7,10 +7,13 @@ use lapin::types::FieldTable; use tokio::net::TcpStream; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; +use tracing::info; use tracing::{debug, error}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::database::{self, DatabaseClient, DeviceMessage}; +use crate::error::Error; +use crate::identity; pub struct DeviceInfo { pub device_id: String, @@ -28,12 +31,16 @@ amqp_consumer: lapin::Consumer, } -#[derive(Debug, derive_more::Display, derive_more::From)] +#[derive( + Debug, derive_more::Display, derive_more::From, derive_more::Error, +)] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), + InternalError, + UnauthorizedDevice, } pub fn consume_error(result: Result) { @@ -43,9 +50,9 @@ } // Parse a session request and retrieve the device information -pub fn handle_first_message_from_device( +pub async fn handle_first_message_from_device( message: &str, -) -> Result { +) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { @@ -58,11 +65,37 @@ device_os: session_info.device_os.take(), }; + // Authenticate device + debug!("Authenticating device: {}", &session_info.device_id); + let auth_request = identity::verify_user_access_token( + &session_info.user_id, + &device_info.device_id, + &session_info.access_token, + ) + .await; + + match auth_request { + Err(e) => { + debug!("Failed to complete request to identity service: {:?}", e); + return Err(SessionError::InternalError.into()); + } + Ok(false) => { + info!("Device failed authentication: {}", &session_info.device_id); + return Err(SessionError::UnauthorizedDevice.into()); + } + Ok(true) => { + debug!( + "Successfully authenticated device: {}", + &session_info.device_id + ); + } + } + return Ok(device_info); } _ => { debug!("Received invalid request"); - return Err(SessionError::InvalidMessage); + return Err(SessionError::InvalidMessage.into()); } } } @@ -73,12 +106,14 @@ db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, - ) -> Result { + ) -> Result { let device_info = match frame { - Message::Text(payload) => handle_first_message_from_device(&payload)?, + Message::Text(payload) => { + handle_first_message_from_device(&payload).await? + } _ => { error!("Client sent wrong frame type for establishing connection"); - return Err(SessionError::InvalidMessage); + return Err(SessionError::InvalidMessage.into()); } };