diff --git a/services/backup/blob_client/src/put_client.rs b/services/backup/blob_client/src/put_client.rs --- a/services/backup/blob_client/src/put_client.rs +++ b/services/backup/blob_client/src/put_client.rs @@ -45,18 +45,14 @@ } fn is_initialized() -> bool { - if CLIENT.lock().expect("access client").tx.is_none() { - return false; - } - if CLIENT.lock().expect("access client").rx.is_none() { - return false; - } - if CLIENT - .lock() - .expect("access client") - .receiver_handle - .is_none() - { + if let Ok(client) = CLIENT.lock() { + if client.tx.is_none() + || client.rx.is_none() + || client.receiver_handle.is_none() + { + return false; + } + } else { return false; } return true; @@ -64,145 +60,174 @@ fn report_error(message: String) { error!("[RUST] Error: {}", message); - ERROR_MESSAGES - .lock() - .expect("access error messages") - .push(message); + if let Ok(mut error_messages) = ERROR_MESSAGES.lock() { + error_messages.push(message); + } + error!("could not access error messages") } fn check_error() -> Result<(), String> { - let errors = ERROR_MESSAGES.lock().expect("access error messages"); - let mut errors_str_value = None; - if !errors.is_empty() { - errors_str_value = Some(errors.join("\n")); + if let Ok(error_messages) = ERROR_MESSAGES.lock() { + if !error_messages.is_empty() { + return Err(error_messages.join("\n")); + } + return Ok(()); + } else { + return Err("could not access error messages".to_string()); } - return match errors_str_value { - Some(value) => Err(value), - None => Ok(()), - }; } pub fn put_client_initialize_cxx() -> Result<(), String> { println!("[RUST] initializing"); assert!(!is_initialized(), "client cannot be initialized twice"); // grpc - let mut grpc_client: Option> = - None; + let mut maybe_grpc_client: Option< + BlobServiceClient, + > = None; RUNTIME.block_on(async { - grpc_client = Some( - BlobServiceClient::connect(BLOB_ADDRESS) - .await - .expect("successfully connect to the blob server"), - ); + maybe_grpc_client = BlobServiceClient::connect(BLOB_ADDRESS).await.ok(); }); + if let Some(mut grpc_client) = maybe_grpc_client { + let (request_thread_tx, mut request_thread_rx): ( + mpsc::Sender, + mpsc::Receiver, + ) = mpsc::channel(MPSC_CHANNEL_BUFFER_CAPACITY); - let (request_thread_tx, mut request_thread_rx): ( - mpsc::Sender, - mpsc::Receiver, - ) = mpsc::channel(MPSC_CHANNEL_BUFFER_CAPACITY); - - let outbound = async_stream::stream! { - while let Some(data) = request_thread_rx.recv().await { - println!("[RUST] [transmitter_thread] field index: {}", data.field_index); - println!("[RUST] [transmitter_thread] data: {:?}", data.data); - let request_data: put_request::Data = match data.field_index { - 0 => Holder(String::from_utf8(data.data).expect("Found invalid UTF-8")), - 1 => BlobHash(String::from_utf8(data.data).expect("Found invalid UTF-8")), - 2 => DataChunk(data.data), - _ => panic!("invalid field index value {}", data.field_index) - }; - let request = PutRequest { - data: Some(request_data), - }; - yield request; - } - }; - - // spawn receiver thread - let (response_thread_tx, response_thread_rx): ( - mpsc::Sender, - mpsc::Receiver, - ) = mpsc::channel(MPSC_CHANNEL_BUFFER_CAPACITY); - let receiver_handle = RUNTIME.spawn(async move { - println!("[RUST] [receiver_thread] begin"); - let maybe_response: Option< - tonic::Response>, - > = match grpc_client - .expect("access grpc client") - .put(tonic::Request::new(outbound)) - .await - { - Ok(res) => Some(res), - Err(err) => { - report_error(err.to_string()); - None - } - }; - if maybe_response.is_none() { - return; - } - match maybe_response { - Some(response) => { - let mut inner_response = response.into_inner(); - let mut response_present = true; - while response_present { - response_present = match inner_response.message().await { - Ok(maybe_response_message) => { - let mut result = false; - if let Some(response_message) = maybe_response_message { - println!( - "[RUST] got response: {}", - response_message.data_exists - ); - // warning: this will hang if there's more unread responses than - // MPSC_CHANNEL_BUFFER_CAPACITY - // you should then use put_client_blocking_read_cxx in order to dequeue - // the responses in c++ and make room for more - if let Ok(_) = response_thread_tx - .send((response_message.data_exists as i32).to_string()) - .await - { - result = true; - } - } - result + let outbound = async_stream::stream! { + while let Some(data) = request_thread_rx.recv().await { + println!("[RUST] [transmitter_thread] field index: {}", data.field_index); + println!("[RUST] [transmitter_thread] data: {:?}", data.data); + let request_data: Option = match data.field_index { + 0 => { + match String::from_utf8(data.data).ok() { + Some(utf8_data) => Some(Holder(utf8_data)), + None => { + report_error("invalid utf-8".to_string()); + None + }, } - Err(err) => { - report_error(err.to_string()); - false + } + 1 => { + match String::from_utf8(data.data).ok() { + Some(utf8_data) => Some(BlobHash(utf8_data)), + None => { + report_error("invalid utf-8".to_string()); + None + }, } + } + 2 => { + Some(DataChunk(data.data)) + } + _ => { + report_error(format!("invalid field index value {}", data.field_index)); + None + } + }; + if let Some (unpacked_data) = request_data { + let request = PutRequest { + data: Some(unpacked_data), }; + yield request; + } else { + report_error("an error occured, aborting connection".to_string()); + break; } } - unexpected => { - report_error(format!("unexpected result received: {:?}", unexpected)); - } }; - println!("[RUST] [receiver_thread] done"); - }); - CLIENT.lock().expect("access client").tx = Some(request_thread_tx); - CLIENT.lock().expect("access client").receiver_handle = Some(receiver_handle); - CLIENT.lock().expect("access client").rx = Some(response_thread_rx); - println!("[RUST] initialized"); - Ok(()) + // spawn receiver thread + let (response_thread_tx, response_thread_rx): ( + mpsc::Sender, + mpsc::Receiver, + ) = mpsc::channel(MPSC_CHANNEL_BUFFER_CAPACITY); + let receiver_handle = RUNTIME.spawn(async move { + println!("[RUST] [receiver_thread] begin"); + let maybe_response: Option< + tonic::Response>, + > = match grpc_client.put(tonic::Request::new(outbound)).await { + Ok(res) => Some(res), + Err(err) => { + report_error(err.to_string()); + None + } + }; + if maybe_response.is_none() { + return; + } + match maybe_response { + Some(response) => { + let mut inner_response = response.into_inner(); + let mut response_present = true; + while response_present { + response_present = match inner_response.message().await { + Ok(maybe_response_message) => { + let mut result = false; + if let Some(response_message) = maybe_response_message { + println!( + "[RUST] got response: {}", + response_message.data_exists + ); + // warning: this will hang if there's more unread responses than + // MPSC_CHANNEL_BUFFER_CAPACITY + // you should then use put_client_blocking_read_cxx in order to dequeue + // the responses in c++ and make room for more + if let Ok(_) = response_thread_tx + .send((response_message.data_exists as i32).to_string()) + .await + { + result = true; + } + } + result + } + Err(err) => { + report_error(err.to_string()); + false + } + }; + } + } + unexpected => { + report_error(format!("unexpected result received: {:?}", unexpected)); + } + }; + println!("[RUST] [receiver_thread] done"); + }); + + if let Ok(mut client) = CLIENT.lock() { + client.tx = Some(request_thread_tx); + client.receiver_handle = Some(receiver_handle); + client.rx = Some(response_thread_rx); + println!("[RUST] initialized"); + return Ok(()); + } + return Err("could not access client".to_string()); + } + Err("could not successfully connect to the blob server".to_string()) } pub fn put_client_blocking_read_cxx() -> Result { let mut response: Option = None; check_error()?; RUNTIME.block_on(async { - let mut rx: mpsc::Receiver = CLIENT - .lock() - .expect("access client") - .rx - .take() - .expect("access client's receiver"); - if let Some(data) = rx.recv().await { - println!("received data {}", data); - response = Some(data); + if let Ok(mut client) = CLIENT.lock() { + if let Some(mut rx) = client.rx.take() { + if let Some(data) = rx.recv().await { + println!("received data {}", data); + response = Some(data); + } else { + report_error( + "couldn't receive data via client's receiver".to_string(), + ); + } + client.rx = Some(rx); + } else { + report_error("couldn't access client's receiver".to_string()); + } + } else { + report_error("couldn't access client".to_string()); } - CLIENT.lock().expect("access client").rx = Some(rx); }); response.ok_or("response not received properly".to_string()) } @@ -219,18 +244,24 @@ println!("[RUST] [put_client_process] data string: {:?}", data_bytes); RUNTIME.block_on(async { - CLIENT - .lock() - .expect("access client") - .tx - .as_ref() - .expect("access client's transmitter") - .send(PutRequestData { - field_index, - data: data_bytes, - }) - .await - .expect("send data to receiver"); + if let Ok(mut client) = CLIENT.lock() { + if let Some(tx) = client.tx.take() { + if let Ok(_) = tx + .send(PutRequestData { + field_index, + data: data_bytes, + }) + .await + { + } else { + report_error("send data to receiver failed".to_string()); + } + } else { + report_error("couldn't access client's transmitter".to_string()); + } + } else { + report_error("couldn't access client".to_string()); + } }); println!("[RUST] [put_client_process] end"); Ok(()) @@ -241,17 +272,20 @@ pub fn put_client_terminate_cxx() -> Result<(), String> { println!("[RUST] put_client_terminating"); check_error()?; - if let Some(receiver_handle) = - CLIENT.lock().expect("access client").receiver_handle.take() - { - if let Some(tx) = CLIENT.lock().expect("access client").tx.take() { - drop(tx); - } - RUNTIME.block_on(async { - if receiver_handle.await.is_err() { - report_error("wait for receiver handle failed".to_string()); + + if let Ok(mut client) = CLIENT.lock() { + if let Some(receiver_handle) = client.receiver_handle.take() { + if let Some(tx) = client.tx.take() { + drop(tx); } - }); + RUNTIME.block_on(async { + if receiver_handle.await.is_err() { + report_error("wait for receiver handle failed".to_string()); + } + }); + } + } else { + report_error("couldn't access client".to_string()); } assert!(