diff --git a/lib/types/identity-service-types.js b/lib/types/identity-service-types.js --- a/lib/types/identity-service-types.js +++ b/lib/types/identity-service-types.js @@ -151,9 +151,11 @@ ) => Promise; +getOutboundKeysForUser: ( userID: string, + selectedDeviceIDs?: $ReadOnlyArray, ) => Promise; +getInboundKeysForUser: ( userID: string, + selectedDeviceIDs?: $ReadOnlyArray, ) => Promise; +uploadOneTimeKeys: (oneTimeKeys: OneTimeKeysResultValues) => Promise; +generateNonce: () => Promise; diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.h b/native/cpp/CommonCpp/NativeModules/CommRustModule.h --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.h @@ -119,13 +119,15 @@ jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, - jsi::String userID) override; + jsi::String userID, + jsi::Array selectedDeviceIDs) override; virtual jsi::Value getInboundKeysForUser( jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, - jsi::String userID) override; + jsi::String userID, + jsi::Array selectedDeviceIDs) override; virtual jsi::Value versionSupported(jsi::Runtime &rt) override; virtual jsi::Value uploadOneTimeKeys( jsi::Runtime &rt, diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp @@ -504,11 +504,13 @@ jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, - jsi::String userID) { + jsi::String userID, + jsi::Array selectedDeviceIDs) { auto authUserIDRust = jsiStringToRustString(authUserID, rt); auto authDeviceIDRust = jsiStringToRustString(authDeviceID, rt); auto authAccessTokenRust = jsiStringToRustString(authAccessToken, rt); auto userIDRust = jsiStringToRustString(userID, rt); + auto selectedDeviceIDsRust = jsiStringArrayToRustVec(selectedDeviceIDs, rt); return createPromiseAsJSIValue( rt, [=, this](jsi::Runtime &innerRt, std::shared_ptr promise) { @@ -521,6 +523,7 @@ authDeviceIDRust, authAccessTokenRust, userIDRust, + selectedDeviceIDsRust, currentID); } catch (const std::exception &e) { error = e.what(); @@ -537,11 +540,13 @@ jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, - jsi::String userID) { + jsi::String userID, + jsi::Array selectedDeviceIDs) { auto authUserIDRust = jsiStringToRustString(authUserID, rt); auto authDeviceIDRust = jsiStringToRustString(authDeviceID, rt); auto authAccessTokenRust = jsiStringToRustString(authAccessToken, rt); auto userIDRust = jsiStringToRustString(userID, rt); + auto selectedDeviceIDsRust = jsiStringArrayToRustVec(selectedDeviceIDs, rt); return createPromiseAsJSIValue( rt, [=, this](jsi::Runtime &innerRt, std::shared_ptr promise) { @@ -554,7 +559,7 @@ authDeviceIDRust, authAccessTokenRust, userIDRust, - currentID); + selectedDeviceIDsRust currentID); } catch (const std::exception &e) { error = e.what(); }; diff --git a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp @@ -49,10 +49,10 @@ return static_cast(&turboModule)->logOutSecondaryDevice(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt)); } static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUser(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->getOutboundKeysForUser(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt)); + return static_cast(&turboModule)->getOutboundKeysForUser(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt), args[4].asObject(rt).asArray(rt)); } static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_getInboundKeysForUser(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { - return static_cast(&turboModule)->getInboundKeysForUser(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt)); + return static_cast(&turboModule)->getInboundKeysForUser(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt), args[3].asString(rt), args[4].asObject(rt).asArray(rt)); } static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->versionSupported(rt); @@ -114,8 +114,8 @@ methodMap_["logOut"] = MethodMetadata {3, __hostFunction_CommRustModuleSchemaCxxSpecJSI_logOut}; methodMap_["logOutPrimaryDevice"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_logOutPrimaryDevice}; methodMap_["logOutSecondaryDevice"] = MethodMetadata {3, __hostFunction_CommRustModuleSchemaCxxSpecJSI_logOutSecondaryDevice}; - methodMap_["getOutboundKeysForUser"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUser}; - methodMap_["getInboundKeysForUser"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getInboundKeysForUser}; + methodMap_["getOutboundKeysForUser"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUser}; + methodMap_["getInboundKeysForUser"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getInboundKeysForUser}; methodMap_["versionSupported"] = MethodMetadata {0, __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported}; methodMap_["uploadOneTimeKeys"] = MethodMetadata {5, __hostFunction_CommRustModuleSchemaCxxSpecJSI_uploadOneTimeKeys}; methodMap_["getKeyserverKeys"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getKeyserverKeys}; diff --git a/native/cpp/CommonCpp/_generated/rustJSI.h b/native/cpp/CommonCpp/_generated/rustJSI.h --- a/native/cpp/CommonCpp/_generated/rustJSI.h +++ b/native/cpp/CommonCpp/_generated/rustJSI.h @@ -32,8 +32,8 @@ virtual jsi::Value logOut(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken) = 0; virtual jsi::Value logOutPrimaryDevice(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken, jsi::String signedDeviceList) = 0; virtual jsi::Value logOutSecondaryDevice(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken) = 0; - virtual jsi::Value getOutboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID) = 0; - virtual jsi::Value getInboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID) = 0; + virtual jsi::Value getOutboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, jsi::Array selectedDeviceIDs) = 0; + virtual jsi::Value getInboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, jsi::Array selectedDeviceIDs) = 0; virtual jsi::Value versionSupported(jsi::Runtime &rt) = 0; virtual jsi::Value uploadOneTimeKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::Array contentOneTimePreKeys, jsi::Array notifOneTimePreKeys) = 0; virtual jsi::Value getKeyserverKeys(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String keyserverID) = 0; @@ -166,21 +166,21 @@ return bridging::callFromJs( rt, &T::logOutSecondaryDevice, jsInvoker_, instance_, std::move(userID), std::move(deviceID), std::move(accessToken)); } - jsi::Value getOutboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID) override { + jsi::Value getOutboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, jsi::Array selectedDeviceIDs) override { static_assert( - bridging::getParameterCount(&T::getOutboundKeysForUser) == 5, - "Expected getOutboundKeysForUser(...) to have 5 parameters"); + bridging::getParameterCount(&T::getOutboundKeysForUser) == 6, + "Expected getOutboundKeysForUser(...) to have 6 parameters"); return bridging::callFromJs( - rt, &T::getOutboundKeysForUser, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(userID)); + rt, &T::getOutboundKeysForUser, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(userID), std::move(selectedDeviceIDs)); } - jsi::Value getInboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID) override { + jsi::Value getInboundKeysForUser(jsi::Runtime &rt, jsi::String authUserID, jsi::String authDeviceID, jsi::String authAccessToken, jsi::String userID, jsi::Array selectedDeviceIDs) override { static_assert( - bridging::getParameterCount(&T::getInboundKeysForUser) == 5, - "Expected getInboundKeysForUser(...) to have 5 parameters"); + bridging::getParameterCount(&T::getInboundKeysForUser) == 6, + "Expected getInboundKeysForUser(...) to have 6 parameters"); return bridging::callFromJs( - rt, &T::getInboundKeysForUser, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(userID)); + rt, &T::getInboundKeysForUser, jsInvoker_, instance_, std::move(authUserID), std::move(authDeviceID), std::move(authAccessToken), std::move(userID), std::move(selectedDeviceIDs)); } jsi::Value versionSupported(jsi::Runtime &rt) override { static_assert( diff --git a/native/identity-service/identity-service-context-provider.react.js b/native/identity-service/identity-service-context-provider.react.js --- a/native/identity-service/identity-service-context-provider.react.js +++ b/native/identity-service/identity-service-context-provider.react.js @@ -224,6 +224,7 @@ }, getOutboundKeysForUser: async ( targetUserID: string, + selectedDeviceIDs: $ReadOnlyArray = [], ): Promise => { const { deviceID: authDeviceID, @@ -236,6 +237,7 @@ authDeviceID, token, targetUserID, + selectedDeviceIDs, ), ); const resultArray = JSON.parse(result); @@ -291,6 +293,7 @@ }, getInboundKeysForUser: async ( targetUserID: string, + selectedDeviceIDs: $ReadOnlyArray = [], ): Promise => { const { deviceID: authDeviceID, @@ -303,6 +306,7 @@ authDeviceID, token, targetUserID, + selectedDeviceIDs, ), ); const resultArray = JSON.parse(result); diff --git a/native/native_rust_library/src/identity/x3dh.rs b/native/native_rust_library/src/identity/x3dh.rs --- a/native/native_rust_library/src/identity/x3dh.rs +++ b/native/native_rust_library/src/identity/x3dh.rs @@ -24,11 +24,14 @@ auth_device_id: String, auth_access_token: String, user_id: String, + selected_devices: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { - let get_outbound_keys_request_info = - GetOutboundKeysRequestInfo { user_id }; + let get_outbound_keys_request_info = GetOutboundKeysRequestInfo { + user_id, + selected_devices, + }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, @@ -48,10 +51,14 @@ auth_device_id: String, auth_access_token: String, user_id: String, + selected_devices: Vec, promise_id: u32, ) { RUNTIME.spawn(async move { - let get_inbound_keys_request_info = GetInboundKeysRequestInfo { user_id }; + let get_inbound_keys_request_info = GetInboundKeysRequestInfo { + user_id, + selected_devices, + }; let auth_info = AuthInfo { access_token: auth_access_token, user_id: auth_user_id, @@ -152,10 +159,12 @@ struct GetOutboundKeysRequestInfo { user_id: String, + selected_devices: Vec, } struct GetInboundKeysRequestInfo { user_id: String, + selected_devices: Vec, } // This struct should not be altered without also updating diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -210,6 +210,7 @@ auth_device_id: String, auth_access_token: String, user_id: String, + selected_devices: Vec, promise_id: u32, ); @@ -219,6 +220,7 @@ auth_device_id: String, auth_access_token: String, user_id: String, + selected_devices: Vec, promise_id: u32, ); diff --git a/native/schema/CommRustModuleSchema.js b/native/schema/CommRustModuleSchema.js --- a/native/schema/CommRustModuleSchema.js +++ b/native/schema/CommRustModuleSchema.js @@ -113,12 +113,14 @@ authDeviceID: string, authAccessToken: string, userID: string, + selectedDeviceIDs: $ReadOnlyArray, ) => Promise; +getInboundKeysForUser: ( authUserID: string, authDeviceID: string, authAccessToken: string, userID: string, + selectedDeviceIDs: $ReadOnlyArray, ) => Promise; +versionSupported: () => Promise; +uploadOneTimeKeys: ( diff --git a/web/grpc/identity-service-client-wrapper.js b/web/grpc/identity-service-client-wrapper.js --- a/web/grpc/identity-service-client-wrapper.js +++ b/web/grpc/identity-service-client-wrapper.js @@ -187,7 +187,11 @@ getOutboundKeysForUser: ( userID: string, - ) => Promise = async (userID: string) => { + selectedDeviceIDs?: $ReadOnlyArray, + ) => Promise = async ( + userID: string, + selectedDeviceIDs = [], + ) => { const client = this.authClient; if (!client) { throw new Error('Identity service client is not initialized'); @@ -195,6 +199,7 @@ const request = new IdentityAuthStructs.OutboundKeysForUserRequest(); request.setUserId(userID); + request.setSelectedDevicesList([...selectedDeviceIDs]); const response = await client.getOutboundKeysForUser(request); const devicesMap = response.toObject()?.devicesMap; @@ -253,7 +258,11 @@ getInboundKeysForUser: ( userID: string, - ) => Promise = async (userID: string) => { + selectedDeviceIDs?: $ReadOnlyArray, + ) => Promise = async ( + userID: string, + selectedDeviceIDs = [], + ) => { const client = this.authClient; if (!client) { throw new Error('Identity service client is not initialized'); @@ -261,6 +270,7 @@ const request = new IdentityAuthStructs.InboundKeysForUserRequest(); request.setUserId(userID); + request.setSelectedDevicesList([...selectedDeviceIDs]); const response = await client.getInboundKeysForUser(request); const devicesMap = response.toObject()?.devicesMap;