diff --git a/services/tunnelbroker/src/websockets/mod.rs b/services/tunnelbroker/src/websockets/mod.rs
--- a/services/tunnelbroker/src/websockets/mod.rs
+++ b/services/tunnelbroker/src/websockets/mod.rs
@@ -7,6 +7,7 @@
 use futures_util::StreamExt;
 use std::net::SocketAddr;
 use std::{env, io::Error};
+use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::net::{TcpListener, TcpStream};
 use tokio_tungstenite::tungstenite::Message;
 use tokio_tungstenite::WebSocketStream;
@@ -106,12 +107,12 @@
   session.close().await
 }
 
-async fn initiate_session(
-  outgoing: SplitSink<WebSocketStream<TcpStream>, Message>,
+async fn initiate_session<S: AsyncRead + AsyncWrite + Unpin>(
+  outgoing: SplitSink<WebSocketStream<S>, Message>,
   frame: Message,
   db_client: DatabaseClient,
   amqp_channel: lapin::Channel,
-) -> Result<WebsocketSession, session::SessionError> {
+) -> Result<WebsocketSession<S>, session::SessionError> {
   let mut session = session::WebsocketSession::from_frame(
     outgoing,
     db_client.clone(),
diff --git a/services/tunnelbroker/src/websockets/session.rs b/services/tunnelbroker/src/websockets/session.rs
--- a/services/tunnelbroker/src/websockets/session.rs
+++ b/services/tunnelbroker/src/websockets/session.rs
@@ -5,7 +5,8 @@
 use lapin::message::Delivery;
 use lapin::options::{BasicConsumeOptions, QueueDeclareOptions};
 use lapin::types::FieldTable;
-use tokio::net::TcpStream;
+use tokio::io::AsyncRead;
+use tokio::io::AsyncWrite;
 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
 use tracing::{debug, error};
 use tunnelbroker_messages::{session::DeviceTypes, Messages};
@@ -20,8 +21,8 @@
   pub device_os: Option<String>,
 }
 
-pub struct WebsocketSession {
-  tx: SplitSink<WebSocketStream<TcpStream>, Message>,
+pub struct WebsocketSession<S> {
+  tx: SplitSink<WebSocketStream<S>, Message>,
   db_client: DatabaseClient,
   pub device_info: DeviceInfo,
   // Stream of messages from AMQP endpoint
@@ -58,22 +59,22 @@
         device_os: session_info.device_os.take(),
       };
 
-      return Ok(device_info);
+      Ok(device_info)
     }
     _ => {
       debug!("Received invalid request");
-      return Err(SessionError::InvalidMessage);
+      Err(SessionError::InvalidMessage)
     }
   }
 }
 
-impl WebsocketSession {
+impl<S: AsyncRead + AsyncWrite + Unpin> WebsocketSession<S> {
   pub async fn from_frame(
-    tx: SplitSink<WebSocketStream<TcpStream>, Message>,
+    tx: SplitSink<WebSocketStream<S>, Message>,
     db_client: DatabaseClient,
     frame: Message,
     amqp_channel: &lapin::Channel,
-  ) -> Result<WebsocketSession, SessionError> {
+  ) -> Result<WebsocketSession<S>, SessionError> {
     let device_info = match frame {
       Message::Text(payload) => handle_first_message_from_device(&payload)?,
       _ => {