diff --git a/keyserver/addons/rust-node-addon/src/identity_client/login_user.rs b/keyserver/addons/rust-node-addon/src/identity_client/login_user.rs --- a/keyserver/addons/rust-node-addon/src/identity_client/login_user.rs +++ b/keyserver/addons/rust-node-addon/src/identity_client/login_user.rs @@ -124,7 +124,12 @@ // Finish PAKE login; send final login request to Identity service let message = response.message().await.map_err(|e| { error!("Received an error from inbound message stream: {}", e); - Error::from_status(Status::GenericFailure) + match e.code() { + Code::NotFound => { + Error::new(Status::InvalidArg, "user not found".to_string()) + } + _ => Error::new(Status::GenericFailure, e.to_string()), + } })?; handle_login_credential_response( message, diff --git a/keyserver/addons/rust-node-addon/src/identity_client/mod.rs b/keyserver/addons/rust-node-addon/src/identity_client/mod.rs --- a/keyserver/addons/rust-node-addon/src/identity_client/mod.rs +++ b/keyserver/addons/rust-node-addon/src/identity_client/mod.rs @@ -46,7 +46,7 @@ use std::env::var; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use tonic::{metadata::MetadataValue, transport::Channel, Request}; +use tonic::{metadata::MetadataValue, transport::Channel, Code, Request}; use tracing::{error, instrument}; lazy_static! { diff --git a/keyserver/src/responders/user-responders.js b/keyserver/src/responders/user-responders.js --- a/keyserver/src/responders/user-responders.js +++ b/keyserver/src/responders/user-responders.js @@ -430,12 +430,24 @@ handleAsyncPromise( (async () => { const rustAPI = await getRustAPI(); - await rustAPI.loginUserPake( - id, - constIdentityKeys.primaryIdentityPublicKeys.ed25519, - request.password, - signedIdentityKeysBlob, - ); + try { + await rustAPI.loginUserPake( + id, + constIdentityKeys.primaryIdentityPublicKeys.ed25519, + request.password, + signedIdentityKeysBlob, + ); + } catch (e) { + if (e.code === 'InvalidArg' && e.message === 'user not found') { + await rustAPI.registerUser( + id, + constIdentityKeys.primaryIdentityPublicKeys.ed25519, + username, + request.password, + signedIdentityKeysBlob, + ); + } + } })(), ); } diff --git a/services/identity/src/service.rs b/services/identity/src/service.rs --- a/services/identity/src/service.rs +++ b/services/identity/src/service.rs @@ -600,3 +600,18 @@ response: PakeLoginResponseStruct, pake_state: ServerLogin, } + +async fn send_to_client( + tx: &tokio::sync::mpsc::Sender>, + response: Result, +) -> Result<(), Status> { + let transport_result = match response { + Ok(message) => tx.send(Ok(message)).await, + Err(status) => { + error!("{}", status.message()); + tx.send(Err(status)).await + } + }; + + transport_result.map_err(|_| Status::internal("disconnection")) +} diff --git a/services/identity/src/service/login.rs b/services/identity/src/service/login.rs --- a/services/identity/src/service/login.rs +++ b/services/identity/src/service/login.rs @@ -32,12 +32,19 @@ )), })), })) => { - let response_and_state = pake_login_start( + let response_and_state = match pake_login_start( client, &pake_credential_request_and_user_id.user_id, &pake_credential_request_and_user_id.pake_credential_request, ) - .await?; + .await + { + Ok(r) => r, + Err(e) => { + send_to_client(&tx, Err(e.clone())).await?; + return Err(e); + } + }; let login_response = LoginResponse { data: Some(PakeLoginResponse(response_and_state.response)), }; diff --git a/services/identity/src/service/update.rs b/services/identity/src/service/update.rs --- a/services/identity/src/service/update.rs +++ b/services/identity/src/service/update.rs @@ -18,25 +18,10 @@ use crate::{database::DatabaseClient, pake_grpc}; use super::{ - handle_db_error, pake_login_start, put_token_helper, Status, + handle_db_error, pake_login_start, put_token_helper, send_to_client, Status, UpdateUserRequest, UpdateUserResponse, }; -async fn send_to_client( - tx: &tokio::sync::mpsc::Sender>, - response: Result, -) -> Result<(), Status> { - let transport_result = match response { - Ok(message) => tx.send(Ok(message)).await, - Err(status) => { - error!("{}", status.message()); - tx.send(Err(status)).await - } - }; - - transport_result.map_err(|_| Status::internal("disconnection")) -} - pub(crate) async fn handle_server_update_user_messages( in_stream: Streaming, client: DatabaseClient,