diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.h b/native/cpp/CommonCpp/NativeModules/CommRustModule.h --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.h @@ -80,6 +80,12 @@ jsi::String authAccessToken, jsi::Array contentOneTimePreKeys, jsi::Array notifOneTimePreKeys) override; + virtual jsi::Value getKeyserverKeys( + jsi::Runtime &rt, + jsi::String authUserID, + jsi::String authDeviceID, + jsi::String authAccessToken, + jsi::String keyserverID) override; public: CommRustModule(std::shared_ptr jsInvoker); diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp @@ -375,4 +375,37 @@ }); } +jsi::Value CommRustModule::getKeyserverKeys( + jsi::Runtime &rt, + jsi::String authUserID, + jsi::String authDeviceID, + jsi::String authAccessToken, + jsi::String keyserverID) { + auto authUserIDRust = jsiStringToRustString(authUserID, rt); + auto authDeviceIDRust = jsiStringToRustString(authDeviceID, rt); + auto authAccessTokenRust = jsiStringToRustString(authAccessToken, rt); + auto keyserverIDRust = jsiStringToRustString(keyserverID, rt); + + return createPromiseAsJSIValue( + rt, [=, this](jsi::Runtime &innerRt, std::shared_ptr promise) { + std::string error; + try { + auto currentID = RustPromiseManager::instance.addPromise( + promise, this->jsInvoker_, innerRt); + identityGetKeyserverKeys( + authUserIDRust, + authDeviceIDRust, + authAccessTokenRust, + keyserverIDRust, + currentID); + } catch (const std::exception &e) { + error = e.what(); + }; + if (!error.empty()) { + this->jsInvoker_->invokeAsync( + [error, promise]() { promise->reject(error); }); + } + }); +} + } // namespace comm diff --git a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp @@ -42,6 +42,9 @@ static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_uploadOneTimeKeys(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->uploadOneTimeKeys(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asObject(rt).asArray(rt), args[4].asObject(rt).asArray(rt)); } +static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->getKeyserverKeys(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt)); +} CommRustModuleSchemaCxxSpecJSI::CommRustModuleSchemaCxxSpecJSI(std::shared_ptr jsInvoker) : TurboModule("CommRustTurboModule", jsInvoker) { @@ -55,6 +58,7 @@ methodMap_["getInboundKeysForUser"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getInboundKeysForUser}; methodMap_["versionSupported"] = MethodMetadata {0, __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported}; methodMap_["uploadOneTimeKeys"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_uploadOneTimeKeys}; + methodMap_["getKeyserverKeys"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys}; } diff --git a/native/cpp/CommonCpp/_generated/rustJSI.h b/native/cpp/CommonCpp/_generated/rustJSI.h --- a/native/cpp/CommonCpp/_generated/rustJSI.h +++ b/native/cpp/CommonCpp/_generated/rustJSI.h @@ -30,6 +30,7 @@ virtual jsi::Value getInboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID) = 0; virtual jsi::Value versionSupported(jsi::Runtime &rt) = 0; virtual jsi::Value uploadOneTimeKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::Array contentOneTimePreKeys, jsi::Array notifOneTimePreKeys) = 0; + virtual jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) = 0; }; @@ -131,6 +132,14 @@ return bridging::callFromJs( rt, &T::uploadOneTimeKeys, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(contentOneTimePreKeys), std::move(notifOneTimePreKeys)); } + jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) override { + static_assert( + bridging::getParameterCount(&T::getKeyserverKeys) == 5, + "Expected getKeyserverKeys(...) to have 5 parameters"); + + return bridging::callFromJs( + rt, &T::getKeyserverKeys, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(keyserverID)); + } private: T *instance_; diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -3,8 +3,8 @@ use comm_opaque2::grpc::opaque_error_to_grpc_status as handle_error; use ffi::{bool_callback, string_callback, void_callback}; use grpc_clients::identity::protos::authenticated::{ - InboundKeyInfo, InboundKeysForUserRequest, OutboundKeyInfo, - OutboundKeysForUserRequest, UpdateUserPasswordFinishRequest, + InboundKeyInfo, InboundKeysForUserRequest, KeyserverKeysResponse, + OutboundKeyInfo, OutboundKeysForUserRequest, UpdateUserPasswordFinishRequest, UpdateUserPasswordStartRequest, UploadOneTimeKeysRequest, }; use grpc_clients::identity::protos::unauth::{ @@ -150,6 +150,15 @@ promise_id: u32, ); + #[cxx_name = "identityGetKeyserverKeys"] + fn get_keyserver_keys( + user_id: String, + device_id: String, + access_token: String, + keyserver_id: String, + promise_id: u32, + ); + // Argon2 #[cxx_name = "compute_backup_key"] fn compute_backup_key_str( @@ -344,6 +353,51 @@ } } +fn get_keyserver_keys( + user_id: String, + device_id: String, + access_token: String, + keyserver_id: String, + promise_id: u32, +) { + RUNTIME.spawn(async move { + let get_keyserver_keys_request = OutboundKeysForUserRequest { + user_id: keyserver_id, + }; + let auth_info = AuthInfo { + access_token, + user_id, + device_id, + }; + let result = + get_keyserver_keys_helper(get_keyserver_keys_request, auth_info).await; + handle_string_result_as_callback(result, promise_id); + }); +} + +async fn get_keyserver_keys_helper( + get_keyserver_keys_request: OutboundKeysForUserRequest, + auth_info: AuthInfo, +) -> Result { + let mut identity_client = get_auth_client( + IDENTITY_SOCKET_ADDR, + auth_info.user_id, + auth_info.device_id, + auth_info.access_token, + CODE_VERSION, + DEVICE_TYPE.as_str_name().to_lowercase(), + ) + .await?; + let mut response = identity_client + .get_keyserver_keys(get_keyserver_keys_request) + .await? + .into_inner(); + + let keyserver_keys = OutboundKeyInfoResponse::try_from(response)?; + + Ok(serde_json::to_string(&keyserver_keys)?) +} + struct AuthInfo { user_id: String, device_id: String, @@ -894,6 +948,15 @@ } } +impl TryFrom for OutboundKeyInfoResponse { + type Error = Error; + + fn try_from(response: KeyserverKeysResponse) -> Result { + let key_info = response.keyserver_info.ok_or(Error::MissingResponseData)?; + Self::try_from(key_info) + } +} + fn get_outbound_keys_for_user( auth_user_id: String, auth_device_id: String, diff --git a/native/schema/CommRustModuleSchema.js b/native/schema/CommRustModuleSchema.js --- a/native/schema/CommRustModuleSchema.js +++ b/native/schema/CommRustModuleSchema.js @@ -75,6 +75,12 @@ contentOneTimePreKeys: $ReadOnlyArray, notifOneTimePreKeys: $ReadOnlyArray, ) => Promise; + +getKeyserverKeys: ( + authUserID: string, + authDeviceID: string, + authAccessToken: string, + keyserverID: string, + ) => Promise; } export default (TurboModuleRegistry.getEnforcing(