diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs index 02b5394e5..13612400d 100644 --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -1,228 +1,232 @@ // Secrets pub const SECRETS_DIRECTORY: &str = "secrets"; pub const SECRETS_SETUP_FILE: &str = "server_setup.txt"; // DynamoDB // User table information, supporting opaque_ke 2.0 and X3DH information // Users can sign in either through username+password or Eth wallet. // // This structure should be aligned with the messages defined in // shared/protos/identity_unauthenticated.proto // // Structure for a user should be: // { // userID: String, // opaqueRegistrationData: Option, // username: Option, // walletAddress: Option, // devices: HashMap // } // // A device is defined as: // { // deviceType: String, # client or keyserver // keyPayload: String, // keyPayloadSignature: String, // identityPreKey: String, // identityPreKeySignature: String, // identityOneTimeKeys: Vec, // notifPreKey: String, // notifPreKeySignature: String, // notifOneTimeKeys: Vec, // socialProof: Option // } // } // // Additional context: // "devices" uses the signing public identity key of the device as a key for the devices map // "keyPayload" is a JSON encoded string containing identity and notif keys (both signature and verification) // if "deviceType" == "keyserver", then the device will not have any notif key information pub const USERS_TABLE: &str = "identity-users"; pub const USERS_TABLE_PARTITION_KEY: &str = "userID"; pub const USERS_TABLE_REGISTRATION_ATTRIBUTE: &str = "opaqueRegistrationData"; pub const USERS_TABLE_USERNAME_ATTRIBUTE: &str = "username"; pub const USERS_TABLE_DEVICES_ATTRIBUTE: &str = "devices"; pub const USERS_TABLE_DEVICES_MAP_DEVICE_TYPE_ATTRIBUTE_NAME: &str = "deviceType"; pub const USERS_TABLE_DEVICES_MAP_KEY_PAYLOAD_ATTRIBUTE_NAME: &str = "keyPayload"; pub const USERS_TABLE_DEVICES_MAP_KEY_PAYLOAD_SIGNATURE_ATTRIBUTE_NAME: &str = "keyPayloadSignature"; pub const USERS_TABLE_DEVICES_MAP_CONTENT_PREKEY_ATTRIBUTE_NAME: &str = "identityPreKey"; pub const USERS_TABLE_DEVICES_MAP_CONTENT_PREKEY_SIGNATURE_ATTRIBUTE_NAME: &str = "identityPreKeySignature"; pub const USERS_TABLE_DEVICES_MAP_CONTENT_ONE_TIME_KEYS_ATTRIBUTE_NAME: &str = "identityOneTimeKeys"; pub const USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_ATTRIBUTE_NAME: &str = "preKey"; pub const USERS_TABLE_DEVICES_MAP_NOTIF_PREKEY_SIGNATURE_ATTRIBUTE_NAME: &str = "preKeySignature"; pub const USERS_TABLE_DEVICES_MAP_NOTIF_ONE_TIME_KEYS_ATTRIBUTE_NAME: &str = "notifOneTimeKeys"; pub const USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE: &str = "walletAddress"; pub const USERS_TABLE_DEVICES_MAP_SOCIAL_PROOF_ATTRIBUTE_NAME: &str = "socialProof"; pub const USERS_TABLE_DEVICELIST_TIMESTAMP_ATTRIBUTE_NAME: &str = "deviceListTimestamp"; pub const USERS_TABLE_USERNAME_INDEX: &str = "username-index"; pub const USERS_TABLE_WALLET_ADDRESS_INDEX: &str = "walletAddress-index"; pub const ACCESS_TOKEN_TABLE: &str = "identity-tokens"; pub const ACCESS_TOKEN_TABLE_PARTITION_KEY: &str = "userID"; pub const ACCESS_TOKEN_SORT_KEY: &str = "signingPublicKey"; pub const ACCESS_TOKEN_TABLE_CREATED_ATTRIBUTE: &str = "created"; pub const ACCESS_TOKEN_TABLE_AUTH_TYPE_ATTRIBUTE: &str = "authType"; pub const ACCESS_TOKEN_TABLE_VALID_ATTRIBUTE: &str = "valid"; pub const ACCESS_TOKEN_TABLE_TOKEN_ATTRIBUTE: &str = "token"; pub const NONCE_TABLE: &str = "identity-nonces"; pub const NONCE_TABLE_PARTITION_KEY: &str = "nonce"; pub const NONCE_TABLE_CREATED_ATTRIBUTE: &str = "created"; pub const NONCE_TABLE_EXPIRATION_TIME_ATTRIBUTE: &str = "expirationTime"; pub const NONCE_TABLE_EXPIRATION_TIME_UNIX_ATTRIBUTE: &str = "expirationTimeUnix"; // Usernames reserved because they exist in Ashoat's keyserver already pub const RESERVED_USERNAMES_TABLE: &str = "identity-reserved-usernames"; pub const RESERVED_USERNAMES_TABLE_PARTITION_KEY: &str = "username"; pub mod devices_table { /// table name pub const NAME: &str = "identity-devices"; pub const TIMESTAMP_INDEX_NAME: &str = "deviceList-timestamp-index"; /// partition key pub const ATTR_USER_ID: &str = "userID"; /// sort key pub const ATTR_ITEM_ID: &str = "itemID"; // itemID prefixes (one shouldn't be a prefix of the other) pub const DEVICE_ITEM_KEY_PREFIX: &str = "device-"; pub const DEVICE_LIST_KEY_PREFIX: &str = "devicelist-"; // device-specific attrs pub const ATTR_DEVICE_TYPE: &str = "deviceType"; pub const ATTR_DEVICE_KEY_INFO: &str = "deviceKeyInfo"; pub const ATTR_CONTENT_PREKEY: &str = "contentPreKey"; pub const ATTR_NOTIF_PREKEY: &str = "notifPreKey"; // IdentityKeyInfo constants pub const ATTR_KEY_PAYLOAD: &str = "keyPayload"; pub const ATTR_KEY_PAYLOAD_SIGNATURE: &str = "keyPayloadSignature"; pub const ATTR_SOCIAL_PROOF: &str = "socialProof"; // PreKey constants pub const ATTR_PREKEY: &str = "preKey"; pub const ATTR_PREKEY_SIGNATURE: &str = "preKeySignature"; // device-list-specific attrs pub const ATTR_TIMESTAMP: &str = "timestamp"; pub const ATTR_DEVICE_IDS: &str = "deviceIDs"; // migration-specific attrs pub const ATTR_CODE_VERSION: &str = "codeVersion"; pub const ATTR_LOGIN_TIME: &str = "loginTime"; } // One time keys table, which need to exist in their own table to ensure // atomicity of additions and removals pub mod one_time_keys_table { // The `PARTITION_KEY` will contain "notification_${deviceID}" or // "content_${deviceID}" to allow for both key sets to coexist in the same table pub const NAME: &str = "identity-one-time-keys"; pub const PARTITION_KEY: &str = "deviceID"; pub const DEVICE_ID: &str = PARTITION_KEY; pub const SORT_KEY: &str = "oneTimeKey"; pub const ONE_TIME_KEY: &str = SORT_KEY; } // One-time key constants for device info map pub const CONTENT_ONE_TIME_KEY: &str = "contentOneTimeKey"; pub const NOTIF_ONE_TIME_KEY: &str = "notifOneTimeKey"; // Tokio pub const MPSC_CHANNEL_BUFFER_CAPACITY: usize = 1; pub const IDENTITY_SERVICE_SOCKET_ADDR: &str = "[::]:50054"; pub const IDENTITY_SERVICE_WEBSOCKET_ADDR: &str = "[::]:51004"; // Token pub const ACCESS_TOKEN_LENGTH: usize = 512; // Temporary config pub const AUTH_TOKEN: &str = "COMM_IDENTITY_SERVICE_AUTH_TOKEN"; pub const KEYSERVER_PUBLIC_KEY: &str = "KEYSERVER_PUBLIC_KEY"; // Nonce pub const NONCE_LENGTH: usize = 17; pub const NONCE_TTL_DURATION: i64 = 30; +// Identity + +pub const DEFAULT_IDENTITY_ENDPOINT: &str = "http://localhost:50054"; + // LocalStack pub const LOCALSTACK_ENDPOINT: &str = "LOCALSTACK_ENDPOINT"; // OPAQUE Server Setup pub const OPAQUE_SERVER_SETUP: &str = "OPAQUE_SERVER_SETUP"; // Opensearch Domain pub const OPENSEARCH_ENDPOINT: &str = "OPENSEARCH_ENDPOINT"; pub const DEFAULT_OPENSEARCH_ENDPOINT: &str = "identity-search-domain.us-east-2.opensearch.localhost.local stack.cloud:4566"; // Tunnelbroker pub const TUNNELBROKER_GRPC_ENDPOINT: &str = "TUNNELBROKER_GRPC_ENDPOINT"; pub const DEFAULT_TUNNELBROKER_ENDPOINT: &str = "http://localhost:50051"; // X3DH key management // Threshold for requesting more one_time keys pub const ONE_TIME_KEY_MINIMUM_THRESHOLD: usize = 5; // Number of keys to be refreshed when below the threshold pub const ONE_TIME_KEY_REFRESH_NUMBER: u32 = 5; // Minimum supported code versions pub const MIN_SUPPORTED_NATIVE_VERSION: u64 = 270; // Request metadata pub mod request_metadata { pub const CODE_VERSION: &str = "code_version"; pub const DEVICE_TYPE: &str = "device_type"; pub const USER_ID: &str = "user_id"; pub const DEVICE_ID: &str = "device_id"; pub const ACCESS_TOKEN: &str = "access_token"; } // CORS pub mod cors { use std::time::Duration; pub const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); pub const DEFAULT_EXPOSED_HEADERS: [&str; 3] = ["grpc-status", "grpc-message", "grpc-status-details-bin"]; pub const DEFAULT_ALLOW_HEADERS: [&str; 9] = [ "x-grpc-web", "content-type", "x-user-agent", "grpc-timeout", super::request_metadata::CODE_VERSION, super::request_metadata::DEVICE_TYPE, super::request_metadata::USER_ID, super::request_metadata::DEVICE_ID, super::request_metadata::ACCESS_TOKEN, ]; pub const DEFAULT_ALLOW_ORIGIN: [&str; 2] = ["https://web.comm.app", "http://localhost:3000"]; } diff --git a/services/identity/src/websockets/auth.rs b/services/identity/src/websockets/auth.rs new file mode 100644 index 000000000..98db38d36 --- /dev/null +++ b/services/identity/src/websockets/auth.rs @@ -0,0 +1,90 @@ +use client_proto::VerifyUserAccessTokenRequest; +use grpc_clients::identity; +use grpc_clients::tonic::Request; +use identity::get_unauthenticated_client; +use identity::protos::unauthenticated as client_proto; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error}; + +use crate::constants::DEFAULT_IDENTITY_ENDPOINT; +use crate::websockets::errors::WebsocketError; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type", rename_all = "camelCase")] +pub struct AuthMessage { + #[serde(rename = "userID")] + pub user_id: String, + #[serde(rename = "deviceID")] + pub device_id: String, + pub access_token: String, +} + +const PLACEHOLDER_CODE_VERSION: u64 = 0; +const DEVICE_TYPE: &str = "service"; + +async fn verify_user_access_token( + user_id: &str, + device_id: &str, + access_token: &str, +) -> Result { + let grpc_client = get_unauthenticated_client( + DEFAULT_IDENTITY_ENDPOINT, + PLACEHOLDER_CODE_VERSION, + DEVICE_TYPE.to_string(), + ) + .await; + + let mut grpc_client = match grpc_client { + Ok(grpc_client) => grpc_client, + Err(e) => { + error!("Failed to get unauthenticated client: {}", e); + return Err(WebsocketError::AuthError); + } + }; + + let message = VerifyUserAccessTokenRequest { + user_id: user_id.to_string(), + device_id: device_id.to_string(), + access_token: access_token.to_string(), + }; + + let request = Request::new(message); + let response = match grpc_client.verify_user_access_token(request).await { + Ok(response) => response, + Err(e) => { + error!("Failed to verify user access token: {}", e); + return Err(WebsocketError::AuthError); + } + }; + + Ok(response.into_inner().token_valid) +} + +pub async fn handle_auth_message(message: &str) -> Result<(), WebsocketError> { + error!("Handling auth message: {}", message); + let auth_message = serde_json::from_str(message.trim()); + + let auth_message: AuthMessage = match auth_message { + Ok(auth_message) => auth_message, + Err(e) => { + error!("Failed to parse auth message: {}", e); + return Err(WebsocketError::InvalidMessage); + } + }; + + let user_id = auth_message.user_id; + let device_id = auth_message.device_id; + let access_token = auth_message.access_token; + + let is_valid_token = + verify_user_access_token(&user_id, &device_id, &access_token).await?; + + if is_valid_token { + debug!("User {} authenticated", user_id); + } else { + debug!("User {} not authenticated", user_id); + return Err(WebsocketError::UnauthorizedDevice); + } + + Ok(()) +} diff --git a/services/identity/src/websockets/errors.rs b/services/identity/src/websockets/errors.rs index 198472e00..80521b25f 100644 --- a/services/identity/src/websockets/errors.rs +++ b/services/identity/src/websockets/errors.rs @@ -1,11 +1,13 @@ pub type BoxedError = Box; #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] pub enum WebsocketError { InvalidMessage, + UnauthorizedDevice, SendError, SearchError, + AuthError, SerializationError, } diff --git a/services/identity/src/websockets/mod.rs b/services/identity/src/websockets/mod.rs index 01890b315..c0b984261 100644 --- a/services/identity/src/websockets/mod.rs +++ b/services/identity/src/websockets/mod.rs @@ -1,302 +1,331 @@ use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use elastic::client::responses::SearchResponse; use futures::lock::Mutex; use futures_util::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, StatusCode}; use hyper_tungstenite::tungstenite::Message; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::WebSocketStream; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tracing::{debug, error, info}; +mod auth; + use crate::config::CONFIG; use crate::constants::IDENTITY_SERVICE_WEBSOCKET_ADDR; pub mod errors; #[derive(Serialize, Deserialize)] struct Query { query: Prefix, } #[derive(Serialize, Deserialize)] struct Prefix { prefix: Username, } #[derive(Serialize, Deserialize)] struct Username { username: String, } #[derive(Serialize, Deserialize)] struct User { #[serde(rename = "userID")] user_id: String, username: String, } struct WebsocketService { addr: SocketAddr, } impl hyper::service::Service> for WebsocketService { type Response = Response; type Error = errors::BoxedError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { let addr = self.addr; let future = async move { tracing::info!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; debug!("Upgraded WebSocket connection from {}", addr); tokio::spawn(async move { accept_connection(websocket, addr).await; }); return Ok(response); } debug!( "Incoming HTTP request on WebSocket port: {} {}", req.method(), req.uri().path() ); let response = match req.uri().path() { "/health" => Response::new(Body::from("OK")), _ => Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not found"))?, }; Ok(response) }; Box::pin(future) } } pub async fn run_server() -> Result<(), errors::BoxedError> { let addr: SocketAddr = IDENTITY_SERVICE_WEBSOCKET_ADDR.parse()?; let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("WebSocket Listening on {}", addr); let mut http = hyper::server::conn::Http::new(); http.http1_only(true); http.http1_keep_alive(true); while let Ok((stream, addr)) = listener.accept().await { let connection = http .serve_connection(stream, WebsocketService { addr }) .with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { error!("Error serving HTTP/WebSocket connection: {:?}", err); } }); } Ok(()) } async fn send_search_request( url: &str, json_body: String, ) -> Result { let client = reqwest::Client::new(); client .post(url) .header(reqwest::header::CONTENT_TYPE, "application/json") .body(json_body) .send() .await } async fn send_error_response( error: errors::WebsocketError, outgoing: Arc, Message>>>, ) { let response_msg = serde_json::json!({ "action": "errorMessage", "error": format!("{}", error) }); match serde_json::to_string(&response_msg) { Ok(serialized_response) => { if let Err(send_error) = outgoing .lock() .await .send(Message::Text(serialized_response)) .await { error!("Failed to send error response: {:?}", send_error); } } Err(serialize_error) => { error!( "Failed to serialize the error response: {:?}", serialize_error ); } } } +async fn close_connection( + outgoing: Arc, Message>>>, +) { + if let Err(e) = outgoing.lock().await.close().await { + error!("Error closing connection: {}", e); + } +} + async fn accept_connection(hyper_ws: HyperWebsocket, addr: SocketAddr) { debug!("Incoming WebSocket connection from {}", addr); let ws_stream = match hyper_ws.await { Ok(stream) => stream, Err(e) => { error!("WebSocket handshake error: {}", e); return; } }; let (outgoing, mut incoming) = ws_stream.split(); let outgoing = Arc::new(Mutex::new(outgoing)); let opensearch_url = format!("https://{}/users/_search/", &CONFIG.opensearch_endpoint); + if let Some(Ok(auth_message)) = incoming.next().await { + match auth_message { + Message::Text(text) => { + if let Err(auth_error) = auth::handle_auth_message(&text).await { + send_error_response(auth_error, outgoing.clone()).await; + close_connection(outgoing).await; + return; + } + } + _ => { + error!("Invalid authentication message from {}", addr); + close_connection(outgoing).await; + return; + } + } + } else { + error!("No authentication message from {}", addr); + close_connection(outgoing).await; + return; + } + while let Some(message) = incoming.next().await { match message { Ok(Message::Close(_)) => { debug!("Connection to {} closed.", addr); break; } Ok(Message::Pong(_)) => { debug!("Received Pong message from {}", addr); } Ok(Message::Ping(msg)) => { debug!("Received Ping message from {}", addr); if let Err(e) = outgoing.lock().await.send(Message::Pong(msg)).await { error!("Error sending message: {}", e); } } Ok(Message::Text(text)) => { let prefix_query = Query { query: Prefix { prefix: Username { username: text.trim().to_string(), }, }, }; let json_body = match serde_json::to_string(&prefix_query) { Ok(json_body) => json_body, Err(e) => { error!("Error serializing prefix query: {}", e); send_error_response( errors::WebsocketError::SerializationError, outgoing.clone(), ) .await; continue; } }; let response = send_search_request(&opensearch_url, json_body).await; let response_text = match response { Ok(response) => match response.text().await { Ok(text) => text, Err(e) => { error!("Error getting response text: {}", e); send_error_response( errors::WebsocketError::SearchError, outgoing.clone(), ) .await; continue; } }, Err(e) => { error!("Error getting search response: {}", e); send_error_response( errors::WebsocketError::SearchError, outgoing.clone(), ) .await; continue; } }; let search_response: SearchResponse = match serde_json::from_str(&response_text) { Ok(search_response) => search_response, Err(e) => { error!("Error deserializing search response: {}", e); send_error_response( errors::WebsocketError::SerializationError, outgoing.clone(), ) .await; continue; } }; let usernames: Vec<&User> = search_response.documents().collect(); let response_msg = serde_json::json!({ "action": "searchResults", "results": usernames }); if let Err(e) = outgoing .lock() .await .send(Message::Text(format!("{}", response_msg.to_string()))) .await { error!("Error sending message: {}", e); send_error_response( errors::WebsocketError::SendError, outgoing.clone(), ) .await; continue; } } Err(e) => { error!("Error in WebSocket message: {}", e); send_error_response( errors::WebsocketError::InvalidMessage, outgoing.clone(), ) .await; continue; } _ => {} } } - if let Err(e) = outgoing.lock().await.close().await { - error!("Failed to close WebSocket connection: {}", e); - }; + close_connection(outgoing).await; }