Changeset View
Changeset View
Standalone View
Standalone View
services/tunnelbroker/src/websockets/mod.rs
use futures::future; | use futures::future; | ||||
use futures_util::stream::SplitSink; | |||||
use futures_util::SinkExt; | |||||
use futures_util::{StreamExt, TryStreamExt}; | use futures_util::{StreamExt, TryStreamExt}; | ||||
use std::net::SocketAddr; | use std::net::SocketAddr; | ||||
use std::{env, io::Error}; | use std::{env, io::Error}; | ||||
use tokio::net::{TcpListener, TcpStream}; | use tokio::net::{TcpListener, TcpStream}; | ||||
use tokio::sync::mpsc; | use tokio::sync::mpsc; | ||||
use tokio_tungstenite::tungstenite::Message; | use tokio_tungstenite::tungstenite::Message; | ||||
use tokio_tungstenite::WebSocketStream; | |||||
use tracing::{debug, error, info}; | use tracing::{debug, error, info}; | ||||
use tunnelbroker_messages::messages::Messages; | use tunnelbroker_messages::messages::Messages; | ||||
use crate::ACTIVE_CONNECTIONS; | use crate::ACTIVE_CONNECTIONS; | ||||
pub async fn run_server() -> Result<(), Error> { | pub async fn run_server() -> Result<(), Error> { | ||||
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") | let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") | ||||
.unwrap_or_else(|_| "127.0.0.1:51001".to_string()); | .unwrap_or_else(|_| "127.0.0.1:51001".to_string()); | ||||
Show All 18 Lines | Err(e) => { | ||||
info!( | info!( | ||||
"Failed to establish connection with {}. Reason: {}", | "Failed to establish connection with {}. Reason: {}", | ||||
addr, e | addr, e | ||||
); | ); | ||||
return; | return; | ||||
} | } | ||||
}; | }; | ||||
let (_outgoing, incoming) = ws_stream.split(); | let (mut outgoing, incoming) = ws_stream.split(); | ||||
// Create channel for messages to be passed to this connection | |||||
let (tx, mut rx) = mpsc::unbounded_channel::<String>(); | |||||
let handle_incoming = incoming.try_for_each(|msg| { | let handle_incoming = incoming.try_for_each(|msg| { | ||||
debug!("Received message from {}", addr); | debug!("Received message from {}", addr); | ||||
match msg { | match msg { | ||||
Message::Text(text) => { | Message::Text(text) => { | ||||
match handle_message(&text) { | match handle_message_from_device(&text, &tx) { | ||||
Ok(_) => { | Ok(_) => { | ||||
debug!("Successfully handled message: {}", text) | debug!("Successfully handled message: {}", text) | ||||
} | } | ||||
Err(e) => { | Err(e) => { | ||||
error!("Failed to process message: {}", e); | error!("Failed to process message: {}", e); | ||||
} | } | ||||
}; | }; | ||||
} | } | ||||
_ => { | _ => { | ||||
error!("Invalid message was received"); | error!("Invalid message was received"); | ||||
} | } | ||||
} | } | ||||
future::ok(()) | future::ok(()) | ||||
}); | }); | ||||
// Create channel for messages to be passed to this connection | |||||
let (tx, mut rx) = mpsc::unbounded_channel::<Messages>(); | |||||
// TODO: Use device's public key, once we support the SessionRequest message | |||||
ACTIVE_CONNECTIONS.insert("test".to_string(), tx.clone()); | |||||
debug!("Polling for messages from: {}", addr); | debug!("Polling for messages from: {}", addr); | ||||
// Poll for messages either being sent to the device (rx) | |||||
// or messages being received from the device (handle_incoming) | |||||
tokio::select! { | tokio::select! { | ||||
Some(_) = rx.recv() => { debug!("Received message from channel") }, | Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; }, | ||||
Ok(_) = handle_incoming => { debug!("Received message from websocket") }, | Ok(_) = handle_incoming => { debug!("Received message from websocket") }, | ||||
else => { | else => { | ||||
info!("Connection with {} closed.", addr); | info!("Connection with {} closed.", addr); | ||||
ACTIVE_CONNECTIONS.remove("test"); | ACTIVE_CONNECTIONS.remove("test"); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
fn handle_message(message: &str) -> Result<(), serde_json::Error> { | fn handle_message_from_device( | ||||
serde_json::from_str::<Messages>(message)?; | message: &str, | ||||
tx: &tokio::sync::mpsc::UnboundedSender<std::string::String>, | |||||
) -> Result<(), serde_json::Error> { | |||||
match serde_json::from_str::<Messages>(message)? { | |||||
Messages::SessionRequest(session_info) => { | |||||
ACTIVE_CONNECTIONS.insert(session_info.device_id, tx.clone()); | |||||
} | |||||
_ => { | |||||
debug!("Received invalid request"); | |||||
} | |||||
} | |||||
Ok(()) | Ok(()) | ||||
} | } | ||||
async fn handle_message_from_service( | |||||
incoming_payload: String, | |||||
outgoing: &mut SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>, | |||||
) { | |||||
if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await { | |||||
error!("Failed to send message to device: {}", e); | |||||
} | |||||
} |