diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs index e80b50e7e..7251938bb 100644 --- a/services/tunnelbroker/src/constants.rs +++ b/services/tunnelbroker/src/constants.rs @@ -1,60 +1,61 @@ use tokio::time::Duration; pub const GRPC_TX_QUEUE_SIZE: usize = 32; pub const GRPC_SERVER_PORT: u16 = 50051; pub const GRPC_KEEP_ALIVE_PING_INTERVAL: Duration = Duration::from_secs(3); pub const GRPC_KEEP_ALIVE_PING_TIMEOUT: Duration = Duration::from_secs(10); pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3); pub const MAX_RMQ_MSG_PRIORITY: u8 = 10; pub const DDB_RMQ_MSG_PRIORITY: u8 = 10; pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1; pub const RMQ_CONSUMER_TAG: &str = "tunnelbroker"; +pub const WS_SESSION_CLOSE_AMQP_MSG: &str = "SessionClose"; pub const ENV_APNS_CONFIG: &str = "APNS_CONFIG"; pub const ENV_FCM_CONFIG: &str = "FCM_CONFIG"; pub const ENV_WEB_PUSH_CONFIG: &str = "WEB_PUSH_CONFIG"; pub const ENV_WNS_CONFIG: &str = "WNS_CONFIG"; pub const LOG_LEVEL_ENV_VAR: &str = tracing_subscriber::filter::EnvFilter::DEFAULT_ENV; pub const FCM_ACCESS_TOKEN_GENERATION_THRESHOLD: u64 = 5 * 60; pub mod dynamodb { // This table holds messages which could not be immediately delivered to // a device. // // - (primary key) = (deviceID: Partition Key, createdAt: Sort Key) // - deviceID: The public key of a device's olm identity key // - payload: Message to be delivered. See shared/tunnelbroker_messages. // - messageID = [createdAt]#[clientMessageID] // - createdAd: UNIX timestamp of when the item was inserted. // Timestamp is needed to order the messages correctly to the device. // Timestamp format is ISO 8601 to handle lexicographical sorting. // - clientMessageID: Message ID generated on client using UUID Version 4. pub mod undelivered_messages { pub const TABLE_NAME: &str = "tunnelbroker-undelivered-messages"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const PAYLOAD: &str = "payload"; pub const MESSAGE_ID: &str = "messageID"; pub const SORT_KEY: &str = "messageID"; } // This table holds a device token associated with a device. // // - (primary key) = (deviceID: Partition Key) // - deviceID: The public key of a device's olm identity key. // - deviceToken: Token to push services uploaded by device. // - tokenInvalid: Information is token is invalid. pub mod device_tokens { pub const TABLE_NAME: &str = "tunnelbroker-device-tokens"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = "deviceID"; pub const DEVICE_TOKEN: &str = "deviceToken"; pub const TOKEN_INVALID: &str = "tokenInvalid"; pub const PLATFORM: &str = "platform"; pub const DEVICE_TOKEN_INDEX_NAME: &str = "deviceToken-index"; } } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 4f6ec64e7..624618573 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,331 +1,336 @@ pub mod session; -use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; +use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::{ ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, Heartbeat, MessageSentStatus, }; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, notif_client: NotifClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); let notif_client = self.notif_client.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { accept_connection(websocket, addr, db_client, channel, notif_client) .await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, notif_client: notif_client.clone(), }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_error_init_response( error: SessionError, mut outgoing: SplitSink, Message>, ) { let error_response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error(error.to_string()), }; match serde_json::to_string(&error_response) { Ok(serialized_response) => { if let Err(send_error) = outgoing.send(Message::Text(serialized_response)).await { error!("Failed to send init error response: {:?}", send_error); } } Err(ser_error) => { error!("Failed to serialize the error response: {:?}", ser_error); } } } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session( outgoing, first_msg, db_client, amqp_channel, notif_client, ) .await { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&response).unwrap(); session .send_message_to_device(Message::Text(serialized_response)) .await; session } Err((err, outgoing)) => { error!("Failed to create session with device"); send_error_init_response(err, outgoing).await; return; } } } else { error!("Failed to create session with device"); send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { trace!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { - session.send_message_to_device(Message::Text(message.to_string())).await; + if message == WS_SESSION_CLOSE_AMQP_MSG { + debug!("Connection to {} closed by server.", addr); + break; + } else { + session.send_message_to_device(Message::Text(message.to_string())).await; + } } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; let request_status = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = DeviceToTunnelbrokerRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); session.send_message_to_device(Message::text(serialized)).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, notif_client, )), Err(e) => Err((e, outgoing)), } }