Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F3245761
D8178.id27874.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
16 KB
Referenced Files
None
Subscribers
None
D8178.id27874.diff
View Options
diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock
--- a/services/tunnelbroker/Cargo.lock
+++ b/services/tunnelbroker/Cargo.lock
@@ -968,19 +968,6 @@
"syn 2.0.15",
]
-[[package]]
-name = "dashmap"
-version = "5.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
-dependencies = [
- "cfg-if",
- "hashbrown",
- "lock_api",
- "once_cell",
- "parking_lot_core",
-]
-
[[package]]
name = "derive_more"
version = "0.99.17"
@@ -2969,7 +2956,6 @@
"aws-types",
"base64 0.20.0",
"clap",
- "dashmap",
"derive_more",
"env_logger",
"fcm",
diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml
--- a/services/tunnelbroker/Cargo.toml
+++ b/services/tunnelbroker/Cargo.toml
@@ -15,7 +15,6 @@
aws-types = "0.55"
base64 = "0.20"
clap = { version = "4.2", features = ["derive", "env"] }
-dashmap = "5.4"
env_logger = "0.9"
fcm = "0.9"
futures = "0.3"
diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs
--- a/services/tunnelbroker/src/grpc/mod.rs
+++ b/services/tunnelbroker/src/grpc/mod.rs
@@ -2,18 +2,36 @@
tonic::include_proto!("tunnelbroker");
}
+use lapin::{options::BasicPublishOptions, BasicProperties};
use proto::tunnelbroker_service_server::{
TunnelbrokerService, TunnelbrokerServiceServer,
};
use proto::Empty;
use tonic::transport::Server;
-use tracing::{debug, error};
+use tracing::debug;
use crate::database::{handle_ddb_error, DatabaseClient};
-use crate::{constants, ACTIVE_CONNECTIONS, CONFIG};
+use crate::{constants, CONFIG};
struct TunnelbrokerGRPC {
client: DatabaseClient,
+ amqp_channel: lapin::Channel,
+}
+
+// By setting mandatory to true, we don't wait for a confirmation for an eventual
+// delivery, instead we get an immediate undelivered error
+const PUBLISH_OPTIONS: BasicPublishOptions = BasicPublishOptions {
+ immediate: false,
+ mandatory: true,
+};
+
+pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status {
+ match error {
+ lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => {
+ tonic::Status::invalid_argument("Invalid argument")
+ }
+ _ => tonic::Status::internal("Internal Error"),
+ }
}
#[tonic::async_trait]
@@ -25,11 +43,19 @@
let message = request.into_inner();
debug!("Received message for {}", &message.device_id);
- if let Some(tx) = ACTIVE_CONNECTIONS.get(&message.device_id) {
- if let Err(_) = tx.send(message.payload) {
- error!("Unable to send message to device: {}", &message.device_id);
- ACTIVE_CONNECTIONS.remove(&message.device_id);
- }
+ if let Ok(confirmation) = self
+ .amqp_channel
+ .basic_publish(
+ "",
+ &message.device_id,
+ PUBLISH_OPTIONS,
+ &message.payload.as_bytes(),
+ BasicProperties::default(),
+ )
+ .await
+ {
+ debug!("Forwarded message: {:?}", &message);
+ confirmation.await.map_err(handle_amqp_error)?;
} else {
self
.client
@@ -45,16 +71,25 @@
pub async fn run_server(
client: DatabaseClient,
+ ampq_connection: &lapin::Connection,
) -> Result<(), tonic::transport::Error> {
let addr = format!("[::1]:{}", CONFIG.grpc_port)
.parse()
.expect("Unable to parse gRPC address");
+ let amqp_channel = ampq_connection
+ .create_channel()
+ .await
+ .expect("Unable to create amqp channel");
+
tracing::info!("Websocket server listening on {}", &addr);
Server::builder()
.http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL))
.http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT))
- .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client }))
+ .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC {
+ client,
+ amqp_channel,
+ }))
.serve(addr)
.await
}
diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs
--- a/services/tunnelbroker/src/main.rs
+++ b/services/tunnelbroker/src/main.rs
@@ -7,15 +7,9 @@
use anyhow::{anyhow, Result};
use config::CONFIG;
-use dashmap::DashMap;
-use once_cell::sync::Lazy;
-use tokio::sync::mpsc::UnboundedSender;
use tracing::{self, Level};
use tracing_subscriber::EnvFilter;
-pub static ACTIVE_CONNECTIONS: Lazy<DashMap<String, UnboundedSender<String>>> =
- Lazy::new(DashMap::new);
-
#[tokio::main]
async fn main() -> Result<()> {
let filter = EnvFilter::builder()
@@ -32,8 +26,9 @@
let db_client = database::DatabaseClient::new(&aws_config);
let amqp_connection = amqp::connect().await;
- let grpc_server = grpc::run_server(db_client.clone());
- let websocket_server = websockets::run_server(db_client.clone());
+ let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection);
+ let websocket_server =
+ websockets::run_server(db_client.clone(), &amqp_connection);
tokio::select! {
Ok(_) = grpc_server => { Ok(()) },
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -1,15 +1,23 @@
mod session;
use crate::database::DatabaseClient;
+use crate::websockets::session::SessionError;
use crate::CONFIG;
+use futures_util::stream::SplitSink;
use futures_util::StreamExt;
use std::net::SocketAddr;
use std::{env, io::Error};
use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::mpsc;
-use tracing::{debug, info};
+use tokio_tungstenite::tungstenite::Message;
+use tokio_tungstenite::WebSocketStream;
+use tracing::{debug, error, info};
-pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> {
+use self::session::WebsocketSession;
+
+pub async fn run_server(
+ db_client: DatabaseClient,
+ amqp_connection: &lapin::Connection,
+) -> Result<(), Error> {
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR")
.unwrap_or_else(|_| format!("127.0.0.1:{}", &CONFIG.http_port));
@@ -17,7 +25,11 @@
info!("Listening on: {}", addr);
while let Ok((stream, addr)) = listener.accept().await {
- tokio::spawn(accept_connection(stream, addr, db_client.clone()));
+ let channel = amqp_connection
+ .create_channel()
+ .await
+ .expect("Unable to create amqp channel");
+ tokio::spawn(accept_connection(stream, addr, db_client.clone(), channel));
}
Ok(())
@@ -28,6 +40,7 @@
raw_stream: TcpStream,
addr: SocketAddr,
db_client: DatabaseClient,
+ amqp_channel: lapin::Channel,
) {
debug!("Incoming connection from: {}", addr);
@@ -43,20 +56,39 @@
};
let (outgoing, mut incoming) = ws_stream.split();
- // Create channel for messages to be passed to this connection
- let (tx, mut rx) = mpsc::unbounded_channel::<String>();
- let mut session = session::WebsocketSession::new(outgoing, db_client.clone());
+ // We don't know the identity of the device until it sends the session
+ // request over the websocket connection
+ let mut session = if let Some(Ok(first_msg)) = incoming.next().await {
+ match initiate_session(outgoing, first_msg, db_client, amqp_channel).await {
+ Ok(session) => session,
+ Err(_) => {
+ error!("Failed to create session with device");
+ return;
+ }
+ }
+ } else {
+ error!("Failed to create session with device");
+ return;
+ };
// Poll for messages either being sent to the device (rx)
// or messages being received from the device (incoming)
loop {
debug!("Polling for messages from: {}", addr);
tokio::select! {
- Some(message) = rx.recv() => { session.send_message_to_device(message).await; },
+ Some(Ok(delivery)) = session.next_amqp_message() => {
+ if let Ok(message) = std::str::from_utf8(&delivery.data) {
+ session.send_message_to_device(message.to_string()).await;
+ } else {
+ error!("Invalid payload");
+ }
+ },
device_message = incoming.next() => {
match device_message {
- Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await,
+ Some(Ok(msg)) => {
+ session::consume_error(session.handle_websocket_frame_from_device(msg).await);
+ }
_ => {
debug!("Connection to {} closed remotely.", addr);
break;
@@ -73,3 +105,26 @@
info!("Unregistering connection to: {}", addr);
session.close().await
}
+
+async fn initiate_session(
+ outgoing: SplitSink<WebSocketStream<TcpStream>, Message>,
+ frame: Message,
+ db_client: DatabaseClient,
+ amqp_channel: lapin::Channel,
+) -> Result<WebsocketSession, session::SessionError> {
+ let mut session = session::WebsocketSession::from_frame(
+ outgoing,
+ db_client.clone(),
+ frame,
+ &amqp_channel,
+ )
+ .await
+ .map_err(|_| {
+ error!("Device failed to send valid connection request.");
+ SessionError::InvalidMessage
+ })?;
+
+ session::consume_error(session.deliver_persisted_messages().await);
+
+ Ok(session)
+}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,15 +1,16 @@
use derive_more;
use futures_util::stream::SplitSink;
use futures_util::SinkExt;
-use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
+use futures_util::StreamExt;
+use lapin::message::Delivery;
+use lapin::options::{BasicConsumeOptions, QueueDeclareOptions};
+use lapin::types::FieldTable;
+use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
use tracing::{debug, error};
use tunnelbroker_messages::{session::DeviceTypes, Messages};
-use crate::{
- database::{self, DatabaseClient, DeviceMessage},
- ACTIVE_CONNECTIONS,
-};
+use crate::database::{self, DatabaseClient, DeviceMessage};
pub struct DeviceInfo {
pub device_id: String,
@@ -22,7 +23,9 @@
pub struct WebsocketSession {
tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
- device_info: Option<DeviceInfo>,
+ pub device_info: DeviceInfo,
+ // Stream of messages from AMQP endpoint
+ amqp_consumer: lapin::Consumer,
}
#[derive(Debug, derive_more::Display, derive_more::From)]
@@ -30,101 +33,135 @@
InvalidMessage,
SerializationError(serde_json::Error),
MessageError(database::MessageErrors),
+ AmqpError(lapin::Error),
}
-fn consume_error<T>(result: Result<T, SessionError>) {
+pub fn consume_error<T>(result: Result<T, SessionError>) {
if let Err(e) = result {
error!("{}", e)
}
}
+
+// Parse a session request and retrieve the device information
+pub fn handle_first_message_from_device(
+ message: &str,
+) -> Result<DeviceInfo, SessionError> {
+ let serialized_message = serde_json::from_str::<Messages>(message)?;
+
+ match serialized_message {
+ Messages::SessionRequest(mut session_info) => {
+ let device_info = DeviceInfo {
+ device_id: session_info.device_id.clone(),
+ notify_token: session_info.notify_token.take(),
+ device_type: session_info.device_type,
+ device_app_version: session_info.device_app_version.take(),
+ device_os: session_info.device_os.take(),
+ };
+
+ return Ok(device_info);
+ }
+ _ => {
+ debug!("Received invalid request");
+ return Err(SessionError::InvalidMessage);
+ }
+ }
+}
+
impl WebsocketSession {
- pub fn new(
+ pub async fn from_frame(
tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
- ) -> WebsocketSession {
- WebsocketSession {
+ frame: Message,
+ amqp_channel: &lapin::Channel,
+ ) -> Result<WebsocketSession, SessionError> {
+ let device_info = match frame {
+ Message::Text(payload) => handle_first_message_from_device(&payload)?,
+ _ => {
+ error!("Client sent wrong frame type for establishing connection");
+ return Err(SessionError::InvalidMessage);
+ }
+ };
+
+ // We don't currently have a use case to interact directly with the queue,
+ // however, we need to declare a queue for a given device
+ amqp_channel
+ .queue_declare(
+ &device_info.device_id,
+ QueueDeclareOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+
+ let amqp_consumer = amqp_channel
+ .basic_consume(
+ &device_info.device_id,
+ "tunnelbroker",
+ BasicConsumeOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+
+ Ok(WebsocketSession {
tx,
db_client,
- device_info: None,
- }
+ device_info,
+ amqp_consumer,
+ })
}
pub async fn handle_websocket_frame_from_device(
&mut self,
frame: Message,
- tx: UnboundedSender<String>,
- ) {
- debug!("Received message from device: {}", frame);
- let result = match frame {
+ ) -> Result<(), SessionError> {
+ match frame {
Message::Text(payload) => {
- self.handle_message_from_device(&payload, tx).await
+ debug!("Received message from device: {}", payload);
+ Ok(())
}
Message::Close(_) => {
self.close().await;
Ok(())
}
_ => Err(SessionError::InvalidMessage),
- };
- consume_error(result);
+ }
+ }
+
+ pub async fn next_amqp_message(
+ &mut self,
+ ) -> Option<Result<Delivery, lapin::Error>> {
+ self.amqp_consumer.next().await
}
- pub async fn handle_message_from_device(
+ pub async fn deliver_persisted_messages(
&mut self,
- message: &str,
- tx: UnboundedSender<String>,
) -> Result<(), SessionError> {
- let serialized_message = serde_json::from_str::<Messages>(message)?;
-
- match serialized_message {
- Messages::SessionRequest(mut session_info) => {
- // TODO: Authenticate device using auth token
-
- // Check if session request was already sent
- if self.device_info.is_some() {
- return Err(SessionError::InvalidMessage);
- }
-
- let device_info = DeviceInfo {
- device_id: session_info.device_id.clone(),
- notify_token: session_info.notify_token.take(),
- device_type: session_info.device_type,
- device_app_version: session_info.device_app_version.take(),
- device_os: session_info.device_os.take(),
- };
-
- // Check for persisted messages
- let messages = self
- .db_client
- .retrieve_messages(&device_info.device_id)
- .await
- .unwrap_or_else(|e| {
- error!("Error while retrieving messages: {}", e);
- Vec::new()
- });
-
- ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone());
-
- for message in messages {
- let device_message = DeviceMessage::from_hashmap(message)?;
- self.send_message_to_device(device_message.payload).await;
- if let Err(e) = self
- .db_client
- .delete_message(&device_info.device_id, &device_message.created_at)
- .await
- {
- error!("Failed to delete message: {}:", e);
- }
- }
-
- debug!("Flushed messages for device: {}", &session_info.device_id);
-
- self.device_info = Some(device_info);
- }
- _ => {
- debug!("Received invalid request");
+ // Check for persisted messages
+ let messages = self
+ .db_client
+ .retrieve_messages(&self.device_info.device_id)
+ .await
+ .unwrap_or_else(|e| {
+ error!("Error while retrieving messages: {}", e);
+ Vec::new()
+ });
+
+ for message in messages {
+ let device_message = DeviceMessage::from_hashmap(message)?;
+ self.send_message_to_device(device_message.payload).await;
+ if let Err(e) = self
+ .db_client
+ .delete_message(&self.device_info.device_id, &device_message.created_at)
+ .await
+ {
+ error!("Failed to delete message: {}:", e);
}
}
+ debug!(
+ "Flushed messages for device: {}",
+ &self.device_info.device_id
+ );
+
Ok(())
}
@@ -136,10 +173,6 @@
// Release websocket and remove from active connections
pub async fn close(&mut self) {
- if let Some(device_info) = &self.device_info {
- ACTIVE_CONNECTIONS.remove(&device_info.device_id);
- }
-
if let Err(e) = self.tx.close().await {
debug!("Failed to close session: {}", e);
}
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Fri, Nov 15, 8:17 PM (20 h, 47 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2495086
Default Alt Text
D8178.id27874.diff (16 KB)
Attached To
Mode
D8178: [Tunnelbroker] Use rabbitmq for message delivery
Attached
Detach File
Event Timeline
Log In to Comment