diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -4,7 +4,7 @@ use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG}; use crate::database::DatabaseClient; use crate::notifs::NotifClient; -use crate::websockets::session::{initialize_amqp, SessionError}; +use crate::websockets::session::SessionError; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; @@ -32,7 +32,7 @@ SplitSink, Message>, ); -use self::session::WebsocketSession; +use self::session::{get_device_info_from_frame, WebsocketSession}; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to @@ -172,14 +172,6 @@ ) { debug!("Incoming connection from: {}", addr); - let amqp_channel = match amqp_connection.new_channel().await { - Ok(channel) => channel, - Err(err) => { - tracing::warn!("Failed to create AMQP channel for {addr}: {err:?}."); - return; - } - }; - let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { @@ -200,7 +192,7 @@ outgoing, first_msg, db_client, - amqp_channel, + amqp_connection, notif_client, ) .await @@ -218,7 +210,7 @@ session } Err((err, outgoing)) => { - debug!("Failed to create session with device"); + debug!("Failed to create session with device: {err:?}"); send_error_init_response(err, outgoing).await; return; } @@ -322,21 +314,14 @@ outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, - amqp_channel: lapin::Channel, + amqp: AmqpConnection, notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { - let initialized_session = - initialize_amqp(db_client.clone(), frame, &amqp_channel).await; + let device_info = match get_device_info_from_frame(frame).await { + Ok(info) => info, + Err(e) => return Err((e, outgoing)), + }; - match initialized_session { - Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( - outgoing, - db_client, - device_info, - amqp_channel, - amqp_consumer, - notif_client, - )), - Err(e) => Err((e, outgoing)), - } + WebsocketSession::new(outgoing, db_client, device_info, amqp, notif_client) + .await } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,3 +1,4 @@ +use crate::amqp::AmqpConnection; use crate::constants::{ error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, @@ -200,11 +201,9 @@ Ok(()) } -pub async fn initialize_amqp( - db_client: DatabaseClient, +pub async fn get_device_info_from_frame( frame: Message, - amqp_channel: &lapin::Channel, -) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { +) -> Result { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? @@ -214,42 +213,66 @@ return Err(SessionError::InvalidMessage); } }; - - let mut args = FieldTable::default(); - args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); - amqp_channel - .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) - .await?; - - publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; - - let amqp_consumer = amqp_channel - .basic_consume( - &device_info.device_id, - RMQ_CONSUMER_TAG, - BasicConsumeOptions::default(), - FieldTable::default(), - ) - .await?; - Ok((device_info, amqp_consumer)) + Ok(device_info) } + impl WebsocketSession { - pub fn new( + pub async fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, - amqp_channel: lapin::Channel, - amqp_consumer: lapin::Consumer, + amqp: AmqpConnection, notif_client: NotifClient, - ) -> Self { - Self { + ) -> Result> { + let (amqp_channel, amqp_consumer) = + match Self::init_amqp(&device_info, &db_client, &amqp).await { + Ok(consumer) => consumer, + Err(err) => return Err((err, tx)), + }; + + Ok(Self { tx, db_client, device_info, amqp_channel, amqp_consumer, notif_client, - } + }) + } + + async fn init_amqp( + device_info: &DeviceInfo, + db_client: &DatabaseClient, + amqp: &AmqpConnection, + ) -> Result<(lapin::Channel, lapin::Consumer), SessionError> { + let amqp_channel = amqp.new_channel().await?; + debug!( + "Got AMQP Channel Id={} for device '{}'", + amqp_channel.id(), + device_info.device_id + ); + + let mut args = FieldTable::default(); + args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); + amqp_channel + .queue_declare( + &device_info.device_id, + QueueDeclareOptions::default(), + args, + ) + .await?; + + publish_persisted_messages(db_client, &amqp_channel, device_info).await?; + + let amqp_consumer = amqp_channel + .basic_consume( + &device_info.device_id, + RMQ_CONSUMER_TAG, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await?; + Ok((amqp_channel, amqp_consumer)) } pub async fn handle_message_to_device(