diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs --- a/services/tunnelbroker/src/grpc/mod.rs +++ b/services/tunnelbroker/src/grpc/mod.rs @@ -11,13 +11,14 @@ use tracing::debug; use tunnelbroker_messages::MessageToDevice; +use crate::amqp::AmqpConnection; use crate::constants::{CLIENT_RMQ_MSG_PRIORITY, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::{handle_ddb_error, DatabaseClient}; use crate::{constants, CONFIG}; struct TunnelbrokerGRPC { client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp: AmqpConnection, } pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status { @@ -57,7 +58,10 @@ .map_err(|_| tonic::Status::invalid_argument("Invalid argument"))?; self - .amqp_channel + .amqp + .channel("grpc") + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, @@ -81,7 +85,10 @@ debug!("Connection close request for device {}", &message.device_id); self - .amqp_channel + .amqp + .channel("grpc") + .await + .map_err(handle_amqp_error)? .basic_publish( "", &message.device_id, @@ -122,24 +129,19 @@ pub async fn run_server( client: DatabaseClient, - ampq_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, ) -> Result<(), tonic::transport::Error> { let addr = format!("[::]:{}", CONFIG.grpc_port) .parse() .expect("Unable to parse gRPC address"); - let amqp_channel = ampq_connection - .create_channel() - .await - .expect("Unable to create amqp channel"); - tracing::info!("gRPC server listening on {}", &addr); Server::builder() .http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL)) .http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT)) .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client, - amqp_channel, + amqp: amqp_connection.clone(), })) .serve(addr) .await diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -48,7 +48,9 @@ config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; let db_client = database::DatabaseClient::new(&aws_config); - let amqp_connection = amqp::connect().await; + let amqp_connection = amqp::AmqpConnection::connect() + .await + .expect("Failed to create AMQP connection"); let apns_config = CONFIG.apns_config.clone(); diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,5 +1,6 @@ pub mod session; +use crate::amqp::AmqpConnection; use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; @@ -39,7 +40,7 @@ /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, - channel: lapin::Channel, + amqp: AmqpConnection, db_client: DatabaseClient, notif_client: NotifClient, } @@ -62,7 +63,7 @@ fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); - let channel = self.channel.clone(); + let amqp = self.amqp.clone(); let notif_client = self.notif_client.clone(); let future = async move { @@ -72,7 +73,7 @@ // Spawn a task to handle the websocket connection. tokio::spawn(async move { - accept_connection(websocket, addr, db_client, channel, notif_client) + accept_connection(websocket, addr, db_client, amqp, notif_client) .await; }); @@ -101,7 +102,7 @@ pub async fn run_server( db_client: DatabaseClient, - amqp_connection: &lapin::Connection, + amqp_connection: &AmqpConnection, notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") @@ -115,15 +116,12 @@ http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { - let channel = amqp_connection - .create_channel() - .await - .expect("Failed to create AMQP channel"); + let amqp = amqp_connection.clone(); let connection = http .serve_connection( stream, WebsocketService { - channel, + amqp, db_client: db_client.clone(), addr, notif_client: notif_client.clone(), @@ -169,11 +167,19 @@ hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp_connection: AmqpConnection, notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); + let amqp_channel = match amqp_connection.channel(addr).await { + Ok(channel) => channel, + Err(err) => { + tracing::warn!("Failed to create AMQP channel for {addr}: {err:?}."); + return; + } + }; + let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => {