diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -73,6 +73,7 @@ pub const WORKFLOWS_IN_PROGRESS_TABLE: &str = "identity-workflows-in-progress"; pub const WORKFLOWS_IN_PROGRESS_PARTITION_KEY: &str = "id"; +pub const WORKFLOWS_IN_PROGRESS_WORKFLOW_ATTRIBUTE: &str = "workflow"; pub const WORKFLOWS_IN_PROGRESS_TABLE_EXPIRATION_TIME_UNIX_ATTRIBUTE: &str = "expirationTimeUnix"; diff --git a/services/identity/src/database.rs b/services/identity/src/database.rs --- a/services/identity/src/database.rs +++ b/services/identity/src/database.rs @@ -57,6 +57,7 @@ pub use grpc_clients::identity::DeviceType; mod device_list; +mod workflows; pub use device_list::{DeviceListRow, DeviceListUpdate, DeviceRow}; use self::device_list::Prekey; diff --git a/services/identity/src/database/workflows.rs b/services/identity/src/database/workflows.rs new file mode 100644 --- /dev/null +++ b/services/identity/src/database/workflows.rs @@ -0,0 +1,88 @@ +use std::collections::HashMap; + +use chrono::{Duration, Utc}; +use comm_lib::{aws::ddb::types::AttributeValue, database::TryFromAttribute}; + +use super::DatabaseClient; +use crate::{ + client_service::WorkflowInProgress, + constants::{ + WORKFLOWS_IN_PROGRESS_PARTITION_KEY, WORKFLOWS_IN_PROGRESS_TABLE, + WORKFLOWS_IN_PROGRESS_TABLE_EXPIRATION_TIME_UNIX_ATTRIBUTE, + WORKFLOWS_IN_PROGRESS_TTL_DURATION, + WORKFLOWS_IN_PROGRESS_WORKFLOW_ATTRIBUTE, + }, + error::Error, + id::generate_uuid, +}; + +type WorkflowID = String; + +impl DatabaseClient { + pub async fn insert_workflow( + &self, + workflow: WorkflowInProgress, + ) -> Result { + let workflow_id = generate_uuid(); + let workflow_expiration_time = + Utc::now() + Duration::seconds(WORKFLOWS_IN_PROGRESS_TTL_DURATION); + let item = HashMap::from([ + ( + WORKFLOWS_IN_PROGRESS_PARTITION_KEY.to_string(), + AttributeValue::S(workflow_id.clone()), + ), + ( + WORKFLOWS_IN_PROGRESS_WORKFLOW_ATTRIBUTE.to_string(), + AttributeValue::S(serde_json::to_string(&workflow)?), + ), + ( + WORKFLOWS_IN_PROGRESS_TABLE_EXPIRATION_TIME_UNIX_ATTRIBUTE.to_string(), + AttributeValue::N(workflow_expiration_time.timestamp().to_string()), + ), + ]); + self + .client + .put_item() + .table_name(WORKFLOWS_IN_PROGRESS_TABLE) + .set_item(Some(item)) + .send() + .await + .map_err(|e| Error::AwsSdk(e.into()))?; + + Ok(workflow_id) + } + + pub async fn get_workflow( + &self, + workflow_id: String, + ) -> Result, Error> { + let get_response = self + .client + .get_item() + .table_name(WORKFLOWS_IN_PROGRESS_TABLE) + .key( + WORKFLOWS_IN_PROGRESS_PARTITION_KEY, + AttributeValue::S(workflow_id), + ) + .send() + .await + .map_err(|e| Error::AwsSdk(e.into()))?; + + let mut workflow_item = get_response.item.unwrap_or_default(); + let raw_workflow = + workflow_item.remove(WORKFLOWS_IN_PROGRESS_WORKFLOW_ATTRIBUTE); + + if raw_workflow.is_none() { + return Ok(None); + } + + let serialized_workflow = String::try_from_attr( + WORKFLOWS_IN_PROGRESS_WORKFLOW_ATTRIBUTE, + raw_workflow, + )?; + + let workflow = serde_json::from_str(&serialized_workflow)?; + + Ok(Some(workflow)) + } +} diff --git a/services/identity/src/error.rs b/services/identity/src/error.rs --- a/services/identity/src/error.rs +++ b/services/identity/src/error.rs @@ -20,6 +20,8 @@ DeviceList(DeviceListError), #[display(...)] MalformedItem, + #[display(...)] + Serde(serde_json::Error), } #[derive(Debug, derive_more::Display, derive_more::Error)]