diff --git a/lib/types/message-types.js b/lib/types/message-types.js --- a/lib/types/message-types.js +++ b/lib/types/message-types.js @@ -452,6 +452,18 @@ +payload: ClientDBMessageInfo, }; +export type ClientDBThreadMessageInfo = { + +id: string, + +start_reached: string, + +last_navigated_to: string, + +last_pruned: string, +}; + +export type ClientDBReplaceThreadsOperation = { + +type: 'replace_threads', + +payload: { +threads: $ReadOnlyArray }, +}; + export type MessageStoreOperation = | RemoveMessageOperation | ReplaceMessageOperation @@ -467,7 +479,10 @@ | ClientDBReplaceMessageOperation | RekeyMessageOperation | RemoveMessagesForThreadsOperation - | RemoveAllMessagesOperation; + | RemoveAllMessagesOperation + | ClientDBReplaceThreadsOperation + | RemoveMessageStoreThreadsOperation + | RemoveMessageStoreAllThreadsOperation; export const messageTruncationStatus = Object.freeze({ // EXHAUSTIVE means we've reached the start of the thread. Either the result diff --git a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp --- a/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp +++ b/native/cpp/CommonCpp/NativeModules/CommCoreModule.cpp @@ -436,6 +436,10 @@ "remove_messages_for_threads"; const std::string REMOVE_ALL_OPERATION = "remove_all"; +const std::string REPLACE_MESSAGE_THREADS_OPERATION = "replace_threads"; +const std::string REMOVE_MESSAGE_THREADS_OPERATION = "remove_threads"; +const std::string REMOVE_ALL_MESSAGE_THREADS_OPERATION = "remove_all_threads"; + std::vector> createMessageStoreOperations(jsi::Runtime &rt, const jsi::Array &operations) { @@ -449,6 +453,11 @@ messageStoreOps.push_back(std::make_unique()); continue; } + if (op_type == REMOVE_ALL_MESSAGE_THREADS_OPERATION) { + messageStoreOps.push_back( + std::make_unique()); + continue; + } auto payload_obj = op.getProperty(rt, "payload").asObject(rt); if (op_type == REMOVE_OPERATION) { @@ -467,6 +476,13 @@ messageStoreOps.push_back( std::make_unique(rt, payload_obj)); + } else if (op_type == REPLACE_MESSAGE_THREADS_OPERATION) { + messageStoreOps.push_back( + std::make_unique(rt, payload_obj)); + } else if (op_type == REMOVE_MESSAGE_THREADS_OPERATION) { + messageStoreOps.push_back( + std::make_unique( + rt, payload_obj)); } else { throw std::runtime_error("unsupported operation: " + op_type); } diff --git a/native/cpp/CommonCpp/NativeModules/MessageStoreOperations.h b/native/cpp/CommonCpp/NativeModules/MessageStoreOperations.h --- a/native/cpp/CommonCpp/NativeModules/MessageStoreOperations.h +++ b/native/cpp/CommonCpp/NativeModules/MessageStoreOperations.h @@ -159,4 +159,64 @@ } }; +class ReplaceMessageThreadsOperation : public MessageStoreOperationBase { +public: + ReplaceMessageThreadsOperation(jsi::Runtime &rt, const jsi::Object &payload) + : msg_threads{} { + auto threads = payload.getProperty(rt, "threads").asObject(rt).asArray(rt); + for (size_t idx = 0; idx < threads.size(rt); idx++) { + auto thread = threads.getValueAtIndex(rt, idx).asObject(rt); + + auto thread_id = thread.getProperty(rt, "id").asString(rt).utf8(rt); + auto start_reached = std::stoi( + thread.getProperty(rt, "start_reached").asString(rt).utf8(rt)); + auto last_navigated_to = std::stoll( + thread.getProperty(rt, "last_navigated_to").asString(rt).utf8(rt)); + auto last_pruned = std::stoll( + thread.getProperty(rt, "last_pruned").asString(rt).utf8(rt)); + + MessageStoreThread msg_thread = MessageStoreThread{ + thread_id, start_reached, last_navigated_to, last_pruned}; + this->msg_threads.push_back(msg_thread); + } + } + + virtual void execute() override { + DatabaseManager::getQueryExecutor().replaceMessageStoreThreads( + this->msg_threads); + } + +private: + std::vector msg_threads; +}; + +class RemoveAllMessageStoreThreadsOperation : public MessageStoreOperationBase { +public: + virtual void execute() override { + DatabaseManager::getQueryExecutor().removeAllMessageStoreThreads(); + } +}; + +class RemoveMessageStoreThreadsOperation : public MessageStoreOperationBase { +public: + RemoveMessageStoreThreadsOperation( + jsi::Runtime &rt, + const jsi::Object &payload) + : thread_ids{} { + auto payload_ids = payload.getProperty(rt, "ids").asObject(rt).asArray(rt); + for (size_t idx = 0; idx < payload_ids.size(rt); idx++) { + this->thread_ids.push_back( + payload_ids.getValueAtIndex(rt, idx).asString(rt).utf8(rt)); + } + } + + virtual void execute() override { + DatabaseManager::getQueryExecutor().removeMessageStoreThreads( + this->thread_ids); + } + +private: + std::vector thread_ids; +}; + } // namespace comm