diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs index 29bb481ef..d32a5b51b 100644 --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -1,146 +1,148 @@ mod proto { tonic::include_proto!("tunnelbroker"); } use lapin::{options::BasicPublishOptions, BasicProperties}; use proto::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; use proto::Empty; use tonic::transport::Server; use tracing::debug; use tunnelbroker_messages::MessageToDevice; +use crate::amqp::AmqpConnection; use crate::constants::{CLIENT_RMQ_MSG_PRIORITY, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp: AmqpConnection, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { match error { lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => { tonic::Status::invalid_argument("Invalid argument") } _ => tonic::Status::internal("Internal Error"), } } #[tonic::async_trait] impl TunnelbrokerService for TunnelbrokerGRPC { async fn send_message_to_device( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Received message for {}", &message.device_id); let client_message_id = uuid::Uuid::new_v4().to_string(); let message_id = self .client .persist_message(&message.device_id, &message.payload, &client_message_id) .await .map_err(handle_ddb_error)?; let message_to_device = MessageToDevice { device_id: message.device_id.clone(), payload: message.payload, message_id, }; let serialized_message = serde_json::to_string(&message_to_device) .map_err(|_| tonic::Status::invalid_argument("Invalid argument"))?; self - .amqp_channel + .amqp + .new_channel() + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await .map_err(handle_amqp_error)?; let response = tonic::Response::new(Empty {}); Ok(response) } async fn force_close_device_connection( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Connection close request for device {}", &message.device_id); self - .amqp_channel + .amqp + .new_channel() + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, BasicPublishOptions::default(), WS_SESSION_CLOSE_AMQP_MSG.as_bytes(), BasicProperties::default() // Connection close request should have higher priority .with_priority(CLIENT_RMQ_MSG_PRIORITY + 1) // The message should expire quickly. If the device isn't connected // (there's no consumer), there's no point in keeping this message. .with_expiration("1000".into()), ) .await .map_err(handle_amqp_error)?; let response = tonic::Response::new(Empty {}); Ok(response) } async fn delete_device_data( &self, request: tonic::Request, ) -> Result, tonic::Status> { let message = request.into_inner(); debug!("Deleting {} data", &message.device_id); self .client .remove_device_token(&message.device_id) .await .map_err(|_| tonic::Status::failed_precondition("unexpected error"))?; let response = tonic::Response::new(Empty {}); Ok(response) } } pub async fn run_server( client: DatabaseClient, - ampq_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); - let amqp_channel = ampq_connection - .create_channel() - .await - .expect("Unable to create amqp channel"); - tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, - amqp_channel, + amqp: amqp_connection.clone(), })) .serve(addr) .await } diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs index f8a1e5689..7022c9469 100644 --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -1,166 +1,168 @@ pub mod amqp; pub mod config; pub mod constants; pub mod database; pub mod error; pub mod grpc; pub mod identity; pub mod notifs; pub mod websockets; use crate::notifs::apns::APNsClient; use crate::notifs::fcm::FCMClient; use crate::notifs::web_push::WebPushClient; use crate::notifs::wns::WNSClient; use crate::notifs::NotifClient; use anyhow::{anyhow, Result}; use config::CONFIG; use constants::{error_types, COMM_SERVICES_USE_JSON_LOGS}; use std::env; use tracing::{self, error, info, Level}; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<()> { let use_json_logs: bool = env::var(COMM_SERVICES_USE_JSON_LOGS) .unwrap_or("false".to_string()) .parse() .unwrap_or_default(); let filter = EnvFilter::builder() .with_default_directive(Level::INFO.into()) .with_env_var(constants::LOG_LEVEL_ENV_VAR) .from_env_lossy(); if use_json_logs { let subscriber = tracing_subscriber::fmt() .json() .with_env_filter(filter) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Unable to configure tracing"); } else { let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish(); tracing::subscriber::set_global_default(subscriber) .expect("Unable to configure tracing"); } config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; let db_client = database::DatabaseClient::new(&aws_config); - let amqp_connection = amqp::connect().await; + let amqp_connection = amqp::AmqpConnection::connect() + .await + .expect("Failed to create AMQP connection"); let apns_config = CONFIG.apns_config.clone(); let apns = match apns_config { Some(config) => match APNsClient::new(&config) { Ok(apns_client) => { info!("APNs client created successfully"); Some(apns_client) } Err(err) => { error!( errorType = error_types::APNS_ERROR, "Error creating APNs client: {}", err ); None } }, None => { error!( errorType = error_types::APNS_ERROR, "APNs config is missing" ); None } }; let fcm_config = CONFIG.fcm_config.clone(); let fcm = match fcm_config { Some(config) => match FCMClient::new(&config) { Ok(fcm_client) => { info!("FCM client created successfully"); Some(fcm_client) } Err(err) => { error!( errorType = error_types::FCM_ERROR, "Error creating FCM client: {}", err ); None } }, None => { error!(errorType = error_types::FCM_ERROR, "FCM config is missing"); None } }; let web_push_config = CONFIG.web_push_config.clone(); let web_push = match web_push_config { Some(config) => match WebPushClient::new(&config) { Ok(web_client) => { info!("Web Push client created successfully"); Some(web_client) } Err(err) => { error!( errorType = error_types::WEB_PUSH_ERROR, "Error creating Web Push client: {}", err ); None } }, None => { error!( errorType = error_types::WEB_PUSH_ERROR, "Web Push config is missing" ); None } }; let wns_config = CONFIG.wns_config.clone(); let wns = match wns_config { Some(config) => match WNSClient::new(&config) { Ok(wns_client) => { info!("WNS client created successfully"); Some(wns_client) } Err(err) => { error!( errorType = error_types::WNS_ERROR, "Error creating WNS client: {}", err ); None } }, None => { error!(errorType = error_types::WNS_ERROR, "WNS config is missing"); None } }; let notif_client = NotifClient { apns, fcm, web_push, wns, }; let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); let websocket_server = websockets::run_server( db_client.clone(), &amqp_connection, notif_client.clone(), ); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, Ok(_) = websocket_server => { Ok(()) }, else => { tracing::error!(errorType = error_types::SERVER_ERROR, "A grpc or websocket server crashed."); Err(anyhow!("A grpc or websocket server crashed.")) } } } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 3b65c1b9e..c5bad2b91 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,336 +1,342 @@ pub mod session; +use crate::amqp::AmqpConnection; use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info, trace}; use tunnelbroker_messages::{ ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, Heartbeat, MessageSentStatus, }; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, - channel: lapin::Channel, + amqp: AmqpConnection, db_client: DatabaseClient, notif_client: NotifClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); - let channel = self.channel.clone(); + let amqp = self.amqp.clone(); let notif_client = self.notif_client.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { - accept_connection(websocket, addr, db_client, channel, notif_client) + accept_connection(websocket, addr, db_client, amqp, notif_client) .await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, - amqp_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { - let channel = amqp_connection - .create_channel() - .await - .expect("Failed to create AMQP channel"); + let amqp = amqp_connection.clone(); let connection = http .serve_connection( stream, WebsocketService { - channel, + amqp, db_client: db_client.clone(), addr, notif_client: notif_client.clone(), }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_error_init_response( error: SessionError, mut outgoing: SplitSink, Message>, ) { let error_response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error(error.to_string()), }; match serde_json::to_string(&error_response) { Ok(serialized_response) => { if let Err(send_error) = outgoing.send(Message::Text(serialized_response)).await { debug!("Failed to send init error response: {:?}", send_error); } } Err(ser_error) => { error!("Failed to serialize the error response: {:?}", ser_error); } } } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp_connection: AmqpConnection, notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); + let amqp_channel = match amqp_connection.new_channel().await { + Ok(channel) => channel, + Err(err) => { + tracing::warn!("Failed to create AMQP channel for {addr}: {err:?}."); + return; + } + }; + let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { match initiate_session( outgoing, first_msg, db_client, amqp_channel, notif_client, ) .await { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&response).unwrap(); session .send_message_to_device(Message::Text(serialized_response)) .await; session } Err((err, outgoing)) => { debug!("Failed to create session with device"); send_error_init_response(err, outgoing).await; return; } } } else { debug!("Failed to create session with device"); send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { trace!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { if message == WS_SESSION_CLOSE_AMQP_MSG { debug!("Connection to {} closed by server.", addr); break; } else { session.send_message_to_device(Message::Text(message.to_string())).await; } } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; let request_status = DeviceToTunnelbrokerRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = DeviceToTunnelbrokerRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); session.send_message_to_device(Message::text(serialized)).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, notif_client, )), Err(e) => Err((e, outgoing)), } }