diff --git a/services/backup/blob_client/src/get_client.rs b/services/backup/blob_client/src/get_client.rs index f23e59661..c38de1ce8 100644 --- a/services/backup/blob_client/src/get_client.rs +++ b/services/backup/blob_client/src/get_client.rs @@ -1,165 +1,182 @@ mod proto { tonic::include_proto!("blob"); } use proto::blob_service_client::BlobServiceClient; use proto::GetRequest; use crate::constants::{BLOB_ADDRESS, MPSC_CHANNEL_BUFFER_CAPACITY}; -use crate::tools::{check_error, report_error}; +use crate::tools::{ + c_char_pointer_to_string, check_error, report_error, string_to_c_char_pointer, +}; use lazy_static::lazy_static; use libc; use libc::c_char; -use std::ffi::CStr; +use std::collections::HashMap; use std::sync::Mutex; use tokio::runtime::Runtime; use tokio::sync::mpsc; use tokio::task::JoinHandle; struct ReadClient { rx: mpsc::Receiver>, rx_handle: JoinHandle<()>, } lazy_static! { - static ref CLIENT: Mutex> = - Mutex::new(None); + // todo: we should consider limiting the clients size, + // if every client is able to allocate up to 4MB data at a time + static ref CLIENTS: Mutex> = + Mutex::new(HashMap::new()); static ref RUNTIME: Runtime = Runtime::new().unwrap(); static ref ERROR_MESSAGES: Mutex> = Mutex::new(Vec::new()); } -fn is_initialized() -> bool { - if let Ok(client) = CLIENT.lock() { - if client.is_some() { - return true; - } +fn is_initialized(holder: &String) -> bool { + if let Ok(clients) = CLIENTS.lock() { + return clients.contains_key(holder); } else { report_error(&ERROR_MESSAGES, "couldn't access client", Some("get")); } false } pub fn get_client_initialize_cxx( holder_char: *const c_char, ) -> Result<(), String> { - if is_initialized() { - get_client_terminate_cxx()?; + let holder = c_char_pointer_to_string(holder_char)?; + if is_initialized(&holder) { + get_client_terminate_cxx(string_to_c_char_pointer(&holder)?)?; } - assert!(!is_initialized(), "client cannot be initialized twice"); - - let holder_cstr: &CStr = unsafe { CStr::from_ptr(holder_char) }; - let holder: String = holder_cstr.to_str().unwrap().to_owned(); + assert!( + !is_initialized(&holder), + "client cannot be initialized twice" + ); // grpc if let Ok(mut grpc_client) = RUNTIME.block_on(async { BlobServiceClient::connect(BLOB_ADDRESS).await }) { // spawn receiver thread let (response_thread_tx, response_thread_rx) = mpsc::channel::>(MPSC_CHANNEL_BUFFER_CAPACITY); + let cloned_holder = holder.clone(); let rx_handle = RUNTIME.spawn(async move { - if let Ok(response) = grpc_client.get(GetRequest { holder }).await { + if let Ok(response) = grpc_client + .get(GetRequest { + holder: cloned_holder, + }) + .await + { let mut inner_response = response.into_inner(); loop { match inner_response.message().await { Ok(maybe_data) => { let mut result = false; if let Some(data) = maybe_data { let data: Vec = data.data_chunk; result = match response_thread_tx.send(data).await { Ok(_) => true, Err(err) => { report_error( &ERROR_MESSAGES, &err.to_string(), Some("get"), ); false } } } if !result { break; } } Err(err) => { report_error(&ERROR_MESSAGES, &err.to_string(), Some("get")); break; } }; } } else { report_error( &ERROR_MESSAGES, "couldn't perform grpc get operation", Some("get"), ); } }); - if let Ok(mut client) = CLIENT.lock() { - *client = Some(ReadClient { + if let Ok(mut clients) = CLIENTS.lock() { + let client = ReadClient { rx_handle, rx: response_thread_rx, - }); + }; + (*clients).insert(holder, client); return Ok(()); } return Err("could not access client".to_string()); } Err("could not successfully connect to the blob server".to_string()) } -pub fn get_client_blocking_read_cxx() -> Result, String> { +pub fn get_client_blocking_read_cxx( + holder_char: *const c_char, +) -> Result, String> { + let holder = c_char_pointer_to_string(holder_char)?; check_error(&ERROR_MESSAGES)?; let response: Option> = RUNTIME.block_on(async { - if let Ok(mut maybe_client) = CLIENT.lock() { - if let Some(mut client) = (*maybe_client).take() { + if let Ok(mut clients) = CLIENTS.lock() { + let maybe_client = clients.get_mut(&holder); + if let Some(client) = maybe_client { let maybe_data = client.rx.recv().await; let response = Some(maybe_data.unwrap_or_else(|| vec![])); - *maybe_client = Some(client); return response; } else { report_error(&ERROR_MESSAGES, "no client present", Some("get")); } } else { report_error(&ERROR_MESSAGES, "couldn't access client", Some("get")); } None }); check_error(&ERROR_MESSAGES)?; response.ok_or("response could not be obtained".to_string()) } -pub fn get_client_terminate_cxx() -> Result<(), String> { +pub fn get_client_terminate_cxx( + holder_char: *const c_char, +) -> Result<(), String> { + let holder = c_char_pointer_to_string(holder_char)?; check_error(&ERROR_MESSAGES)?; - if !is_initialized() { + if !is_initialized(&holder) { check_error(&ERROR_MESSAGES)?; return Ok(()); } - if let Ok(mut maybe_client) = CLIENT.lock() { - if let Some(client) = (*maybe_client).take() { + if let Ok(mut clients) = CLIENTS.lock() { + let maybe_client = clients.remove(&holder); + if let Some(client) = maybe_client { RUNTIME.block_on(async { if client.rx_handle.await.is_err() { report_error( &ERROR_MESSAGES, "wait for receiver handle failed", Some("get"), ); } }); } else { return Err("no client detected".to_string()); } } else { return Err("couldn't access client".to_string()); } assert!( - !is_initialized(), + !is_initialized(&holder), "client transmitter handler released properly" ); check_error(&ERROR_MESSAGES)?; Ok(()) } diff --git a/services/backup/blob_client/src/lib.rs b/services/backup/blob_client/src/lib.rs index 29d04e79e..bf3ec8762 100644 --- a/services/backup/blob_client/src/lib.rs +++ b/services/backup/blob_client/src/lib.rs @@ -1,38 +1,42 @@ mod constants; mod get_client; mod put_client; mod tools; use put_client::{ put_client_blocking_read_cxx, put_client_initialize_cxx, put_client_terminate_cxx, put_client_write_cxx, }; use get_client::{ get_client_blocking_read_cxx, get_client_initialize_cxx, get_client_terminate_cxx, }; #[cxx::bridge] mod ffi { extern "Rust" { unsafe fn put_client_initialize_cxx( holder_char: *const c_char, ) -> Result<()>; unsafe fn put_client_write_cxx( holder_char: *const c_char, field_index: usize, data: *const c_char, ) -> Result<()>; unsafe fn put_client_blocking_read_cxx( holder_char: *const c_char, ) -> Result; unsafe fn put_client_terminate_cxx( holder_char: *const c_char, ) -> Result<()>; unsafe fn get_client_initialize_cxx( holder_char: *const c_char, ) -> Result<()>; - fn get_client_blocking_read_cxx() -> Result>; - fn get_client_terminate_cxx() -> Result<()>; + unsafe fn get_client_blocking_read_cxx( + holder_char: *const c_char, + ) -> Result>; + unsafe fn get_client_terminate_cxx( + holder_char: *const c_char, + ) -> Result<()>; } }