diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 88cbcd63f..6b1f150f7 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,130 +1,131 @@ mod session; use crate::database::DatabaseClient; use crate::websockets::session::SessionError; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::StreamExt; use std::net::SocketAddr; use std::{env, io::Error}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use tracing::{debug, error, info}; use self::session::WebsocketSession; pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, ) -> Result<(), Error> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Unable to create amqp channel"); tokio::spawn(accept_connection(stream, addr, db_client.clone(), channel)); } Ok(()) } /// Handler for any incoming websocket connections async fn accept_connection( raw_stream: TcpStream, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match tokio_tungstenite::accept_async(raw_stream).await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { Ok(session) => session, Err(_) => { error!("Failed to create session with device"); return; } } } else { error!("Failed to create session with device"); return; }; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(message.to_string()).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { match device_message { Some(Ok(msg)) => { session::consume_error(session.handle_websocket_frame_from_device(msg).await); } _ => { debug!("Connection to {} closed remotely.", addr); break; } } }, else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } -async fn initiate_session( - outgoing: SplitSink, Message>, +async fn initiate_session( + outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, -) -> Result { +) -> Result, session::SessionError> { let mut session = session::WebsocketSession::from_frame( outgoing, db_client.clone(), frame, &amqp_channel, ) .await .map_err(|_| { error!("Device failed to send valid connection request."); SessionError::InvalidMessage })?; session::consume_error(session.deliver_persisted_messages().await); Ok(session) } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index fb709477b..3277c952f 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,172 +1,173 @@ use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use lapin::message::Delivery; use lapin::options::{BasicConsumeOptions, QueueDeclareOptions}; use lapin::types::FieldTable; -use tokio::net::TcpStream; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use tracing::{debug, error}; use tunnelbroker_messages::{session::DeviceTypes, Messages}; use crate::database::{self, DatabaseClient, DeviceMessage}; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, } -pub struct WebsocketSession { - tx: SplitSink, Message>, +pub struct WebsocketSession { + tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, } #[derive(Debug, derive_more::Display, derive_more::From)] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), } pub fn consume_error(result: Result) { if let Err(e) = result { error!("{}", e) } } // Parse a session request and retrieve the device information pub fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), }; - return Ok(device_info); + Ok(device_info) } _ => { debug!("Received invalid request"); - return Err(SessionError::InvalidMessage); + Err(SessionError::InvalidMessage) } } } -impl WebsocketSession { +impl WebsocketSession { pub async fn from_frame( - tx: SplitSink, Message>, + tx: SplitSink, Message>, db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, - ) -> Result { + ) -> Result, SessionError> { let device_info = match frame { Message::Text(payload) => handle_first_message_from_device(&payload)?, _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; // We don't currently have a use case to interact directly with the queue, // however, we need to declare a queue for a given device amqp_channel .queue_declare( &device_info.device_id, QueueDeclareOptions::default(), FieldTable::default(), ) .await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, "tunnelbroker", BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok(WebsocketSession { tx, db_client, device_info, amqp_consumer, }) } pub async fn handle_websocket_frame_from_device( &self, msg: Message, ) -> Result<(), SessionError> { debug!("Received frame: {:?}", msg); Ok(()) } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn deliver_persisted_messages( &mut self, ) -> Result<(), SessionError> { // Check for persisted messages let messages = self .db_client .retrieve_messages(&self.device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let device_message = DeviceMessage::from_hashmap(message)?; self.send_message_to_device(device_message.payload).await; if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &device_message.created_at) .await { error!("Failed to delete message: {}:", e); } } debug!( "Flushed messages for device: {}", &self.device_info.device_id ); Ok(()) } pub async fn send_message_to_device(&mut self, incoming_payload: String) { if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { error!("Failed to send message to device: {}", e); } } // Release websocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close session: {}", e); } } }