Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F3509550
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
23 KB
Referenced Files
None
Subscribers
None
View Options
diff --git a/services/tunnelbroker/src/notifs/apns/mod.rs b/services/tunnelbroker/src/notifs/apns/mod.rs
index 8833a2eba..ccf39b6bb 100644
--- a/services/tunnelbroker/src/notifs/apns/mod.rs
+++ b/services/tunnelbroker/src/notifs/apns/mod.rs
@@ -1,137 +1,137 @@
use crate::notifs::apns::config::APNsConfig;
use crate::notifs::apns::error::Error::ResponseError;
use crate::notifs::apns::headers::{NotificationHeaders, PushType};
use crate::notifs::apns::response::ErrorBody;
use crate::notifs::apns::token::APNsToken;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::debug;
pub mod config;
pub mod error;
pub(crate) mod headers;
-mod response;
+pub mod response;
pub mod token;
#[derive(Clone)]
pub struct APNsClient {
http2_client: reqwest::Client,
token: APNsToken,
is_prod: bool,
}
#[derive(Serialize, Deserialize)]
pub struct APNsNotif {
pub device_token: String,
pub headers: NotificationHeaders,
pub payload: String,
}
impl APNsClient {
pub fn new(config: &APNsConfig) -> Result<Self, error::Error> {
let token_ttl = Duration::from_secs(60 * 55);
let token = APNsToken::new(config, token_ttl)?;
let http2_client = reqwest::Client::builder()
.http2_prior_knowledge()
.http2_keep_alive_interval(Some(Duration::from_secs(5)))
.http2_keep_alive_while_idle(true)
.build()?;
Ok(APNsClient {
http2_client,
token,
is_prod: config.production,
})
}
async fn build_headers(
&self,
notif_headers: NotificationHeaders,
) -> Result<HeaderMap, error::Error> {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let bearer = self.token.get_bearer().await?;
let token = format!("bearer {bearer}");
headers.insert(AUTHORIZATION, HeaderValue::from_str(&token)?);
if let Some(apns_topic) = ¬if_headers.apns_topic {
headers.insert("apns-topic", HeaderValue::from_str(apns_topic)?);
}
if let Some(apns_id) = ¬if_headers.apns_id {
headers.insert("apns-id", HeaderValue::from_str(apns_id)?);
}
if let Some(push_type) = ¬if_headers.apns_push_type {
let push_type_str = match push_type {
PushType::Alert => "alert",
PushType::Background => "background",
PushType::Location => "location",
PushType::Voip => "voip",
PushType::Complication => "complication",
PushType::FileProvider => "fileprovider",
PushType::Mdm => "mdm",
PushType::LiveActivity => "live",
PushType::PushToTalk => "pushtotalk",
};
headers.insert("apns-push-type", HeaderValue::from_static(push_type_str));
}
if let Some(expiration) = notif_headers.apns_expiration {
headers.insert("apns-expiration", HeaderValue::from(expiration));
}
if let Some(priority) = notif_headers.apns_priority {
headers.insert("apns-priority", HeaderValue::from(priority));
}
if let Some(collapse_id) = ¬if_headers.apns_collapse_id {
headers.insert("apns-collapse-id", HeaderValue::from_str(collapse_id)?);
}
Ok(headers)
}
fn get_endpoint(&self) -> &'static str {
if self.is_prod {
return "api.push.apple.com";
}
"api.development.push.apple.com"
}
pub async fn send(&self, notif: APNsNotif) -> Result<(), error::Error> {
debug!("Sending APNs notif to {}", notif.device_token);
let headers = self.build_headers(notif.headers.clone()).await?;
let url = format!(
"https://{}/3/device/{}",
self.get_endpoint(),
notif.device_token
);
let response = self
.http2_client
.post(url)
.headers(headers.clone())
.body(notif.payload)
.send()
.await?;
match response.status() {
StatusCode::OK => Ok(()),
_ => {
let error_body: ErrorBody = response.json().await?;
Err(ResponseError(error_body))
}
}
}
}
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
index d759ae879..29bbdb20d 100644
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -1,586 +1,630 @@
use crate::constants::{
CLIENT_RMQ_MSG_PRIORITY, DDB_RMQ_MSG_PRIORITY, MAX_RMQ_MSG_PRIORITY,
RMQ_CONSUMER_TAG,
};
use comm_lib::aws::ddb::error::SdkError;
use comm_lib::aws::ddb::operation::put_item::PutItemError;
use derive_more;
use futures_util::stream::SplitSink;
use futures_util::SinkExt;
use futures_util::StreamExt;
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
use lapin::message::Delivery;
use lapin::options::{
BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions,
QueueDeclareOptions, QueueDeleteOptions,
};
use lapin::types::FieldTable;
use lapin::BasicProperties;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tracing::{debug, error, info, trace};
+use tunnelbroker_messages::bad_device_token::BadDeviceToken;
use tunnelbroker_messages::{
message_to_device_request_status::Failure,
message_to_device_request_status::MessageSentStatus, session::DeviceTypes,
DeviceToTunnelbrokerMessage, Heartbeat, MessageToDevice,
MessageToDeviceRequest, MessageToTunnelbroker,
};
+use crate::notifs::apns::response::ErrorReason;
+
use crate::database::{self, DatabaseClient, MessageToDeviceExt};
use crate::identity;
+use crate::notifs::apns::error::Error;
use crate::notifs::apns::headers::NotificationHeaders;
use crate::notifs::apns::APNsNotif;
use crate::notifs::fcm::firebase_message::{
AndroidConfig, AndroidMessagePriority, FCMMessage,
};
use crate::notifs::web_push::WebPushNotif;
use crate::notifs::NotifClient;
pub struct DeviceInfo {
pub device_id: String,
pub notify_token: Option<String>,
pub device_type: DeviceTypes,
pub device_app_version: Option<String>,
pub device_os: Option<String>,
pub is_authenticated: bool,
}
pub struct WebsocketSession<S> {
tx: SplitSink<WebSocketStream<S>, Message>,
db_client: DatabaseClient,
pub device_info: DeviceInfo,
amqp_channel: lapin::Channel,
// Stream of messages from AMQP endpoint
amqp_consumer: lapin::Consumer,
notif_client: NotifClient,
}
#[derive(
Debug, derive_more::Display, derive_more::From, derive_more::Error,
)]
pub enum SessionError {
InvalidMessage,
SerializationError(serde_json::Error),
MessageError(database::MessageErrors),
AmqpError(lapin::Error),
InternalError,
UnauthorizedDevice,
PersistenceError(SdkError<PutItemError>),
DatabaseError(comm_lib::database::Error),
MissingAPNsClient,
MissingFCMClient,
MissingWebPushClient,
MissingDeviceToken,
InvalidDeviceToken,
}
// Parse a session request and retrieve the device information
pub async fn handle_first_message_from_device(
message: &str,
) -> Result<DeviceInfo, SessionError> {
let serialized_message =
serde_json::from_str::<DeviceToTunnelbrokerMessage>(message)?;
match serialized_message {
DeviceToTunnelbrokerMessage::ConnectionInitializationMessage(
mut session_info,
) => {
let device_info = DeviceInfo {
device_id: session_info.device_id.clone(),
notify_token: session_info.notify_token.take(),
device_type: session_info.device_type,
device_app_version: session_info.device_app_version.take(),
device_os: session_info.device_os.take(),
is_authenticated: true,
};
// Authenticate device
debug!("Authenticating device: {}", &session_info.device_id);
let auth_request = identity::verify_user_access_token(
&session_info.user_id,
&device_info.device_id,
&session_info.access_token,
)
.await;
match auth_request {
Err(e) => {
error!("Failed to complete request to identity service: {:?}", e);
return Err(SessionError::InternalError);
}
Ok(false) => {
info!("Device failed authentication: {}", &session_info.device_id);
return Err(SessionError::UnauthorizedDevice);
}
Ok(true) => {
debug!(
"Successfully authenticated device: {}",
&session_info.device_id
);
}
}
Ok(device_info)
}
DeviceToTunnelbrokerMessage::AnonymousInitializationMessage(
session_info,
) => {
debug!(
"Starting unauthenticated session with device: {}",
&session_info.device_id
);
let device_info = DeviceInfo {
device_id: session_info.device_id,
device_type: session_info.device_type,
device_app_version: session_info.device_app_version,
device_os: session_info.device_os,
is_authenticated: false,
notify_token: None,
};
Ok(device_info)
}
_ => {
debug!("Received invalid request");
Err(SessionError::InvalidMessage)
}
}
}
async fn publish_persisted_messages(
db_client: &DatabaseClient,
amqp_channel: &lapin::Channel,
device_info: &DeviceInfo,
) -> Result<(), SessionError> {
let messages = db_client
.retrieve_messages(&device_info.device_id)
.await
.unwrap_or_else(|e| {
error!("Error while retrieving messages: {}", e);
Vec::new()
});
for message in messages {
let message_to_device = MessageToDevice::from_hashmap(message)?;
let serialized_message = serde_json::to_string(&message_to_device)?;
amqp_channel
.basic_publish(
"",
&message_to_device.device_id,
BasicPublishOptions::default(),
serialized_message.as_bytes(),
BasicProperties::default().with_priority(DDB_RMQ_MSG_PRIORITY),
)
.await?;
}
debug!("Flushed messages for device: {}", &device_info.device_id);
Ok(())
}
pub async fn initialize_amqp(
db_client: DatabaseClient,
frame: Message,
amqp_channel: &lapin::Channel,
) -> Result<(DeviceInfo, lapin::Consumer), SessionError> {
let device_info = match frame {
Message::Text(payload) => {
handle_first_message_from_device(&payload).await?
}
_ => {
error!("Client sent wrong frame type for establishing connection");
return Err(SessionError::InvalidMessage);
}
};
let mut args = FieldTable::default();
args.insert("x-max-priority".into(), MAX_RMQ_MSG_PRIORITY.into());
amqp_channel
.queue_declare(&device_info.device_id, QueueDeclareOptions::default(), args)
.await?;
publish_persisted_messages(&db_client, amqp_channel, &device_info).await?;
let amqp_consumer = amqp_channel
.basic_consume(
&device_info.device_id,
RMQ_CONSUMER_TAG,
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await?;
Ok((device_info, amqp_consumer))
}
impl<S: AsyncRead + AsyncWrite + Unpin> WebsocketSession<S> {
pub fn new(
tx: SplitSink<WebSocketStream<S>, Message>,
db_client: DatabaseClient,
device_info: DeviceInfo,
amqp_channel: lapin::Channel,
amqp_consumer: lapin::Consumer,
notif_client: NotifClient,
) -> Self {
Self {
tx,
db_client,
device_info,
amqp_channel,
amqp_consumer,
notif_client,
}
}
pub async fn handle_message_to_device(
&self,
message_request: &MessageToDeviceRequest,
) -> Result<(), SessionError> {
let message_id = self
.db_client
.persist_message(
&message_request.device_id,
&message_request.payload,
&message_request.client_message_id,
)
.await?;
let message_to_device = MessageToDevice {
device_id: message_request.device_id.clone(),
payload: message_request.payload.clone(),
message_id: message_id.clone(),
};
let serialized_message = serde_json::to_string(&message_to_device)?;
let publish_result = self
.amqp_channel
.basic_publish(
"",
&message_request.device_id,
BasicPublishOptions::default(),
serialized_message.as_bytes(),
BasicProperties::default().with_priority(CLIENT_RMQ_MSG_PRIORITY),
)
.await;
if let Err(publish_error) = publish_result {
self
.db_client
.delete_message(&self.device_info.device_id, &message_id)
.await
.expect("Error deleting message");
return Err(SessionError::AmqpError(publish_error));
}
Ok(())
}
pub async fn handle_message_to_tunnelbroker(
&self,
message_to_tunnelbroker: &MessageToTunnelbroker,
) -> Result<(), SessionError> {
match message_to_tunnelbroker {
MessageToTunnelbroker::SetDeviceToken(token) => {
self
.db_client
.set_device_token(&self.device_info.device_id, &token.device_token)
.await?;
}
}
Ok(())
}
pub async fn handle_websocket_frame_from_device(
&mut self,
msg: String,
) -> Option<MessageSentStatus> {
let Ok(serialized_message) =
serde_json::from_str::<DeviceToTunnelbrokerMessage>(&msg)
else {
return Some(MessageSentStatus::SerializationError(msg));
};
match serialized_message {
DeviceToTunnelbrokerMessage::Heartbeat(Heartbeat {}) => {
trace!("Received heartbeat from: {}", self.device_info.device_id);
None
}
DeviceToTunnelbrokerMessage::MessageReceiveConfirmation(confirmation) => {
for message_id in confirmation.message_ids {
if let Err(e) = self
.db_client
.delete_message(&self.device_info.device_id, &message_id)
.await
{
error!("Failed to delete message: {}:", e);
}
}
None
}
DeviceToTunnelbrokerMessage::MessageToDeviceRequest(message_request) => {
// unauthenticated clients cannot send messages
if !self.device_info.is_authenticated {
debug!(
"Unauthenticated device {} tried to send text message. Aborting.",
self.device_info.device_id
);
return Some(MessageSentStatus::Unauthenticated);
}
debug!("Received message for {}", message_request.device_id);
let result = self.handle_message_to_device(&message_request).await;
Some(self.get_message_to_device_status(
&message_request.client_message_id,
result,
))
}
DeviceToTunnelbrokerMessage::MessageToTunnelbrokerRequest(
message_request,
) => {
// unauthenticated clients cannot send messages
if !self.device_info.is_authenticated {
debug!(
"Unauthenticated device {} tried to send text message. Aborting.",
self.device_info.device_id
);
return Some(MessageSentStatus::Unauthenticated);
}
debug!("Received message for Tunnelbroker");
let Ok(message_to_tunnelbroker) =
serde_json::from_str(&message_request.payload)
else {
return Some(MessageSentStatus::SerializationError(
message_request.payload,
));
};
let result = self
.handle_message_to_tunnelbroker(&message_to_tunnelbroker)
.await;
Some(self.get_message_to_device_status(
&message_request.client_message_id,
result,
))
}
DeviceToTunnelbrokerMessage::APNsNotif(notif) => {
// unauthenticated clients cannot send notifs
if !self.device_info.is_authenticated {
debug!(
"Unauthenticated device {} tried to send text notif. Aborting.",
self.device_info.device_id
);
return Some(MessageSentStatus::Unauthenticated);
}
debug!("Received APNs notif for {}", notif.device_id);
let Ok(headers) =
serde_json::from_str::<NotificationHeaders>(¬if.headers)
else {
return Some(MessageSentStatus::SerializationError(notif.headers));
};
- let device_token = match self.get_device_token(notif.device_id).await {
- Ok(token) => token,
- Err(e) => {
- return Some(
- self
- .get_message_to_device_status(¬if.client_message_id, Err(e)),
- )
- }
- };
+ let device_token =
+ match self.get_device_token(notif.device_id.clone()).await {
+ Ok(token) => token,
+ Err(e) => {
+ return Some(self.get_message_to_device_status(
+ ¬if.client_message_id,
+ Err(e),
+ ))
+ }
+ };
let apns_notif = APNsNotif {
- device_token,
+ device_token: device_token.clone(),
headers,
payload: notif.payload,
};
if let Some(apns) = self.notif_client.apns.clone() {
let response = apns.send(apns_notif).await;
+ if let Err(Error::ResponseError(body)) = &response {
+ if matches!(
+ body.reason,
+ ErrorReason::BadDeviceToken
+ | ErrorReason::Unregistered
+ | ErrorReason::ExpiredToken
+ ) {
+ if let Err(e) = self
+ .invalidate_device_token(notif.device_id, device_token)
+ .await
+ {
+ error!("Error invalidating device token: {:?}", e);
+ };
+ }
+ }
return Some(
self
.get_message_to_device_status(¬if.client_message_id, response),
);
}
Some(self.get_message_to_device_status(
¬if.client_message_id,
Err(SessionError::MissingAPNsClient),
))
}
DeviceToTunnelbrokerMessage::FCMNotif(notif) => {
// unauthenticated clients cannot send notifs
if !self.device_info.is_authenticated {
debug!(
"Unauthenticated device {} tried to send text notif. Aborting.",
self.device_info.device_id
);
return Some(MessageSentStatus::Unauthenticated);
}
debug!("Received FCM notif for {}", notif.device_id);
let Some(priority) = AndroidMessagePriority::from_str(¬if.priority)
else {
return Some(MessageSentStatus::SerializationError(notif.priority));
};
let Ok(data) = serde_json::from_str(¬if.data) else {
return Some(MessageSentStatus::SerializationError(notif.data));
};
let device_token = match self.get_device_token(notif.device_id).await {
Ok(token) => token,
Err(e) => {
return Some(
self
.get_message_to_device_status(¬if.client_message_id, Err(e)),
)
}
};
let fcm_message = FCMMessage {
data,
token: device_token.to_string(),
android: AndroidConfig { priority },
};
if let Some(fcm) = self.notif_client.fcm.clone() {
let response = fcm.send(fcm_message).await;
return Some(
self
.get_message_to_device_status(¬if.client_message_id, response),
);
}
Some(self.get_message_to_device_status(
¬if.client_message_id,
Err(SessionError::MissingFCMClient),
))
}
DeviceToTunnelbrokerMessage::WebPushNotif(notif) => {
// unauthenticated clients cannot send notifs
if !self.device_info.is_authenticated {
debug!(
"Unauthenticated device {} tried to send web push notif. Aborting.",
self.device_info.device_id
);
return Some(MessageSentStatus::Unauthenticated);
}
debug!("Received WebPush notif for {}", notif.device_id);
let Some(web_push_client) = self.notif_client.web_push.clone() else {
return Some(self.get_message_to_device_status(
¬if.client_message_id,
Err(SessionError::MissingWebPushClient),
));
};
let device_token = match self.get_device_token(notif.device_id).await {
Ok(token) => token,
Err(e) => {
return Some(
self
.get_message_to_device_status(¬if.client_message_id, Err(e)),
)
}
};
let web_push_notif = WebPushNotif {
device_token,
payload: notif.payload,
};
let result = web_push_client.send(web_push_notif).await;
Some(
self.get_message_to_device_status(¬if.client_message_id, result),
)
}
_ => {
error!("Client sent invalid message type");
Some(MessageSentStatus::InvalidRequest)
}
}
}
pub async fn next_amqp_message(
&mut self,
) -> Option<Result<Delivery, lapin::Error>> {
self.amqp_consumer.next().await
}
pub async fn send_message_to_device(&mut self, message: Message) {
if let Err(e) = self.tx.send(message).await {
error!("Failed to send message to device: {}", e);
}
}
// Release WebSocket and remove from active connections
pub async fn close(&mut self) {
if let Err(e) = self.tx.close().await {
debug!("Failed to close WebSocket session: {}", e);
}
if let Err(e) = self
.amqp_channel
.basic_cancel(
self.amqp_consumer.tag().as_str(),
BasicCancelOptions::default(),
)
.await
{
error!("Failed to cancel consumer: {}", e);
}
if let Err(e) = self
.amqp_channel
.queue_delete(
self.device_info.device_id.as_str(),
QueueDeleteOptions::default(),
)
.await
{
error!("Failed to delete queue: {}", e);
}
}
pub fn get_message_to_device_status<E>(
&mut self,
client_message_id: &str,
result: Result<(), E>,
) -> MessageSentStatus
where
E: std::error::Error,
{
match result {
Ok(()) => MessageSentStatus::Success(client_message_id.to_string()),
Err(err) => MessageSentStatus::Error(Failure {
id: client_message_id.to_string(),
error: err.to_string(),
}),
}
}
async fn get_device_token(
&self,
device_id: String,
) -> Result<String, SessionError> {
let db_token = self
.db_client
.get_device_token(&device_id)
.await
.map_err(SessionError::DatabaseError)?;
match db_token {
Some(token) => {
if token.token_invalid {
Err(SessionError::InvalidDeviceToken)
} else {
Ok(token.device_token)
}
}
None => Err(SessionError::MissingDeviceToken),
}
}
+
+ async fn invalidate_device_token(
+ &self,
+ device_id: String,
+ invalidated_token: String,
+ ) -> Result<(), SessionError> {
+ let bad_device_token_message = BadDeviceToken { invalidated_token };
+ let payload = serde_json::to_string(&bad_device_token_message)?;
+ let message_request = MessageToDeviceRequest {
+ client_message_id: uuid::Uuid::new_v4().to_string(),
+ device_id: device_id.to_string(),
+ payload,
+ };
+
+ self.handle_message_to_device(&message_request).await?;
+
+ self
+ .db_client
+ .mark_device_token_as_invalid(&device_id)
+ .await
+ .map_err(SessionError::DatabaseError)?;
+
+ Ok(())
+ }
}
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Mon, Dec 23, 6:11 AM (23 h, 40 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
2690448
Default Alt Text
(23 KB)
Attached To
Mode
rCOMM Comm
Attached
Detach File
Event Timeline
Log In to Comment