diff --git a/lib/components/keyserver-connection-handler.js b/lib/components/keyserver-connection-handler.js --- a/lib/components/keyserver-connection-handler.js +++ b/lib/components/keyserver-connection-handler.js @@ -10,11 +10,14 @@ useLogOut, } from '../actions/user-actions.js'; import { extractKeyserverIDFromID } from '../keyserver-conn/keyserver-call-utils.js'; +import { setSessionRecoveryInProgressActionType } from '../keyserver-conn/keyserver-conn-types.js'; +import { resolveKeyserverSessionInvalidation } from '../keyserver-conn/recovery-utils.js'; import { filterThreadIDsInFilterList } from '../reducers/calendar-filters-reducer.js'; import { connectionSelector, cookieSelector, deviceTokenSelector, + urlPrefixSelector, } from '../selectors/keyserver-selectors.js'; import { isLoggedInToKeyserver } from '../selectors/user-selectors.js'; import { IdentityClientContext } from '../shared/identity-client-context.js'; @@ -23,7 +26,7 @@ import { logInActionSources } from '../types/account-types.js'; import { authoritativeKeyserverID } from '../utils/authoritative-keyserver.js'; import { useDispatchActionPromise } from '../utils/redux-promise-utils.js'; -import { useSelector } from '../utils/redux-utils.js'; +import { useSelector, useDispatch } from '../utils/redux-utils.js'; import { usingCommServicesAccessToken } from '../utils/services-utils.js'; import sleep from '../utils/sleep.js'; @@ -191,12 +194,106 @@ olmSessionCreator, ]); + const sessionRecoveryInProgress = useSelector( + state => + state.keyserverStore.keyserverInfos[keyserverID]?.connection + .sessionRecoveryInProgress, + ); + + const dispatch = useDispatch(); + const urlPrefix = useSelector(urlPrefixSelector(keyserverID)); + const performRecovery = React.useCallback(() => { + invariant( + urlPrefix, + `urlPrefix for ${keyserverID} should be set during performRecovery`, + ); + + setAuthInProgress(true); + + let cancelled = false; + const cancel = () => { + cancelled = true; + setAuthInProgress(false); + }; + + const promise = (async () => { + try { + const sessionChange = await resolveKeyserverSessionInvalidation( + dispatch, + cookie, + urlPrefix, + logInActionSources.cookieInvalidationResolutionAttempt, + keyserverID, + ); + if (cancelled) { + // TODO: cancellation won't work because above call handles Redux + // dispatch directly + throw new Error(CANCELLED_ERROR); + } + if ( + !sessionChange || + sessionChange.cookieInvalidated || + !sessionChange.cookie || + !sessionChange.cookie.startsWith('user=') + ) { + dispatch({ + type: setSessionRecoveryInProgressActionType, + payload: { sessionRecoveryInProgress: false, keyserverID }, + }); + } + } catch (e) { + if (cancelled) { + return; + } + + console.log( + `Error while recovering session with keyserver id ${keyserverID}`, + e, + ); + + dispatch({ + type: setSessionRecoveryInProgressActionType, + payload: { sessionRecoveryInProgress: false, keyserverID }, + }); + } finally { + if (!cancelled) { + setAuthInProgress(false); + } + } + })(); + return [promise, cancel]; + }, [dispatch, cookie, urlPrefix, keyserverID]); + const cancelPendingAuth = React.useRef void>(null); const prevPerformAuth = React.useRef(performAuth); const isUserAuthenticated = useSelector(isLoggedInToKeyserver(keyserverID)); const hasAccessToken = useSelector(state => !!state.commServicesAccessToken); + const cancelPendingRecovery = React.useRef void>(null); + const prevPerformRecovery = React.useRef(performRecovery); + React.useEffect(() => { + if (sessionRecoveryInProgress && isUserAuthenticated) { + cancelPendingAuth.current?.(); + cancelPendingAuth.current = null; + + if (prevPerformRecovery.current !== performRecovery) { + cancelPendingRecovery.current?.(); + cancelPendingRecovery.current = null; + prevPerformRecovery.current = performRecovery; + } + + if (!authInProgress) { + const [, cancel] = performRecovery(); + cancelPendingRecovery.current = cancel; + } + + return; + } + + cancelPendingRecovery.current?.(); + cancelPendingRecovery.current = null; + if (!hasAccessToken) { cancelPendingAuth.current?.(); cancelPendingAuth.current = null; @@ -222,7 +319,14 @@ const [, cancel] = performAuth(); cancelPendingAuth.current = cancel; - }, [authInProgress, hasAccessToken, isUserAuthenticated, performAuth]); + }, [ + sessionRecoveryInProgress, + authInProgress, + performRecovery, + hasAccessToken, + isUserAuthenticated, + performAuth, + ]); return ; } diff --git a/lib/keyserver-conn/call-keyserver-endpoint-provider.react.js b/lib/keyserver-conn/call-keyserver-endpoint-provider.react.js --- a/lib/keyserver-conn/call-keyserver-endpoint-provider.react.js +++ b/lib/keyserver-conn/call-keyserver-endpoint-provider.react.js @@ -5,17 +5,17 @@ import * as React from 'react'; import { createSelector } from 'reselect'; -import { useKeyserverCallInfos } from './keyserver-call-infos.js'; +import { + useKeyserverCallInfos, + type KeyserverCallInfo, +} from './keyserver-call-infos.js'; import { setNewSession, type SingleKeyserverActionFunc, type ActionFunc, + setSessionRecoveryInProgressActionType, } from './keyserver-conn-types.js'; -import { - canResolveKeyserverSessionInvalidation, - resolveKeyserverSessionInvalidation, -} from './recovery-utils.js'; -import { logInActionSources } from '../types/account-types.js'; +import { canResolveKeyserverSessionInvalidation } from './recovery-utils.js'; import type { PlatformDetails } from '../types/device-types.js'; import type { Endpoint, SocketAPIHandler } from '../types/endpoints.js'; import type { Dispatch } from '../types/redux-types.js'; @@ -114,6 +114,7 @@ sessionID, currentUserInfo, isSocketConnected, + sessionRecoveryInProgress, canRecoverSession, lastCommunicatedPlatformDetails, keyserverID, @@ -145,63 +146,20 @@ // just let the caller callSingleKeyserverEndpoint instance continue return Promise.resolve(null); } - const ongoingRecoveryAttempt = - ongoingRecoveryAttemptsRef.current.get(keyserverID); - if (!ongoingRecoveryAttempt) { + if (!sessionRecoveryInProgress) { // Our cookie seems to be valid return Promise.resolve(null); } + const recoveryAttempts = ongoingRecoveryAttemptsRef.current; + let keyserverRecoveryAttempts = recoveryAttempts.get(keyserverID); + if (!keyserverRecoveryAttempts) { + keyserverRecoveryAttempts = { waitingCalls: [] }; + recoveryAttempts.set(keyserverID, keyserverRecoveryAttempts); + } + const ongoingRecoveryAttempts = keyserverRecoveryAttempts; // Wait to run until we get our new cookie return new Promise(r => - ongoingRecoveryAttempt.waitingCalls.push(r), - ); - }; - // These functions are helpers for cookieInvalidationRecovery, defined below - const attemptToResolveInvalidationHelper = async ( - sessionChange: ClientSessionChange, - ) => { - const newAnonymousCookie = sessionChange.cookie; - const newSessionChange = await resolveKeyserverSessionInvalidation( - dispatch, - newAnonymousCookie, - urlPrefix, - logInActionSources.cookieInvalidationResolutionAttempt, - keyserverID, - ); - - return newSessionChange - ? bindCookieAndUtilsIntoCallSingleKeyserverEndpoint({ - ...params, - cookie: newSessionChange.cookie, - sessionID: newSessionChange.sessionID, - currentUserInfo: newSessionChange.currentUserInfo, - }) - : null; - }; - const attemptToResolveInvalidation = ( - sessionChange: ClientSessionChange, - ) => { - return new Promise( - // eslint-disable-next-line no-async-promise-executor - async (resolve, reject) => { - try { - const newCallSingleKeyserverEndpoint = - await attemptToResolveInvalidationHelper(sessionChange); - const ongoingRecoveryAttempt = - ongoingRecoveryAttemptsRef.current.get(keyserverID); - ongoingRecoveryAttemptsRef.current.delete(keyserverID); - const currentWaitingCalls = - ongoingRecoveryAttempt?.waitingCalls ?? []; - - resolve(newCallSingleKeyserverEndpoint); - - for (const func of currentWaitingCalls) { - func(newCallSingleKeyserverEndpoint); - } - } catch (e) { - reject(e); - } - }, + ongoingRecoveryAttempts.waitingCalls.push(r), ); }; // If this function is called, callSingleKeyserverEndpoint got a response @@ -224,15 +182,23 @@ // user to log-in after a cookieInvalidation while logged out return Promise.resolve(null); } - const ongoingRecoveryAttempt = - ongoingRecoveryAttemptsRef.current.get(keyserverID); - if (ongoingRecoveryAttempt) { - return new Promise(r => - ongoingRecoveryAttempt.waitingCalls.push(r), - ); + + const recoveryAttempts = ongoingRecoveryAttemptsRef.current; + let keyserverRecoveryAttempts = recoveryAttempts.get(keyserverID); + if (!keyserverRecoveryAttempts) { + keyserverRecoveryAttempts = { waitingCalls: [] }; + recoveryAttempts.set(keyserverID, keyserverRecoveryAttempts); + } + if (!sessionRecoveryInProgress) { + dispatch({ + type: setSessionRecoveryInProgressActionType, + payload: { sessionRecoveryInProgress: true, keyserverID }, + }); } - ongoingRecoveryAttemptsRef.current.set(keyserverID, { waitingCalls: [] }); - return attemptToResolveInvalidation(sessionChange); + const ongoingRecoveryAttempts = keyserverRecoveryAttempts; + return new Promise(r => + ongoingRecoveryAttempts.waitingCalls.push(r), + ); }; return ( @@ -309,7 +275,7 @@ [bindCookieAndUtilsIntoCallSingleKeyserverEndpoint], ); - // SECTION 3: getBoundSingleKeyserverActionFunc + // SECTION 3: getCallSingleKeyserverEndpoint const dispatch = useDispatch(); const currentUserInfo = useSelector(state => state.currentUserInfo); @@ -349,6 +315,55 @@ ], ); + // SECTION 4: flush waitingCalls when sessionRecoveryInProgress flips to false + + const prevKeyserverCallInfosRef = React.useRef<{ + +[keyserverID: string]: KeyserverCallInfo, + }>(keyserverCallInfos); + React.useEffect(() => { + const sessionRecoveriesConcluded = new Set(); + const prevKeyserverCallInfos = prevKeyserverCallInfosRef.current; + for (const keyserverID in keyserverCallInfos) { + const prevKeyserverCallInfo = prevKeyserverCallInfos[keyserverID]; + if (!prevKeyserverCallInfo) { + continue; + } + const keyserverCallInfo = keyserverCallInfos[keyserverID]; + if ( + !keyserverCallInfo.sessionRecoveryInProgress && + prevKeyserverCallInfo.sessionRecoveryInProgress + ) { + sessionRecoveriesConcluded.add(keyserverID); + } + } + + for (const keyserverID of sessionRecoveriesConcluded) { + const recoveryAttempts = ongoingRecoveryAttemptsRef.current; + const keyserverRecoveryAttempts = recoveryAttempts.get(keyserverID); + if (!keyserverRecoveryAttempts) { + continue; + } + const { waitingCalls } = keyserverRecoveryAttempts; + if (waitingCalls.length === 0) { + continue; + } + + const { cookie } = keyserverCallInfos[keyserverID]; + const hasUserCookie = cookie && cookie.startsWith('user='); + + const boundCallSingleKeyserverEndpoint = hasUserCookie + ? getCallSingleKeyserverEndpoint(keyserverID) + : null; + for (const waitingCall of waitingCalls) { + waitingCall(boundCallSingleKeyserverEndpoint); + } + } + + prevKeyserverCallInfosRef.current = keyserverCallInfos; + }, [keyserverCallInfos, getCallSingleKeyserverEndpoint]); + + // SECTION 5: getBoundSingleKeyserverActionFunc + const createBoundSingleKeyserverActionFuncSelector: CreateBoundSingleKeyserverActionFuncSelector = React.useCallback( actionFunc => @@ -394,7 +409,7 @@ ], ); - // SECTION 4: getBoundKeyserverActionFunc + // SECTION 6: getBoundKeyserverActionFunc const callKeyserverEndpoint = React.useCallback( (