diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs index 73f457f13..5fac6325e 100644 --- a/services/tunnelbroker/src/main.rs +++ b/services/tunnelbroker/src/main.rs @@ -1,70 +1,73 @@ pub mod amqp; pub mod config; pub mod constants; pub mod database; pub mod error; pub mod grpc; pub mod identity; pub mod notifs; pub mod websockets; use crate::constants::ENV_APNS_CONFIG; use crate::notifs::apns::config::APNsConfig; use crate::notifs::apns::APNsClient; use crate::notifs::NotifClient; use anyhow::{anyhow, Result}; use config::CONFIG; use std::str::FromStr; use tracing::{self, error, info, Level}; use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<()> { let filter = EnvFilter::builder() .with_default_directive(Level::INFO.into()) .with_env_var(constants::LOG_LEVEL_ENV_VAR) .from_env_lossy(); let subscriber = tracing_subscriber::fmt().with_env_filter(filter).finish(); tracing::subscriber::set_global_default(subscriber) .expect("Unable to configure tracing"); config::parse_cmdline_args()?; let aws_config = config::load_aws_config().await; let db_client = database::DatabaseClient::new(&aws_config); let amqp_connection = amqp::connect().await; let apns_config = CONFIG.apns_config.clone(); let apns = match apns_config { Some(config) => match APNsClient::new(&config) { Ok(apns_client) => { info!("APNs client created successfully"); Some(apns_client) } Err(err) => { error!("Error creating APNs client: {}", err); None } }, None => { error!("APNs config is missing"); None } }; let notif_client = NotifClient { apns }; let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection); - let websocket_server = - websockets::run_server(db_client.clone(), &amqp_connection); + let websocket_server = websockets::run_server( + db_client.clone(), + &amqp_connection, + notif_client.clone(), + ); tokio::select! { Ok(_) = grpc_server => { Ok(()) }, Ok(_) = websocket_server => { Ok(()) }, else => { tracing::error!("A grpc or websocket server crashed."); Err(anyhow!("A grpc or websocket server crashed.")) } } } diff --git a/services/tunnelbroker/src/notifs/apns/mod.rs b/services/tunnelbroker/src/notifs/apns/mod.rs index aff54eb29..63adac7de 100644 --- a/services/tunnelbroker/src/notifs/apns/mod.rs +++ b/services/tunnelbroker/src/notifs/apns/mod.rs @@ -1,138 +1,138 @@ use crate::notifs::apns::config::APNsConfig; use crate::notifs::apns::error::Error::ResponseError; use crate::notifs::apns::headers::{NotificationHeaders, PushType}; use crate::notifs::apns::response::ErrorBody; use crate::notifs::apns::token::APNsToken; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::time::Duration; use tracing::debug; pub mod config; pub mod error; -mod headers; +pub(crate) mod headers; mod response; pub mod token; #[derive(Clone)] pub struct APNsClient { http2_client: reqwest::Client, token: APNsToken, is_prod: bool, } #[derive(Serialize, Deserialize)] pub struct APNsNotif { pub device_token: String, pub headers: NotificationHeaders, pub payload: String, } impl APNsClient { pub fn new(config: &APNsConfig) -> Result { let token_ttl = Duration::from_secs(60 * 55); let token = APNsToken::new(config, token_ttl)?; let http2_client = reqwest::Client::builder() .http2_prior_knowledge() .http2_keep_alive_interval(Some(Duration::from_secs(5))) .http2_keep_alive_while_idle(true) .build()?; Ok(APNsClient { http2_client, token, is_prod: config.production, }) } async fn build_headers( &self, notif_headers: NotificationHeaders, ) -> Result { let mut headers = HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); let bearer = self.token.get_bearer().await?; let token = format!("bearer {bearer}"); headers.insert(AUTHORIZATION, HeaderValue::from_str(&token)?); if let Some(apns_topic) = ¬if_headers.apns_topic { headers.insert("apns-topic", HeaderValue::from_str(apns_topic)?); } if let Some(apns_id) = ¬if_headers.apns_id { headers.insert("apns-id", HeaderValue::from_str(apns_id)?); } if let Some(push_type) = ¬if_headers.apns_push_type { let push_type_str = match push_type { PushType::Alert => "alert", PushType::Background => "background", PushType::Location => "location", PushType::Voip => "voip", PushType::Complication => "complication", PushType::FileProvider => "fileprovider", PushType::Mdm => "mdm", PushType::LiveActivity => "live", PushType::PushToTalk => "pushtotalk", }; headers.insert("apns-push-type", HeaderValue::from_static(push_type_str)); } if let Some(expiration) = notif_headers.apns_expiration { headers.insert("apns-expiration", HeaderValue::from(expiration)); } if let Some(priority) = notif_headers.apns_priority { headers.insert("apns-priority", HeaderValue::from(priority)); } if let Some(collapse_id) = ¬if_headers.apns_collapse_id { headers.insert("apns-collapse-id", HeaderValue::from_str(collapse_id)?); } Ok(headers) } fn get_endpoint(&self) -> &'static str { if self.is_prod { return "api.push.apple.com"; } "api.development.push.apple.com" } pub async fn send(&self, notif: APNsNotif) -> Result<(), error::Error> { debug!("Sending notif to {}", notif.device_token); let headers = self.build_headers(notif.headers.clone()).await?; let url = format!( "https://{}/3/device/{}", self.get_endpoint(), notif.device_token ); let response = self .http2_client .post(url) .headers(headers.clone()) .body(notif.payload) .send() .await?; match response.status() { StatusCode::OK => Ok(()), _ => { let error_body: ErrorBody = response.json().await?; Err(ResponseError(error_body)) } } } } diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs index 3a66439c5..c5bf1af8a 100644 --- a/services/tunnelbroker/src/websockets/mod.rs +++ b/services/tunnelbroker/src/websockets/mod.rs @@ -1,314 +1,331 @@ pub mod session; use crate::constants::SOCKET_HEARTBEAT_TIMEOUT; use crate::database::DatabaseClient; +use crate::notifs::NotifClient; use crate::websockets::session::{initialize_amqp, SessionError}; use crate::CONFIG; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use std::env; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tracing::{debug, error, info}; use tunnelbroker_messages::{ ConnectionInitializationStatus, Heartbeat, MessageSentStatus, MessageToDeviceRequestStatus, }; type BoxedError = Box; pub type ErrorWithStreamHandle = ( session::SessionError, SplitSink, Message>, ); use self::session::WebsocketSession; /// Hyper HTTP service that handles incoming HTTP and websocket connections /// It handles the initial websocket upgrade request and spawns a task to /// handle the websocket connection. /// It also handles regular HTTP requests (currently health check) struct WebsocketService { addr: SocketAddr, channel: lapin::Channel, db_client: DatabaseClient, + notif_client: NotifClient, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = BoxedError; type Future = Pin> + Send>>; // This function is called to check if the service is ready to accept // connections. Since we don't have any state to check, we're always ready. fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let db_client = self.db_client.clone(); let channel = self.channel.clone(); + let notif_client = self.notif_client.clone(); let future = async move { // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; // Spawn a task to handle the websocket connection. tokio::spawn(async move { - accept_connection(websocket, addr, db_client, channel).await; + accept_connection(websocket, addr, db_client, channel, notif_client) + .await; }); // Return the response so the spawned future can continue. return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); // A simple router for regular HTTP requests let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server( db_client: DatabaseClient, amqp_connection: &lapin::Connection, + notif_client: NotifClient, ) -> Result<(), BoxedError> { let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR") .unwrap_or_else(|_| format!("0.0.0.0:{}", &CONFIG.http_port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket listening on: {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let channel = amqp_connection .create_channel() .await .expect("Failed to create AMQP channel"); let connection = http .serve_connection( stream, WebsocketService { channel, db_client: db_client.clone(), addr, + notif_client: notif_client.clone(), }, ) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_error_init_response( error: SessionError, mut outgoing: SplitSink, Message>, ) { let error_response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Error(error.to_string()), }; match serde_json::to_string(&error_response) { Ok(serialized_response) => { if let Err(send_error) = outgoing.send(Message::Text(serialized_response)).await { error!("Failed to send init error response: {:?}", send_error); } } Err(ser_error) => { error!("Failed to serialize the error response: {:?}", ser_error); } } } /// Handler for any incoming websocket connections async fn accept_connection( hyper_ws: HyperWebsocket, addr: SocketAddr, db_client: DatabaseClient, amqp_channel: lapin::Channel, + notif_client: NotifClient, ) { debug!("Incoming connection from: {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { info!( "Failed to establish connection with {}. Reason: {}", addr, e ); return; } }; let (outgoing, mut incoming) = ws_stream.split(); // We don't know the identity of the device until it sends the session // request over the websocket connection let mut session = if let Some(Ok(first_msg)) = incoming.next().await { - match initiate_session(outgoing, first_msg, db_client, amqp_channel).await { + match initiate_session( + outgoing, + first_msg, + db_client, + amqp_channel, + notif_client, + ) + .await + { Ok(mut session) => { let response = tunnelbroker_messages::ConnectionInitializationResponse { status: ConnectionInitializationStatus::Success, }; let serialized_response = serde_json::to_string(&response).unwrap(); session .send_message_to_device(Message::Text(serialized_response)) .await; session } Err((err, outgoing)) => { error!("Failed to create session with device"); send_error_init_response(err, outgoing).await; return; } } } else { error!("Failed to create session with device"); send_error_init_response(SessionError::InvalidMessage, outgoing).await; return; }; let mut ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let mut got_heartbeat_response = true; // Poll for messages either being sent to the device (rx) // or messages being received from the device (incoming) loop { debug!("Polling for messages from: {}", addr); tokio::select! { Some(Ok(delivery)) = session.next_amqp_message() => { if let Ok(message) = std::str::from_utf8(&delivery.data) { session.send_message_to_device(Message::Text(message.to_string())).await; } else { error!("Invalid payload"); } }, device_message = incoming.next() => { let message: Message = match device_message { Some(Ok(msg)) => msg, _ => { debug!("Connection to {} closed remotely.", addr); break; } }; match message { Message::Close(_) => { debug!("Connection to {} closed.", addr); break; } Message::Pong(_) => { debug!("Received Pong message from {}", addr); } Message::Ping(msg) => { debug!("Received Ping message from {}", addr); session.send_message_to_device(Message::Pong(msg)).await; } Message::Text(msg) => { got_heartbeat_response = true; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); let Some(message_status) = session.handle_websocket_frame_from_device(msg).await else { continue; }; let request_status = MessageToDeviceRequestStatus { client_message_ids: vec![message_status] }; if let Ok(response) = serde_json::to_string(&request_status) { session.send_message_to_device(Message::text(response)).await; } else { break; } } _ => { error!("Client sent invalid message type"); let confirmation = MessageToDeviceRequestStatus {client_message_ids: vec![MessageSentStatus::InvalidRequest]}; if let Ok(response) = serde_json::to_string(&confirmation) { session.send_message_to_device(Message::text(response)).await; } else { break; } } } }, _ = &mut ping_timeout => { if !got_heartbeat_response { error!("Connection to {} died", addr); break; } let serialized = serde_json::to_string(&Heartbeat {}).unwrap(); session.send_message_to_device(Message::text(serialized)).await; got_heartbeat_response = false; ping_timeout = Box::pin(tokio::time::sleep(SOCKET_HEARTBEAT_TIMEOUT)); } else => { debug!("Unhealthy connection for: {}", addr); break; }, } } info!("Unregistering connection to: {}", addr); session.close().await } async fn initiate_session( outgoing: SplitSink, Message>, frame: Message, db_client: DatabaseClient, amqp_channel: lapin::Channel, + notif_client: NotifClient, ) -> Result, ErrorWithStreamHandle> { let initialized_session = initialize_amqp(db_client.clone(), frame, &amqp_channel).await; match initialized_session { Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new( outgoing, db_client, device_info, amqp_channel, amqp_consumer, + notif_client, )), Err(e) => Err((e, outgoing)), } } diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs index 5983eafb9..1e7f2551f 100644 --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -1,404 +1,470 @@ use crate::constants::{ CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG, }; use comm_lib::aws::ddb::error::SdkError; use comm_lib::aws::ddb::operation::put_item::PutItemError; use derive_more; use futures_util::stream::SplitSink; use futures_util::SinkExt; use futures_util::StreamExt; use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use lapin::message::Delivery; use lapin::options::{ BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions, QueueDeleteOptions, }; use lapin::types::FieldTable; use lapin::BasicProperties; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tracing::{debug, error, info}; use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, Heartbeat, MessageToDevice, MessageToDeviceRequest, MessageToTunnelbroker, Messages, }; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; use crate::identity; +use crate::notifs::apns::headers::NotificationHeaders; +use crate::notifs::apns::APNsNotif; +use crate::notifs::NotifClient; pub struct DeviceInfo { pub device_id: String, pub notify_token: Option, pub device_type: DeviceTypes, pub device_app_version: Option, pub device_os: Option, pub is_authenticated: bool, } pub struct WebsocketSession { tx: SplitSink, Message>, db_client: DatabaseClient, pub device_info: DeviceInfo, amqp_channel: lapin::Channel, // Stream of messages from AMQP endpoint amqp_consumer: lapin::Consumer, + notif_client: NotifClient, } #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum SessionError { InvalidMessage, SerializationError(serde_json::Error), MessageError(database::MessageErrors), AmqpError(lapin::Error), InternalError, UnauthorizedDevice, PersistenceError(SdkError), DatabaseError(comm_lib::database::Error), + MissingAPNsClient, + MissingDeviceToken, } // Parse a session request and retrieve the device information pub async fn handle_first_message_from_device( message: &str, ) -> Result { let serialized_message = serde_json::from_str::(message)?; match serialized_message { Messages::ConnectionInitializationMessage(mut session_info) => { let device_info = DeviceInfo { device_id: session_info.device_id.clone(), notify_token: session_info.notify_token.take(), device_type: session_info.device_type, device_app_version: session_info.device_app_version.take(), device_os: session_info.device_os.take(), is_authenticated: true, }; // Authenticate device debug!("Authenticating device: {}", &session_info.device_id); let auth_request = identity::verify_user_access_token( &session_info.user_id, &device_info.device_id, &session_info.access_token, ) .await; match auth_request { Err(e) => { error!("Failed to complete request to identity service: {:?}", e); return Err(SessionError::InternalError); } Ok(false) => { info!("Device failed authentication: {}", &session_info.device_id); return Err(SessionError::UnauthorizedDevice); } Ok(true) => { debug!( "Successfully authenticated device: {}", &session_info.device_id ); } } Ok(device_info) } Messages::AnonymousInitializationMessage(session_info) => { debug!( "Starting unauthenticated session with device: {}", &session_info.device_id ); let device_info = DeviceInfo { device_id: session_info.device_id, device_type: session_info.device_type, device_app_version: session_info.device_app_version, device_os: session_info.device_os, is_authenticated: false, notify_token: None, }; Ok(device_info) } _ => { debug!("Received invalid request"); Err(SessionError::InvalidMessage) } } } async fn publish_persisted_messages( db_client: &DatabaseClient, amqp_channel: &lapin::Channel, device_info: &DeviceInfo, ) -> Result<(), SessionError> { let messages = db_client .retrieve_messages(&device_info.device_id) .await .unwrap_or_else(|e| { error!("Error while retrieving messages: {}", e); Vec::new() }); for message in messages { let message_to_device = MessageToDevice::from_hashmap(message)?; let serialized_message = serde_json::to_string(&message_to_device)?; amqp_channel .basic_publish( "", &message_to_device.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY), ) .await?; } debug!("Flushed messages for device: {}", &device_info.device_id); Ok(()) } pub async fn initialize_amqp( db_client: DatabaseClient, frame: Message, amqp_channel: &lapin::Channel, ) -> Result<(DeviceInfo, lapin::Consumer), SessionError> { let device_info = match frame { Message::Text(payload) => { handle_first_message_from_device(&payload).await? } _ => { error!("Client sent wrong frame type for establishing connection"); return Err(SessionError::InvalidMessage); } }; let mut args = FieldTable::default(); args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into()); amqp_channel .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args) .await?; publish_persisted_messages(&db_client, amqp_channel, &device_info).await?; let amqp_consumer = amqp_channel .basic_consume( &device_info.device_id, RMQ_CONSUMER_TAG, BasicConsumeOptions::default(), FieldTable::default(), ) .await?; Ok((device_info, amqp_consumer)) } impl WebsocketSession { pub fn new( tx: SplitSink, Message>, db_client: DatabaseClient, device_info: DeviceInfo, amqp_channel: lapin::Channel, amqp_consumer: lapin::Consumer, + notif_client: NotifClient, ) -> Self { Self { tx, db_client, device_info, amqp_channel, amqp_consumer, + notif_client, } } pub async fn handle_message_to_device( &self, message_request: &MessageToDeviceRequest, ) -> Result<(), SessionError> { let message_id = self .db_client .persist_message( &message_request.device_id, &message_request.payload, &message_request.client_message_id, ) .await?; let message_to_device = MessageToDevice { device_id: message_request.device_id.clone(), payload: message_request.payload.clone(), message_id: message_id.clone(), }; let serialized_message = serde_json::to_string(&message_to_device)?; let publish_result = self .amqp_channel .basic_publish( "", &message_request.device_id, BasicPublishOptions::default(), serialized_message.as_bytes(), BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY), ) .await; if let Err(publish_error) = publish_result { self .db_client .delete_message(&self.device_info.device_id, &message_id) .await .expect("Error deleting message"); return Err(SessionError::AmqpError(publish_error)); } Ok(()) } pub async fn handle_message_to_tunnelbroker( &self, message_to_tunnelbroker: &MessageToTunnelbroker, ) -> Result<(), SessionError> { match message_to_tunnelbroker { MessageToTunnelbroker::SetDeviceToken(token) => { self .db_client .set_device_token(&self.device_info.device_id, &token.device_token) .await?; } } Ok(()) } pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { return Some(MessageSentStatus::SerializationError(msg)); }; match serialized_message { Messages::Heartbeat(Heartbeat {}) => { debug!("Received heartbeat from: {}", self.device_info.device_id); None } Messages::MessageReceiveConfirmation(confirmation) => { for message_id in confirmation.message_ids { if let Err(e) = self .db_client .delete_message(&self.device_info.device_id, &message_id) .await { error!("Failed to delete message: {}:", e); } } None } Messages::MessageToDeviceRequest(message_request) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } Messages::MessageToTunnelbrokerRequest(message_request) => { // unauthenticated clients cannot send messages if !self.device_info.is_authenticated { debug!( "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for Tunnelbroker"); let Ok(message_to_tunnelbroker) = serde_json::from_str(&message_request.payload) else { return Some(MessageSentStatus::SerializationError( message_request.payload, )); }; let result = self .handle_message_to_tunnelbroker(&message_to_tunnelbroker) .await; Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } + Messages::APNsNotif(notif) => { + // unauthenticated clients cannot send notifs + if !self.device_info.is_authenticated { + debug!( + "Unauthenticated device {} tried to send text notif. Aborting.", + self.device_info.device_id + ); + return Some(MessageSentStatus::Unauthenticated); + } + debug!("Received APNs notif for {}", notif.device_id); + + let Ok(headers) = + serde_json::from_str::(¬if.headers) + else { + return Some(MessageSentStatus::SerializationError(notif.headers)); + }; + + let device_token = + match self.db_client.get_device_token(¬if.device_id).await { + Ok(db_token) => { + let Some(token) = db_token else { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::MissingDeviceToken), + )); + }; + token + } + Err(e) => { + return Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::DatabaseError(e)), + )); + } + }; + + let apns_notif = APNsNotif { + device_token, + headers, + payload: notif.payload, + }; + + if let Some(apns) = self.notif_client.apns.clone() { + let response = apns.send(apns_notif).await; + return Some( + self + .get_message_to_device_status(¬if.client_message_id, response), + ); + } + + Some(self.get_message_to_device_status( + ¬if.client_message_id, + Err(SessionError::MissingAPNsClient), + )) + } _ => { error!("Client sent invalid message type"); Some(MessageSentStatus::InvalidRequest) } } } pub async fn next_amqp_message( &mut self, ) -> Option> { self.amqp_consumer.next().await } pub async fn send_message_to_device(&mut self, message: Message) { if let Err(e) = self.tx.send(message).await { error!("Failed to send message to device: {}", e); } } // Release WebSocket and remove from active connections pub async fn close(&mut self) { if let Err(e) = self.tx.close().await { debug!("Failed to close WebSocket session: {}", e); } if let Err(e) = self .amqp_channel .basic_cancel( self.amqp_consumer.tag().as_str(), BasicCancelOptions::default(), ) .await { error!("Failed to cancel consumer: {}", e); } if let Err(e) = self .amqp_channel .queue_delete( self.device_info.device_id.as_str(), QueueDeleteOptions::default(), ) .await { error!("Failed to delete queue: {}", e); } } - pub fn get_message_to_device_status( + pub fn get_message_to_device_status( &mut self, client_message_id: &str, - result: Result<(), SessionError>, - ) -> MessageSentStatus { + result: Result<(), E>, + ) -> MessageSentStatus + where + E: std::error::Error, + { match result { Ok(()) => MessageSentStatus::Success(client_message_id.to_string()), Err(err) => MessageSentStatus::Error(Failure { id: client_message_id.to_string(), error: err.to_string(), }), } } }