Page Menu
Home
Phorge
Search
Configure Global Search
Log In
Files
F32603308
D7800.1767436855.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Flag For Later
Award Token
Size
10 KB
Referenced Files
None
Subscribers
None
D7800.1767436855.diff
View Options
diff --git a/services/commtest/tests/tunnelbroker_integration_test.rs b/services/commtest/tests/tunnelbroker_integration_test.rs
--- a/services/commtest/tests/tunnelbroker_integration_test.rs
+++ b/services/commtest/tests/tunnelbroker_integration_test.rs
@@ -110,11 +110,11 @@
.expect("Failed to send message");
// Have keyserver receive any websocket messages
- let response = socket.next().await.unwrap().unwrap();
-
- // Check that message received by keyserver matches what identity server
- // issued
- let serialized_response: RefreshKeyRequest =
- serde_json::from_str(&response.to_text().unwrap()).unwrap();
- assert_eq!(serialized_response, refresh_request);
+ if let Some(Ok(response)) = socket.next().await {
+ // Check that message received by keyserver matches what identity server
+ // issued
+ let serialized_response: RefreshKeyRequest =
+ serde_json::from_str(&response.to_text().unwrap()).unwrap();
+ assert_eq!(serialized_response, refresh_request);
+ };
}
diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock
--- a/services/tunnelbroker/Cargo.lock
+++ b/services/tunnelbroker/Cargo.lock
@@ -664,6 +664,12 @@
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
+[[package]]
+name = "convert_case"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
+
[[package]]
name = "core-foundation"
version = "0.9.3"
@@ -756,6 +762,19 @@
"parking_lot_core",
]
+[[package]]
+name = "derive_more"
+version = "0.99.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321"
+dependencies = [
+ "convert_case",
+ "proc-macro2",
+ "quote",
+ "rustc_version",
+ "syn 1.0.109",
+]
+
[[package]]
name = "digest"
version = "0.10.6"
@@ -2491,6 +2510,7 @@
"base64 0.20.0",
"clap",
"dashmap",
+ "derive_more",
"env_logger",
"fcm",
"futures",
diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml
--- a/services/tunnelbroker/Cargo.toml
+++ b/services/tunnelbroker/Cargo.toml
@@ -34,6 +34,7 @@
tracing = "0.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
tunnelbroker_messages = { path = "../../shared/tunnelbroker_messages" }
+derive_more = "0.99.17"
[build-dependencies]
tonic-build = "0.8"
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -2,18 +2,12 @@
use crate::database::DatabaseClient;
use crate::CONFIG;
-use futures_util::stream::SplitSink;
-use futures_util::SinkExt;
-use futures_util::{StreamExt, TryStreamExt};
+use futures_util::StreamExt;
use std::net::SocketAddr;
use std::{env, io::Error};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
-use tokio_tungstenite::tungstenite::Message;
-use tokio_tungstenite::WebSocketStream;
-use tracing::{debug, error, info};
-
-use crate::ACTIVE_CONNECTIONS;
+use tracing::{debug, info};
pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> {
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR")
@@ -48,49 +42,34 @@
}
};
- let (mut outgoing, incoming) = ws_stream.split();
+ let (outgoing, mut incoming) = ws_stream.split();
// Create channel for messages to be passed to this connection
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
- let session = session::WebsocketSession::new(tx.clone(), db_client.clone());
- let handle_incoming = incoming.try_for_each(|msg| async {
- debug!("Received message from {}", addr);
- match msg {
- Message::Text(text) => {
- match session.handle_message_from_device(&text).await {
- Ok(_) => {
- debug!("Successfully handled message: {}", text)
- }
- Err(e) => {
- error!("Failed to process message: {}", e);
- }
- };
- }
- _ => {
- error!("Invalid message was received");
- }
- }
- Ok(())
- });
+ let mut session = session::WebsocketSession::new(outgoing, db_client.clone());
- debug!("Polling for messages from: {}", addr);
// Poll for messages either being sent to the device (rx)
- // or messages being received from the device (handle_incoming)
- tokio::select! {
- Some(message) = rx.recv() => { handle_message_from_service(message, &mut outgoing).await; },
- Ok(_) = handle_incoming => { debug!("Received message from websocket") },
- else => {
- info!("Connection with {} closed.", addr);
- ACTIVE_CONNECTIONS.remove("test");
+ // or messages being received from the device (incoming)
+ loop {
+ debug!("Polling for messages from: {}", addr);
+ tokio::select! {
+ Some(message) = rx.recv() => { session.send_message_to_device(message).await; },
+ device_message = incoming.next() => {
+ match device_message {
+ Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await,
+ _ => {
+ debug!("Connection to {} closed remotely.", addr);
+ break;
+ }
+ }
+ },
+ else => {
+ debug!("Unhealthy connection for: {}", addr);
+ break;
+ },
}
}
-}
-async fn handle_message_from_service(
- incoming_payload: String,
- outgoing: &mut SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
-) {
- if let Err(e) = outgoing.send(Message::Text(incoming_payload)).await {
- error!("Failed to send message to device: {}", e);
- }
+ info!("Unregistering connection to: {}", addr);
+ session.close().await
}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,58 +1,124 @@
-use tracing::debug;
-use tunnelbroker_messages::Messages;
+use derive_more;
+use futures_util::stream::SplitSink;
+use futures_util::SinkExt;
+use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
+use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
+use tracing::{debug, error};
+use tunnelbroker_messages::{session::DeviceTypes, Messages};
use crate::{
constants::dynamodb::undelivered_messages::CREATED_AT,
database::DatabaseClient, ACTIVE_CONNECTIONS,
};
+pub struct DeviceInfo {
+ pub device_id: String,
+ pub notify_token: Option<String>,
+ pub device_type: DeviceTypes,
+ pub device_app_version: Option<String>,
+ pub device_os: Option<String>,
+}
+
pub struct WebsocketSession {
- tx: tokio::sync::mpsc::UnboundedSender<std::string::String>,
+ tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
+ device_info: Option<DeviceInfo>,
+}
+
+#[derive(Debug, derive_more::Display, derive_more::From)]
+pub enum SessionError {
+ InvalidMessage,
+ SerializationError(serde_json::Error),
}
+fn consume_error<T>(result: Result<T, SessionError>) {
+ if let Err(e) = result {
+ error!("{}", e)
+ }
+}
impl WebsocketSession {
pub fn new(
- tx: tokio::sync::mpsc::UnboundedSender<std::string::String>,
+ tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
) -> WebsocketSession {
- WebsocketSession { tx, db_client }
+ WebsocketSession {
+ tx,
+ db_client,
+ device_info: None,
+ }
+ }
+
+ pub async fn handle_websocket_frame_from_device(
+ &mut self,
+ frame: Message,
+ tx: UnboundedSender<String>,
+ ) {
+ debug!("Received message from device: {}", frame);
+ let result = match frame {
+ Message::Text(payload) => {
+ self.handle_message_from_device(&payload, tx).await
+ }
+ Message::Close(_) => {
+ self.close().await;
+ Ok(())
+ }
+ _ => Err(SessionError::InvalidMessage),
+ };
+ consume_error(result);
}
pub async fn handle_message_from_device(
- &self,
+ &mut self,
message: &str,
- ) -> Result<(), serde_json::Error> {
- match serde_json::from_str::<Messages>(message)? {
- Messages::SessionRequest(session_info) => {
+ tx: UnboundedSender<String>,
+ ) -> Result<(), SessionError> {
+ let serialized_message = serde_json::from_str::<Messages>(message)?;
+
+ match serialized_message {
+ Messages::SessionRequest(mut session_info) => {
// TODO: Authenticate device using auth token
+
+ // Check if session request was already sent
+ if self.device_info.is_some() {
+ return Err(SessionError::InvalidMessage);
+ }
+
+ let device_info = DeviceInfo {
+ device_id: session_info.device_id.clone(),
+ notify_token: session_info.notify_token.take(),
+ device_type: session_info.device_type,
+ device_app_version: session_info.device_app_version.take(),
+ device_os: session_info.device_os.take(),
+ };
+
// Check for persisted messages
let messages = self
.db_client
- .retrieve_messages(&session_info.device_id)
+ .retrieve_messages(&device_info.device_id)
.await
- .expect("Failed to retreive messages");
+ .unwrap_or_else(|e| {
+ error!("Error while retrieving messages: {}", e);
+ Vec::new()
+ });
- ACTIVE_CONNECTIONS
- .insert(session_info.device_id.clone(), self.tx.clone());
+ ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone());
for message in messages {
let payload =
message.get("payload").unwrap().as_s().unwrap().to_string();
- self
- .tx
- .send(payload)
- .expect("Failed to send message to client");
let created_at =
message.get(CREATED_AT).unwrap().as_n().unwrap().to_string();
+ self.send_message_to_device(payload).await;
self
.db_client
- .delete_message(&session_info.device_id, &created_at)
+ .delete_message(&device_info.device_id, &created_at)
.await
.expect("Failed to delete messages");
}
debug!("Flushed messages for device: {}", &session_info.device_id);
+
+ self.device_info = Some(device_info);
}
_ => {
debug!("Received invalid request");
@@ -61,4 +127,21 @@
Ok(())
}
+
+ pub async fn send_message_to_device(&mut self, incoming_payload: String) {
+ if let Err(e) = self.tx.send(Message::Text(incoming_payload)).await {
+ error!("Failed to send message to device: {}", e);
+ }
+ }
+
+ // Release websocket and remove from active connections
+ pub async fn close(&mut self) {
+ if let Some(device_info) = &self.device_info {
+ ACTIVE_CONNECTIONS.remove(&device_info.device_id);
+ }
+
+ if let Err(e) = self.tx.close().await {
+ debug!("Failed to close session: {}", e);
+ }
+ }
}
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Sat, Jan 3, 10:40 AM (9 h, 45 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
5887880
Default Alt Text
D7800.1767436855.diff (10 KB)
Attached To
Mode
D7800: [Tunnelbroker] Refactor connection lifetimes into session object
Attached
Detach File
Event Timeline
Log In to Comment