diff --git a/services/commtest/tests/identity_tunnelbroker_tests.rs b/services/commtest/tests/identity_tunnelbroker_tests.rs new file mode 100644 index 000000000..bbc88596a --- /dev/null +++ b/services/commtest/tests/identity_tunnelbroker_tests.rs @@ -0,0 +1,110 @@ +mod client { + tonic::include_proto!("identity.client"); +} +mod auth_proto { + tonic::include_proto!("identity.authenticated"); +} +use auth_proto::identity_client_service_client::IdentityClientServiceClient as AuthClient; +use client::identity_client_service_client::IdentityClientServiceClient; +use client::UploadOneTimeKeysRequest; +use commtest::identity::device::create_device; +use futures_util::SinkExt; +use futures_util::StreamExt; +use tokio_tungstenite::{connect_async, tungstenite::Message}; +use tonic::transport::Endpoint; +use tonic::Request; +use tunnelbroker_messages::{ + ConnectionInitializationMessage, DeviceTypes, RefreshKeyRequest, +}; + +#[tokio::test] +async fn test_refresh_keys_request_upon_depletion() { + let device_info = create_device().await; + + let mut identity_client = + IdentityClientServiceClient::connect("http://127.0.0.1:50054") + .await + .expect("Couldn't connect to identitiy service"); + + let upload_request = UploadOneTimeKeysRequest { + user_id: device_info.user_id.clone(), + device_id: device_info.device_id.clone(), + access_token: device_info.access_token.clone(), + content_one_time_pre_keys: vec!["content1".to_string()], + notif_one_time_pre_keys: vec!["notif1".to_string()], + }; + + identity_client + .upload_one_time_keys(upload_request) + .await + .unwrap(); + + // Request outbound keys, which should trigger identity service to ask for more keys + let channel = Endpoint::from_static("http://[::1]:50054") + .connect() + .await + .unwrap(); + + let mut client = + AuthClient::with_interceptor(channel, |mut inter_request: Request<()>| { + let metadata = inter_request.metadata_mut(); + metadata.insert("user_id", device_info.user_id.parse().unwrap()); + metadata.insert("device_id", device_info.device_id.parse().unwrap()); + metadata + .insert("access_token", device_info.access_token.parse().unwrap()); + Ok(inter_request) + }); + + let keyserver_request = auth_proto::OutboundKeysForUserRequest { + user_id: device_info.user_id.clone(), + }; + + println!("Getting keyserver info for user, {}", device_info.user_id); + let first_reponse = client + .get_keyserver_keys(keyserver_request.clone()) + .await + .expect("Second keyserver keys request failed") + .into_inner() + .keyserver_info + .unwrap(); + + // The current threshold is 5, but we only upload two. Should receive request + // from tunnelbroker to refresh keys + // Create session as a keyserver + + let (mut socket, _) = connect_async("ws://localhost:51001") + .await + .expect("Can't connect"); + + let session_request = ConnectionInitializationMessage { + device_id: device_info.device_id.to_string(), + access_token: device_info.access_token.to_string(), + user_id: device_info.user_id.to_string(), + notify_token: None, + device_type: DeviceTypes::Keyserver, + device_app_version: None, + device_os: None, + }; + + let serialized_request = serde_json::to_string(&session_request) + .expect("Failed to serialize connection request"); + + socket + .send(Message::Text(serialized_request)) + .await + .expect("Failed to send message"); + + // Have keyserver receive any websocket messages + if let Some(Ok(response)) = socket.next().await { + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + + let expected_response = RefreshKeyRequest { + device_id: device_info.device_id.to_string(), + number_of_keys: 5, + }; + assert_eq!(serialized_response, expected_response); + }; +} diff --git a/services/tunnelbroker/src/config.rs b/services/tunnelbroker/src/config.rs index 8c87a4071..ad0adacc4 100644 --- a/services/tunnelbroker/src/config.rs +++ b/services/tunnelbroker/src/config.rs @@ -1,54 +1,54 @@ use crate::constants; use anyhow::{ensure, Result}; use clap::Parser; use once_cell::sync::Lazy; use tracing::info; #[derive(Parser)] #[command(version, about, long_about = None)] pub struct AppConfig { /// gRPC server listening port #[arg(long, default_value_t = constants::GRPC_SERVER_PORT)] pub grpc_port: u16, /// HTTP server listening port #[arg(long, default_value_t = 51001)] pub http_port: u16, /// AMQP server URI - #[arg(long, default_value_t = String::from("amqp://localhost:5672"))] + #[arg(long, default_value_t = String::from("amqp://comm:comm@localhost:5672"))] pub amqp_uri: String, /// AWS Localstack service URL #[arg(env = "LOCALSTACK_ENDPOINT")] #[arg(long)] pub localstack_endpoint: Option, } /// Stores configuration parsed from command-line arguments /// and environment variables pub static CONFIG: Lazy = Lazy::new(AppConfig::parse); /// Processes the command-line arguments and environment variables. /// Should be called at the beginning of the `main()` function. pub(super) fn parse_cmdline_args() -> Result<()> { // force evaluation of the lazy initialized config let cfg = Lazy::force(&CONFIG); // Perform some additional validation for CLI args ensure!( cfg.grpc_port != cfg.http_port, "gRPC and HTTP ports cannot be the same: {}", cfg.grpc_port ); Ok(()) } /// Provides region/credentials configuration for AWS SDKs pub async fn load_aws_config() -> aws_config::SdkConfig { let mut config_builder = aws_config::from_env(); if let Some(endpoint) = &CONFIG.localstack_endpoint { info!("Using localstack URL: {}", endpoint); config_builder = config_builder.endpoint_url(endpoint); } config_builder.load().await } diff --git a/shared/tunnelbroker_messages/src/messages/mod.rs b/shared/tunnelbroker_messages/src/messages/mod.rs index 788d6a933..98f8b99e0 100644 --- a/shared/tunnelbroker_messages/src/messages/mod.rs +++ b/shared/tunnelbroker_messages/src/messages/mod.rs @@ -1,15 +1,15 @@ // Messages sent between tunnelbroker and a device pub mod keys; pub mod session; pub use keys::*; pub use session::*; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] #[serde(untagged)] pub enum Messages { RefreshKeysRequest(RefreshKeyRequest), - SessionRequest(SessionRequest), + SessionRequest(ConnectionInitializationMessage), } diff --git a/shared/tunnelbroker_messages/src/messages/session.rs b/shared/tunnelbroker_messages/src/messages/session.rs index 7c9df423f..eb92c2f30 100644 --- a/shared/tunnelbroker_messages/src/messages/session.rs +++ b/shared/tunnelbroker_messages/src/messages/session.rs @@ -1,68 +1,71 @@ // Messages sent between tunnelbroker and a device use serde::{Deserialize, Serialize}; -/// The workflow when estabilishing a tunnelbroker connection: +/// The workflow when establishing a tunnelbroker connection: /// - Client sends SessionRequest /// - Tunnelbroker validates access_token with identity service /// - Tunnelbroker emits an AMQP message declaring that it has opened a new /// connection with a given device, so that the respective tunnelbroker /// instance can close the existing connection. /// - Tunnelbroker returns a session_id representing that the connection was /// accepted /// - Tunnelbroker will flush all messages related to device from RabbitMQ. /// This must be done first before flushing DynamoDB to prevent duplicated /// messages. /// - Tunnelbroker flushes all messages in DynamoDB /// - Tunnelbroker orders messages by creation date (oldest first), and sends /// messages to device /// - Tunnelbroker then polls for incoming messages from device #[derive(Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "camelCase")] pub enum DeviceTypes { Mobile, Web, Keyserver, } /// Message sent by a client to tunnelbroker to initiate a websocket /// session. Tunnelbroker will then validate the access token with identity /// service before continuing with the request. #[derive(Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] -pub struct SessionRequest { +pub struct ConnectionInitializationMessage { + pub user_id: String, pub device_id: String, pub access_token: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } #[derive(Serialize, Deserialize)] pub struct SessionResponse { pub session_id: String, } #[cfg(test)] mod session_tests { use super::*; #[test] fn test_session_deserialization() { let example_payload = r#"{ "type": "sessionRequest", "accessToken": "xkdeifjsld", "deviceId": "foo", + "userId": "alice", "deviceType": "keyserver" }"#; let request = - serde_json::from_str::(example_payload).unwrap(); + serde_json::from_str::(example_payload) + .unwrap(); assert_eq!(request.device_id, "foo"); assert_eq!(request.access_token, "xkdeifjsld"); assert_eq!(request.device_os, None); assert_eq!(request.device_type, DeviceTypes::Keyserver); } }