diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.h b/native/cpp/CommonCpp/NativeModules/CommRustModule.h --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.h @@ -98,6 +98,18 @@ jsi::String authDeviceID, jsi::String authAccessToken, jsi::String updatePayload) override; + virtual jsi::Value uploadSecondaryDeviceKeysAndLogIn( + jsi::Runtime &rt, + jsi::String userID, + jsi::String challengeResponse, + jsi::String keyPayload, + jsi::String keyPayloadSignature, + jsi::String contentPrekey, + jsi::String contentPrekeySignature, + jsi::String notifPrekey, + jsi::String notifPrekeySignature, + jsi::Array contentOneTimeKeys, + jsi::Array notifOneTimeKeys) override; public: CommRustModule(std::shared_ptr jsInvoker); diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp @@ -473,4 +473,58 @@ }); } +jsi::Value CommRustModule::uploadSecondaryDeviceKeysAndLogIn( + jsi::Runtime &rt, + jsi::String userID, + jsi::String challengeResponse, + jsi::String keyPayload, + jsi::String keyPayloadSignature, + jsi::String contentPrekey, + jsi::String contentPrekeySignature, + jsi::String notifPrekey, + jsi::String notifPrekeySignature, + jsi::Array contentOneTimeKeys, + jsi::Array notifOneTimeKeys) { + auto userIDRust = jsiStringToRustString(userID, rt); + auto challengeResponseRust = jsiStringToRustString(challengeResponse, rt); + auto keyPayloadRust = jsiStringToRustString(keyPayload, rt); + auto keyPayloadSignatureRust = jsiStringToRustString(keyPayloadSignature, rt); + auto contentPrekeyRust = jsiStringToRustString(contentPrekey, rt); + auto contentPrekeySignatureRust = + jsiStringToRustString(contentPrekeySignature, rt); + auto notifPrekeyRust = jsiStringToRustString(notifPrekey, rt); + auto notifPrekeySignatureRust = + jsiStringToRustString(notifPrekeySignature, rt); + auto contentOneTimeKeysRust = jsiStringArrayToRustVec(contentOneTimeKeys, rt); + auto notifOneTimeKeysRust = jsiStringArrayToRustVec(notifOneTimeKeys, rt); + + return createPromiseAsJSIValue( + rt, [=, this](jsi::Runtime &innerRt, std::shared_ptr promise) { + std::string error; + try { + auto currentID = RustPromiseManager::instance.addPromise( + {promise, this->jsInvoker_, innerRt}); + identityUploadSecondaryDeviceKeysAndLogIn( + userIDRust, + challengeResponseRust, + keyPayloadRust, + keyPayloadSignatureRust, + contentPrekeyRust, + contentPrekeySignatureRust, + notifPrekeyRust, + notifPrekeySignatureRust, + contentOneTimeKeysRust, + notifOneTimeKeysRust, + currentID); + } catch (const std::exception &e) { + error = e.what(); + }; + if (!error.empty()) { + this->jsInvoker_->invokeAsync( + [error, promise]() { promise->reject(error); }); + } + }); + return jsi::Value::undefined(); +} + } // namespace comm diff --git a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp @@ -51,6 +51,9 @@ static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_updateDeviceList(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->updateDeviceList(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt)); } +static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_uploadSecondaryDeviceKeysAndLogIn(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->uploadSecondaryDeviceKeysAndLogIn(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt), args[4].asString(rt), args[5].asString(rt), args[6].asString(rt), args[7].asString(rt), args[8].asObject(rt).asArray(rt), args[9].asObject(rt).asArray(rt)); +} CommRustModuleSchemaCxxSpecJSI::CommRustModuleSchemaCxxSpecJSI(std::shared_ptr jsInvoker) : TurboModule("CommRustTurboModule", jsInvoker) { @@ -67,6 +70,7 @@ methodMap_["getKeyserverKeys"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys}; methodMap_["getDeviceListForUser"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getDeviceListForUser}; methodMap_["updateDeviceList"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_updateDeviceList}; + methodMap_["uploadSecondaryDeviceKeysAndLogIn"] = MethodMetadata {10, __hostFunction_CommRustModuleSchemaCxxSpecJSI_uploadSecondaryDeviceKeysAndLogIn}; } diff --git a/native/cpp/CommonCpp/_generated/rustJSI.h b/native/cpp/CommonCpp/_generated/rustJSI.h --- a/native/cpp/CommonCpp/_generated/rustJSI.h +++ b/native/cpp/CommonCpp/_generated/rustJSI.h @@ -33,6 +33,7 @@ virtual jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) = 0; virtual jsi::Value getDeviceListForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, std::optional sinceTimestamp) = 0; virtual jsi::Value updateDeviceList(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String updatePayload) = 0; + virtual jsi::Value uploadSecondaryDeviceKeysAndLogIn(jsi::Runtime &rt, jsi::String userID, jsi::String challengeResponse, jsi::String keyPayload, jsi::String keyPayloadSignature, jsi::String contentPrekey, jsi::String contentPrekeySignature, jsi::String notifPrekey, jsi::String notifPrekeySignature, jsi::Array contentOneTimeKeys, jsi::Array notifOneTimeKeys) = 0; }; @@ -158,6 +159,14 @@ return bridging::callFromJs( rt, &T::updateDeviceList, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(updatePayload)); } + jsi::Value uploadSecondaryDeviceKeysAndLogIn(jsi::Runtime &rt, jsi::String userID, jsi::String challengeResponse, jsi::String keyPayload, jsi::String keyPayloadSignature, jsi::String contentPrekey, jsi::String contentPrekeySignature, jsi::String notifPrekey, jsi::String notifPrekeySignature, jsi::Array contentOneTimeKeys, jsi::Array notifOneTimeKeys) override { + static_assert( + bridging::getParameterCount(&T::uploadSecondaryDeviceKeysAndLogIn) == 11, + "Expected uploadSecondaryDeviceKeysAndLogIn(...) to have 11 parameters"); + + return bridging::callFromJs( + rt, &T::uploadSecondaryDeviceKeysAndLogIn, jsInvoker_, instance_, std::move(userID), std::move(challengeResponse), std::move(keyPayload), std::move(keyPayloadSignature), std::move(contentPrekey), std::move(contentPrekeySignature), std::move(notifPrekey), std::move(notifPrekeySignature), std::move(contentOneTimeKeys), std::move(notifOneTimeKeys)); + } private: T *instance_; diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -12,9 +12,10 @@ UploadOneTimeKeysRequest, }; use grpc_clients::identity::protos::unauth::{ - DeviceKeyUpload, DeviceType, Empty, IdentityKeyInfo, + AuthResponse, DeviceKeyUpload, DeviceType, Empty, IdentityKeyInfo, OpaqueLoginFinishRequest, OpaqueLoginStartRequest, Prekey, - RegistrationFinishRequest, RegistrationStartRequest, WalletLoginRequest, + RegistrationFinishRequest, RegistrationStartRequest, + SecondaryDeviceKeysUploadRequest, WalletLoginRequest, }; use grpc_clients::identity::{ get_auth_client, get_unauthenticated_client, REQUEST_METADATA_COOKIE_KEY, @@ -193,6 +194,21 @@ promise_id: u32, ); + #[cxx_name = "identityUploadSecondaryDeviceKeysAndLogIn"] + fn upload_secondary_device_keys_and_log_in( + user_id: String, + challenge_response: String, + key_payload: String, + key_payload_signature: String, + content_prekey: String, + content_prekey_signature: String, + notif_prekey: String, + notif_prekey_signature: String, + content_one_time_keys: Vec, + notif_one_time_keys: Vec, + promise_id: u32, + ); + // Argon2 #[cxx_name = "compute_backup_key"] fn compute_backup_key_str( @@ -525,6 +541,19 @@ access_token: String, } +impl From for UserIDAndDeviceAccessToken { + fn from(value: AuthResponse) -> Self { + let AuthResponse { + user_id, + access_token, + } = value; + Self { + user_id, + access_token, + } + } +} + async fn register_user_helper( password_user_info: PasswordUserInfo, ) -> Result { @@ -600,10 +629,8 @@ .register_password_user_finish(finish_request) .await? .into_inner(); - let user_id_and_access_token = UserIDAndDeviceAccessToken { - user_id: registration_finish_response.user_id, - access_token: registration_finish_response.access_token, - }; + let user_id_and_access_token = + UserIDAndDeviceAccessToken::from(registration_finish_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } @@ -712,10 +739,8 @@ .log_in_password_user_finish(finish_request) .await? .into_inner(); - let user_id_and_access_token = UserIDAndDeviceAccessToken { - user_id: login_finish_response.user_id, - access_token: login_finish_response.access_token, - }; + let user_id_and_access_token = + UserIDAndDeviceAccessToken::from(login_finish_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } @@ -802,10 +827,8 @@ .await? .into_inner(); - let user_id_and_access_token = UserIDAndDeviceAccessToken { - user_id: login_response.user_id, - access_token: login_response.access_token, - }; + let user_id_and_access_token = + UserIDAndDeviceAccessToken::from(login_response); Ok(serde_json::to_string(&user_id_and_access_token)?) } @@ -1353,6 +1376,76 @@ Ok(()) } +fn upload_secondary_device_keys_and_log_in( + user_id: String, + challenge_response: String, + key_payload: String, + key_payload_signature: String, + content_prekey: String, + content_prekey_signature: String, + notif_prekey: String, + notif_prekey_signature: String, + content_one_time_keys: Vec, + notif_one_time_keys: Vec, + promise_id: u32, +) { + RUNTIME.spawn(async move { + let device_key_upload = DeviceKeyUpload { + device_key_info: Some(IdentityKeyInfo { + payload: key_payload, + payload_signature: key_payload_signature, + social_proof: None, + }), + content_upload: Some(Prekey { + prekey: content_prekey, + prekey_signature: content_prekey_signature, + }), + notif_upload: Some(Prekey { + prekey: notif_prekey, + prekey_signature: notif_prekey_signature, + }), + one_time_content_prekeys: content_one_time_keys, + one_time_notif_prekeys: notif_one_time_keys, + device_type: DEVICE_TYPE.into(), + }; + + let result = upload_secondary_device_keys_and_log_in_helper( + user_id, + challenge_response, + device_key_upload, + ) + .await; + handle_string_result_as_callback(result, promise_id); + }); +} + +async fn upload_secondary_device_keys_and_log_in_helper( + user_id: String, + challenge_response: String, + device_key_upload: DeviceKeyUpload, +) -> Result { + let mut identity_client = get_unauthenticated_client( + IDENTITY_SOCKET_ADDR, + CODE_VERSION, + DEVICE_TYPE.as_str_name().to_lowercase(), + ) + .await?; + + let request = SecondaryDeviceKeysUploadRequest { + user_id, + challenge_response, + device_key_upload: Some(device_key_upload), + }; + + let response = identity_client + .upload_keys_for_registered_device_and_log_in(request) + .await? + .into_inner(); + + let user_id_and_access_token = UserIDAndDeviceAccessToken::from(response); + Ok(serde_json::to_string(&user_id_and_access_token)?) +} + #[derive( Debug, derive_more::Display, derive_more::From, derive_more::Error, )] diff --git a/native/schema/CommRustModuleSchema.js b/native/schema/CommRustModuleSchema.js --- a/native/schema/CommRustModuleSchema.js +++ b/native/schema/CommRustModuleSchema.js @@ -93,6 +93,18 @@ authAccessToken: string, updatePayload: string, ) => Promise; + +uploadSecondaryDeviceKeysAndLogIn: ( + userID: string, + challengeResponse: string, + keyPayload: string, + keyPayloadSignature: string, + contentPrekey: string, + contentPrekeySignature: string, + notifPrekey: string, + notifPrekeySignature: string, + contentOneTimeKeys: $ReadOnlyArray, + notifOneTimeKeys: $ReadOnlyArray, + ) => Promise; } export default (TurboModuleRegistry.getEnforcing(