diff --git a/services/tunnelbroker/src/cxx_bridge.rs b/services/tunnelbroker/src/cxx_bridge.rs --- a/services/tunnelbroker/src/cxx_bridge.rs +++ b/services/tunnelbroker/src/cxx_bridge.rs @@ -57,5 +57,9 @@ notifyToken: &str, ) -> NewSessionResult; pub fn getSessionItem(sessionID: &str) -> Result; + pub fn updateSessionItemIsOnline( + sessionID: &str, + isOnline: bool, + ) -> Result<()>; } } diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.h b/services/tunnelbroker/src/libcpp/Tunnelbroker.h --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.h +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.h @@ -16,3 +16,4 @@ rust::Str deviceOS, rust::Str notifyToken); SessionItem getSessionItem(rust::Str sessionID); +void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline); diff --git a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp --- a/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp +++ b/services/tunnelbroker/src/libcpp/Tunnelbroker.cpp @@ -170,3 +170,8 @@ .deviceOS = sessionItem->getDeviceOs(), .isOnline = sessionItem->getIsOnline()}; } + +void updateSessionItemIsOnline(rust::Str sessionID, bool isOnline) { + comm::network::database::DatabaseManager::getInstance() + .updateSessionItemIsOnline(std::string{sessionID}, isOnline); +} diff --git a/services/tunnelbroker/src/server/mod.rs b/services/tunnelbroker/src/server/mod.rs --- a/services/tunnelbroker/src/server/mod.rs +++ b/services/tunnelbroker/src/server/mod.rs @@ -1,13 +1,16 @@ use super::constants; use super::cxx_bridge::ffi::{ - getSessionItem, newSessionHandler, sessionSignatureHandler, GRPCStatusCodes, + getSessionItem, newSessionHandler, sessionSignatureHandler, + updateSessionItemIsOnline, GRPCStatusCodes, }; use anyhow::Result; use futures::Stream; use std::pin::Pin; use tokio::sync::mpsc; +use tokio::time::{sleep, Duration}; use tokio_stream::wrappers::ReceiverStream; use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tracing::debug; use tunnelbroker::tunnelbroker_service_server::{ TunnelbrokerService, TunnelbrokerServiceServer, }; @@ -98,7 +101,53 @@ Err(err) => return Err(Status::unauthenticated(err.what())), }; - let (_tx, rx) = mpsc::channel(constants::GRPC_TX_QUEUE_SIZE); + let (tx, rx) = mpsc::channel(constants::GRPC_TX_QUEUE_SIZE); + + // Through this function, we will write to the output stream from different Tokio + // tasks and update the device's online status if the write was unsuccessful + async fn tx_writer( + session_id: &str, + channel: &tokio::sync::mpsc::Sender, + payload: T, + ) -> Result<(), String> { + let result = channel.send(payload).await; + match result { + Ok(result) => Ok(result), + Err(err) => { + if let Err(err) = updateSessionItemIsOnline(&session_id, false) { + return Err(err.what().to_string()); + } + return Err(err.to_string()); + } + } + } + + if let Err(err) = updateSessionItemIsOnline(&session_id, true) { + return Err(Status::internal(err.what())); + } + + // Spawning asynchronous Tokio task with the client pinging loop inside to + // make sure that the client is online + tokio::spawn({ + let session_id = session_id.clone(); + let tx = tx.clone(); + async move { + loop { + sleep(Duration::from_millis(constants::GRPC_PING_INTERVAL_MS)).await; + let result = tx_writer( + &session_id, + &tx, + Ok(tunnelbroker::MessageToClient { + data: Some(tunnelbroker::message_to_client::Data::Ping(())), + }), + ); + if let Err(err) = result.await { + debug!("Failed to write ping to a channel: {}", err); + break; + }; + } + } + }); let output_stream = ReceiverStream::new(rx); Ok(Response::new(