Page MenuHomePhabricator

D8178.id27874.diff
No OneTemporary

D8178.id27874.diff

diff --git a/services/tunnelbroker/Cargo.lock b/services/tunnelbroker/Cargo.lock
--- a/services/tunnelbroker/Cargo.lock
+++ b/services/tunnelbroker/Cargo.lock
@@ -968,19 +968,6 @@
"syn 2.0.15",
]
-[[package]]
-name = "dashmap"
-version = "5.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
-dependencies = [
- "cfg-if",
- "hashbrown",
- "lock_api",
- "once_cell",
- "parking_lot_core",
-]
-
[[package]]
name = "derive_more"
version = "0.99.17"
@@ -2969,7 +2956,6 @@
"aws-types",
"base64 0.20.0",
"clap",
- "dashmap",
"derive_more",
"env_logger",
"fcm",
diff --git a/services/tunnelbroker/Cargo.toml b/services/tunnelbroker/Cargo.toml
--- a/services/tunnelbroker/Cargo.toml
+++ b/services/tunnelbroker/Cargo.toml
@@ -15,7 +15,6 @@
aws-types = "0.55"
base64 = "0.20"
clap = { version = "4.2", features = ["derive", "env"] }
-dashmap = "5.4"
env_logger = "0.9"
fcm = "0.9"
futures = "0.3"
diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs
--- a/services/tunnelbroker/src/grpc/mod.rs
+++ b/services/tunnelbroker/src/grpc/mod.rs
@@ -2,18 +2,36 @@
tonic::include_proto!("tunnelbroker");
}
+use lapin::{options::BasicPublishOptions, BasicProperties};
use proto::tunnelbroker_service_server::{
TunnelbrokerService, TunnelbrokerServiceServer,
};
use proto::Empty;
use tonic::transport::Server;
-use tracing::{debug, error};
+use tracing::debug;
use crate::database::{handle_ddb_error, DatabaseClient};
-use crate::{constants, ACTIVE_CONNECTIONS, CONFIG};
+use crate::{constants, CONFIG};
struct TunnelbrokerGRPC {
client: DatabaseClient,
+ amqp_channel: lapin::Channel,
+}
+
+// By setting mandatory to true, we don't wait for a confirmation for an eventual
+// delivery, instead we get an immediate undelivered error
+const PUBLISH_OPTIONS: BasicPublishOptions = BasicPublishOptions {
+ immediate: false,
+ mandatory: true,
+};
+
+pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status {
+ match error {
+ lapin::Error::SerialisationError(_) | lapin::Error::ParsingError(_) => {
+ tonic::Status::invalid_argument("Invalid argument")
+ }
+ _ => tonic::Status::internal("Internal Error"),
+ }
}
#[tonic::async_trait]
@@ -25,11 +43,19 @@
let message = request.into_inner();
debug!("Received message for {}", &message.device_id);
- if let Some(tx) = ACTIVE_CONNECTIONS.get(&message.device_id) {
- if let Err(_) = tx.send(message.payload) {
- error!("Unable to send message to device: {}", &message.device_id);
- ACTIVE_CONNECTIONS.remove(&message.device_id);
- }
+ if let Ok(confirmation) = self
+ .amqp_channel
+ .basic_publish(
+ "",
+ &message.device_id,
+ PUBLISH_OPTIONS,
+ &message.payload.as_bytes(),
+ BasicProperties::default(),
+ )
+ .await
+ {
+ debug!("Forwarded message: {:?}", &message);
+ confirmation.await.map_err(handle_amqp_error)?;
} else {
self
.client
@@ -45,16 +71,25 @@
pub async fn run_server(
client: DatabaseClient,
+ ampq_connection: &lapin::Connection,
) -> Result<(), tonic::transport::Error> {
let addr = format!("[::1]:{}", CONFIG.grpc_port)
.parse()
.expect("Unable to parse gRPC address");
+ let amqp_channel = ampq_connection
+ .create_channel()
+ .await
+ .expect("Unable to create amqp channel");
+
tracing::info!("Websocket server listening on {}", &addr);
Server::builder()
.http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL))
.http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT))
- .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC { client }))
+ .add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC {
+ client,
+ amqp_channel,
+ }))
.serve(addr)
.await
}
diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs
--- a/services/tunnelbroker/src/main.rs
+++ b/services/tunnelbroker/src/main.rs
@@ -7,15 +7,9 @@
use anyhow::{anyhow, Result};
use config::CONFIG;
-use dashmap::DashMap;
-use once_cell::sync::Lazy;
-use tokio::sync::mpsc::UnboundedSender;
use tracing::{self, Level};
use tracing_subscriber::EnvFilter;
-pub static ACTIVE_CONNECTIONS: Lazy<DashMap<String, UnboundedSender<String>>> =
- Lazy::new(DashMap::new);
-
#[tokio::main]
async fn main() -> Result<()> {
let filter = EnvFilter::builder()
@@ -32,8 +26,9 @@
let db_client = database::DatabaseClient::new(&aws_config);
let amqp_connection = amqp::connect().await;
- let grpc_server = grpc::run_server(db_client.clone());
- let websocket_server = websockets::run_server(db_client.clone());
+ let grpc_server = grpc::run_server(db_client.clone(), &amqp_connection);
+ let websocket_server =
+ websockets::run_server(db_client.clone(), &amqp_connection);
tokio::select! {
Ok(_) = grpc_server => { Ok(()) },
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -1,15 +1,23 @@
mod session;
use crate::database::DatabaseClient;
+use crate::websockets::session::SessionError;
use crate::CONFIG;
+use futures_util::stream::SplitSink;
use futures_util::StreamExt;
use std::net::SocketAddr;
use std::{env, io::Error};
use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::mpsc;
-use tracing::{debug, info};
+use tokio_tungstenite::tungstenite::Message;
+use tokio_tungstenite::WebSocketStream;
+use tracing::{debug, error, info};
-pub async fn run_server(db_client: DatabaseClient) -> Result<(), Error> {
+use self::session::WebsocketSession;
+
+pub async fn run_server(
+ db_client: DatabaseClient,
+ amqp_connection: &lapin::Connection,
+) -> Result<(), Error> {
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR")
.unwrap_or_else(|_| format!("127.0.0.1:{}", &CONFIG.http_port));
@@ -17,7 +25,11 @@
info!("Listening on: {}", addr);
while let Ok((stream, addr)) = listener.accept().await {
- tokio::spawn(accept_connection(stream, addr, db_client.clone()));
+ let channel = amqp_connection
+ .create_channel()
+ .await
+ .expect("Unable to create amqp channel");
+ tokio::spawn(accept_connection(stream, addr, db_client.clone(), channel));
}
Ok(())
@@ -28,6 +40,7 @@
raw_stream: TcpStream,
addr: SocketAddr,
db_client: DatabaseClient,
+ amqp_channel: lapin::Channel,
) {
debug!("Incoming connection from: {}", addr);
@@ -43,20 +56,39 @@
};
let (outgoing, mut incoming) = ws_stream.split();
- // Create channel for messages to be passed to this connection
- let (tx, mut rx) = mpsc::unbounded_channel::<String>();
- let mut session = session::WebsocketSession::new(outgoing, db_client.clone());
+ // We don't know the identity of the device until it sends the session
+ // request over the websocket connection
+ let mut session = if let Some(Ok(first_msg)) = incoming.next().await {
+ match initiate_session(outgoing, first_msg, db_client, amqp_channel).await {
+ Ok(session) => session,
+ Err(_) => {
+ error!("Failed to create session with device");
+ return;
+ }
+ }
+ } else {
+ error!("Failed to create session with device");
+ return;
+ };
// Poll for messages either being sent to the device (rx)
// or messages being received from the device (incoming)
loop {
debug!("Polling for messages from: {}", addr);
tokio::select! {
- Some(message) = rx.recv() => { session.send_message_to_device(message).await; },
+ Some(Ok(delivery)) = session.next_amqp_message() => {
+ if let Ok(message) = std::str::from_utf8(&delivery.data) {
+ session.send_message_to_device(message.to_string()).await;
+ } else {
+ error!("Invalid payload");
+ }
+ },
device_message = incoming.next() => {
match device_message {
- Some(Ok(msg)) => session.handle_websocket_frame_from_device(msg, tx.clone()).await,
+ Some(Ok(msg)) => {
+ session::consume_error(session.handle_websocket_frame_from_device(msg).await);
+ }
_ => {
debug!("Connection to {} closed remotely.", addr);
break;
@@ -73,3 +105,26 @@
info!("Unregistering connection to: {}", addr);
session.close().await
}
+
+async fn initiate_session(
+ outgoing: SplitSink<WebSocketStream<TcpStream>, Message>,
+ frame: Message,
+ db_client: DatabaseClient,
+ amqp_channel: lapin::Channel,
+) -> Result<WebsocketSession, session::SessionError> {
+ let mut session = session::WebsocketSession::from_frame(
+ outgoing,
+ db_client.clone(),
+ frame,
+ &amqp_channel,
+ )
+ .await
+ .map_err(|_| {
+ error!("Device failed to send valid connection request.");
+ SessionError::InvalidMessage
+ })?;
+
+ session::consume_error(session.deliver_persisted_messages().await);
+
+ Ok(session)
+}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,15 +1,16 @@
use derive_more;
use futures_util::stream::SplitSink;
use futures_util::SinkExt;
-use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
+use futures_util::StreamExt;
+use lapin::message::Delivery;
+use lapin::options::{BasicConsumeOptions, QueueDeclareOptions};
+use lapin::types::FieldTable;
+use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
use tracing::{debug, error};
use tunnelbroker_messages::{session::DeviceTypes, Messages};
-use crate::{
- database::{self, DatabaseClient, DeviceMessage},
- ACTIVE_CONNECTIONS,
-};
+use crate::database::{self, DatabaseClient, DeviceMessage};
pub struct DeviceInfo {
pub device_id: String,
@@ -22,7 +23,9 @@
pub struct WebsocketSession {
tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
- device_info: Option<DeviceInfo>,
+ pub device_info: DeviceInfo,
+ // Stream of messages from AMQP endpoint
+ amqp_consumer: lapin::Consumer,
}
#[derive(Debug, derive_more::Display, derive_more::From)]
@@ -30,101 +33,135 @@
InvalidMessage,
SerializationError(serde_json::Error),
MessageError(database::MessageErrors),
+ AmqpError(lapin::Error),
}
-fn consume_error<T>(result: Result<T, SessionError>) {
+pub fn consume_error<T>(result: Result<T, SessionError>) {
if let Err(e) = result {
error!("{}", e)
}
}
+
+// Parse a session request and retrieve the device information
+pub fn handle_first_message_from_device(
+ message: &str,
+) -> Result<DeviceInfo, SessionError> {
+ let serialized_message = serde_json::from_str::<Messages>(message)?;
+
+ match serialized_message {
+ Messages::SessionRequest(mut session_info) => {
+ let device_info = DeviceInfo {
+ device_id: session_info.device_id.clone(),
+ notify_token: session_info.notify_token.take(),
+ device_type: session_info.device_type,
+ device_app_version: session_info.device_app_version.take(),
+ device_os: session_info.device_os.take(),
+ };
+
+ return Ok(device_info);
+ }
+ _ => {
+ debug!("Received invalid request");
+ return Err(SessionError::InvalidMessage);
+ }
+ }
+}
+
impl WebsocketSession {
- pub fn new(
+ pub async fn from_frame(
tx: SplitSink<WebSocketStream<TcpStream>, Message>,
db_client: DatabaseClient,
- ) -> WebsocketSession {
- WebsocketSession {
+ frame: Message,
+ amqp_channel: &lapin::Channel,
+ ) -> Result<WebsocketSession, SessionError> {
+ let device_info = match frame {
+ Message::Text(payload) => handle_first_message_from_device(&payload)?,
+ _ => {
+ error!("Client sent wrong frame type for establishing connection");
+ return Err(SessionError::InvalidMessage);
+ }
+ };
+
+ // We don't currently have a use case to interact directly with the queue,
+ // however, we need to declare a queue for a given device
+ amqp_channel
+ .queue_declare(
+ &device_info.device_id,
+ QueueDeclareOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+
+ let amqp_consumer = amqp_channel
+ .basic_consume(
+ &device_info.device_id,
+ "tunnelbroker",
+ BasicConsumeOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+
+ Ok(WebsocketSession {
tx,
db_client,
- device_info: None,
- }
+ device_info,
+ amqp_consumer,
+ })
}
pub async fn handle_websocket_frame_from_device(
&mut self,
frame: Message,
- tx: UnboundedSender<String>,
- ) {
- debug!("Received message from device: {}", frame);
- let result = match frame {
+ ) -> Result<(), SessionError> {
+ match frame {
Message::Text(payload) => {
- self.handle_message_from_device(&payload, tx).await
+ debug!("Received message from device: {}", payload);
+ Ok(())
}
Message::Close(_) => {
self.close().await;
Ok(())
}
_ => Err(SessionError::InvalidMessage),
- };
- consume_error(result);
+ }
+ }
+
+ pub async fn next_amqp_message(
+ &mut self,
+ ) -> Option<Result<Delivery, lapin::Error>> {
+ self.amqp_consumer.next().await
}
- pub async fn handle_message_from_device(
+ pub async fn deliver_persisted_messages(
&mut self,
- message: &str,
- tx: UnboundedSender<String>,
) -> Result<(), SessionError> {
- let serialized_message = serde_json::from_str::<Messages>(message)?;
-
- match serialized_message {
- Messages::SessionRequest(mut session_info) => {
- // TODO: Authenticate device using auth token
-
- // Check if session request was already sent
- if self.device_info.is_some() {
- return Err(SessionError::InvalidMessage);
- }
-
- let device_info = DeviceInfo {
- device_id: session_info.device_id.clone(),
- notify_token: session_info.notify_token.take(),
- device_type: session_info.device_type,
- device_app_version: session_info.device_app_version.take(),
- device_os: session_info.device_os.take(),
- };
-
- // Check for persisted messages
- let messages = self
- .db_client
- .retrieve_messages(&device_info.device_id)
- .await
- .unwrap_or_else(|e| {
- error!("Error while retrieving messages: {}", e);
- Vec::new()
- });
-
- ACTIVE_CONNECTIONS.insert(device_info.device_id.clone(), tx.clone());
-
- for message in messages {
- let device_message = DeviceMessage::from_hashmap(message)?;
- self.send_message_to_device(device_message.payload).await;
- if let Err(e) = self
- .db_client
- .delete_message(&device_info.device_id, &device_message.created_at)
- .await
- {
- error!("Failed to delete message: {}:", e);
- }
- }
-
- debug!("Flushed messages for device: {}", &session_info.device_id);
-
- self.device_info = Some(device_info);
- }
- _ => {
- debug!("Received invalid request");
+ // Check for persisted messages
+ let messages = self
+ .db_client
+ .retrieve_messages(&self.device_info.device_id)
+ .await
+ .unwrap_or_else(|e| {
+ error!("Error while retrieving messages: {}", e);
+ Vec::new()
+ });
+
+ for message in messages {
+ let device_message = DeviceMessage::from_hashmap(message)?;
+ self.send_message_to_device(device_message.payload).await;
+ if let Err(e) = self
+ .db_client
+ .delete_message(&self.device_info.device_id, &device_message.created_at)
+ .await
+ {
+ error!("Failed to delete message: {}:", e);
}
}
+ debug!(
+ "Flushed messages for device: {}",
+ &self.device_info.device_id
+ );
+
Ok(())
}
@@ -136,10 +173,6 @@
// Release websocket and remove from active connections
pub async fn close(&mut self) {
- if let Some(device_info) = &self.device_info {
- ACTIVE_CONNECTIONS.remove(&device_info.device_id);
- }
-
if let Err(e) = self.tx.close().await {
debug!("Failed to close session: {}", e);
}

File Metadata

Mime Type
text/plain
Expires
Fri, Nov 15, 8:17 PM (20 h, 47 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2495086
Default Alt Text
D8178.id27874.diff (16 KB)

Event Timeline