diff --git a/services/backup/Cargo.lock b/services/backup/Cargo.lock --- a/services/backup/Cargo.lock +++ b/services/backup/Cargo.lock @@ -747,6 +747,7 @@ name = "backup" version = "0.1.0" dependencies = [ + "actix-multipart", "actix-web", "anyhow", "async-stream", diff --git a/services/backup/Cargo.toml b/services/backup/Cargo.toml --- a/services/backup/Cargo.toml +++ b/services/backup/Cargo.toml @@ -30,6 +30,7 @@ tracing-actix-web = "0.7.3" reqwest = "0.11.18" derive_more = "0.99" +actix-multipart = "0.6" [build-dependencies] tonic-build = "0.8" diff --git a/services/backup/src/error.rs b/services/backup/src/error.rs --- a/services/backup/src/error.rs +++ b/services/backup/src/error.rs @@ -5,7 +5,9 @@ }, HttpResponse, ResponseError, }; +pub use aws_sdk_dynamodb::Error as DynamoDBError; use comm_services_lib::blob::client::BlobServiceError; +use comm_services_lib::database::Error as DBError; use reqwest::StatusCode; use tracing::{error, trace, warn}; @@ -14,6 +16,7 @@ )] pub enum BackupError { BlobError(BlobServiceError), + DB(comm_services_lib::database::Error), } impl From<&BackupError> for actix_web::Error { @@ -41,6 +44,20 @@ error!("Unexpected blob error: {err}"); ErrorInternalServerError("server error") } + BackupError::DB(err) => match err { + DBError::AwsSdk( + err @ (DynamoDBError::InternalServerError(_) + | DynamoDBError::ProvisionedThroughputExceededException(_) + | DynamoDBError::RequestLimitExceeded(_)), + ) => { + warn!("AWS transient error occurred: {err}"); + ErrorServiceUnavailable("please retry") + } + unexpected => { + error!("Received an unexpected DB error: {0:?} - {0}", unexpected); + ErrorInternalServerError("server error") + } + }, } } } diff --git a/services/backup/src/http/handlers/backup.rs b/services/backup/src/http/handlers/backup.rs new file mode 100644 --- /dev/null +++ b/services/backup/src/http/handlers/backup.rs @@ -0,0 +1,146 @@ +use std::{collections::HashSet, convert::Infallible}; + +use actix_web::{ + error::ErrorBadRequest, + web::{self, Bytes}, + HttpResponse, +}; +use comm_services_lib::{ + auth::UserIdentity, + blob::{client::BlobServiceClient, types::BlobInfo}, + http::multipart::{get_named_text_field, get_text_field}, +}; +use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tracing::{info, instrument, trace, warn}; + +use crate::{ + database::{backup_item::BackupItem, DatabaseClient}, + error::BackupError, +}; + +#[instrument(name = "upload_backup", skip_all, fields(backup_id))] +pub async fn upload( + user: UserIdentity, + blob_client: web::Data, + db_client: web::Data, + mut multipart: actix_multipart::Multipart, +) -> actix_web::Result { + info!("Upload backup request"); + + let backup_id = get_named_text_field("backup_id", &mut multipart).await?; + + tracing::Span::current().record("backup_id", &backup_id); + + let user_keys_blob_info = forward_field_to_blob( + &mut multipart, + &blob_client, + "user_keys_hash", + "user_keys", + ) + .await?; + + let user_data_blob_info = forward_field_to_blob( + &mut multipart, + &blob_client, + "user_data_hash", + "user_data", + ) + .await?; + + let attachments_holders: HashSet = + match get_text_field(&mut multipart).await? { + Some((name, attachments)) => { + if name != "attachments" { + warn!( + name, + "Malformed request: 'attachments' text field expected." + ); + return Err(ErrorBadRequest("Bad request")); + } + + attachments.lines().map(ToString::to_string).collect() + } + None => HashSet::new(), + }; + + let item = BackupItem::new( + user.user_id, + backup_id, + user_keys_blob_info, + user_data_blob_info, + attachments_holders, + ); + + db_client + .put_backup_item(item) + .await + .map_err(BackupError::from)?; + Ok(HttpResponse::Ok().finish()) +} + +#[instrument( + skip_all, + name = "forward_to_blob", + fields(hash_field_name, data_field_name) +)] +async fn forward_field_to_blob( + multipart: &mut actix_multipart::Multipart, + blob_client: &web::Data, + hash_field_name: &str, + data_field_name: &str, +) -> actix_web::Result { + trace!("Reading blob fields: {hash_field_name:?}, {data_field_name:?}"); + + let blob_hash = get_named_text_field(hash_field_name, multipart).await?; + + let Some(mut field) = multipart.try_next().await? else { + warn!("Malformed request: expected a field."); + return Err(ErrorBadRequest("Bad request"))?; + }; + if field.name() != data_field_name { + warn!( + hash_field_name, + "Malformed request: '{data_field_name}' data field expected." + ); + return Err(ErrorBadRequest("Bad request"))?; + } + + let blob_info = BlobInfo { + blob_hash, + holder: uuid::Uuid::new_v4().to_string(), + }; + + // [`actix_multipart::Multipart`] isn't [`std::marker::Send`], and so we cannot pass it to the blob client directly. + // Instead we have to forward it to a channel and create stream from the receiver. + let (tx, rx) = tokio::sync::mpsc::channel(1); + let receive_promise = async move { + trace!("Receiving blob data"); + // [`actix_multipart::MultipartError`] isn't [`std::marker::Send`] so we return it here, and pass [`Infallible`] + // as the error to the channel + while let Some(chunk) = field.try_next().await? { + if let Err(err) = tx.send(Result::::Ok(chunk)).await { + warn!("Error when sending data through a channel: '{err}'"); + // Error here means that the channel has been closed from the blob client side. We don't want to return an error + // here, because `tokio::try_join!` only returns the first error it receives and we want to prioritize the backup + // client error. + break; + } + } + trace!("Finished receiving blob data"); + Result::<(), actix_web::Error>::Ok(()) + }; + + let data_stream = ReceiverStream::new(rx); + let send_promise = async { + blob_client + .simple_put(&blob_info.blob_hash, &blob_info.holder, data_stream) + .await + .map_err(BackupError::from)?; + + Ok(()) + }; + + tokio::try_join!(receive_promise, send_promise)?; + + Ok(blob_info) +} diff --git a/services/backup/src/http/mod.rs b/services/backup/src/http/mod.rs --- a/services/backup/src/http/mod.rs +++ b/services/backup/src/http/mod.rs @@ -1,10 +1,17 @@ use actix_web::{web, App, HttpServer}; use anyhow::Result; -use comm_services_lib::blob::client::BlobServiceClient; +use comm_services_lib::{ + blob::client::BlobServiceClient, + http::auth::get_comm_authentication_middleware, +}; use tracing::info; use crate::{database::DatabaseClient, CONFIG}; +mod handlers { + pub(super) mod backup; +} + pub async fn run_http_server( db_client: DatabaseClient, blob_client: BlobServiceClient, @@ -18,6 +25,8 @@ let blob = web::Data::new(blob_client); HttpServer::new(move || { + let auth_middleware = get_comm_authentication_middleware(); + App::new() .wrap(tracing_actix_web::TracingLogger::default()) .wrap(comm_services_lib::http::cors_config( @@ -26,7 +35,9 @@ .app_data(db.clone()) .app_data(blob.clone()) .service( - web::resource("/hello").route(web::get().to(|| async { "world" })), + web::resource("/backups") + .route(web::post().to(handlers::backup::upload)) + .wrap(auth_middleware), ) }) .bind(("0.0.0.0", CONFIG.http_port))? diff --git a/services/comm-services-lib/src/database.rs b/services/comm-services-lib/src/database.rs --- a/services/comm-services-lib/src/database.rs +++ b/services/comm-services-lib/src/database.rs @@ -1,5 +1,5 @@ use aws_sdk_dynamodb::types::AttributeValue; -use aws_sdk_dynamodb::Error as DynamoDBError; +pub use aws_sdk_dynamodb::Error as DynamoDBError; use chrono::{DateTime, Utc}; use std::collections::HashSet; use std::fmt::{Display, Formatter}; diff --git a/services/comm-services-lib/src/http/multipart.rs b/services/comm-services-lib/src/http/multipart.rs --- a/services/comm-services-lib/src/http/multipart.rs +++ b/services/comm-services-lib/src/http/multipart.rs @@ -1,6 +1,7 @@ use actix_multipart::{Field, MultipartError}; -use actix_web::error::ParseError; +use actix_web::error::{ErrorBadRequest, ParseError}; use tokio_stream::StreamExt; +use tracing::warn; /// Can be used to get a single field from multipart body with it's data /// converted to a string @@ -40,3 +41,20 @@ Ok(Some((name, text))) } + +pub async fn get_named_text_field( + name: &str, + multipart: &mut actix_multipart::Multipart, +) -> actix_web::Result { + let Some((field_name, backup_id)) = get_text_field(multipart).await? else { + warn!("Malformed request: expected a field."); + return Err(ErrorBadRequest("Bad request")); + }; + + if field_name != name { + warn!(name, "Malformed request: '{name}' text field expected."); + return Err(ErrorBadRequest("Bad request")); + } + + Ok(backup_id) +}