Page MenuHomePhorge

D7800.1767436855.diff
No OneTemporary

Size
10 KB
Referenced Files
None
Subscribers
None

D7800.1767436855.diff

diff --git a/services/commtest/tests/tunnelbroker_integration_test.rs b/services/commtest/tests/tunnelbroker_integration_test.rs
--- a/services/commtest/tests/tunnelbroker_integration_test.rs
+++ b/services/commtest/tests/tunnelbroker_integration_test.rs
@@ -110,11 +110,11 @@
.expect("Failed to send message");
// Have keyserver receive any websocket messages
- let response = socket.next().await.unwrap().unwrap();
-
- // Check that message received by keyserver matches what identity server
- // issued
- let serialized_response: RefreshKeyRequest =
- serde_json::from_str(&response.to_text().unwrap()).unwrap();
- assert_eq!(serialized_response, refresh_request);
+ if let Some(Ok(response)) = socket.next().await {
+ // Check that message received by keyserver matches what identity server
+ // issued
+ let serialized_response: RefreshKeyRequest =
+ serde_json::from_str(&response.to_text().unwrap()).unwrap();
+ assert_eq!(serialized_response, refresh_request);
+ };
}
diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock
--- a/services/tunnelbroker/Cargo.lock
+++ b/services/tunnelbroker/Cargo.lock
@@ -664,6 +664,12 @@
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
+[[package]]
+name = "convert_case"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
+
[[package]]
name = "core-foundation"
version = "0.9.3"
@@ -756,6 +762,19 @@
"parking_lot_core",
]
+[[package]]
+name = "derive_more"
+version = "0.99.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321"
+dependencies = [
+ "convert_case",
+ "proc-macro2",
+ "quote",
+ "rustc_version",
+ "syn 1.0.109",
+]
+
[[package]]
name = "digest"
version = "0.10.6"
@@ -2491,6 +2510,7 @@
"base64 0.20.0",
"clap",
"dashmap",
+ "derive_more",
"env_logger",
"fcm",
"futures",
diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml
--- a/services/tunnelbroker/Cargo.toml
+++ b/services/tunnelbroker/Cargo.toml
@@ -34,6 +34,7 @@
tracing = "0.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
tunnelbroker_messages = { path = "../../shared/tunnelbroker_messages" }
+derive_more = "0.99.17"
[build-dependencies]
tonic-build = "0.8"
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -2,18 +2,12 @@
use crate::database::DatabaseClient;
use crate::CONFIG;
-use futures_util::stream::SplitSink;
-use futures_util::SinkExt;
-use futures_util::{StreamExt, TryStreamExt};
+use futures_util::StreamExt;
use std::net::SocketAddr;
use std::{env, io::Error};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
-use tokio_tungstenite::tungstenite::Message;
-use tokio_tungstenite::WebSocketStream;
-use tracing::{debug, error, info};
-
-use crate::ACTIVE_CONNECTIONS;
+use tracing::{debug, info};
pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> {
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR")
@@ -48,49 +42,34 @@
}
};
- let (mut outgoing, incoming) = ws_stream.split();
+ let (outgoing, mut incoming) = ws_stream.split();
// Create channel for messages to be passed to this connection
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
- let session = session::WebsocketSession::new(tx.clone(), db_client.clone());
- let handle_incoming = incoming.try_for_each(|msg| async {
- debug!("Received message from {}", addr);
- match msg {
- Message::Text(text) => {
- match session.handle_message_from_device(&text).await {
- Ok(_) => {
- debug!("Successfully handled message: {}", text)
- }
- Err(e) => {
- error!("Failed to process message: {}", e);
- }
- };
- }
- _ => {
- error!("Invalid message was received");
- }
- }
- Ok(())
- });
+ let mut session = session::WebsocketSession::new(outgoing, db_client.clone());
- debug!("Polling for messages from: {}", addr);
// Poll for messages either being sent to the device (rx)
- // or messages being received from the device (handle_incoming)
- tokio::select! {
- Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; },
- Ok(_) = handle_incoming => { debug!("Received message from websocket") },
- else => {
- info!("Connection with {} closed.", addr);
- ACTIVE_CONNECTIONS.remove("test");
+ // or messages being received from the device (incoming)
+ loop {
+ debug!("Polling for messages from: {}", addr);
+ tokio::select! {
+ Some(message) = rx.recv() => { session.send_message_to_device(message).await; },
+ device_message = incoming.next() => {
+ match device_message {
+ Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await,
+ _ => {
+ debug!("Connection to {} closed remotely.", addr);
+ break;
+ }
+ }
+ },
+ else => {
+ debug!("Unhealthy connection for: {}", addr);
+ break;
+ },
}
}
-}
-async fn handle_message_from_service(
- incoming_payload: String,
- outgoing: &mut SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
-) {
- if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await {
- error!("Failed to send message to device: {}", e);
- }
+ info!("Unregistering connection to: {}", addr);
+ session.close().await
}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,58 +1,124 @@
-use tracing::debug;
-use tunnelbroker_messages::Messages;
+use derive_more;
+use futures_util::stream::SplitSink;
+use futures_util::SinkExt;
+use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
+use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
+use tracing::{debug, error};
+use tunnelbroker_messages::{session::DeviceTypes, Messages};
use crate::{
constants::dynamodb::undelivered_messages::CREATED_AT,
database::DatabaseClient, ACTIVE_CONNECTIONS,
};
+pub struct DeviceInfo {
+ pub device_id: String,
+ pub notify_token: Option<String>,
+ pub device_type: DeviceTypes,
+ pub device_app_version: Option<String>,
+ pub device_os: Option<String>,
+}
+
pub struct WebsocketSession {
- tx: tokio::sync::mpsc::UnboundedSender<std::string::String>,
+ tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
+ device_info: Option<DeviceInfo>,
+}
+
+#[derive(Debug, derive_more::Display, derive_more::From)]
+pub enum SessionError {
+ InvalidMessage,
+ SerializationError(serde_json::Error),
}
+fn consume_error<T>(result: Result<T, SessionError>) {
+ if let Err(e) = result {
+ error!("{}", e)
+ }
+}
impl WebsocketSession {
pub fn new(
- tx: tokio::sync::mpsc::UnboundedSender<std::string::String>,
+ tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
) -> WebsocketSession {
- WebsocketSession { tx, db_client }
+ WebsocketSession {
+ tx,
+ db_client,
+ device_info: None,
+ }
+ }
+
+ pub async fn handle_websocket_frame_from_device(
+ &mut self,
+ frame: Message,
+ tx: UnboundedSender<String>,
+ ) {
+ debug!("Received message from device: {}", frame);
+ let result = match frame {
+ Message::Text(payload) => {
+ self.handle_message_from_device(&payload, tx).await
+ }
+ Message::Close(_) => {
+ self.close().await;
+ Ok(())
+ }
+ _ => Err(SessionError::InvalidMessage),
+ };
+ consume_error(result);
}
pub async fn handle_message_from_device(
- &self,
+ &mut self,
message: &str,
- ) -> Result<(), serde_json::Error> {
- match serde_json::from_str::<Messages>(message)? {
- Messages::SessionRequest(session_info) => {
+ tx: UnboundedSender<String>,
+ ) -> Result<(), SessionError> {
+ let serialized_message = serde_json::from_str::<Messages>(message)?;
+
+ match serialized_message {
+ Messages::SessionRequest(mut session_info) => {
// TODO: Authenticate device using auth token
+
+ // Check if session request was already sent
+ if self.device_info.is_some() {
+ return Err(SessionError::InvalidMessage);
+ }
+
+ let device_info = DeviceInfo {
+ device_id: session_info.device_id.clone(),
+ notify_token: session_info.notify_token.take(),
+ device_type: session_info.device_type,
+ device_app_version: session_info.device_app_version.take(),
+ device_os: session_info.device_os.take(),
+ };
+
// Check for persisted messages
let messages = self
.db_client
- .retrieve_messages(&session_info.device_id)
+ .retrieve_messages(&device_info.device_id)
.await
- .expect("Failed to retreive messages");
+ .unwrap_or_else(|e| {
+ error!("Error while retrieving messages: {}", e);
+ Vec::new()
+ });
- ACTIVE_CONNECTIONS
- .insert(session_info.device_id.clone(), self.tx.clone());
+ ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone());
for message in messages {
let payload =
message.get("payload").unwrap().as_s().unwrap().to_string();
- self
- .tx
- .send(payload)
- .expect("Failed to send message to client");
let created_at =
message.get(CREATED_AT).unwrap().as_n().unwrap().to_string();
+ self.send_message_to_device(payload).await;
self
.db_client
- .delete_message(&session_info.device_id, &created_at)
+ .delete_message(&device_info.device_id, &created_at)
.await
.expect("Failed to delete messages");
}
debug!("Flushed messages for device: {}", &session_info.device_id);
+
+ self.device_info = Some(device_info);
}
_ => {
debug!("Received invalid request");
@@ -61,4 +127,21 @@
Ok(())
}
+
+ pub async fn send_message_to_device(&mut self, incoming_payload: String) {
+ if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await {
+ error!("Failed to send message to device: {}", e);
+ }
+ }
+
+ // Release websocket and remove from active connections
+ pub async fn close(&mut self) {
+ if let Some(device_info) = &self.device_info {
+ ACTIVE_CONNECTIONS.remove(&device_info.device_id);
+ }
+
+ if let Err(e) = self.tx.close().await {
+ debug!("Failed to close session: {}", e);
+ }
+ }
}

File Metadata

Mime Type
text/plain
Expires
Sat, Jan 3, 10:40 AM (9 h, 45 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
5887880
Default Alt Text
D7800.1767436855.diff (10 KB)

Event Timeline