diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock --- a/services/tunnelbroker/Cargo.lock +++ b/services/tunnelbroker/Cargo.lock @@ -968,19 +968,6 @@ "syn 2.0.15", ] -[[package]] -name = "dashmap" -version = "5.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" -dependencies = [ - "cfg-if", - "hashbrown", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "derive_more" version = "0.99.17" @@ -2969,7 +2956,6 @@ "aws-types", "base64 0.20.0", "clap", - "dashmap", "derive_more", "env_logger", "fcm", diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml --- a/services/tunnelbroker/Cargo.toml +++ b/services/tunnelbroker/Cargo.toml @@ -15,7 +15,6 @@ aws-types = "0.55" base64 = "0.20" clap = { version = "4.2", features = ["derive", "env"] } -dashmap = "5.4" env_logger = "0.9" fcm = "0.9" futures = "0.3" diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -2,20 +2,29 @@ tonic::include_proto!("tunnelbroker"); } +use lapin::{options::BasicPublishOptions, BasicProperties}; use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; use proto::Empty; use tonic::transport::Server; -use tracing::{debug, error}; +use tracing::debug; use crate::database::{handle_ddb_error, DatabaseClient}; -use crate::{constants, ACTIVE_CONNECTIONS, CONFIG}; +use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, + amqp_channel: lapin::Channel, } +// By setting mandatory to true, we don't wait for a confirmation for an eventual +// delivery, instead we get an immediate undelivered error +const PUBLISH_OPTIONS: BasicPublishOptions = BasicPublishOptions { + immediate: false, + mandatory: true, +}; + #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( @@ -25,11 +34,19 @@ let message = request.into_inner(); debug!("Received message for {}", &message.device_id); - if let Some(tx) = ACTIVE_CONNECTIONS.get(&message.device_id) { - if let Err(_) = tx.send(message.payload) { - error!("Unable to send message to device: {}", &message.device_id); - ACTIVE_CONNECTIONS.remove(&message.device_id); - } + if let Ok(confirmation) = self + .amqp_channel + .basic_publish( + "", + &message.device_id, + PUBLISH_OPTIONS, + &message.payload.as_bytes(), + BasicProperties::default(), + ) + .await + { + debug!("Forwarded message: {:?}", &message); + confirmation.await.expect("Didn't get acknowledgement"); } else { self .client @@ -45,16 +62,25 @@ pub async fn run_server( client: DatabaseClient, + ampq_connection: &lapin::Connection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::1]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); + let amqp_channel = ampq_connection + .create_channel() + .await + .expect("Unable to create amqp channel"); + tracing::info!("Websocket server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) - .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client })) + .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { + client, + amqp_channel, + })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -7,15 +7,9 @@ use anyhow::{anyhow, Result}; use config::CONFIG; -use dashmap::DashMap; -use once_cell::sync::Lazy; -use tokio::sync::mpsc::UnboundedSender; use tracing::{self, Level}; use tracing_subscriber::EnvFilter; -pub static ACTIVE_CONNECTIONS: Lazy>> = - Lazy::new(DashMap::new); - #[tokio::main] async fn main() -> Result<()> { let filter = EnvFilter::builder() @@ -32,8 +26,9 @@ let db_client = database::DatabaseClient::new(&aws_config); let amqp_connection = amqp::connect().await; - let grpc_server = grpc::run_server(db_client.clone()); - let websocket_server = websockets::run_server(db_client.clone()); + let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); + let websocket_server = + websockets::run_server(db_client.clone(), &amqp_connection); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -3,13 +3,17 @@ use crate::database::DatabaseClient; use crate::CONFIG; use futures_util::StreamExt; +use lapin::options::{BasicConsumeOptions, QueueDeclareOptions}; +use lapin::types::FieldTable; use std::net::SocketAddr; use std::{env, io::Error}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; -use tracing::{debug, info}; +use tracing::{debug, error, info}; -pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> { +pub async fn run_server( + db_client: DatabaseClient, + amqp_connection: &lapin::Connection, +) -> Result<(), Error> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("127.0.0.1:{}", &CONFIG.http_port)); @@ -17,7 +21,11 @@ info!("Listening on: {}", addr); while let Ok((stream, addr)) = listener.accept().await { - tokio::spawn(accept_connection(stream, addr, db_client.clone())); + let channel = amqp_connection + .create_channel() + .await + .expect("Unable to create amqp channel"); + tokio::spawn(accept_connection(stream, addr, db_client.clone(), channel)); } Ok(()) @@ -28,6 +36,7 @@ raw_stream: TcpStream, addr: SocketAddr, db_client: DatabaseClient, + amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); @@ -43,20 +52,66 @@ }; let (outgoing, mut incoming) = ws_stream.split(); - // Create channel for messages to be passed to this connection - let (tx, mut rx) = mpsc::unbounded_channel::(); - let mut session = session::WebsocketSession::new(outgoing, db_client.clone()); + // We don't know the identity of the device until it sends the session + // request over the websocket connection + let mut session = if let Some(Ok(first_msg)) = incoming.next().await { + if let Ok(mut session) = session::WebsocketSession::from_frame( + outgoing, + db_client.clone(), + first_msg, + ) { + // TODO: Authenticate device + session::consume_error(session.deliver_persisted_messages().await); + session + } else { + error!("Device failed to send valid connection request."); + return; + } + } else { + error!("Device closed connection before sending first message"); + return; + }; + + let _amqp_queue = amqp_channel + .queue_declare( + &session.device_info.device_id, + QueueDeclareOptions::default(), + FieldTable::default(), + ) + .await + .expect(&format!( + "Failed to create amqp queue for device: {}", + &session.device_info.device_id + )); + + let mut amqp_consumer = amqp_channel + .basic_consume( + &session.device_info.device_id, + "tunnelbroker", + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .expect("Failed to create amqp consumer."); // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { - Some(message) = rx.recv() => { session.send_message_to_device(message).await; }, + Some(Ok(delivery)) = amqp_consumer.next() => { + if let Ok(message) = std::str::from_utf8(&delivery.data) { + session.send_message_to_device(message.to_string()).await; + } else { + error!("Invalid payload"); + } + }, device_message = incoming.next() => { match device_message { - Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await, + Some(Ok(msg)) => { + session::consume_error(session.handle_websocket_frame_from_device(msg).await); + } _ => { debug!("Connection to {} closed remotely.", addr); break; diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,15 +1,12 @@ use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; -use tokio::{net::TcpStream, sync::mpsc::UnboundedSender}; +use tokio::net::TcpStream; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use tracing::{debug, error}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; -use crate::{ - database::{self, DatabaseClient, DeviceMessage}, - ACTIVE_CONNECTIONS, -}; +use crate::database::{self, DatabaseClient, DeviceMessage}; pub struct DeviceInfo { pub device_id: String, @@ -22,7 +19,7 @@ pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, - device_info: Option, + pub device_info: DeviceInfo, } #[derive(Debug, derive_more::Display, derive_more::From)] @@ -32,99 +29,105 @@ MessageError(database::MessageErrors), } -fn consume_error(result: Result) { +pub fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } + +// Parse a session request and retreive the device information +pub fn handle_first_message_from_device( + message: &str, +) -> Result { + let serialized_message = serde_json::from_str::(message)?; + + match serialized_message { + Messages::SessionRequest(mut session_info) => { + let device_info = DeviceInfo { + device_id: session_info.device_id.clone(), + notify_token: session_info.notify_token.take(), + device_type: session_info.device_type, + device_app_version: session_info.device_app_version.take(), + device_os: session_info.device_os.take(), + }; + + return Ok(device_info); + } + _ => { + debug!("Received invalid request"); + return Err(SessionError::InvalidMessage); + } + } +} + impl WebsocketSession { - pub fn new( + pub fn from_frame( tx: SplitSink, Message>, db_client: DatabaseClient, - ) -> WebsocketSession { - WebsocketSession { + frame: Message, + ) -> Result { + let device_info = match frame { + Message::Text(payload) => handle_first_message_from_device(&payload)?, + _ => { + error!("Client sent wrong frame type for establishing connection"); + return Err(SessionError::InvalidMessage); + } + }; + + Ok(WebsocketSession { tx, db_client, - device_info: None, - } + device_info, + }) } pub async fn handle_websocket_frame_from_device( &mut self, frame: Message, - tx: UnboundedSender, - ) { - debug!("Received message from device: {}", frame); - let result = match frame { + ) -> Result<(), SessionError> { + match frame { Message::Text(payload) => { - self.handle_message_from_device(&payload, tx).await + debug!("Received message from device: {}", payload); + Ok(()) } Message::Close(_) => { self.close().await; Ok(()) } _ => Err(SessionError::InvalidMessage), - }; - consume_error(result); + } } - pub async fn handle_message_from_device( + pub async fn deliver_persisted_messages( &mut self, - message: &str, - tx: UnboundedSender, ) -> Result<(), SessionError> { - let serialized_message = serde_json::from_str::(message)?; - - match serialized_message { - Messages::SessionRequest(mut session_info) => { - // TODO: Authenticate device using auth token - - // Check if session request was already sent - if self.device_info.is_some() { - return Err(SessionError::InvalidMessage); - } - - let device_info = DeviceInfo { - device_id: session_info.device_id.clone(), - notify_token: session_info.notify_token.take(), - device_type: session_info.device_type, - device_app_version: session_info.device_app_version.take(), - device_os: session_info.device_os.take(), - }; - - // Check for persisted messages - let messages = self - .db_client - .retrieve_messages(&device_info.device_id) - .await - .unwrap_or_else(|e| { - error!("Error while retrieving messages: {}", e); - Vec::new() - }); - - ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone()); - - for message in messages { - let device_message = DeviceMessage::from_hashmap(message)?; - self.send_message_to_device(device_message.payload).await; - if let Err(e) = self - .db_client - .delete_message(&device_info.device_id, &device_message.created_at) - .await - { - error!("Failed to delete message: {}:", e); - } - } - - debug!("Flushed messages for device: {}", &session_info.device_id); - - self.device_info = Some(device_info); - } - _ => { - debug!("Received invalid request"); + // Check for persisted messages + let messages = self + .db_client + .retrieve_messages(&self.device_info.device_id) + .await + .unwrap_or_else(|e| { + error!("Error while retrieving messages: {}", e); + Vec::new() + }); + + for message in messages { + let device_message = DeviceMessage::from_hashmap(message)?; + self.send_message_to_device(device_message.payload).await; + if let Err(e) = self + .db_client + .delete_message(&self.device_info.device_id, &device_message.created_at) + .await + { + error!("Failed to delete message: {}:", e); } } + debug!( + "Flushed messages for device: {}", + &self.device_info.device_id + ); + Ok(()) } @@ -136,10 +139,6 @@ // Release websocket and remove from active connections pub async fn close(&mut self) { - if let Some(device_info) = &self.device_info { - ACTIVE_CONNECTIONS.remove(&device_info.device_id); - } - if let Err(e) = self.tx.close().await { debug!("Failed to close session: {}", e); }