diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.h b/native/cpp/CommonCpp/NativeModules/CommRustModule.h --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.h @@ -68,6 +68,12 @@ jsi::String userID, jsi::String deviceID) override; virtual jsi::Value versionSupported(jsi::Runtime &rt) override; + virtual jsi::Value getKeyserverKeys( + jsi::Runtime &rt, + jsi::String authUserID, + jsi::String authDeviceID, + jsi::String authAccessToken, + jsi::String keyserverID) override; public: CommRustModule(std::shared_ptr jsInvoker); diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp @@ -263,4 +263,37 @@ }); } +jsi::Value CommRustModule::getKeyserverKeys( + jsi::Runtime &rt, + jsi::String authUserID, + jsi::String authDeviceID, + jsi::String authAccessToken, + jsi::String keyserverID) { + auto authUserIDRust = jsiStringToRustString(authUserID, rt); + auto authDeviceIDRust = jsiStringToRustString(authDeviceID, rt); + auto authAccessTokenRust = jsiStringToRustString(authAccessToken, rt); + auto keyserverIDRust = jsiStringToRustString(keyserverID, rt); + + return createPromiseAsJSIValue( + rt, [=, this](jsi::Runtime &innerRt, std::shared_ptr promise) { + std::string error; + try { + auto currentID = RustPromiseManager::instance.addPromise( + promise, this->jsInvoker_, innerRt); + identityGetKeyserverKeys( + authUserIDRust, + authDeviceIDRust, + authAccessTokenRust, + keyserverIDRust, + currentID); + } catch (const std::exception &e) { + error = e.what(); + }; + if (!error.empty()) { + this->jsInvoker_->invokeAsync( + [error, promise]() { promise->reject(error); }); + } + }); +} + } // namespace comm diff --git a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp @@ -36,6 +36,9 @@ static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->versionSupported(rt); } +static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->getKeyserverKeys(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt)); +} CommRustModuleSchemaCxxSpecJSI::CommRustModuleSchemaCxxSpecJSI(std::shared_ptr jsInvoker) : TurboModule("CommRustTurboModule", jsInvoker) { @@ -47,6 +50,7 @@ methodMap_["deleteUser"] = MethodMetadata {3, __hostFunction_CommRustModuleSchemaCxxSpecJSI_deleteUser}; methodMap_["getOutboundKeysForUserDevice"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUserDevice}; methodMap_["versionSupported"] = MethodMetadata {0, __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported}; + methodMap_["getKeyserverKeys"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys}; } diff --git a/native/cpp/CommonCpp/_generated/rustJSI.h b/native/cpp/CommonCpp/_generated/rustJSI.h --- a/native/cpp/CommonCpp/_generated/rustJSI.h +++ b/native/cpp/CommonCpp/_generated/rustJSI.h @@ -28,6 +28,7 @@ virtual jsi::Value deleteUser(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken) = 0; virtual jsi::Value getOutboundKeysForUserDevice(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, jsi::String deviceID) = 0; virtual jsi::Value versionSupported(jsi::Runtime &rt) = 0; + virtual jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) = 0; }; @@ -113,6 +114,14 @@ return bridging::callFromJs( rt, &T::versionSupported, jsInvoker_, instance_); } + jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) override { + static_assert( + bridging::getParameterCount(&T::getKeyserverKeys) == 5, + "Expected getKeyserverKeys(...) to have 5 parameters"); + + return bridging::callFromJs( + rt, &T::getKeyserverKeys, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(keyserverID)); + } private: T *instance_; diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -3,8 +3,8 @@ use comm_opaque2::grpc::opaque_error_to_grpc_status as handle_error; use ffi::{bool_callback, string_callback, void_callback}; use grpc_clients::identity::protos::authenticated::{ - OutboundKeyInfo, OutboundKeysForUserRequest, UpdateUserPasswordFinishRequest, - UpdateUserPasswordStartRequest, + KeyserverKeysResponse, OutboundKeyInfo, OutboundKeysForUserRequest, + UpdateUserPasswordFinishRequest, UpdateUserPasswordStartRequest, }; use grpc_clients::identity::protos::client::{ DeviceKeyUpload, DeviceType, Empty, IdentityKeyInfo, @@ -131,6 +131,15 @@ #[cxx_name = "identityVersionSupported"] fn version_supported(promise_id: u32); + #[cxx_name = "identityGetKeyserverKeys"] + fn get_keyserver_keys( + user_id: String, + device_id: String, + access_token: String, + keyserver_id: String, + promise_id: u32, + ); + // Argon2 #[cxx_name = "compute_backup_key"] fn compute_backup_key_str( @@ -308,6 +317,51 @@ } } +fn get_keyserver_keys( + user_id: String, + device_id: String, + access_token: String, + keyserver_id: String, + promise_id: u32, +) { + RUNTIME.spawn(async move { + let get_keyserver_keys_request = OutboundKeysForUserRequest { + user_id: keyserver_id, + }; + let auth_info = AuthInfo { + access_token, + user_id, + device_id, + }; + let result = + get_keyserver_keys_helper(get_keyserver_keys_request, auth_info).await; + handle_string_result_as_callback(result, promise_id); + }); +} + +async fn get_keyserver_keys_helper( + get_keyserver_keys_request: OutboundKeysForUserRequest, + auth_info: AuthInfo, +) -> Result { + let mut identity_client = get_auth_client( + IDENTITY_SOCKET_ADDR, + auth_info.user_id, + auth_info.device_id, + auth_info.access_token, + CODE_VERSION, + DEVICE_TYPE.as_str_name().to_lowercase(), + ) + .await?; + let mut response = identity_client + .get_keyserver_keys(get_keyserver_keys_request) + .await? + .into_inner(); + + let keyserver_keys = OutboundKeyInfoResponse::try_from(response)?; + + Ok(serde_json::to_string(&keyserver_keys)?) +} + struct AuthInfo { user_id: String, device_id: String, @@ -841,6 +895,15 @@ } } +impl TryFrom for OutboundKeyInfoResponse { + type Error = Error; + + fn try_from(response: KeyserverKeysResponse) -> Result { + let key_info = response.keyserver_info.ok_or(Error::MissingResponseData)?; + Self::try_from(key_info) + } +} + fn get_outbound_keys_for_user_device( auth_user_id: String, auth_device_id: String, diff --git a/native/schema/CommRustModuleSchema.js b/native/schema/CommRustModuleSchema.js --- a/native/schema/CommRustModuleSchema.js +++ b/native/schema/CommRustModuleSchema.js @@ -63,6 +63,12 @@ deviceID: string, ) => Promise; +versionSupported: () => Promise; + +getKeyserverKeys: ( + authUserID: string, + authDeviceID: string, + authAccessToken: string, + keyserverID: string, + ) => Promise; } export default (TurboModuleRegistry.getEnforcing(