diff --git a/services/commtest/tests/tunnelbroker_integration_test.rs b/services/commtest/tests/tunnelbroker_integration_test.rs --- a/services/commtest/tests/tunnelbroker_integration_test.rs +++ b/services/commtest/tests/tunnelbroker_integration_test.rs @@ -45,11 +45,70 @@ payload, }; let grpc_message = tonic::Request::new(request); + + tunnelbroker_client + .send_message_to_device(grpc_message) + .await + .unwrap(); + + // Have keyserver receive any websocket messages + let response = socket.next().await.unwrap().unwrap(); + + // Check that message received by keyserver matches what identity server + // issued + let serialized_response: RefreshKeyRequest = + serde_json::from_str(&response.to_text().unwrap()).unwrap(); + assert_eq!(serialized_response, refresh_request); +} + +/// Test that a message to an offline device gets pushed to dynamodb +/// then recalled once a device connects +#[tokio::test] +async fn presist_messages() { + // Send request for keyserver to refresh keys (identity service) + let mut tunnelbroker_client = + TunnelbrokerServiceClient::connect("http://localhost:50051") + .await + .unwrap(); + + let refresh_request = messages::RefreshKeyRequest { + device_id: "bar".to_string(), + number_of_keys: 5, + }; + + let payload = serde_json::to_string(&refresh_request).unwrap(); + let request = MessageToDevice { + device_id: "bar".to_string(), + payload, + }; + let grpc_message = tonic::Request::new(request); tunnelbroker_client .send_message_to_device(grpc_message) .await .unwrap(); + // Wait one second to ensure that message had time to persist + use std::{thread, time}; + let ten_millis = time::Duration::from_millis(50); + thread::sleep(ten_millis); + + // Create session as a keyserver + let (mut socket, _) = connect_async("ws://localhost:51001") + .await + .expect("Can't connect"); + + let session_request = r#"{ + "type": "sessionRequest", + "accessToken": "xkdexfjsld", + "deviceId": "bar", + "deviceType": "keyserver" + }"#; + + socket + .send(Message::Text(session_request.to_string())) + .await + .expect("Failed to send message"); + // Have keyserver receive any websocket messages let response = socket.next().await.unwrap().unwrap(); diff --git a/services/tunnelbroker/src/database.rs b/services/tunnelbroker/src/database.rs --- a/services/tunnelbroker/src/database.rs +++ b/services/tunnelbroker/src/database.rs @@ -1,12 +1,44 @@ use aws_config::SdkConfig; -use aws_sdk_dynamodb::Client; +use aws_sdk_dynamodb::error::SdkError; +use aws_sdk_dynamodb::operation::delete_item::{ + DeleteItemError, DeleteItemOutput, +}; +use aws_sdk_dynamodb::operation::put_item::{PutItemError, PutItemOutput}; +use aws_sdk_dynamodb::operation::query::QueryError; +use aws_sdk_dynamodb::{types::AttributeValue, Client}; +use std::collections::HashMap; use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use tracing::{debug, error}; + +use crate::constants::dynamodb::undelivered_messages::{ + PARTITION_KEY, PAYLOAD, SORT_KEY, TABLE_NAME, +}; #[derive(Clone)] pub struct DatabaseClient { client: Arc, } +pub fn unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time is misconfigured") + .as_secs() +} + +pub fn handle_ddb_error(db_error: SdkError) -> tonic::Status { + match db_error { + SdkError::TimeoutError(_) | SdkError::ServiceError(_) => { + tonic::Status::unavailable("please retry") + } + e => { + error!("Encountered an unexpected error: {}", e); + tonic::Status::failed_precondition("unexpected error") + } + } +} + impl DatabaseClient { pub fn new(aws_config: &SdkConfig) -> Self { let client = Client::new(aws_config); @@ -15,4 +47,79 @@ client: Arc::new(client), } } + + pub async fn persist_message( + &self, + device_id: &str, + payload: &str, + ) -> Result> { + let device_av = AttributeValue::S(device_id.to_string()); + let payload_av = AttributeValue::S(payload.to_string()); + let created_av = AttributeValue::N(unix_timestamp().to_string()); + + let request = self + .client + .put_item() + .table_name(TABLE_NAME) + .item(PARTITION_KEY, device_av) + .item(SORT_KEY, created_av) + .item(PAYLOAD, payload_av); + + debug!("Persisting message to device: {}", &device_id); + + request.send().await + } + + pub async fn retrieve_messages( + &self, + device_id: &str, + ) -> Result>, SdkError> { + debug!("Retrieving messages for device: {}", device_id); + + let response = self + .client + .query() + .table_name(TABLE_NAME) + .key_condition_expression(format!("{} = :u", PARTITION_KEY)) + .expression_attribute_values( + ":u", + AttributeValue::S(device_id.to_string()), + ) + .consistent_read(true) + .send() + .await?; + + debug!("Retrieved {} messages for {}", response.count, device_id); + match response.items { + None => Ok(Vec::new()), + Some(items) => Ok(items.to_vec()), + } + } + + pub async fn delete_message( + &self, + device_id: &str, + created_at: &str, + ) -> Result> { + debug!("Deleting message for device: {}", device_id); + + let key = HashMap::from([ + ( + PARTITION_KEY.to_string(), + AttributeValue::S(device_id.to_string()), + ), + ( + SORT_KEY.to_string(), + AttributeValue::N(created_at.to_string()), + ), + ]); + + self + .client + .delete_item() + .table_name(TABLE_NAME) + .set_key(Some(key)) + .send() + .await + } } diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -7,13 +7,14 @@ }; use proto::Empty; use tonic::transport::Server; -use tonic::Status; -use tracing::debug; +use tracing::{debug, error}; +use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, ACTIVE_CONNECTIONS, CONFIG}; -#[derive(Debug, Default)] -struct TunnelbrokerGRPC {} +struct TunnelbrokerGRPC { + client: DatabaseClient, +} #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { @@ -24,18 +25,27 @@ let message = request.into_inner(); debug!("Received message for {}", &message.device_id); - // TODO: Persist messages for inactive connections - let tx = ACTIVE_CONNECTIONS - .get(&message.device_id) - .ok_or(Status::unavailable("Device does not exist"))?; - tx.send(message.payload).expect("Unable to send message"); + if let Some(tx) = ACTIVE_CONNECTIONS.get(&message.device_id) { + if let Err(_) = tx.send(message.payload) { + error!("Unable to send message to device: {}", &message.device_id); + ACTIVE_CONNECTIONS.remove(&message.device_id); + } + } else { + self + .client + .persist_message(&message.device_id, &message.payload) + .await + .map_err(handle_ddb_error)?; + } let response = tonic::Response::new(Empty {}); Ok(response) } } -pub async fn run_server() -> Result<(), tonic::transport::Error> { +pub async fn run_server( + client: DatabaseClient, +) -> Result<(), tonic::transport::Error> { let addr = format!("[::1]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); @@ -44,7 +54,7 @@ Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) - .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC::default())) + .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -28,10 +28,10 @@ config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; - let _db_client = database::DatabaseClient::new(&aws_config); + let db_client = database::DatabaseClient::new(&aws_config); - let grpc_server = grpc::run_server(); - let websocket_server = websockets::run_server(); + let grpc_server = grpc::run_server(db_client.clone()); + let websocket_server = websockets::run_server(db_client.clone()); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,7 +1,7 @@ mod session; +use crate::database::DatabaseClient; use crate::CONFIG; -use futures::future; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::{StreamExt, TryStreamExt}; @@ -12,11 +12,10 @@ use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use tracing::{debug, error, info}; -use tunnelbroker_messages::messages::Messages; use crate::ACTIVE_CONNECTIONS; -pub async fn run_server() -> Result<(), Error> { +pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("127.0.0.1:{}", &CONFIG.http_port)); @@ -24,14 +23,18 @@ info!("Listening on: {}", addr); while let Ok((stream, addr)) = listener.accept().await { - tokio::spawn(accept_connection(stream, addr)); + tokio::spawn(accept_connection(stream, addr, db_client.clone())); } Ok(()) } /// Handler for any incoming websocket connections -async fn accept_connection(raw_stream: TcpStream, addr: SocketAddr) { +async fn accept_connection( + raw_stream: TcpStream, + addr: SocketAddr, + db_client: DatabaseClient, +) { debug!("Incoming connection from: {}", addr); let ws_stream = match tokio_tungstenite::accept_async(raw_stream).await { @@ -49,12 +52,12 @@ // Create channel for messages to be passed to this connection let (tx, mut rx) = mpsc::unbounded_channel::(); - let session = session::WebsocketSession::new(tx.clone()); - let handle_incoming = incoming.try_for_each(|msg| { + let session = session::WebsocketSession::new(tx.clone(), db_client.clone()); + let handle_incoming = incoming.try_for_each(|msg| async { debug!("Received message from {}", addr); match msg { Message::Text(text) => { - match session.handle_message_from_device(&text) { + match session.handle_message_from_device(&text).await { Ok(_) => { debug!("Successfully handled message: {}", text) } @@ -67,8 +70,7 @@ error!("Invalid message was received"); } } - - future::ok(()) + Ok(()) }); debug!("Polling for messages from: {}", addr); diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,26 +1,58 @@ use tracing::debug; use tunnelbroker_messages::Messages; -use crate::ACTIVE_CONNECTIONS; +use crate::{ + constants::dynamodb::undelivered_messages::CREATED_AT, + database::DatabaseClient, ACTIVE_CONNECTIONS, +}; pub struct WebsocketSession { tx: tokio::sync::mpsc::UnboundedSender, + db_client: DatabaseClient, } impl WebsocketSession { pub fn new( tx: tokio::sync::mpsc::UnboundedSender, + db_client: DatabaseClient, ) -> WebsocketSession { - WebsocketSession { tx } + WebsocketSession { tx, db_client } } - pub fn handle_message_from_device( + pub async fn handle_message_from_device( &self, message: &str, ) -> Result<(), serde_json::Error> { match serde_json::from_str::(message)? { Messages::SessionRequest(session_info) => { - ACTIVE_CONNECTIONS.insert(session_info.device_id, self.tx.clone()); + // TODO: Authenticate device using auth token + // Check for persisted messages + let messages = self + .db_client + .retrieve_messages(&session_info.device_id) + .await + .expect("Failed to retreive messages"); + + ACTIVE_CONNECTIONS + .insert(session_info.device_id.clone(), self.tx.clone()); + + for message in messages { + let payload = + message.get("payload").unwrap().as_s().unwrap().to_string(); + self + .tx + .send(payload) + .expect("Failed to send message to client"); + let created_at = + message.get(CREATED_AT).unwrap().as_n().unwrap().to_string(); + self + .db_client + .delete_message(&session_info.device_id, &created_at) + .await + .expect("Failed to delete messages"); + } + + debug!("Flushed messages for device: {}", &session_info.device_id); } _ => { debug!("Received invalid request");