Changeset View
Changeset View
Standalone View
Standalone View
services/tunnelbroker/src/websockets/session.rs
use tracing::debug; | use derive_more; | ||||
use tunnelbroker_messages::Messages; | use futures_util::stream::SplitSink; | ||||
use futures_util::SinkExt; | |||||
use tokio::{net::TcpStream, sync::mpsc::UnboundedSender}; | |||||
use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; | |||||
use tracing::{debug, error}; | |||||
use tunnelbroker_messages::{session::DeviceTypes, Messages}; | |||||
use crate::{ | use crate::{ | ||||
constants::dynamodb::undelivered_messages::CREATED_AT, | constants::dynamodb::undelivered_messages::CREATED_AT, | ||||
database::DatabaseClient, ACTIVE_CONNECTIONS, | database::DatabaseClient, ACTIVE_CONNECTIONS, | ||||
}; | }; | ||||
pub struct DeviceInfo { | |||||
pub device_id: String, | |||||
pub notify_token: Option<String>, | |||||
pub device_type: DeviceTypes, | |||||
pub device_app_version: Option<String>, | |||||
pub device_os: Option<String>, | |||||
} | |||||
pub struct WebsocketSession { | pub struct WebsocketSession { | ||||
tx: tokio::sync::mpsc::UnboundedSender<std::string::String>, | tx: SplitSink<WebSocketStream<TcpStream>, Message>, | ||||
db_client: DatabaseClient, | db_client: DatabaseClient, | ||||
device_info: Option<DeviceInfo>, | |||||
} | |||||
#[derive(Debug, derive_more::Display, derive_more::From)] | |||||
pub enum SessionError { | |||||
InvalidMessage, | |||||
SerializationError(serde_json::Error), | |||||
} | } | ||||
fn consume_error<T>(result: Result<T, SessionError>) { | |||||
if let Err(e) = result { | |||||
error!("{}", e) | |||||
} | |||||
} | |||||
impl WebsocketSession { | impl WebsocketSession { | ||||
pub fn new( | pub fn new( | ||||
tx: tokio::sync::mpsc::UnboundedSender<std::string::String>, | tx: SplitSink<WebSocketStream<TcpStream>, Message>, | ||||
db_client: DatabaseClient, | db_client: DatabaseClient, | ||||
) -> WebsocketSession { | ) -> WebsocketSession { | ||||
WebsocketSession { tx, db_client } | WebsocketSession { | ||||
tx, | |||||
db_client, | |||||
device_info: None, | |||||
} | |||||
} | |||||
pub async fn handle_websocket_frame_from_device( | |||||
&mut self, | |||||
frame: Message, | |||||
tx: UnboundedSender<String>, | |||||
) { | |||||
debug!("Received message from device: {}", frame); | |||||
let result = match frame { | |||||
Message::Text(payload) => { | |||||
self.handle_message_from_device(&payload, tx).await | |||||
} | |||||
Message::Close(_) => { | |||||
self.close().await; | |||||
Ok(()) | |||||
} | |||||
_ => Err(SessionError::InvalidMessage), | |||||
}; | |||||
consume_error(result); | |||||
} | } | ||||
pub async fn handle_message_from_device( | pub async fn handle_message_from_device( | ||||
&self, | &mut self, | ||||
message: &str, | message: &str, | ||||
) -> Result<(), serde_json::Error> { | tx: UnboundedSender<String>, | ||||
match serde_json::from_str::<Messages>(message)? { | ) -> Result<(), SessionError> { | ||||
Messages::SessionRequest(session_info) => { | let serialized_message = serde_json::from_str::<Messages>(message)?; | ||||
match serialized_message { | |||||
Messages::SessionRequest(mut session_info) => { | |||||
// TODO: Authenticate device using auth token | // TODO: Authenticate device using auth token | ||||
// Check if session request was already sent | |||||
if self.device_info.is_some() { | |||||
return Err(SessionError::InvalidMessage); | |||||
} | |||||
let device_info = DeviceInfo { | |||||
device_id: session_info.device_id.clone(), | |||||
notify_token: session_info.notify_token.take(), | |||||
device_type: session_info.device_type, | |||||
device_app_version: session_info.device_app_version.take(), | |||||
device_os: session_info.device_os.take(), | |||||
}; | |||||
// Check for persisted messages | // Check for persisted messages | ||||
let messages = self | let messages = self | ||||
.db_client | .db_client | ||||
.retrieve_messages(&session_info.device_id) | .retrieve_messages(&device_info.device_id) | ||||
.await | .await | ||||
.expect("Failed to retreive messages"); | .unwrap_or_else(|e| { | ||||
error!("Error while retrieving messages: {}", e); | |||||
Vec::new() | |||||
}); | |||||
ACTIVE_CONNECTIONS | ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone()); | ||||
.insert(session_info.device_id.clone(), self.tx.clone()); | |||||
for message in messages { | for message in messages { | ||||
let payload = | let payload = | ||||
message.get("payload").unwrap().as_s().unwrap().to_string(); | message.get("payload").unwrap().as_s().unwrap().to_string(); | ||||
self | |||||
.tx | |||||
.send(payload) | |||||
.expect("Failed to send message to client"); | |||||
let created_at = | let created_at = | ||||
message.get(CREATED_AT).unwrap().as_n().unwrap().to_string(); | message.get(CREATED_AT).unwrap().as_n().unwrap().to_string(); | ||||
self.send_message_to_device(payload).await; | |||||
self | self | ||||
.db_client | .db_client | ||||
.delete_message(&session_info.device_id, &created_at) | .delete_message(&device_info.device_id, &created_at) | ||||
.await | .await | ||||
.expect("Failed to delete messages"); | .expect("Failed to delete messages"); | ||||
} | } | ||||
debug!("Flushed messages for device: {}", &session_info.device_id); | debug!("Flushed messages for device: {}", &session_info.device_id); | ||||
self.device_info = Some(device_info); | |||||
} | } | ||||
_ => { | _ => { | ||||
debug!("Received invalid request"); | debug!("Received invalid request"); | ||||
} | } | ||||
} | } | ||||
Ok(()) | Ok(()) | ||||
} | } | ||||
pub async fn send_message_to_device(&mut self, incoming_payload: String) { | |||||
if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await { | |||||
error!("Failed to send message to device: {}", e); | |||||
} | |||||
} | |||||
// Release websocket and remove from active connections | |||||
pub async fn close(&mut self) { | |||||
if let Some(device_info) = &self.device_info { | |||||
ACTIVE_CONNECTIONS.remove(&device_info.device_id); | |||||
} | |||||
if let Err(e) = self.tx.close().await { | |||||
debug!("Failed to close session: {}", e); | |||||
} | |||||
} | |||||
} | } |