diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.h b/native/cpp/CommonCpp/NativeModules/CommRustModule.h --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.h +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.h @@ -65,6 +65,7 @@ jsi::String identifierType, jsi::String identifierValue, jsi::String deviceID) override; + virtual jsi::Value versionSupported(jsi::Runtime &rt) override; public: CommRustModule(std::shared_ptr jsInvoker); diff --git a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommRustModule.cpp @@ -247,4 +247,18 @@ }); } +jsi::Value CommRustModule::versionSupported(jsi::Runtime &rt) { + return createPromiseAsJSIValue( + rt, [this](jsi::Runtime &innerRt, std::shared_ptr promise) { + std::string error; + try { + auto currentID = RustPromiseManager::instance.addPromise( + promise, this->jsInvoker_, innerRt); + identityVersionSupported(currentID); + } catch (const std::exception &e) { + error = e.what(); + }; + }); +} + } // namespace comm diff --git a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp --- a/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp +++ b/native/cpp/CommonCpp/_generated/rustJSI-generated.cpp @@ -33,6 +33,9 @@ static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUserDevice(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { return static_cast(&turboModule)->getOutboundKeysForUserDevice(rt, args[0].asString(rt), args[1].asString(rt), args[2].asString(rt)); } +static jsi::Value __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported(jsi::Runtime &rt, TurboModule &turboModule, const jsi::Value* args, size_t count) { + return static_cast(&turboModule)->versionSupported(rt); +} CommRustModuleSchemaCxxSpecJSI::CommRustModuleSchemaCxxSpecJSI(std::shared_ptr jsInvoker) : TurboModule("CommRustTurboModule", jsInvoker) { @@ -43,6 +46,7 @@ methodMap_["updatePassword"] = MethodMetadata {4, __hostFunction_CommRustModuleSchemaCxxSpecJSI_updatePassword}; methodMap_["deleteUser"] = MethodMetadata {3, __hostFunction_CommRustModuleSchemaCxxSpecJSI_deleteUser}; methodMap_["getOutboundKeysForUserDevice"] = MethodMetadata {3, __hostFunction_CommRustModuleSchemaCxxSpecJSI_getOutboundKeysForUserDevice}; + methodMap_["versionSupported"] = MethodMetadata {0, __hostFunction_CommRustModuleSchemaCxxSpecJSI_versionSupported}; } diff --git a/native/cpp/CommonCpp/_generated/rustJSI.h b/native/cpp/CommonCpp/_generated/rustJSI.h --- a/native/cpp/CommonCpp/_generated/rustJSI.h +++ b/native/cpp/CommonCpp/_generated/rustJSI.h @@ -27,6 +27,7 @@ virtual jsi::Value updatePassword(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken, jsi::String password) = 0; virtual jsi::Value deleteUser(jsi::Runtime &rt, jsi::String userID, jsi::String deviceID, jsi::String accessToken) = 0; virtual jsi::Value getOutboundKeysForUserDevice(jsi::Runtime &rt, jsi::String identifierType, jsi::String identifierValue, jsi::String deviceID) = 0; + virtual jsi::Value versionSupported(jsi::Runtime &rt) = 0; }; @@ -104,6 +105,14 @@ return bridging::callFromJs( rt, &T::getOutboundKeysForUserDevice, jsInvoker_, instance_, std::move(identifierType), std::move(identifierValue), std::move(deviceID)); } + jsi::Value versionSupported(jsi::Runtime &rt) override { + static_assert( + bridging::getParameterCount(&T::versionSupported) == 1, + "Expected versionSupported(...) to have 1 parameters"); + + return bridging::callFromJs( + rt, &T::versionSupported, jsInvoker_, instance_); + } private: T *instance_; diff --git a/native/native_rust_library/RustCallback.h b/native/native_rust_library/RustCallback.h --- a/native/native_rust_library/RustCallback.h +++ b/native/native_rust_library/RustCallback.h @@ -6,5 +6,6 @@ void stringCallback(rust::String error, uint32_t promiseID, rust::String ret); void voidCallback(rust::String error, uint32_t promiseID); +void boolCallback(rust::String error, uint32_t promiseID, bool ret); } // namespace comm diff --git a/native/native_rust_library/RustCallback.cpp b/native/native_rust_library/RustCallback.cpp --- a/native/native_rust_library/RustCallback.cpp +++ b/native/native_rust_library/RustCallback.cpp @@ -36,4 +36,19 @@ } } +void boolCallback(rust::String error, uint32_t promiseID, bool ret) { + auto it = RustPromiseManager::instance.promises.find(promiseID); + if (it == RustPromiseManager::instance.promises.end()) { + return; + } + + if (error.size()) { + RustPromiseManager::instance.rejectPromise(promiseID, std::string(error)); + } else { + folly::dynamic retDyn; + retDyn = ret; + RustPromiseManager::instance.resolvePromise(promiseID, retDyn); + } +} + } // namespace comm diff --git a/native/native_rust_library/src/lib.rs b/native/native_rust_library/src/lib.rs --- a/native/native_rust_library/src/lib.rs +++ b/native/native_rust_library/src/lib.rs @@ -1,4 +1,4 @@ -use crate::ffi::{string_callback, void_callback}; +use crate::ffi::{bool_callback, string_callback, void_callback}; use comm_opaque2::client::{Login, Registration}; use comm_opaque2::grpc::opaque_error_to_grpc_status as handle_error; use grpc_clients::identity::get_unauthenticated_client; @@ -122,6 +122,9 @@ #[cxx_name = "identityGenerateNonce"] fn generate_nonce(promise_id: u32); + #[cxx_name = "identityVersionSupported"] + fn version_supported(promise_id: u32); + // Argon2 fn compute_backup_key(password: &str, backup_id: &str) -> Result<[u8; 32]>; } @@ -135,6 +138,10 @@ #[namespace = "comm"] #[cxx_name = "voidCallback"] fn void_callback(error: String, promise_id: u32); + + #[namespace = "comm"] + #[cxx_name = "boolCallback"] + fn bool_callback(error: String, promise_id: u32, ret: bool); } } @@ -160,6 +167,16 @@ } } +fn handle_bool_result_as_callback(result: Result, promise_id: u32) +where + E: std::fmt::Display, +{ + match result { + Err(e) => bool_callback(e.to_string(), promise_id, false), + Ok(r) => bool_callback("".to_string(), promise_id, r), + } +} + fn generate_nonce(promise_id: u32) { RUNTIME.spawn(async move { let result = fetch_nonce().await; @@ -182,6 +199,33 @@ Ok(nonce) } +fn version_supported(promise_id: u32) { + RUNTIME.spawn(async move { + let result = version_supported_helper().await; + handle_bool_result_as_callback(result, promise_id); + }); +} + +async fn version_supported_helper() -> Result { + let mut identity_client = get_unauthenticated_client( + "http://127.0.0.1:50054", + CODE_VERSION, + DEVICE_TYPE.as_str_name().to_lowercase(), + ) + .await?; + let response = identity_client.ping(Empty {}).await; + match response { + Ok(_) => Ok(true), + Err(e) => { + if grpc_clients::error::is_version_unsupported(&e) { + Ok(false) + } else { + Err(e.into()) + } + } + } +} + struct AuthInfo { user_id: String, device_id: String, diff --git a/native/schema/CommRustModuleSchema.js b/native/schema/CommRustModuleSchema.js --- a/native/schema/CommRustModuleSchema.js +++ b/native/schema/CommRustModuleSchema.js @@ -60,6 +60,7 @@ identifierValue: string, deviceID: string, ) => Promise; + +versionSupported: () => Promise; } export default (TurboModuleRegistry.getEnforcing( diff --git a/shared/grpc_clients/src/error.rs b/shared/grpc_clients/src/error.rs --- a/shared/grpc_clients/src/error.rs +++ b/shared/grpc_clients/src/error.rs @@ -18,7 +18,7 @@ Status::unimplemented("Unsupported version") } -pub fn is_version_unsupported(status: Status) -> bool { +pub fn is_version_unsupported(status: &Status) -> bool { status.code() == Code::Unimplemented && status.message() == "Unsupported version" }