diff --git a/keyserver/src/database/migration-config.js b/keyserver/src/database/migration-config.js --- a/keyserver/src/database/migration-config.js +++ b/keyserver/src/database/migration-config.js @@ -63,6 +63,17 @@ ); }, ], + [ + 8, + async () => { + await dbQuery( + SQL` + ALTER TABLE users + ADD COLUMN IF NOT EXISTS ethereum_address char(42) DEFAULT NULL; + `, + ); + }, + ], ]); const newDatabaseVersion: number = Math.max(...migrations.keys()); diff --git a/keyserver/src/database/setup-db.js b/keyserver/src/database/setup-db.js --- a/keyserver/src/database/setup-db.js +++ b/keyserver/src/database/setup-db.js @@ -185,6 +185,7 @@ username varchar(${usernameMaxLength}) COLLATE utf8mb4_bin NOT NULL, hash char(60) COLLATE utf8mb4_bin DEFAULT NULL, avatar varchar(191) COLLATE utf8mb4_bin DEFAULT NULL, + ethereum_address char(42) DEFAULT NULL, creation_time bigint(20) NOT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; diff --git a/keyserver/src/push/rescind.js b/keyserver/src/push/rescind.js --- a/keyserver/src/push/rescind.js +++ b/keyserver/src/push/rescind.js @@ -170,7 +170,6 @@ ): apn.Notification { const notification = new apn.Notification(); notification.contentAvailable = true; - notification.badge = unreadCount; notification.topic = getAPNsNotificationTopic(codeVersion); notification.priority = 5; notification.pushType = 'background'; diff --git a/landing/team.react.js b/landing/team.react.js --- a/landing/team.react.js +++ b/landing/team.react.js @@ -140,6 +140,11 @@ githubHandle="pweglik" imageURL={`${assetsCacheURLPrefix}/przemek.jpg`} /> + ); diff --git a/native/account/logged-out-modal.react.js b/native/account/logged-out-modal.react.js --- a/native/account/logged-out-modal.react.js +++ b/native/account/logged-out-modal.react.js @@ -2,7 +2,6 @@ import _isEqual from 'lodash/fp/isEqual'; import * as React from 'react'; -import { useContext } from 'react'; import { View, StyleSheet, @@ -42,14 +41,12 @@ derivedDimensionsInfoSelector, } from '../selectors/dimensions-selectors'; import { splashStyleSelector } from '../splash'; -import { StaffContext } from '../staff/staff-context'; import type { EventSubscription, KeyboardEvent } from '../types/react-native'; import type { ImageStyle } from '../types/styles'; import { runTiming, ratchetAlongWithKeyboardHeight, } from '../utils/animation-utils'; -import { useStaffCanSee } from '../utils/staff-utils'; import { type StateContainer, type StateChange, @@ -58,6 +55,7 @@ import { splashBackgroundURI } from './background-info'; import LogInPanel from './log-in-panel.react'; import type { LogInState } from './log-in-panel.react'; +import LoggedOutStaffInfo from './logged-out-staff-info.react'; import RegisterPanel from './register-panel.react'; import type { RegisterState } from './register-panel.react'; import SIWEPanel from './siwe-panel.react'; @@ -117,8 +115,6 @@ +splashStyle: ImageStyle, // Redux dispatch functions +dispatch: Dispatch, - +staffUserHasBeenLoggedIn: boolean, - +staffCanSee: boolean, ... }; type State = { @@ -468,14 +464,6 @@ ); } - let staffUserHasBeenLoggedInIndicator = null; - if (this.props.staffCanSee && this.props.staffUserHasBeenLoggedIn) { - staffUserHasBeenLoggedInIndicator = ( - - STAFF HAS BEEN LOGGED IN - - ); - } if (this.state.mode === 'siwe') { panel = ( @@ -504,7 +492,7 @@ const opacityStyle = { opacity: this.buttonOpacity }; buttons = ( - {staffUserHasBeenLoggedInIndicator} + {siweButton} ); }); diff --git a/native/account/logged-out-staff-info.react.js b/native/account/logged-out-staff-info.react.js new file mode 100644 --- /dev/null +++ b/native/account/logged-out-staff-info.react.js @@ -0,0 +1,121 @@ +// @flow + +import * as React from 'react'; +import { Text, View } from 'react-native'; + +import SWMansionIcon from '../components/swmansion-icon.react'; +import { StaffContext } from '../staff/staff-context'; +import { useStyles, useColors } from '../themes/colors'; +import { isStaffRelease, useStaffCanSee } from '../utils/staff-utils'; + +function LoggedOutStaffInfo(): React.Node { + const staffCanSee = useStaffCanSee(); + const { staffUserHasBeenLoggedIn } = React.useContext(StaffContext); + const styles = useStyles(unboundStyles); + const colors = useColors(); + + const checkIcon = React.useMemo( + () => ( + + ), + [colors.vibrantGreenButton], + ); + const crossIcon = React.useMemo( + () => ( + + ), + [colors.vibrantRedButton], + ); + + const isDevBuildStyle = React.useMemo(() => { + return [ + styles.infoText, + __DEV__ ? styles.infoTextTrue : styles.infoTextFalse, + ]; + }, [styles.infoText, styles.infoTextFalse, styles.infoTextTrue]); + + const isStaffReleaseStyle = React.useMemo(() => { + return [ + styles.infoText, + isStaffRelease ? styles.infoTextTrue : styles.infoTextFalse, + ]; + }, [styles.infoText, styles.infoTextFalse, styles.infoTextTrue]); + + const hasStaffUserLoggedInStyle = React.useMemo(() => { + return [ + styles.infoText, + staffUserHasBeenLoggedIn ? styles.infoTextTrue : styles.infoTextFalse, + ]; + }, [ + staffUserHasBeenLoggedIn, + styles.infoText, + styles.infoTextFalse, + styles.infoTextTrue, + ]); + + let loggedOutStaffInfo = null; + if (staffCanSee || staffUserHasBeenLoggedIn) { + loggedOutStaffInfo = ( + + + + {__DEV__ ? checkIcon : crossIcon} + __DEV__ + + + {isStaffRelease ? checkIcon : crossIcon} + isStaffRelease + + + {staffUserHasBeenLoggedIn ? checkIcon : crossIcon} + + staffUserHasBeenLoggedIn + + + + + ); + } + + return loggedOutStaffInfo; +} + +const unboundStyles = { + cell: { + flexDirection: 'row', + alignItems: 'center', + }, + infoBadge: { + backgroundColor: 'codeBackground', + borderRadius: 6, + justifyContent: 'flex-start', + marginBottom: 10, + marginLeft: 40, + marginRight: 40, + marginTop: 10, + padding: 8, + }, + infoText: { + fontFamily: 'OpenSans-Semibold', + fontSize: 14, + lineHeight: 24, + paddingLeft: 4, + textAlign: 'left', + }, + infoTextFalse: { + color: 'vibrantRedButton', + }, + infoTextTrue: { + color: 'vibrantGreenButton', + }, +}; + +export default LoggedOutStaffInfo; diff --git a/native/chat/settings/color-selector-modal.react.js b/native/chat/settings/color-selector-modal.react.js --- a/native/chat/settings/color-selector-modal.react.js +++ b/native/chat/settings/color-selector-modal.react.js @@ -151,8 +151,6 @@ }, colorSelectorContainer: { backgroundColor: 'modalBackground', - borderColor: 'modalForegroundBorder', - borderRadius: 5, borderWidth: 2, flex: 0, marginHorizontal: 15, diff --git a/native/components/modal.react.js b/native/components/modal.react.js --- a/native/components/modal.react.js +++ b/native/components/modal.react.js @@ -48,9 +48,8 @@ }, modal: { backgroundColor: 'modalBackground', - shadowColor: 'gray', - shadowOpacity: 100, - shadowRadius: 6, + borderColor: 'modalForegroundBorder', + borderWidth: 2, borderRadius: 5, flex: 1, justifyContent: 'center', diff --git a/native/data/sqlite-context-provider.js b/native/data/sqlite-context-provider.js --- a/native/data/sqlite-context-provider.js +++ b/native/data/sqlite-context-provider.js @@ -7,6 +7,7 @@ import { setMessageStoreMessages } from 'lib/actions/message-actions.js'; import { setThreadStoreActionType } from 'lib/actions/thread-actions'; +import { isLoggedIn } from 'lib/selectors/user-selectors'; import { logInActionSources } from 'lib/types/account-types'; import { fetchNewCookieFromNativeCredentials } from 'lib/utils/action-utils'; import { getMessageForException } from 'lib/utils/errors'; @@ -31,14 +32,22 @@ const cookie = useSelector(state => state.cookie); const urlPrefix = useSelector(state => state.urlPrefix); const staffCanSee = useStaffCanSee(); + const loggedIn = useSelector(isLoggedIn); React.useEffect(() => { if (storeLoaded || !rehydrateConcluded) { return; } + if (!loggedIn) { + setStoreLoaded(true); + return; + } (async () => { try { - const threads = await commCoreModule.getAllThreads(); + const [threads, messages] = await Promise.all([ + commCoreModule.getAllThreads(), + commCoreModule.getAllMessages(), + ]); const threadInfosFromDB = convertClientDBThreadInfosToRawThreadInfos( threads, ); @@ -46,7 +55,6 @@ type: setThreadStoreActionType, payload: { threadInfos: threadInfosFromDB }, }); - const messages = await commCoreModule.getAllMessages(); dispatch({ type: setMessageStoreMessages, payload: messages, @@ -84,6 +92,7 @@ } })(); }, [ + loggedIn, cookie, dispatch, rehydrateConcluded, diff --git a/native/flow-typed/npm/react-native-reanimated_v2.x.x.js b/native/flow-typed/npm/react-native-reanimated_v2.x.x.js --- a/native/flow-typed/npm/react-native-reanimated_v2.x.x.js +++ b/native/flow-typed/npm/react-native-reanimated_v2.x.x.js @@ -14,6 +14,26 @@ */ declare module 'react-native-reanimated' { + // This was taken from the flow typed library definitions of bottom-tabs_v6 + declare type StyleObj = + | null + | void + | number + | false + | '' + | $ReadOnlyArray + | { [name: string]: any, ... }; + + declare type ViewStyleProp = StyleObj; + declare type TextStyleProp = StyleObj; + + declare type StyleProps = {| + ...ViewStyleProp, + ...TextStyleProp, + +originX?: number, + +originY?: number, + +[key: string]: any, + |}; declare class Node { } @@ -123,6 +143,7 @@ ... }; declare export var EasingNode: EasingModule; + declare type EasingFn = (t: number) => number; declare export type TimingState = { +finished: Value, @@ -213,6 +234,165 @@ ): Node, |}; + declare type LayoutAnimation = {| + +initialValues: StyleProps, + +animations: StyleProps, + +callback?: (finished: boolean) => void, + |}; + + declare type AnimationFunction = (a?: any, b?: any, c?: any) => any; + + declare type EntryAnimationsValues = {| + +targetOriginX: number, + +targetOriginY: number, + +targetWidth: number, + +targetHeight: number, + +targetGlobalOriginX: number, + +targetGlobalOriginY: number, + |}; + + declare type ExitAnimationsValues = {| + +currentOriginX: number, + +currentOriginY: number, + +currentWidth: number, + +currentHeight: number, + +currentGlobalOriginX: number, + +currentGlobalOriginY: number, + |}; + + declare export type EntryExitAnimationFunction = ( + targetValues: EntryAnimationsValues | ExitAnimationsValues, + ) => LayoutAnimation; + + declare type AnimationConfigFunction = ( + targetValues: T, + ) => LayoutAnimation; + + declare type LayoutAnimationsValues = {| + +currentOriginX: number, + +currentOriginY: number, + +currentWidth: number, + +currentHeight: number, + +currentGlobalOriginX: number, + +currentGlobalOriginY: number, + +targetOriginX: number, + +targetOriginY: number, + +targetWidth: number, + +targetHeight: number, + +targetGlobalOriginX: number, + +argetGlobalOriginY: number, + +windowWidth: number, + +windowHeight: number, + |}; + + declare type LayoutAnimationFunction = ( + targetValues: LayoutAnimationsValues, + ) => LayoutAnimation; + + declare type BaseLayoutAnimationConfig = {| + +duration?: number, + +easing?: EasingFn, + +type?: AnimationFunction, + +damping?: number, + +mass?: number, + +stiffness?: number, + +overshootClamping?: number, + +restDisplacementThreshold?: number, + +restSpeedThreshold?: number, + |}; + + declare type BaseBuilderAnimationConfig = {| + ...BaseLayoutAnimationConfig, + rotate?: number | string, + |}; + + declare type LayoutAnimationAndConfig = [ + AnimationFunction, + BaseBuilderAnimationConfig, + ]; + + declare export class BaseAnimationBuilder { + static duration(durationMs: number): BaseAnimationBuilder; + duration(durationMs: number): BaseAnimationBuilder; + + static delay(delayMs: number): BaseAnimationBuilder; + delay(delayMs: number): BaseAnimationBuilder; + + static withCallback( + callback: (finished: boolean) => void, + ): BaseAnimationBuilder; + withCallback(callback: (finished: boolean) => void): BaseAnimationBuilder; + + static getDuration(): number; + getDuration(): number; + + static randomDelay(): BaseAnimationBuilder; + randomDelay(): BaseAnimationBuilder; + + getDelay(): number; + getDelayFunction(): AnimationFunction; + + static build(): EntryExitAnimationFunction | LayoutAnimationFunction; + } + + declare export type ReanimatedAnimationBuilder = + | Class + | BaseAnimationBuilder; + + declare export class ComplexAnimationBuilder extends BaseAnimationBuilder { + static easing(easingFunction: EasingFn): ComplexAnimationBuilder; + easing(easingFunction: EasingFn): ComplexAnimationBuilder; + + static rotate(degree: string): ComplexAnimationBuilder; + rotate(degree: string): ComplexAnimationBuilder; + + static springify(): ComplexAnimationBuilder; + springify(): ComplexAnimationBuilder; + + static damping(damping: number): ComplexAnimationBuilder; + damping(damping: number): ComplexAnimationBuilder; + + static mass(mass: number): ComplexAnimationBuilder; + mass(mass: number): ComplexAnimationBuilder; + + static stiffness(stiffness: number): ComplexAnimationBuilder; + stiffness(stiffness: number): ComplexAnimationBuilder; + + static overshootClamping( + overshootClamping: number, + ): ComplexAnimationBuilder; + overshootClamping(overshootClamping: number): ComplexAnimationBuilder; + + static restDisplacementThreshold( + restDisplacementThreshold: number, + ): ComplexAnimationBuilder; + restDisplacementThreshold( + restDisplacementThreshold: number, + ): ComplexAnimationBuilder; + + static restSpeedThreshold( + restSpeedThreshold: number, + ): ComplexAnimationBuilder; + restSpeedThreshold(restSpeedThreshold: number): ComplexAnimationBuilder; + + static withInitialValues(values: StyleProps): BaseAnimationBuilder; + withInitialValues(values: StyleProps): BaseAnimationBuilder; + + getAnimationAndConfig(): LayoutAnimationAndConfig; + } + + declare export class SlideInDown extends ComplexAnimationBuilder { + static createInstance(): SlideInDown; + + build(): AnimationConfigFunction; + } + + declare export class SlideOutDown extends ComplexAnimationBuilder { + static createInstance(): SlideOutDown; + + build(): AnimationConfigFunction; + } + declare type $SyntheticEvent = { +nativeEvent: $ReadOnly<$Exact>, ... diff --git a/native/profile/build-info.react.js b/native/profile/build-info.react.js --- a/native/profile/build-info.react.js +++ b/native/profile/build-info.react.js @@ -2,13 +2,55 @@ import * as React from 'react'; import { View, Text, ScrollView } from 'react-native'; +import { useSelector } from 'react-redux'; + +import { isStaff } from 'lib/shared/user-utils'; import { persistConfig, codeVersion } from '../redux/persist'; +import { StaffContext } from '../staff/staff-context'; import { useStyles } from '../themes/colors'; +import { isStaffRelease, useStaffCanSee } from '../utils/staff-utils'; // eslint-disable-next-line no-unused-vars function BuildInfo(props: { ... }): React.Node { + const isCurrentUserStaff = useSelector( + state => + state.currentUserInfo && + state.currentUserInfo.id && + isStaff(state.currentUserInfo.id), + ); + const { staffUserHasBeenLoggedIn } = React.useContext(StaffContext); const styles = useStyles(unboundStyles); + const staffCanSee = useStaffCanSee(); + + let staffCanSeeRows; + if (staffCanSee || staffUserHasBeenLoggedIn) { + staffCanSeeRows = ( + <> + + __DEV__ + {__DEV__ ? 'TRUE' : 'FALSE'} + + + Staff Release + {isStaffRelease ? 'TRUE' : 'FALSE'} + + + isCurrentUserStaff + + {isCurrentUserStaff ? 'TRUE' : 'FALSE'} + + + + hasStaffUserLoggedIn + + {staffUserHasBeenLoggedIn ? 'TRUE' : 'FALSE'} + + + + ); + } + return ( State version {persistConfig.version} + {staffCanSeeRows} diff --git a/native/redux/persist.js b/native/redux/persist.js --- a/native/redux/persist.js +++ b/native/redux/persist.js @@ -18,7 +18,6 @@ type MessageStore, messageTypes, type ClientDBMessageStoreOperation, - type RawMessageInfo, } from 'lib/types/message-types'; import { defaultConnectionInfo } from 'lib/types/socket-types'; import { translateRawMessageInfoToClientDBMessageInfo } from 'lib/utils/message-ops-utils'; @@ -394,10 +393,7 @@ +currentAsOf: number, +threads: { +[threadID: string]: PersistedThreadMessageInfo }, }; -type RehydratedMessageStore = $Diff< - MessageStore, - { +messages: { +[id: string]: RawMessageInfo } }, ->; + const messageStoreMessagesBlocklistTransform: Transform = createTransform( (state: MessageStore): PersistedMessageStore => { const { messages, threads, ...messageStoreSansMessages } = state; @@ -411,13 +407,13 @@ } return { ...messageStoreSansMessages, threads: threadsToPersist }; }, - (state: PersistedMessageStore): RehydratedMessageStore => { + (state: PersistedMessageStore): MessageStore => { const { threads: persistedThreads, ...messageStore } = state; const threads = {}; for (const threadID in persistedThreads) { threads[threadID] = { ...persistedThreads[threadID], messageIDs: [] }; } - return { ...messageStore, threads }; + return { ...messageStore, threads, messages: {} }; }, { whitelist: ['messageStore'] }, ); diff --git a/native/redux/redux-setup.js b/native/redux/redux-setup.js --- a/native/redux/redux-setup.js +++ b/native/redux/redux-setup.js @@ -19,10 +19,12 @@ invalidSessionDowngrade, invalidSessionRecovery, } from 'lib/shared/account-utils'; +import { isStaff } from 'lib/shared/user-utils'; import { logInActionSources } from 'lib/types/account-types'; import { defaultEnabledApps } from 'lib/types/enabled-apps'; import { defaultCalendarFilters } from 'lib/types/filter-types'; import type { Dispatch, BaseAction } from 'lib/types/redux-types'; +import { rehydrateActionType } from 'lib/types/redux-types'; import type { SetSessionPayload } from 'lib/types/session-types'; import { defaultConnectionInfo, @@ -45,6 +47,7 @@ import { defaultDeviceCameraInfo } from '../types/camera'; import { defaultConnectivityInfo } from '../types/connectivity'; import { defaultGlobalThemeInfo } from '../types/themes'; +import { isStaffRelease } from '../utils/staff-utils'; import { defaultURLPrefix, natNodeServer, @@ -130,6 +133,53 @@ return action.payload.state; } + // We want to alert staff/developers if there's a difference between the keys + // we expect to see REHYDRATED and the keys that are actually REHYDRATED. + // Context: https://linear.app/comm/issue/ENG-2127/ + if ( + action.type === rehydrateActionType && + (__DEV__ || + isStaffRelease || + (state.currentUserInfo && + state.currentUserInfo.id && + isStaff(state.currentUserInfo.id))) + ) { + // 1. Construct set of keys expected to be REHYDRATED + const defaultKeys = Object.keys(defaultState); + const expectedKeys = defaultKeys.filter( + each => !persistConfig.blacklist.includes(each), + ); + const expectedKeysSet = new Set(expectedKeys); + + // 2. Construct set of keys actually REHYDRATED + const rehydratedKeys = Object.keys(action.payload ?? {}); + const rehydratedKeysSet = new Set(rehydratedKeys); + + // 3. Determine the difference between the two sets + const expectedKeysNotRehydrated = expectedKeys.filter( + each => !rehydratedKeysSet.has(each), + ); + const rehydratedKeysNotExpected = rehydratedKeys.filter( + each => !expectedKeysSet.has(each), + ); + + // 4. Display alerts with the differences between the two sets + if (expectedKeysNotRehydrated.length > 0) { + Alert.alert( + `EXPECTED KEYS NOT REHYDRATED: ${JSON.stringify( + expectedKeysNotRehydrated, + )}`, + ); + } + if (rehydratedKeysNotExpected.length > 0) { + Alert.alert( + `REHYDRATED KEYS NOT EXPECTED: ${JSON.stringify( + rehydratedKeysNotExpected, + )}`, + ); + } + } + if ( (action.type === setNewSessionActionType && invalidSessionDowngrade( diff --git a/native/utils/staff-utils.js b/native/utils/staff-utils.js --- a/native/utils/staff-utils.js +++ b/native/utils/staff-utils.js @@ -16,4 +16,4 @@ return __DEV__ || isStaffRelease || isCurrentUserStaff; } -export { useStaffCanSee }; +export { isStaffRelease, useStaffCanSee }; diff --git a/scripts/get_clang_paths.js b/scripts/get_clang_paths.js --- a/scripts/get_clang_paths.js +++ b/scripts/get_clang_paths.js @@ -6,6 +6,10 @@ extensions: ['h', 'cpp'], excludes: ['_generated'], }, + { + path: 'services/lib/src', + extensions: ['cpp', 'h'], + }, { path: 'services/tunnelbroker/src', extensions: ['cpp', 'h'], diff --git a/services/identity/src/constants.rs b/services/identity/src/constants.rs --- a/services/identity/src/constants.rs +++ b/services/identity/src/constants.rs @@ -11,6 +11,7 @@ pub const USERS_TABLE_REGISTRATION_ATTRIBUTE: &str = "pakeRegistrationData"; pub const USERS_TABLE_USERNAME_ATTRIBUTE: &str = "username"; pub const USERS_TABLE_USER_PUBLIC_KEY_ATTRIBUTE: &str = "userPublicKey"; +pub const USERS_TABLE_DEVICE_ATTRIBUTE: &str = "device"; pub const USERS_TABLE_DEVICES_ATTRIBUTE: &str = "devices"; pub const USERS_TABLE_DEVICES_MAP_ATTRIBUTE_NAME: &str = "deviceID"; pub const USERS_TABLE_WALLET_ADDRESS_ATTRIBUTE: &str = "walletAddress"; diff --git a/services/lib/src/BaseReactor.h b/services/lib/src/BaseReactor.h --- a/services/lib/src/BaseReactor.h +++ b/services/lib/src/BaseReactor.h @@ -22,7 +22,7 @@ virtual void terminate(const grpc::Status &status) = 0; // Validates current values of the reactor's fields. virtual void validate() = 0; - // Should be called when `OnDone` is called. gRPC calls `OnDone` when there + // Should be called when `OnDone` is called. gRPC calls `OnDone` when there // are not going to be more rpc operations. virtual void doneCallback() = 0; // Should be called when `terminate` is called. diff --git a/services/lib/src/GlobalTools.cpp b/services/lib/src/GlobalTools.cpp --- a/services/lib/src/GlobalTools.cpp +++ b/services/lib/src/GlobalTools.cpp @@ -6,12 +6,12 @@ #include #include +#include #include #include #include #include #include -#include namespace comm { namespace network { diff --git a/services/lib/src/client-base-reactors/ClientBidiReactorBase.h b/services/lib/src/client-base-reactors/ClientBidiReactorBase.h --- a/services/lib/src/client-base-reactors/ClientBidiReactorBase.h +++ b/services/lib/src/client-base-reactors/ClientBidiReactorBase.h @@ -2,8 +2,8 @@ #include "BaseReactor.h" -#include #include +#include namespace comm { namespace network { diff --git a/services/lib/src/client-base-reactors/ClientReadReactorBase.h b/services/lib/src/client-base-reactors/ClientReadReactorBase.h --- a/services/lib/src/client-base-reactors/ClientReadReactorBase.h +++ b/services/lib/src/client-base-reactors/ClientReadReactorBase.h @@ -2,8 +2,8 @@ #include "BaseReactor.h" -#include #include +#include namespace comm { namespace network { diff --git a/services/lib/src/client-base-reactors/ClientWriteReactorBase.h b/services/lib/src/client-base-reactors/ClientWriteReactorBase.h --- a/services/lib/src/client-base-reactors/ClientWriteReactorBase.h +++ b/services/lib/src/client-base-reactors/ClientWriteReactorBase.h @@ -2,8 +2,8 @@ #include "BaseReactor.h" -#include #include +#include namespace comm { namespace network { diff --git a/services/lib/src/server-base-reactors/ServerBidiReactorBase.h b/services/lib/src/server-base-reactors/ServerBidiReactorBase.h --- a/services/lib/src/server-base-reactors/ServerBidiReactorBase.h +++ b/services/lib/src/server-base-reactors/ServerBidiReactorBase.h @@ -34,9 +34,15 @@ public BaseReactor { std::shared_ptr statusHolder = std::make_shared(); + + std::atomic ongoingPoolTaskCounter{0}; + Request request; Response response; + void beginPoolTask(); + void finishPoolTask(); + protected: ServerBidiReactorStatus status; bool readingAborted = false; @@ -86,17 +92,20 @@ template void ServerBidiReactorBase::OnDone() { - this->statusHolder->state = ReactorState::DONE; - this->doneCallback(); - // This looks weird but apparently it is okay to do this. More information: - // https://phabricator.ashoat.com/D3246#87890 - delete this; + this->beginPoolTask(); + ThreadPool::getInstance().scheduleWithCallback( + [this]() { + this->statusHolder->state = ReactorState::DONE; + this->doneCallback(); + }, + [this](std::unique_ptr err) { this->finishPoolTask(); }); } template void ServerBidiReactorBase::terminate( ServerBidiReactorStatus status) { this->setStatus(status); + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { this->terminateCallback(); @@ -108,6 +117,7 @@ grpc::Status(grpc::StatusCode::INTERNAL, std::string(*err)))); } if (this->statusHolder->state != ReactorState::RUNNING) { + this->finishPoolTask(); return; } if (this->getStatus().sendLastResponse) { @@ -117,6 +127,7 @@ this->Finish(this->getStatus().status); } this->statusHolder->state = ReactorState::TERMINATED; + this->finishPoolTask(); }); } @@ -142,6 +153,7 @@ this->terminate(ServerBidiReactorStatus(grpc::Status::OK)); return; } + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { this->response = Response(); @@ -158,6 +170,7 @@ this->terminate(ServerBidiReactorStatus( grpc::Status(grpc::StatusCode::INTERNAL, *err))); } + this->finishPoolTask(); }); } @@ -177,6 +190,23 @@ return this->statusHolder; } +template +void ServerBidiReactorBase::beginPoolTask() { + this->ongoingPoolTaskCounter++; +} + +template +void ServerBidiReactorBase::finishPoolTask() { + this->ongoingPoolTaskCounter--; + if (!this->ongoingPoolTaskCounter.load() && + this->statusHolder->state == ReactorState::DONE) { + // This looks weird but apparently it is okay to do this. More + // information: + // https://phab.comm.dev/D3246#87890 + delete this; + } +} + } // namespace reactor } // namespace network } // namespace comm diff --git a/services/lib/src/server-base-reactors/ServerReadReactorBase.h b/services/lib/src/server-base-reactors/ServerReadReactorBase.h --- a/services/lib/src/server-base-reactors/ServerReadReactorBase.h +++ b/services/lib/src/server-base-reactors/ServerReadReactorBase.h @@ -24,8 +24,13 @@ public BaseReactor { std::shared_ptr statusHolder = std::make_shared(); + + std::atomic ongoingPoolTaskCounter{0}; Request request; + void beginPoolTask(); + void finishPoolTask(); + protected: Response *response; @@ -68,6 +73,7 @@ this->terminate(grpc::Status::OK); return; } + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { std::unique_ptr status = this->readRequest(this->request); @@ -81,6 +87,7 @@ if (err != nullptr) { this->terminate(grpc::Status(grpc::StatusCode::INTERNAL, *err)); } + this->finishPoolTask(); }); } @@ -88,7 +95,7 @@ void ServerReadReactorBase::terminate( const grpc::Status &status) { this->statusHolder->setStatus(status); - + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { this->terminateCallback(); @@ -102,21 +109,23 @@ if (!this->statusHolder->getStatus().ok()) { LOG(ERROR) << this->statusHolder->getStatus().error_message(); } - if (this->statusHolder->state != ReactorState::RUNNING) { - return; + if (this->statusHolder->state == ReactorState::RUNNING) { + this->Finish(this->statusHolder->getStatus()); + this->statusHolder->state = ReactorState::TERMINATED; } - this->Finish(this->statusHolder->getStatus()); - this->statusHolder->state = ReactorState::TERMINATED; + this->finishPoolTask(); }); } template void ServerReadReactorBase::OnDone() { - this->statusHolder->state = ReactorState::DONE; - this->doneCallback(); - // This looks weird but apparently it is okay to do this. More information: - // https://phabricator.ashoat.com/D3246#87890 - delete this; + this->beginPoolTask(); + ThreadPool::getInstance().scheduleWithCallback( + [this]() { + this->statusHolder->state = ReactorState::DONE; + this->doneCallback(); + }, + [this](std::unique_ptr err) { this->finishPoolTask(); }); } template @@ -125,6 +134,23 @@ return this->statusHolder; } +template +void ServerReadReactorBase::beginPoolTask() { + this->ongoingPoolTaskCounter++; +} + +template +void ServerReadReactorBase::finishPoolTask() { + this->ongoingPoolTaskCounter--; + if (!this->ongoingPoolTaskCounter.load() && + this->statusHolder->state == ReactorState::DONE) { + // This looks weird but apparently it is okay to do this. More + // information: + // https://phab.comm.dev/D3246#87890 + delete this; + } +} + } // namespace reactor } // namespace network } // namespace comm diff --git a/services/lib/src/server-base-reactors/ServerWriteReactorBase.h b/services/lib/src/server-base-reactors/ServerWriteReactorBase.h --- a/services/lib/src/server-base-reactors/ServerWriteReactorBase.h +++ b/services/lib/src/server-base-reactors/ServerWriteReactorBase.h @@ -24,10 +24,14 @@ public BaseReactor { std::shared_ptr statusHolder = std::make_shared(); + + std::atomic ongoingPoolTaskCounter{0}; Response response; bool initialized = false; void nextWrite(); + void beginPoolTask(); + void finishPoolTask(); protected: // this is a const ref since it's not meant to be modified @@ -64,6 +68,7 @@ void ServerWriteReactorBase::terminate( const grpc::Status &status) { this->statusHolder->setStatus(status); + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { this->terminateCallback(); @@ -77,11 +82,11 @@ if (!this->statusHolder->getStatus().ok()) { LOG(ERROR) << this->statusHolder->getStatus().error_message(); } - if (this->statusHolder->state != ReactorState::RUNNING) { - return; + if (this->statusHolder->state == ReactorState::RUNNING) { + this->Finish(this->statusHolder->getStatus()); + this->statusHolder->state = ReactorState::TERMINATED; } - this->Finish(this->statusHolder->getStatus()); - this->statusHolder->state = ReactorState::TERMINATED; + this->finishPoolTask(); }); } @@ -98,6 +103,7 @@ template void ServerWriteReactorBase::nextWrite() { + this->beginPoolTask(); ThreadPool::getInstance().scheduleWithCallback( [this]() { if (!this->initialized) { @@ -117,6 +123,7 @@ if (err != nullptr) { this->terminate(grpc::Status(grpc::StatusCode::INTERNAL, *err)); } + this->finishPoolTask(); }); } @@ -128,10 +135,10 @@ template void ServerWriteReactorBase::OnDone() { - this->doneCallback(); - // This looks weird but apparently it is okay to do this. More information: - // https://phabricator.ashoat.com/D3246#87890 - delete this; + this->beginPoolTask(); + ThreadPool::getInstance().scheduleWithCallback( + [this]() { this->doneCallback(); }, + [this](std::unique_ptr err) { this->finishPoolTask(); }); } template @@ -149,6 +156,23 @@ this->nextWrite(); } +template +void ServerWriteReactorBase::beginPoolTask() { + this->ongoingPoolTaskCounter++; +} + +template +void ServerWriteReactorBase::finishPoolTask() { + this->ongoingPoolTaskCounter--; + if (!this->ongoingPoolTaskCounter.load() && + this->statusHolder->state == ReactorState::DONE) { + // This looks weird but apparently it is okay to do this. More + // information: + // https://phab.comm.dev/D3246#87890 + delete this; + } +} + } // namespace reactor } // namespace network } // namespace comm