diff --git a/services/commtest/Cargo.lock b/services/commtest/Cargo.lock --- a/services/commtest/Cargo.lock +++ b/services/commtest/Cargo.lock @@ -165,11 +165,13 @@ "num_cpus", "openssl", "prost", + "serde_json", "sha2", "tokio", "tokio-tungstenite", "tonic", "tonic-build", + "tunnelbroker_messages", "url", ] @@ -851,6 +853,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + [[package]] name = "semver" version = "1.0.16" @@ -859,9 +867,34 @@ [[package]] name = "serde" -version = "1.0.151" +version = "1.0.160" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.160" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "serde_json" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fed41fc1a24994d044e6db6935e69511a1153b52c15eb42493b26fa87feba0" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +dependencies = [ + "itoa", + "ryu", + "serde", +] [[package]] name = "sha1" @@ -1221,6 +1254,14 @@ "utf-8", ] +[[package]] +name = "tunnelbroker_messages" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "typenum" version = "1.16.0" diff --git a/services/commtest/Cargo.toml b/services/commtest/Cargo.toml --- a/services/commtest/Cargo.toml +++ b/services/commtest/Cargo.toml @@ -20,8 +20,10 @@ sha2 = "0.10.2" hex = "0.4.3" tokio-tungstenite = "0.18.0" +tunnelbroker_messages = { path = "../../shared/tunnelbroker_messages" } url = "2.3.1" futures-util = "0.3.28" +serde_json = "1.0.96" [build-dependencies] tonic-build = "0.8" diff --git a/services/commtest/tests/tunnelbroker_integration_test.rs b/services/commtest/tests/tunnelbroker_integration_test.rs --- a/services/commtest/tests/tunnelbroker_integration_test.rs +++ b/services/commtest/tests/tunnelbroker_integration_test.rs @@ -1,8 +1,17 @@ use futures_util::SinkExt; use tokio_tungstenite::{connect_async, tungstenite::Message}; +mod proto { + tonic::include_proto!("tunnelbroker"); +} +use futures_util::StreamExt; +use proto::tunnelbroker_service_client::TunnelbrokerServiceClient; +use proto::MessageToDevice; +use tunnelbroker_messages as messages; +use tunnelbroker_messages::RefreshKeyRequest; #[tokio::test] -async fn open_websocket_connection() { +async fn send_refresh_request() { + // Create session as a keyserver let (mut socket, _) = connect_async("ws://localhost:51001") .await .expect("Can't connect"); @@ -18,4 +27,35 @@ .send(Message::Text(session_request.to_string())) .await .expect("Failed to send message"); + + // Send request for keyserver to refresh keys (identity service) + let mut tunnelbroker_client = + TunnelbrokerServiceClient::connect("http://localhost:50051") + .await + .unwrap(); + + let refresh_request = messages::RefreshKeyRequest { + device_id: "foo".to_string(), + number_of_keys: 5, + }; + + let payload = serde_json::to_string(&refresh_request).unwrap(); + let request = MessageToDevice { + device_id: "foo".to_string(), + payload, + }; + let grpc_message = tonic::Request::new(request); + tunnelbroker_client + .send_message_to_device(grpc_message) + .await + .unwrap(); + + // Have keyserver receive any websocket messages + let response = socket.next().await.unwrap().unwrap(); + + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + assert_eq!(serialized_response, refresh_request); } diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -5,9 +5,12 @@ use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; +use proto::Empty; use tonic::transport::Server; +use tonic::Status; +use tracing::debug; -use crate::constants; +use crate::{constants, ACTIVE_CONNECTIONS}; #[derive(Debug, Default)] struct TunnelbrokerGRPC {} @@ -16,9 +19,19 @@ impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( &self, - _request: tonic::Request, + request: tonic::Request, ) -> Result, tonic::Status> { - unimplemented!() + let message = request.into_inner(); + + debug!("Received message for {}", &message.device_id); + // TODO: Persist messages for inactive connections + let tx = ACTIVE_CONNECTIONS + .get(&message.device_id) + .ok_or(Status::unavailable("Device does not exist"))?; + tx.send(message.payload).expect("Unable to send message"); + + let response = tonic::Response::new(Empty {}); + Ok(response) } } diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -7,11 +7,9 @@ use tokio::sync::mpsc::UnboundedSender; use tracing::{self, Level}; use tracing_subscriber::EnvFilter; -use tunnelbroker_messages::Messages; -pub static ACTIVE_CONNECTIONS: Lazy< - DashMap>, -> = Lazy::new(DashMap::new); +pub static ACTIVE_CONNECTIONS: Lazy>> = + Lazy::new(DashMap::new); #[tokio::main] async fn main() -> Result<(), io::Error> { diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,10 +1,13 @@ use futures::future; +use futures_util::stream::SplitSink; +use futures_util::SinkExt; use futures_util::{StreamExt, TryStreamExt}; use std::net::SocketAddr; use std::{env, io::Error}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::WebSocketStream; use tracing::{debug, error, info}; use tunnelbroker_messages::messages::Messages; @@ -39,13 +42,15 @@ } }; - let (_outgoing, incoming) = ws_stream.split(); + let (mut outgoing, incoming) = ws_stream.split(); + // Create channel for messages to be passed to this connection + let (tx, mut rx) = mpsc::unbounded_channel::(); let handle_incoming = incoming.try_for_each(|msg| { debug!("Received message from {}", addr); match msg { Message::Text(text) => { - match handle_message(&text) { + match handle_message_from_device(&text, &tx) { Ok(_) => { debug!("Successfully handled message: {}", text) } @@ -62,14 +67,11 @@ future::ok(()) }); - // Create channel for messages to be passed to this connection - let (tx, mut rx) = mpsc::unbounded_channel::(); - // TODO: Use device's public key, once we support the SessionRequest message - ACTIVE_CONNECTIONS.insert("test".to_string(), tx.clone()); - debug!("Polling for messages from: {}", addr); + // Poll for messages either being sent to the device (rx) + // or messages being received from the device (handle_incoming) tokio::select! { - Some(_) = rx.recv() => { debug!("Received message from channel") }, + Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; }, Ok(_) = handle_incoming => { debug!("Received message from websocket") }, else => { info!("Connection with {} closed.", addr); @@ -78,8 +80,27 @@ } } -fn handle_message(message: &str) -> Result<(), serde_json::Error> { - serde_json::from_str::(message)?; +fn handle_message_from_device( + message: &str, + tx: &tokio::sync::mpsc::UnboundedSender, +) -> Result<(), serde_json::Error> { + match serde_json::from_str::(message)? { + Messages::SessionRequest(session_info) => { + ACTIVE_CONNECTIONS.insert(session_info.device_id, tx.clone()); + } + _ => { + debug!("Received invalid request"); + } + } Ok(()) } + +async fn handle_message_from_service( + incoming_payload: String, + outgoing: &mut SplitSink, Message>, +) { + if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await { + error!("Failed to send message to device: {}", e); + } +} diff --git a/shared/tunnelbroker_messages/src/messages/keys.rs b/shared/tunnelbroker_messages/src/messages/keys.rs --- a/shared/tunnelbroker_messages/src/messages/keys.rs +++ b/shared/tunnelbroker_messages/src/messages/keys.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(tag = "type", rename_all = "camelCase")] pub struct RefreshKeyRequest { pub device_id: String,