diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 6bcf41c46..889ca750f 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,107 +1,94 @@ +mod session; + use crate::CONFIG; use futures::future; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::{StreamExt, TryStreamExt}; use std::net::SocketAddr; use std::{env, io::Error}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use tracing::{debug, error, info}; use tunnelbroker_messages::messages::Messages; use crate::ACTIVE_CONNECTIONS; pub async fn run_server() -> Result<(), Error> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("127.0.0.1:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening on: {}", addr); while let Ok((stream, addr)) = listener.accept().await { tokio::spawn(accept_connection(stream, addr)); } Ok(()) } /// Handler for any incoming websocket connections async fn accept_connection(raw_stream: TcpStream, addr: SocketAddr) { debug!("Incoming connection from: {}", addr); let ws_stream = match tokio_tungstenite::accept_async(raw_stream).await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (mut outgoing, incoming) = ws_stream.split(); // Create channel for messages to be passed to this connection let (tx, mut rx) = mpsc::unbounded_channel::(); + let session = session::WebsocketSession::new(tx.clone()); let handle_incoming = incoming.try_for_each(|msg| { debug!("Received message from {}", addr); match msg { Message::Text(text) => { - match handle_message_from_device(&text, &tx) { + match session.handle_message_from_device(&text) { Ok(_) => { debug!("Successfully handled message: {}", text) } Err(e) => { error!("Failed to process message: {}", e); } }; } _ => { error!("Invalid message was received"); } } future::ok(()) }); debug!("Polling for messages from: {}", addr); // Poll for messages either being sent to the device (rx) // or messages being received from the device (handle_incoming) tokio::select! { Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; }, Ok(_) = handle_incoming => { debug!("Received message from websocket") }, else => { info!("Connection with {} closed.", addr); ACTIVE_CONNECTIONS.remove("test"); } } } -fn handle_message_from_device( - message: &str, - tx: &tokio::sync::mpsc::UnboundedSender, -) -> Result<(), serde_json::Error> { - match serde_json::from_str::(message)? { - Messages::SessionRequest(session_info) => { - ACTIVE_CONNECTIONS.insert(session_info.device_id, tx.clone()); - } - _ => { - debug!("Received invalid request"); - } - } - - Ok(()) -} - async fn handle_message_from_service( incoming_payload: String, outgoing: &mut SplitSink, Message>, ) { if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await { error!("Failed to send message to device: {}", e); } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs new file mode 100644 index 000000000..53789514f --- /dev/null +++ b/services/tunnelbroker/src/websockets/session.rs @@ -0,0 +1,32 @@ +use tracing::debug; +use tunnelbroker_messages::Messages; + +use crate::ACTIVE_CONNECTIONS; + +pub struct WebsocketSession { + tx: tokio::sync::mpsc::UnboundedSender, +} + +impl WebsocketSession { + pub fn new( + tx: tokio::sync::mpsc::UnboundedSender, + ) -> WebsocketSession { + WebsocketSession { tx } + } + + pub fn handle_message_from_device( + &self, + message: &str, + ) -> Result<(), serde_json::Error> { + match serde_json::from_str::(message)? { + Messages::SessionRequest(session_info) => { + ACTIVE_CONNECTIONS.insert(session_info.device_id, self.tx.clone()); + } + _ => { + debug!("Received invalid request"); + } + } + + Ok(()) + } +}