diff --git a/services/identity/src/service.rs b/services/identity/src/service.rs --- a/services/identity/src/service.rs +++ b/services/identity/src/service.rs @@ -1,8 +1,10 @@ use chrono::Utc; use constant_time_eq::constant_time_eq; use futures_core::Stream; +use opaque_ke::ServerRegistration; use opaque_ke::{ - CredentialFinalization, CredentialRequest, ServerLogin, + CredentialFinalization, CredentialRequest, + RegistrationRequest as PakeRegistrationRequest, ServerLogin, ServerLoginStartParameters, }; use rand::rngs::OsRng; @@ -31,7 +33,11 @@ pake_login_request::Data::PakeCredentialFinalization, pake_login_request::Data::PakeCredentialRequestAndUserId, pake_login_response::Data::AccessToken, - pake_login_response::Data::PakeCredentialResponse, LoginRequest, + pake_login_response::Data::PakeCredentialResponse, + registration_request::Data::PakeCredentialFinalization as PakeRegistrationCredentialFinalization, + registration_request::Data::PakeRegistrationRequestAndUserId, + registration_request::Data::PakeRegistrationUploadAndCredentialRequest, + registration_response::Data::PakeRegistrationResponse, LoginRequest, LoginResponse, PakeCredentialRequestAndUserId as PakeCredentialRequestAndUserIdStruct, PakeLoginResponse as PakeLoginResponseStruct, RegistrationRequest, @@ -63,12 +69,67 @@ >, >; + #[instrument(skip(self))] async fn register_user( &self, request: Request>, ) -> Result, Status> { - println!("Got a registration request: {:?}", request); - unimplemented!() + let mut in_stream = request.into_inner(); + let (tx, rx) = mpsc::channel(1); + let config = self.config.clone(); + let client = self.client.clone(); + tokio::spawn(async move { + let mut user_id: String = "".to_string(); + let mut device_id: String = "".to_string(); + let mut server_registration: Option> = None; + let mut server_login: Option> = None; + let mut num_messages_received = 0; + while let Some(message) = in_stream.next().await { + match message { + Ok(registration_request) => { + if let Some(data) = registration_request.data { + match data { + PakeRegistrationRequestAndUserId( + pake_registration_request_and_user_id, + ) => { + if let Err(e) = tx + .send( + pake_registration_start( + config.clone(), + &mut OsRng, + &pake_registration_request_and_user_id + .pake_registration_request, + &mut server_registration, + num_messages_received, + ) + .await, + ) + .await + { + error!("Response was dropped: {}", e); + break; + } + user_id = pake_registration_request_and_user_id.user_id; + device_id = pake_registration_request_and_user_id.device_id; + } + PakeRegistrationUploadAndCredentialRequest( + pake_registration_upload_and_credential_request, + ) => unimplemented!(), + PakeRegistrationCredentialFinalization( + pake_credential_finalization, + ) => unimplemented!(), + } + } + } + Err(e) => unimplemented!(), + } + num_messages_received += 1; + } + }); + let out_stream = ReceiverStream::new(rx); + Ok(Response::new( + Box::pin(out_stream) as Self::RegisterUserStream + )) } type LoginUserStream = @@ -492,3 +553,37 @@ } } } + +async fn pake_registration_start( + config: Config, + rng: &mut (impl Rng + CryptoRng), + registration_request_bytes: &[u8], + server_registration: &mut Option>, + num_messages_received: u8, +) -> Result { + if num_messages_received != 0 { + error!("Too many messages received in stream, aborting"); + return Err(Status::aborted("please retry")); + } + match ServerRegistration::::start( + rng, + PakeRegistrationRequest::deserialize(registration_request_bytes).unwrap(), + config.server_keypair.public(), + ) { + Ok(server_registration_start_result) => { + *server_registration = Some(server_registration_start_result.state); + Ok(RegistrationResponse { + data: Some(PakeRegistrationResponse( + server_registration_start_result.message.serialize(), + )), + }) + } + Err(e) => { + error!( + "Encountered a PAKE protocol error when starting registration: {}", + e + ); + Err(Status::aborted("server error")) + } + } +}