Page MenuHomePhabricator

D13595.id.diff
No OneTemporary

D13595.id.diff

diff --git a/services/tunnelbroker/src/amqp.rs b/services/tunnelbroker/src/amqp.rs
--- a/services/tunnelbroker/src/amqp.rs
+++ b/services/tunnelbroker/src/amqp.rs
@@ -1,10 +1,13 @@
use comm_lib::database::batch_operations::ExponentialBackoffConfig;
use lapin::{uri::AMQPUri, Connection, ConnectionProperties};
use once_cell::sync::Lazy;
+use std::hash::Hasher;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::{Arc, RwLock};
use std::time::Duration;
-use tracing::info;
+use tracing::{debug, error, info, warn};
-use crate::constants::error_types;
+use crate::constants::{error_types, NUM_AMQP_CHANNELS};
use crate::CONFIG;
static AMQP_URI: Lazy<AMQPUri> = Lazy::new(|| {
@@ -24,45 +27,187 @@
amqp_uri
});
-pub async fn connect() -> Connection {
+async fn create_connection() -> Result<Connection, lapin::Error> {
let options = ConnectionProperties::default()
.with_executor(tokio_executor_trait::Tokio::current())
.with_reactor(tokio_reactor_trait::Tokio);
let retry_config = ExponentialBackoffConfig {
- max_attempts: 5,
+ max_attempts: 8,
base_duration: Duration::from_millis(500),
..Default::default()
};
let mut retry_counter = retry_config.new_counter();
tracing::debug!("Attempting to connect to AMQP...");
- let conn_result = loop {
+ loop {
let amqp_uri = Lazy::force(&AMQP_URI).clone();
match lapin::Connection::connect_uri(amqp_uri, options.clone()).await {
- Ok(conn) => break Ok(conn),
+ Ok(conn) => return Ok(conn),
Err(err) => {
let attempt = retry_counter.attempt();
tracing::warn!(attempt, "AMQP connection attempt failed: {err}.");
if retry_counter.sleep_and_retry().await.is_err() {
tracing::error!("Unable to connect to AMQP: {err}");
- break Err(err);
+ return Err(err);
}
}
}
- };
+ }
+}
+
+/// Inner connection that is a direct wrapper over lapin::Connection
+/// This should be instantiated only once
+struct ConnectionInner {
+ conn: lapin::Connection,
+ // channel pool
+ channels: [lapin::Channel; NUM_AMQP_CHANNELS],
+}
+
+impl ConnectionInner {
+ async fn new() -> Result<Self, lapin::Error> {
+ let conn = create_connection().await?;
+ conn.on_error(|err| {
+ // TODO: we should filter out some IOErrors here to avoid spamming alerts
+ error!(errorType = error_types::AMQP_ERROR, "Lapin error: {err:?}");
+ });
+
+ debug!("Creating channels...");
+ let mut channels = Vec::with_capacity(NUM_AMQP_CHANNELS);
+ for idx in 0..NUM_AMQP_CHANNELS {
+ let channel = conn.create_channel().await?;
+ tracing::trace!("Creating channel ID={} at index={}", channel.id(), idx);
+ channels.push(channel);
+ }
+
+ Ok(Self {
+ conn,
+ channels: channels
+ .try_into()
+ .expect("Channels vec size doesn't match array size"),
+ })
+ }
+
+ pub fn get_channel(
+ &self,
+ id_hash: impl std::hash::Hash,
+ ) -> Result<lapin::Channel, lapin::Error> {
+ // We have channel pool and want to distribute them between connected
+ // devices. Round robin would work too, but by using "hash modulo N"
+ // we make sure the same device will always use the same channel.
+ // Generally this shouldn't matter, but helps avoiding potential issues
+ // with the same queue name being declared by different channels,
+ // in case of reconnection.
+ let mut hasher = std::hash::DefaultHasher::new();
+ id_hash.hash(&mut hasher);
+ let channel_idx: usize = hasher.finish() as usize % NUM_AMQP_CHANNELS;
+
+ let channel = self.channels[channel_idx].clone();
+ let channel_id = channel.id();
+ tracing::trace!(channel_id, channel_idx, "Retrieving AMQP Channel");
+ Ok(channel)
+ }
+
+ fn is_connected(&self) -> bool {
+ self.conn.status().connected()
+ }
+
+ fn raw(&self) -> &lapin::Connection {
+ &self.conn
+ }
+}
+
+/// Thread safe connection wrapper that is Clone + Send + Sync
+/// and can be shared wherever needed
+#[derive(Clone)]
+pub struct AmqpConnection {
+ inner: Arc<RwLock<ConnectionInner>>,
+ is_connecting: Arc<AtomicBool>,
+}
- let conn = conn_result.expect("Unable to connect to AMQP. Exiting.");
- conn.on_error(|error| {
- tracing::error!(
- errorType = error_types::AMQP_ERROR,
- "Lapin error: {error:?}"
- );
- });
+impl AmqpConnection {
+ pub async fn connect() -> Result<Self, lapin::Error> {
+ let is_connecting = AtomicBool::new(true);
+ let conn = ConnectionInner::new().await?;
+ let inner = Arc::new(RwLock::new(conn));
+ is_connecting.store(false, Ordering::Relaxed);
+ info!("Connected to AMQP endpoint: {}", &CONFIG.amqp_uri);
+ Ok(Self {
+ inner,
+ is_connecting: Arc::new(is_connecting),
+ })
+ }
+
+ pub async fn channel(
+ &self,
+ id_hash: impl std::hash::Hash,
+ ) -> Result<lapin::Channel, lapin::Error> {
+ if !self.is_connected() {
+ warn!("AMQP disconnected while retrieving channel");
+ self.reset_conn().await?;
+ }
+ self.inner.read().unwrap().get_channel(id_hash)
+ }
+
+ async fn reset_conn(&self) -> Result<(), lapin::Error> {
+ if let Ok(false) = self.is_connecting.compare_exchange(
+ false,
+ true,
+ Ordering::Relaxed,
+ Ordering::Relaxed,
+ ) {
+ debug!("Resetting connection...");
+ let new_conn = ConnectionInner::new().await?;
+ let mut inner = self.inner.write().unwrap();
+ if !inner.is_connected() {
+ *inner = new_conn;
+ info!("AMQP Connection restored.");
+ }
+
+ self.is_connecting.store(false, Ordering::Relaxed);
+ } else {
+ debug!("Already resetting on other thread");
+ while self.is_connecting() {
+ tokio::time::sleep(Duration::from_millis(10)).await;
+ }
+ if !self.is_connected() {
+ // other thread failed to reset
+ let state = self.inner.read().unwrap().raw().status().state();
+ warn!("Other thread failed to reset. State: {:?}", state);
+ return Err(lapin::Error::InvalidConnectionState(state));
+ }
+ }
+ Ok(())
+ }
+
+ /// Triggers reconnecting in background, without awaiting
+ pub fn trigger_reconnect(&self) {
+ if !self.is_connected() && !self.is_connecting() {
+ let this = self.clone();
+ tokio::spawn(async move {
+ if let Err(err) = this.reset_conn().await {
+ tracing::warn!("AMQP background reconnect failed: {:?}", err);
+ }
+ });
+ }
+ }
+
+ fn is_connecting(&self) -> bool {
+ self.is_connecting.load(Ordering::Relaxed)
+ }
+
+ fn is_connected(&self) -> bool {
+ self.inner.read().unwrap().is_connected()
+ }
+}
- info!("Connected to AMQP endpoint: {}", &CONFIG.amqp_uri);
- conn
+pub fn is_connection_error(err: &lapin::Error) -> bool {
+ matches!(
+ err,
+ lapin::Error::InvalidChannelState(_)
+ | lapin::Error::InvalidConnectionState(_)
+ )
}
fn from_env(var_name: &str) -> Option<String> {
diff --git a/services/tunnelbroker/src/constants.rs b/services/tunnelbroker/src/constants.rs
--- a/services/tunnelbroker/src/constants.rs
+++ b/services/tunnelbroker/src/constants.rs
@@ -7,6 +7,7 @@
pub const SOCKET_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(3);
+pub const NUM_AMQP_CHANNELS: usize = 8;
pub const MAX_RMQ_MSG_PRIORITY: u8 = 10;
pub const DDB_RMQ_MSG_PRIORITY: u8 = 10;
pub const CLIENT_RMQ_MSG_PRIORITY: u8 = 1;
diff --git a/services/tunnelbroker/src/grpc/mod.rs b/services/tunnelbroker/src/grpc/mod.rs
--- a/services/tunnelbroker/src/grpc/mod.rs
+++ b/services/tunnelbroker/src/grpc/mod.rs
@@ -11,13 +11,14 @@
use tracing::debug;
use tunnelbroker_messages::MessageToDevice;
+use crate::amqp::AmqpConnection;
use crate::constants::{CLIENT_RMQ_MSG_PRIORITY, WS_SESSION_CLOSE_AMQP_MSG};
use crate::database::{handle_ddb_error, DatabaseClient};
use crate::{constants, CONFIG};
struct TunnelbrokerGRPC {
client: DatabaseClient,
- amqp_channel: lapin::Channel,
+ amqp: AmqpConnection,
}
pub fn handle_amqp_error(error: lapin::Error) -> tonic::Status {
@@ -57,7 +58,10 @@
.map_err(|_| tonic::Status::invalid_argument("Invalid argument"))?;
self
- .amqp_channel
+ .amqp
+ .channel("grpc")
+ .await
+ .map_err(handle_amqp_error)?
.basic_publish(
"",
&message.device_id,
@@ -81,7 +85,10 @@
debug!("Connection close request for device {}", &message.device_id);
self
- .amqp_channel
+ .amqp
+ .channel("grpc")
+ .await
+ .map_err(handle_amqp_error)?
.basic_publish(
"",
&message.device_id,
@@ -122,24 +129,19 @@
pub async fn run_server(
client: DatabaseClient,
- ampq_connection: &lapin::Connection,
+ amqp_connection: &AmqpConnection,
) -> Result<(), tonic::transport::Error> {
let addr = format!("[::]:{}", CONFIG.grpc_port)
.parse()
.expect("Unable to parse gRPC address");
- let amqp_channel = ampq_connection
- .create_channel()
- .await
- .expect("Unable to create amqp channel");
-
tracing::info!("gRPC server listening on {}", &addr);
Server::builder()
.http2_keepalive_interval(Some(constants::GRPC_KEEP_ALIVE_PING_INTERVAL))
.http2_keepalive_timeout(Some(constants::GRPC_KEEP_ALIVE_PING_TIMEOUT))
.add_service(TunnelbrokerServiceServer::new(TunnelbrokerGRPC {
client,
- amqp_channel,
+ amqp: amqp_connection.clone(),
}))
.serve(addr)
.await
diff --git a/services/tunnelbroker/src/main.rs b/services/tunnelbroker/src/main.rs
--- a/services/tunnelbroker/src/main.rs
+++ b/services/tunnelbroker/src/main.rs
@@ -48,7 +48,9 @@
config::parse_cmdline_args()?;
let aws_config = config::load_aws_config().await;
let db_client = database::DatabaseClient::new(&aws_config);
- let amqp_connection = amqp::connect().await;
+ let amqp_connection = amqp::AmqpConnection::connect()
+ .await
+ .expect("Failed to create AMQP connection");
let apns_config = CONFIG.apns_config.clone();
diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -1,9 +1,10 @@
pub mod session;
+use crate::amqp::AmqpConnection;
use crate::constants::{SOCKET_HEARTBEAT_TIMEOUT, WS_SESSION_CLOSE_AMQP_MSG};
use crate::database::DatabaseClient;
use crate::notifs::NotifClient;
-use crate::websockets::session::{initialize_amqp, SessionError};
+use crate::websockets::session::{handle_first_ws_frame, SessionError};
use crate::CONFIG;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
@@ -18,7 +19,7 @@
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
-use tracing::{debug, error, info, trace};
+use tracing::{debug, error, info, trace, warn};
use tunnelbroker_messages::{
ConnectionInitializationStatus, DeviceToTunnelbrokerRequestStatus, Heartbeat,
MessageSentStatus,
@@ -39,7 +40,7 @@
/// It also handles regular HTTP requests (currently health check)
struct WebsocketService {
addr: SocketAddr,
- channel: lapin::Channel,
+ amqp: AmqpConnection,
db_client: DatabaseClient,
notif_client: NotifClient,
}
@@ -62,7 +63,7 @@
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let addr = self.addr;
let db_client = self.db_client.clone();
- let channel = self.channel.clone();
+ let amqp = self.amqp.clone();
let notif_client = self.notif_client.clone();
let future = async move {
@@ -72,7 +73,7 @@
// Spawn a task to handle the websocket connection.
tokio::spawn(async move {
- accept_connection(websocket, addr, db_client, channel, notif_client)
+ accept_connection(websocket, addr, db_client, amqp, notif_client)
.await;
});
@@ -101,7 +102,7 @@
pub async fn run_server(
db_client: DatabaseClient,
- amqp_connection: &lapin::Connection,
+ amqp_connection: &AmqpConnection,
notif_client: NotifClient,
) -> Result<(), BoxedError> {
let addr = env::var("COMM_TUNNELBROKER_WEBSOCKET_ADDR")
@@ -115,15 +116,12 @@
http.http1_keep_alive(true);
while let Ok((stream, addr)) = listener.accept().await {
- let channel = amqp_connection
- .create_channel()
- .await
- .expect("Failed to create AMQP channel");
+ let amqp = amqp_connection.clone();
let connection = http
.serve_connection(
stream,
WebsocketService {
- channel,
+ amqp,
db_client: db_client.clone(),
addr,
notif_client: notif_client.clone(),
@@ -169,7 +167,7 @@
hyper_ws: HyperWebsocket,
addr: SocketAddr,
db_client: DatabaseClient,
- amqp_channel: lapin::Channel,
+ amqp_connection: AmqpConnection,
notif_client: NotifClient,
) {
debug!("Incoming connection from: {}", addr);
@@ -194,7 +192,7 @@
outgoing,
first_msg,
db_client,
- amqp_channel,
+ amqp_connection,
notif_client,
)
.await
@@ -231,16 +229,31 @@
loop {
trace!("Polling for messages from: {}", addr);
tokio::select! {
- Some(Ok(delivery)) = session.next_amqp_message() => {
- if let Ok(message) = std::str::from_utf8(&delivery.data) {
- if message == WS_SESSION_CLOSE_AMQP_MSG {
- debug!("Connection to {} closed by server.", addr);
+ Some(delivery_result) = session.next_amqp_message() => {
+ match delivery_result {
+ Ok(delivery) => {
+ if let Ok(message) = std::str::from_utf8(&delivery.data) {
+ if message == WS_SESSION_CLOSE_AMQP_MSG {
+ debug!("Connection to {} closed by server.", addr);
+ break;
+ } else {
+ session.send_message_to_device(Message::Text(message.to_string())).await;
+ }
+ } else {
+ error!("Invalid payload");
+ }
+ },
+ Err(ref err) if crate::amqp::is_connection_error(err) => {
+ if let Err(e) = session.reset_failed_amqp().await {
+ warn!("Connection to {} closed due to failed AMQP restoration: {:?}", addr, e);
+ break;
+ }
+ continue;
+ }
+ Err(err) => {
+ warn!("Connection to {} closed due to AMQP error: {:?}", addr, err);
break;
- } else {
- session.send_message_to_device(Message::Text(message.to_string())).await;
}
- } else {
- error!("Invalid payload");
}
},
device_message = incoming.next() => {
@@ -316,21 +329,14 @@
outgoing: SplitSink<WebSocketStream<S>, Message>,
frame: Message,
db_client: DatabaseClient,
- amqp_channel: lapin::Channel,
+ amqp: AmqpConnection,
notif_client: NotifClient,
) -> Result<WebsocketSession<S>, ErrorWithStreamHandle<S>> {
- let initialized_session =
- initialize_amqp(db_client.clone(), frame, &amqp_channel).await;
+ let device_info = match handle_first_ws_frame(frame).await {
+ Ok(info) => info,
+ Err(e) => return Err((e, outgoing)),
+ };
- match initialized_session {
- Ok((device_info, amqp_consumer)) => Ok(WebsocketSession::new(
- outgoing,
- db_client,
- device_info,
- amqp_channel,
- amqp_consumer,
- notif_client,
- )),
- Err(e) => Err((e, outgoing)),
- }
+ WebsocketSession::new(outgoing, db_client, device_info, amqp, notif_client)
+ .await
}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,3 +1,4 @@
+use crate::amqp::{is_connection_error, AmqpConnection};
use crate::constants::{
error_types, CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY,
MAX_RMQ_MSG_PRIORITY, RMQ_CONSUMER_TAG,
@@ -24,7 +25,7 @@
use reqwest::Url;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
-use tracing::{debug, error, info, trace};
+use tracing::{debug, error, info, trace, warn};
use tunnelbroker_messages::bad_device_token::BadDeviceToken;
use tunnelbroker_messages::Platform;
use tunnelbroker_messages::{
@@ -48,6 +49,7 @@
use crate::notifs::{apns, NotifClient, NotifClientType};
use crate::{identity, notifs};
+#[derive(Clone)]
pub struct DeviceInfo {
pub device_id: String,
pub notify_token: Option<String>,
@@ -61,6 +63,7 @@
tx: SplitSink<WebSocketStream<S>, Message>,
db_client: DatabaseClient,
pub device_info: DeviceInfo,
+ amqp: AmqpConnection,
amqp_channel: lapin::Channel,
// Stream of messages from AMQP endpoint
amqp_consumer: lapin::Consumer,
@@ -200,11 +203,9 @@
Ok(())
}
-pub async fn initialize_amqp(
- db_client: DatabaseClient,
+pub async fn handle_first_ws_frame(
frame: Message,
- amqp_channel: &lapin::Channel,
-) -> Result<(DeviceInfo, lapin::Consumer), SessionError> {
+) -> Result<DeviceInfo, SessionError> {
let device_info = match frame {
Message::Text(payload) => {
handle_first_message_from_device(&payload).await?
@@ -214,42 +215,93 @@
return Err(SessionError::InvalidMessage);
}
};
-
- let mut args = FieldTable::default();
- args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into());
- amqp_channel
- .queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args)
- .await?;
-
- publish_persisted_messages(&db_client, amqp_channel, &device_info).await?;
-
- let amqp_consumer = amqp_channel
- .basic_consume(
- &device_info.device_id,
- RMQ_CONSUMER_TAG,
- BasicConsumeOptions::default(),
- FieldTable::default(),
- )
- .await?;
- Ok((device_info, amqp_consumer))
+ Ok(device_info)
}
impl<S: AsyncRead + AsyncWrite + Unpin> WebsocketSession<S> {
- pub fn new(
+ pub async fn new(
tx: SplitSink<WebSocketStream<S>, Message>,
db_client: DatabaseClient,
device_info: DeviceInfo,
- amqp_channel: lapin::Channel,
- amqp_consumer: lapin::Consumer,
+ amqp: AmqpConnection,
notif_client: NotifClient,
- ) -> Self {
- Self {
+ ) -> Result<Self, super::ErrorWithStreamHandle<S>> {
+ let (amqp_channel, amqp_consumer) =
+ match Self::init_amqp(&device_info, &db_client, &amqp).await {
+ Ok(consumer) => consumer,
+ Err(err) => return Err((err, tx)),
+ };
+
+ Ok(Self {
tx,
db_client,
device_info,
+ amqp,
amqp_channel,
amqp_consumer,
notif_client,
+ })
+ }
+
+ async fn init_amqp(
+ device_info: &DeviceInfo,
+ db_client: &DatabaseClient,
+ amqp: &AmqpConnection,
+ ) -> Result<(lapin::Channel, lapin::Consumer), SessionError> {
+ let amqp_channel = amqp.channel(&device_info.device_id).await?;
+ debug!(
+ "Got AMQP Channel Id={} for device '{}'",
+ amqp_channel.id(),
+ device_info.device_id
+ );
+
+ let mut args = FieldTable::default();
+ args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into());
+ amqp_channel
+ .queue_declare(
+ &device_info.device_id,
+ QueueDeclareOptions::default(),
+ args,
+ )
+ .await?;
+
+ publish_persisted_messages(db_client, &amqp_channel, device_info).await?;
+
+ // cancel previous consumer. If not done, Rabbit yells that
+ // "trying to reuse tag" and closes channels.
+ if let Err(e) = amqp_channel
+ .basic_cancel(RMQ_CONSUMER_TAG, BasicCancelOptions::default())
+ .await
+ {
+ warn!(
+ errorType = error_types::AMQP_ERROR,
+ "Failed to cancel previous consumer: {}", e
+ );
}
+
+ let amqp_consumer = amqp_channel
+ .basic_consume(
+ &device_info.device_id,
+ RMQ_CONSUMER_TAG,
+ BasicConsumeOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+ Ok((amqp_channel, amqp_consumer))
+ }
+
+ pub async fn reset_failed_amqp(&mut self) -> Result<(), SessionError> {
+ debug!(
+ "Resetting failed amqp for session with {}",
+ &self.device_info.device_id
+ );
+
+ let (amqp_channel, amqp_consumer) =
+ Self::init_amqp(&self.device_info, &self.db_client, &self.amqp).await?;
+
+ self.amqp_channel = amqp_channel;
+ self.amqp_consumer = amqp_consumer;
+
+ Ok(())
}
pub async fn handle_message_to_device(
@@ -704,10 +756,15 @@
)
.await
{
- error!(
- errorType = error_types::AMQP_ERROR,
- "Failed to cancel consumer: {}", e
- );
+ if is_connection_error(&e) {
+ warn!("AMQP connection dead when closing WS session.");
+ self.amqp.trigger_reconnect();
+ } else {
+ error!(
+ errorType = error_types::AMQP_ERROR,
+ "Failed to cancel consumer: {}", e
+ );
+ }
}
if let Err(e) = self
@@ -718,10 +775,15 @@
)
.await
{
- error!(
- errorType = error_types::AMQP_ERROR,
- "Failed to delete queue: {}", e
- );
+ if is_connection_error(&e) {
+ warn!("AMQP connection dead when closing WS session.");
+ self.amqp.trigger_reconnect();
+ } else {
+ error!(
+ errorType = error_types::AMQP_ERROR,
+ "Failed to delete queue: {}", e
+ );
+ }
}
}
diff --git a/shared/tunnelbroker_messages/src/messages/session.rs b/shared/tunnelbroker_messages/src/messages/session.rs
--- a/shared/tunnelbroker_messages/src/messages/session.rs
+++ b/shared/tunnelbroker_messages/src/messages/session.rs
@@ -18,7 +18,7 @@
/// messages to device
/// - Tunnelbroker then polls for incoming messages from device
-#[derive(Serialize, Deserialize, Debug, PartialEq)]
+#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "camelCase")]
pub enum DeviceTypes {
Mobile,

File Metadata

Mime Type
text/plain
Expires
Sat, Oct 5, 5:47 AM (8 h, 23 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2234120
Default Alt Text
D13595.id.diff (21 KB)

Event Timeline