diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs --- a/services/tunnelbroker/src/websockets/session.rs +++ b/services/tunnelbroker/src/websockets/session.rs @@ -22,7 +22,8 @@ use tunnelbroker_messages::{ message_to_device_request_status::Failure, message_to_device_request_status::MessageSentStatus, session::DeviceTypes, - Heartbeat, MessageToDevice, MessageToDeviceRequest, Messages, + Heartbeat, MessageToDevice, MessageToDeviceRequest, MessageToTunnelbroker, + Messages, }; use crate::database::{self, DatabaseClient, MessageToDeviceExt}; @@ -57,6 +58,7 @@ InternalError, UnauthorizedDevice, PersistenceError(SdkError), + DatabaseError(comm_lib::database::Error), } // Parse a session request and retrieve the device information @@ -252,12 +254,28 @@ Ok(()) } + pub async fn handle_message_to_tunnelbroker( + &self, + message_to_tunnelbroker: &MessageToTunnelbroker, + ) -> Result<(), SessionError> { + match message_to_tunnelbroker { + MessageToTunnelbroker::SetDeviceToken(token) => { + self + .db_client + .set_device_token(&self.device_info.device_id, &token.device_token) + .await?; + } + } + + Ok(()) + } + pub async fn handle_websocket_frame_from_device( &mut self, msg: String, ) -> Option { let Ok(serialized_message) = serde_json::from_str::(&msg) else { - return Option::from(MessageSentStatus::SerializationError(msg)); + return Some(MessageSentStatus::SerializationError(msg)); }; match serialized_message { @@ -285,19 +303,46 @@ "Unauthenticated device {} tried to send text message. Aborting.", self.device_info.device_id ); - return Option::from(MessageSentStatus::Unauthenticated); + return Some(MessageSentStatus::Unauthenticated); } debug!("Received message for {}", message_request.device_id); let result = self.handle_message_to_device(&message_request).await; - Option::from(self.get_message_to_device_status( + Some(self.get_message_to_device_status( + &message_request.client_message_id, + result, + )) + } + Messages::MessageToTunnelbrokerRequest(message_request) => { + // unauthenticated clients cannot send messages + if !self.device_info.is_authenticated { + debug!( + "Unauthenticated device {} tried to send text message. Aborting.", + self.device_info.device_id + ); + return Some(MessageSentStatus::Unauthenticated); + } + debug!("Received message for Tunnelbroker"); + + let Ok(message_to_tunnelbroker) = + serde_json::from_str(&message_request.payload) + else { + return Some(MessageSentStatus::SerializationError( + message_request.payload, + )); + }; + + let result = self + .handle_message_to_tunnelbroker(&message_to_tunnelbroker) + .await; + Some(self.get_message_to_device_status( &message_request.client_message_id, result, )) } _ => { error!("Client sent invalid message type"); - Option::from(MessageSentStatus::InvalidRequest) + Some(MessageSentStatus::InvalidRequest) } } }