diff --git a/.github/workflows/test-codegen.yml b/.github/workflows/test-codegen.yml index d49dc383be..29df56d3ad 100644 --- a/.github/workflows/test-codegen.yml +++ b/.github/workflows/test-codegen.yml @@ -147,4 +147,4 @@ jobs: name: package - name: Run are-the-types-wrong - run: yarn dlx @arethetypeswrong/cli@latest ./package.tgz --format table + run: yarn dlx @arethetypeswrong/cli@latest ./package.tgz --format table --exclude-entrypoints cli diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 46f9d3475b..8558ad9a43 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -305,7 +305,7 @@ jobs: fail-fast: false matrix: node: ['22.x'] - ts: ['5.3', '5.4', '5.5', '5.6', '5.7', '5.8', 'next'] + ts: ['5.3', '5.4', '5.5', '5.6', '5.7', '5.8', '5.9', 'next'] example: [ { name: 'bundler', moduleResolution: 'Bundler' }, @@ -340,10 +340,11 @@ jobs: run: yarn workspace @examples-type-portability/${{ matrix.example.name }} run test - name: Test type portability with `moduleResolution Node10` + if: matrix.ts != 'next' && !startsWith(matrix.ts, '6.') run: yarn workspace @examples-type-portability/${{ matrix.example.name }} run test --module CommonJS --moduleResolution Node10 --preserveSymLinks --verbatimModuleSyntax false - name: Test type portability with `moduleResolution Node10` and `type module` in `package.json` - if: matrix.example.name == 'nodenext-esm' || matrix.example.name == 'bundler' + if: (matrix.example.name == 'nodenext-esm' || matrix.example.name == 'bundler') && matrix.ts != 'next' && !startsWith(matrix.ts, '6.') run: | npm --workspace=@examples-type-portability/${{ matrix.example.name }} pkg set type=module yarn workspace @examples-type-portability/${{ matrix.example.name }} run test --module ESNext --moduleResolution Node10 --preserveSymLinks --verbatimModuleSyntax false diff --git a/.gitignore b/.gitignore index a8309316f0..78e8787e73 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,8 @@ typesversions .pnp.* *.tgz -tsconfig.vitest-temp.json \ No newline at end of file +tsconfig.vitest-temp.json + +# node version manager files +.node-version +.nvmrc \ No newline at end of file diff --git a/docs/api/createReducer.mdx b/docs/api/createReducer.mdx index c320251074..38ff902425 100644 --- a/docs/api/createReducer.mdx +++ b/docs/api/createReducer.mdx @@ -90,6 +90,14 @@ const counterReducer = createReducer(initialState, (builder) => { [params,examples](docblock://mapBuilders.ts?token=ActionReducerMapBuilder.addCase) +### `builder.addAsyncThunk` + +[summary,remarks](docblock://mapBuilders.ts?token=ActionReducerMapBuilder.addAsyncThunk) + +#### Parameters + +[params,examples](docblock://mapBuilders.ts?token=ActionReducerMapBuilder.addAsyncThunk) + ### `builder.addMatcher` [summary,remarks](docblock://mapBuilders.ts?token=ActionReducerMapBuilder.addMatcher) diff --git a/docs/rtk-query/api/createApi.mdx b/docs/rtk-query/api/createApi.mdx index 2d38b41ffe..81c8c16d20 100644 --- a/docs/rtk-query/api/createApi.mdx +++ b/docs/rtk-query/api/createApi.mdx @@ -482,6 +482,8 @@ By default, this function will take the query arguments, sort object keys where [summary](docblock://query/createApi.ts?token=CreateApiOptions.keepUnusedDataFor) +[examples](docblock://query/createApi.ts?token=CreateApiOptions.keepUnusedDataFor) + ### `refetchOnMountOrArgChange` [summary](docblock://query/createApi.ts?token=CreateApiOptions.refetchOnMountOrArgChange) @@ -684,7 +686,7 @@ Overrides the api-wide definition of `keepUnusedDataFor` for this endpoint only. [summary](docblock://query/createApi.ts?token=CreateApiOptions.keepUnusedDataFor) -[examples](docblock://query/createApi.ts?token=CreateApiOptions.keepUnusedDataFor) +[examples](docblock://query/core/buildMiddleware/cacheCollection.ts?token=CacheCollectionQueryExtraOptions) ### `serializeQueryArgs` diff --git a/docs/rtk-query/usage/cache-behavior.mdx b/docs/rtk-query/usage/cache-behavior.mdx index 84b46a2048..baafd65993 100644 --- a/docs/rtk-query/usage/cache-behavior.mdx +++ b/docs/rtk-query/usage/cache-behavior.mdx @@ -122,7 +122,7 @@ Alternatively, you can dispatch the `initiate` thunk action for an endpoint, pas ```tsx no-transpile title="Force refetch example" import { useDispatch } from 'react-redux' -import { useGetPostsQuery } from './api' +import { api, useGetPostsQuery } from './api' const Component = () => { const dispatch = useDispatch() diff --git a/docs/rtk-query/usage/code-generation.mdx b/docs/rtk-query/usage/code-generation.mdx index 9985d8c9f1..5e84bd914a 100644 --- a/docs/rtk-query/usage/code-generation.mdx +++ b/docs/rtk-query/usage/code-generation.mdx @@ -84,6 +84,17 @@ const api = await generateEndpoints({ }) ``` +#### With Node.js Child process + +```ts no-transpile title="bin/openapi-codegen.ts" +import { exec } from 'node:child_process' + +const cliPath = require.resolve('@rtk-query/codegen-openapi/cli') + +// you can also use esbuild-runner (esr) or ts-node instead of tsx +exec(`tsx ${cliPath} config.ts`) +``` + ### Config file options #### Simple usage diff --git a/errors.json b/errors.json index 8169973800..984b6dd650 100644 --- a/errors.json +++ b/errors.json @@ -41,5 +41,6 @@ "39": "called \\`injectEndpoints\\` to override already-existing endpointName without specifying \\`overrideExisting: true\\`", "40": "maxPages for endpoint '' must be a number greater than 0", "41": "getPreviousPageParam for endpoint '' must be a function if maxPages is used", - "42": "Duplicate middleware references found when creating the store. Ensure that each middleware is only included once." + "42": "Duplicate middleware references found when creating the store. Ensure that each middleware is only included once.", + "43": "`builder.addAsyncThunk` should only be called before calling `builder.addDefaultCase`" } \ No newline at end of file diff --git a/packages/rtk-query-codegen-openapi/package.json b/packages/rtk-query-codegen-openapi/package.json index 1213e5766a..e9bea3327a 100644 --- a/packages/rtk-query-codegen-openapi/package.json +++ b/packages/rtk-query-codegen-openapi/package.json @@ -18,7 +18,8 @@ "types": "./lib/index.d.ts", "default": "./lib/index.js" } - } + }, + "./cli": "./lib/bin/cli.mjs" }, "repository": { "type": "git", @@ -77,7 +78,7 @@ "@apidevtools/swagger-parser": "^10.1.1", "commander": "^6.2.0", "lodash.camelcase": "^4.3.0", - "oazapfts": "^6.1.0", + "oazapfts": "^6.3.0", "prettier": "^3.2.5", "semver": "^7.3.5", "swagger2openapi": "^7.0.4", diff --git a/packages/rtk-query-codegen-openapi/src/generate.ts b/packages/rtk-query-codegen-openapi/src/generate.ts index d128d100b4..c43ed65946 100644 --- a/packages/rtk-query-codegen-openapi/src/generate.ts +++ b/packages/rtk-query-codegen-openapi/src/generate.ts @@ -117,6 +117,7 @@ export async function generateApi( useEnumType = false, mergeReadWriteOnly = false, httpResolverOptions, + useUnknown = false, }: GenerationOptions ) { const v3Doc = (v3DocCache[spec] ??= await getV3Doc(spec, httpResolverOptions)); @@ -125,6 +126,7 @@ export async function generateApi( unionUndefined, useEnumType, mergeReadWriteOnly, + useUnknown, }); // temporary workaround for https://github.com/oazapfts/oazapfts/issues/491 @@ -448,7 +450,11 @@ export async function generateApi( const encodedValue = encodeQueryParams && param.param?.in === 'query' ? factory.createConditionalExpression( - value, + factory.createBinaryExpression( + value, + ts.SyntaxKind.ExclamationEqualsToken, + factory.createNull() + ), undefined, factory.createCallExpression(factory.createIdentifier('encodeURIComponent'), undefined, [ factory.createCallExpression(factory.createIdentifier('String'), undefined, [value]), diff --git a/packages/rtk-query-codegen-openapi/src/types.ts b/packages/rtk-query-codegen-openapi/src/types.ts index 437e058087..64473ba32f 100644 --- a/packages/rtk-query-codegen-openapi/src/types.ts +++ b/packages/rtk-query-codegen-openapi/src/types.ts @@ -38,79 +38,89 @@ export interface CommonOptions { */ schemaFile: string; /** - * defaults to "api" + * @default "api" */ apiImport?: string; /** - * defaults to "enhancedApi" + * @default "enhancedApi" */ exportName?: string; /** - * defaults to "ApiArg" + * @default "ApiArg" */ argSuffix?: string; /** - * defaults to "ApiResponse" + * @default "ApiResponse" */ responseSuffix?: string; /** - * defaults to empty + * @default "" */ operationNameSuffix?: string; /** - * defaults to `false` * `true` will generate hooks for queries and mutations, but no lazyQueries + * @default false */ hooks?: boolean | { queries: boolean; lazyQueries: boolean; mutations: boolean }; /** - * defaults to false * `true` will generate a union type for `undefined` properties like: `{ id?: string | undefined }` instead of `{ id?: string }` + * @default false */ unionUndefined?: boolean; /** - * defaults to false * `true` will result in all generated endpoints having `providesTags`/`invalidatesTags` declarations for the `tags` of their respective operation definition + * @default false * @see https://redux-toolkit.js.org/rtk-query/usage/code-generation for more information */ tag?: boolean; /** - * defaults to false * `true` will add `encodeURIComponent` to the generated path parameters + * @default false */ encodePathParams?: boolean; /** - * defaults to false * `true` will add `encodeURIComponent` to the generated query parameters + * @default false */ encodeQueryParams?: boolean; /** - * defaults to false * `true` will "flatten" the arg so that you can do things like `useGetEntityById(1)` instead of `useGetEntityById({ entityId: 1 })` + * @default false */ flattenArg?: boolean; /** - * default to false * If set to `true`, the default response type will be included in the generated code for all endpoints. + * @default false * @see https://swagger.io/docs/specification/describing-responses/#default */ includeDefault?: boolean; /** - * default to false * `true` will not generate separate types for read-only and write-only properties. + * @default false */ mergeReadWriteOnly?: boolean; /** - * * HTTPResolverOptions object that is passed to the SwaggerParser bundle function. */ httpResolverOptions?: SwaggerParser.HTTPResolverOptions; /** - * defaults to undefined * If present the given file will be used as prettier config when formatting the generated code. If undefined the default prettier config * resolution mechanism will be used. + * @default undefined */ prettierConfigFile?: string; + + /** + * Determines the fallback type for empty schemas. + * + * If set to **`true`**, **`unknown`** will be used + * instead of **`any`** when a schema is empty. + * + * @default false + * @since 2.1.0 + */ + useUnknown?: boolean; } export type TextMatcher = string | RegExp | (string | RegExp)[]; @@ -128,8 +138,8 @@ export interface OutputFileOptions extends Partial { filterEndpoints?: EndpointMatcher; endpointOverrides?: EndpointOverrides[]; /** - * defaults to false * If passed as true it will generate TS enums instead of union of strings + * @default false */ useEnumType?: boolean; } diff --git a/packages/rtk-query-codegen-openapi/test/generateEndpoints.test.ts b/packages/rtk-query-codegen-openapi/test/generateEndpoints.test.ts index 93ed235849..50024e7210 100644 --- a/packages/rtk-query-codegen-openapi/test/generateEndpoints.test.ts +++ b/packages/rtk-query-codegen-openapi/test/generateEndpoints.test.ts @@ -242,7 +242,7 @@ describe('option encodeQueryParams', () => { }); expect(api).toMatch( - /params:\s*{\s*\n\s*status:\s*queryArg\.status\s*\?\s*encodeURIComponent\(\s*String\(queryArg\.status\)\s*\)\s*:\s*undefined\s*,?\s*\n\s*}/s + /params:\s*{\s*\n\s*status:\s*queryArg\.status\s*!=\s*null\s*\?\s*encodeURIComponent\(\s*String\(queryArg\.status\)\s*\)\s*:\s*undefined\s*,?\s*\n\s*}/s ); }); diff --git a/packages/toolkit/package.json b/packages/toolkit/package.json index a17129a82b..5875dca388 100644 --- a/packages/toolkit/package.json +++ b/packages/toolkit/package.json @@ -1,6 +1,6 @@ { "name": "@reduxjs/toolkit", - "version": "2.8.2", + "version": "2.9.0", "description": "The official, opinionated, batteries-included toolset for efficient Redux development", "author": "Mark Erikson ", "license": "MIT", diff --git a/packages/toolkit/src/createSlice.ts b/packages/toolkit/src/createSlice.ts index 226d6cb8c8..bb7a64c8c3 100644 --- a/packages/toolkit/src/createSlice.ts +++ b/packages/toolkit/src/createSlice.ts @@ -23,7 +23,11 @@ import type { ReducerWithInitialState, } from './createReducer' import { createReducer } from './createReducer' -import type { ActionReducerMapBuilder, TypedActionCreator } from './mapBuilders' +import type { + ActionReducerMapBuilder, + AsyncThunkReducers, + TypedActionCreator, +} from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' import type { Id, TypeGuard } from './tsHelpers' import { getOrInsertComputed } from './utils' @@ -300,25 +304,7 @@ type AsyncThunkSliceReducerConfig< ThunkArg extends any, Returned = unknown, ThunkApiConfig extends AsyncThunkConfig = {}, -> = { - pending?: CaseReducer< - State, - ReturnType['pending']> - > - rejected?: CaseReducer< - State, - ReturnType['rejected']> - > - fulfilled?: CaseReducer< - State, - ReturnType['fulfilled']> - > - settled?: CaseReducer< - State, - ReturnType< - AsyncThunk['rejected' | 'fulfilled'] - > - > +> = AsyncThunkReducers & { options?: AsyncThunkOptions } diff --git a/packages/toolkit/src/index.ts b/packages/toolkit/src/index.ts index 65291ab83b..f6ada698d6 100644 --- a/packages/toolkit/src/index.ts +++ b/packages/toolkit/src/index.ts @@ -10,7 +10,7 @@ export { original, isDraft, } from 'immer' -export type { Draft } from 'immer' +export type { Draft, WritableDraft } from 'immer' export { createSelector, createSelectorCreator, @@ -104,6 +104,7 @@ export type { export type { // types ActionReducerMapBuilder, + AsyncThunkReducers, } from './mapBuilders' export { Tuple } from './utils' @@ -126,6 +127,8 @@ export { } from './createAsyncThunk' export type { AsyncThunk, + AsyncThunkConfig, + AsyncThunkDispatchConfig, AsyncThunkOptions, AsyncThunkAction, AsyncThunkPayloadCreatorReturnValue, diff --git a/packages/toolkit/src/mapBuilders.ts b/packages/toolkit/src/mapBuilders.ts index 31ba5b4142..36ca6f5840 100644 --- a/packages/toolkit/src/mapBuilders.ts +++ b/packages/toolkit/src/mapBuilders.ts @@ -5,6 +5,33 @@ import type { ActionMatcherDescriptionCollection, } from './createReducer' import type { TypeGuard } from './tsHelpers' +import type { AsyncThunk, AsyncThunkConfig } from './createAsyncThunk' + +export type AsyncThunkReducers< + State, + ThunkArg extends any, + Returned = unknown, + ThunkApiConfig extends AsyncThunkConfig = {}, +> = { + pending?: CaseReducer< + State, + ReturnType['pending']> + > + rejected?: CaseReducer< + State, + ReturnType['rejected']> + > + fulfilled?: CaseReducer< + State, + ReturnType['fulfilled']> + > + settled?: CaseReducer< + State, + ReturnType< + AsyncThunk['rejected' | 'fulfilled'] + > + > +} export type TypedActionCreator = { (...args: any[]): Action @@ -31,7 +58,7 @@ export interface ActionReducerMapBuilder { /** * Adds a case reducer to handle a single exact action type. * @remarks - * All calls to `builder.addCase` must come before any calls to `builder.addMatcher` or `builder.addDefaultCase`. + * All calls to `builder.addCase` must come before any calls to `builder.addAsyncThunk`, `builder.addMatcher` or `builder.addDefaultCase`. * @param actionCreator - Either a plain action type string, or an action creator generated by [`createAction`](./createAction) that can be used to determine the action type. * @param reducer - The actual case reducer function. */ @@ -40,12 +67,53 @@ export interface ActionReducerMapBuilder { reducer: CaseReducer, ): ActionReducerMapBuilder + /** + * Adds case reducers to handle actions based on a `AsyncThunk` action creator. + * @remarks + * All calls to `builder.addAsyncThunk` must come before after any calls to `builder.addCase` and before any calls to `builder.addMatcher` or `builder.addDefaultCase`. + * @param asyncThunk - The async thunk action creator itself. + * @param reducers - A mapping from each of the `AsyncThunk` action types to the case reducer that should handle those actions. + * @example +```ts no-transpile +import { createAsyncThunk, createReducer } from '@reduxjs/toolkit' + +const fetchUserById = createAsyncThunk('users/fetchUser', async (id) => { + const response = await fetch(`https://reqres.in/api/users/${id}`) + return (await response.json()).data +}) + +const reducer = createReducer(initialState, (builder) => { + builder.addAsyncThunk(fetchUserById, { + pending: (state, action) => { + state.fetchUserById.loading = 'pending' + }, + fulfilled: (state, action) => { + state.fetchUserById.data = action.payload + }, + rejected: (state, action) => { + state.fetchUserById.error = action.error + }, + settled: (state, action) => { + state.fetchUserById.loading = action.meta.requestStatus + }, + }) +}) + */ + addAsyncThunk< + Returned, + ThunkArg, + ThunkApiConfig extends AsyncThunkConfig = {}, + >( + asyncThunk: AsyncThunk, + reducers: AsyncThunkReducers, + ): Omit, 'addCase'> + /** * Allows you to match your incoming actions against your own filter function instead of only the `action.type` property. * @remarks * If multiple matcher reducers match, all of them will be executed in the order * they were defined in - even if a case reducer already matched. - * All calls to `builder.addMatcher` must come after any calls to `builder.addCase` and before any calls to `builder.addDefaultCase`. + * All calls to `builder.addMatcher` must come after any calls to `builder.addCase` and `builder.addAsyncThunk` and before any calls to `builder.addDefaultCase`. * @param matcher - A matcher function. In TypeScript, this should be a [type predicate](https://www.typescriptlang.org/docs/handbook/2/narrowing.html#using-type-predicates) * function * @param reducer - The actual case reducer function. @@ -99,7 +167,7 @@ const reducer = createReducer(initialState, (builder) => { addMatcher( matcher: TypeGuard | ((action: any) => boolean), reducer: CaseReducer, - ): Omit, 'addCase'> + ): Omit, 'addCase' | 'addAsyncThunk'> /** * Adds a "default case" reducer that is executed if no case reducer and no matcher @@ -173,6 +241,35 @@ export function executeReducerBuilderCallback( actionsMap[type] = reducer return builder }, + addAsyncThunk< + Returned, + ThunkArg, + ThunkApiConfig extends AsyncThunkConfig = {}, + >( + asyncThunk: AsyncThunk, + reducers: AsyncThunkReducers, + ) { + if (process.env.NODE_ENV !== 'production') { + // since this uses both action cases and matchers, we can't enforce the order in runtime other than checking for default case + if (defaultCaseReducer) { + throw new Error( + '`builder.addAsyncThunk` should only be called before calling `builder.addDefaultCase`', + ) + } + } + if (reducers.pending) + actionsMap[asyncThunk.pending.type] = reducers.pending + if (reducers.rejected) + actionsMap[asyncThunk.rejected.type] = reducers.rejected + if (reducers.fulfilled) + actionsMap[asyncThunk.fulfilled.type] = reducers.fulfilled + if (reducers.settled) + actionMatchers.push({ + matcher: asyncThunk.settled, + reducer: reducers.settled, + }) + return builder + }, addMatcher( matcher: TypeGuard, reducer: CaseReducer, diff --git a/packages/toolkit/src/query/core/apiState.ts b/packages/toolkit/src/query/core/apiState.ts index 7e1c1420e0..91cccf9ca0 100644 --- a/packages/toolkit/src/query/core/apiState.ts +++ b/packages/toolkit/src/query/core/apiState.ts @@ -148,6 +148,7 @@ export type SubscriptionOptions = { */ refetchOnFocus?: boolean } +export type SubscribersInternal = Map export type Subscribers = { [requestId: string]: SubscriptionOptions } export type QueryKeys = { [K in keyof Definitions]: Definitions[K] extends QueryDefinition< @@ -327,6 +328,8 @@ export type QueryState = { | undefined } +export type SubscriptionInternalState = Map + export type SubscriptionState = { [queryCacheKey: string]: Subscribers | undefined } diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index f5af4d8667..54ba080373 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -42,6 +42,7 @@ import type { ThunkApiMetaConfig, } from './buildThunks' import type { ApiEndpointQuery } from './module' +import type { InternalMiddlewareState } from './buildMiddleware/types' export type BuildInitiateApiEndpointQuery< Definition extends QueryDefinition, @@ -270,6 +271,7 @@ export function buildInitiate({ mutationThunk, api, context, + internalState, }: { serializeQueryArgs: InternalSerializeQueryArgs queryThunk: QueryThunk @@ -277,20 +279,9 @@ export function buildInitiate({ mutationThunk: MutationThunk api: Api context: ApiContext + internalState: InternalMiddlewareState }) { - const runningQueries: Map< - Dispatch, - Record< - string, - | QueryActionCreatorResult - | InfiniteQueryActionCreatorResult - | undefined - > - > = new Map() - const runningMutations: Map< - Dispatch, - Record | undefined> - > = new Map() + const { runningQueries, runningMutations } = internalState const { unsubscribeQueryResult, diff --git a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts index b847dee01a..bffbc32ebe 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts @@ -1,12 +1,12 @@ import type { InternalHandlerBuilder, SubscriptionSelectors } from './types' -import type { SubscriptionState } from '../apiState' +import type { SubscriptionInternalState, SubscriptionState } from '../apiState' import { produceWithPatches } from 'immer' import type { Action } from '@reduxjs/toolkit' -import { countObjectKeys } from '../../utils/countObjectKeys' +import { getOrInsertComputed, createNewMap } from '../../utils/getOrInsert' export const buildBatchedActionsHandler: InternalHandlerBuilder< [actionShouldContinue: boolean, returnValue: SubscriptionSelectors | boolean] -> = ({ api, queryThunk, internalState }) => { +> = ({ api, queryThunk, internalState, mwApi }) => { const subscriptionsPrefix = `${api.reducerPath}/subscriptions` let previousSubscriptions: SubscriptionState = @@ -20,58 +20,63 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< // Actually intentionally mutate the subscriptions state used in the middleware // This is done to speed up perf when loading many components const actuallyMutateSubscriptions = ( - mutableState: SubscriptionState, + currentSubscriptions: SubscriptionInternalState, action: Action, ) => { if (updateSubscriptionOptions.match(action)) { const { queryCacheKey, requestId, options } = action.payload - if (mutableState?.[queryCacheKey]?.[requestId]) { - mutableState[queryCacheKey]![requestId] = options + const sub = currentSubscriptions.get(queryCacheKey) + if (sub?.has(requestId)) { + sub.set(requestId, options) } return true } if (unsubscribeQueryResult.match(action)) { const { queryCacheKey, requestId } = action.payload - if (mutableState[queryCacheKey]) { - delete mutableState[queryCacheKey]![requestId] + const sub = currentSubscriptions.get(queryCacheKey) + if (sub) { + sub.delete(requestId) } return true } if (api.internalActions.removeQueryResult.match(action)) { - delete mutableState[action.payload.queryCacheKey] + currentSubscriptions.delete(action.payload.queryCacheKey) return true } if (queryThunk.pending.match(action)) { const { meta: { arg, requestId }, } = action - const substate = (mutableState[arg.queryCacheKey] ??= {}) - substate[`${requestId}_running`] = {} + const substate = getOrInsertComputed( + currentSubscriptions, + arg.queryCacheKey, + createNewMap, + ) if (arg.subscribe) { - substate[requestId] = - arg.subscriptionOptions ?? substate[requestId] ?? {} + substate.set( + requestId, + arg.subscriptionOptions ?? substate.get(requestId) ?? {}, + ) } return true } let mutated = false - if ( - queryThunk.fulfilled.match(action) || - queryThunk.rejected.match(action) - ) { - const state = mutableState[action.meta.arg.queryCacheKey] || {} - const key = `${action.meta.requestId}_running` - mutated ||= !!state[key] - delete state[key] - } + if (queryThunk.rejected.match(action)) { const { meta: { condition, arg, requestId }, } = action if (condition && arg.subscribe) { - const substate = (mutableState[arg.queryCacheKey] ??= {}) - substate[requestId] = - arg.subscriptionOptions ?? substate[requestId] ?? {} + const substate = getOrInsertComputed( + currentSubscriptions, + arg.queryCacheKey, + createNewMap, + ) + substate.set( + requestId, + arg.subscriptionOptions ?? substate.get(requestId) ?? {}, + ) mutated = true } @@ -83,12 +88,12 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< const getSubscriptions = () => internalState.currentSubscriptions const getSubscriptionCount = (queryCacheKey: string) => { const subscriptions = getSubscriptions() - const subscriptionsForQueryArg = subscriptions[queryCacheKey] ?? {} - return countObjectKeys(subscriptionsForQueryArg) + const subscriptionsForQueryArg = subscriptions.get(queryCacheKey) + return subscriptionsForQueryArg?.size ?? 0 } const isRequestSubscribed = (queryCacheKey: string, requestId: string) => { const subscriptions = getSubscriptions() - return !!subscriptions?.[queryCacheKey]?.[requestId] + return !!subscriptions?.get(queryCacheKey)?.get(requestId) } const subscriptionSelectors: SubscriptionSelectors = { @@ -97,6 +102,21 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< isRequestSubscribed, } + function serializeSubscriptions( + currentSubscriptions: SubscriptionInternalState, + ): SubscriptionState { + // We now use nested Maps for subscriptions, instead of + // plain Records. Stringify this accordingly so we can + // convert it to the shape we need for the store. + return JSON.parse( + JSON.stringify( + Object.fromEntries( + [...currentSubscriptions].map(([k, v]) => [k, Object.fromEntries(v)]), + ), + ), + ) + } + return ( action, mwApi, @@ -106,13 +126,14 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< ] => { if (!previousSubscriptions) { // Initialize it the first time this handler runs - previousSubscriptions = JSON.parse( - JSON.stringify(internalState.currentSubscriptions), + previousSubscriptions = serializeSubscriptions( + internalState.currentSubscriptions, ) } if (api.util.resetApiState.match(action)) { - previousSubscriptions = internalState.currentSubscriptions = {} + previousSubscriptions = {} + internalState.currentSubscriptions.clear() updateSyncTimer = null return [true, false] } @@ -133,6 +154,15 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< let actionShouldContinue = true + // HACK Sneak the test-only polling state back out + if ( + process.env.NODE_ENV === 'test' && + typeof action.type === 'string' && + action.type === `${api.reducerPath}/getPolling` + ) { + return [false, internalState.currentPolls] as any + } + if (didMutate) { if (!updateSyncTimer) { // We only use the subscription state for the Redux DevTools at this point, @@ -142,8 +172,8 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< // In 1.9, it was updated in a microtask, but now we do it at most every 500ms. updateSyncTimer = setTimeout(() => { // Deep clone the current subscription data - const newSubscriptions: SubscriptionState = JSON.parse( - JSON.stringify(internalState.currentSubscriptions), + const newSubscriptions: SubscriptionState = serializeSubscriptions( + internalState.currentSubscriptions, ) // Figure out a smaller diff between original and current const [, patches] = produceWithPatches( diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts index 2f1024d13b..94737d3ea3 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts @@ -1,5 +1,5 @@ import type { QueryDefinition } from '../../endpointDefinitions' -import type { ConfigState, QueryCacheKey } from '../apiState' +import type { ConfigState, QueryCacheKey, QuerySubState } from '../apiState' import { isAnyOf } from '../rtkImports' import type { ApiMiddlewareInternalHandler, @@ -11,16 +11,30 @@ import type { export type ReferenceCacheCollection = never -function isObjectEmpty(obj: Record) { - // Apparently a for..in loop is faster than `Object.keys()` here: - // https://stackoverflow.com/a/59787784/62937 - for (const k in obj) { - // If there is at least one key, it's not empty - return false - } - return true -} - +/** + * @example + * ```ts + * // codeblock-meta title="keepUnusedDataFor example" + * import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react' + * interface Post { + * id: number + * name: string + * } + * type PostsResponse = Post[] + * + * const api = createApi({ + * baseQuery: fetchBaseQuery({ baseUrl: '/' }), + * endpoints: (build) => ({ + * getPosts: build.query({ + * query: () => 'posts', + * // highlight-start + * keepUnusedDataFor: 5 + * // highlight-end + * }) + * }) + * }) + * ``` + */ export type CacheCollectionQueryExtraOptions = { /** * Overrides the api-wide definition of `keepUnusedDataFor` for this endpoint only. _(This value is in seconds.)_ @@ -44,10 +58,14 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ context, internalState, selectors: { selectQueryEntry, selectConfig }, + getRunningQueryThunk, + mwApi, }) => { const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } = api.internalActions + const runningQueries = internalState.runningQueries.get(mwApi.dispatch)! + const canTriggerUnsubscribe = isAnyOf( unsubscribeQueryResult.match, queryThunk.fulfilled, @@ -56,8 +74,14 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ ) function anySubscriptionsRemainingForKey(queryCacheKey: string) { - const subscriptions = internalState.currentSubscriptions[queryCacheKey] - return !!subscriptions && !isObjectEmpty(subscriptions) + const subscriptions = internalState.currentSubscriptions.get(queryCacheKey) + if (!subscriptions) { + return false + } + + const hasSubscriptions = subscriptions.size > 0 + const isRunning = runningQueries?.[queryCacheKey] !== undefined + return hasSubscriptions || isRunning } const currentRemovalTimeouts: QueryStateMeta = {} @@ -69,6 +93,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ ) => { const state = mwApi.getState() const config = selectConfig(state) + if (canTriggerUnsubscribe(action)) { let queryCacheKeys: QueryCacheKey[] @@ -114,18 +139,20 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ const state = api.getState() for (const queryCacheKey of cacheKeys) { const entry = selectQueryEntry(state, queryCacheKey) - handleUnsubscribe(queryCacheKey, entry?.endpointName, api, config) + if (entry?.endpointName) { + handleUnsubscribe(queryCacheKey, entry.endpointName, api, config) + } } } function handleUnsubscribe( queryCacheKey: QueryCacheKey, - endpointName: string | undefined, + endpointName: string, api: SubMiddlewareApi, config: ConfigState, ) { const endpointDefinition = context.endpointDefinitions[ - endpointName! + endpointName ] as QueryDefinition const keepUnusedDataFor = endpointDefinition?.keepUnusedDataFor ?? config.keepUnusedDataFor @@ -151,6 +178,15 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ currentRemovalTimeouts[queryCacheKey] = setTimeout(() => { if (!anySubscriptionsRemainingForKey(queryCacheKey)) { + // Try to abort any running query for this cache key + const entry = selectQueryEntry(api.getState(), queryCacheKey) + + if (entry?.endpointName) { + const runningQuery = api.dispatch( + getRunningQueryThunk(entry.endpointName, entry.originalArgs), + ) + runningQuery?.abort() + } api.dispatch(removeQueryResult({ queryCacheKey })) } delete currentRemovalTimeouts![queryCacheKey] diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index 1f55b5ef22..4a24e1ec88 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -45,7 +45,7 @@ export function buildMiddleware< ReducerPath extends string, TagTypes extends string, >(input: BuildMiddlewareInput) { - const { reducerPath, queryThunk, api, context } = input + const { reducerPath, queryThunk, api, context, internalState } = input const { apiUid } = context const actions = { @@ -73,10 +73,6 @@ export function buildMiddleware< > = (mwApi) => { let initialized = false - const internalState: InternalMiddlewareState = { - currentSubscriptions: {}, - } - const builderArgs = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, @@ -86,6 +82,7 @@ export function buildMiddleware< internalState, refetchQuery, isThisApiSliceAction, + mwApi, } const handlers = handlerBuilders.map((build) => build(builderArgs)) diff --git a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts index 78d90aa725..ae50030894 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts @@ -17,9 +17,8 @@ import type { SubMiddlewareApi, InternalHandlerBuilder, ApiMiddlewareInternalHandler, - InternalMiddlewareState, } from './types' -import { countObjectKeys } from '../../utils/countObjectKeys' +import { getOrInsertComputed, createNewMap } from '../../utils/getOrInsert' export const buildInvalidationByTagsHandler: InternalHandlerBuilder = ({ reducerPath, @@ -111,11 +110,14 @@ export const buildInvalidationByTagsHandler: InternalHandlerBuilder = ({ const valuesArray = Array.from(toInvalidate.values()) for (const { queryCacheKey } of valuesArray) { const querySubState = state.queries[queryCacheKey] - const subscriptionSubState = - internalState.currentSubscriptions[queryCacheKey] ?? {} + const subscriptionSubState = getOrInsertComputed( + internalState.currentSubscriptions, + queryCacheKey, + createNewMap, + ) if (querySubState) { - if (countObjectKeys(subscriptionSubState) === 0) { + if (subscriptionSubState.size === 0) { mwApi.dispatch( removeQueryResult({ queryCacheKey: queryCacheKey as QueryCacheKey, diff --git a/packages/toolkit/src/query/core/buildMiddleware/polling.ts b/packages/toolkit/src/query/core/buildMiddleware/polling.ts index 94b6258842..70f7b177d0 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/polling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/polling.ts @@ -2,6 +2,7 @@ import type { QueryCacheKey, QuerySubstateIdentifier, Subscribers, + SubscribersInternal, } from '../apiState' import { QueryStatus } from '../apiState' import type { @@ -20,25 +21,25 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ refetchQuery, internalState, }) => { - const currentPolls: QueryStateMeta<{ - nextPollTimestamp: number - timeout?: TimeoutId - pollingInterval: number - }> = {} + const { currentPolls, currentSubscriptions } = internalState + + // Batching state for polling updates + const pendingPollingUpdates = new Set() + let pollingUpdateTimer: ReturnType | null = null const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { if ( api.internalActions.updateSubscriptionOptions.match(action) || api.internalActions.unsubscribeQueryResult.match(action) ) { - updatePollingInterval(action.payload, mwApi) + schedulePollingUpdate(action.payload.queryCacheKey, mwApi) } if ( queryThunk.pending.match(action) || (queryThunk.rejected.match(action) && action.meta.condition) ) { - updatePollingInterval(action.meta.arg, mwApi) + schedulePollingUpdate(action.meta.arg.queryCacheKey, mwApi) } if ( @@ -50,6 +51,27 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ if (api.util.resetApiState.match(action)) { clearPolls() + // Clear any pending updates + if (pollingUpdateTimer) { + clearTimeout(pollingUpdateTimer) + pollingUpdateTimer = null + } + pendingPollingUpdates.clear() + } + } + + function schedulePollingUpdate(queryCacheKey: string, api: SubMiddlewareApi) { + pendingPollingUpdates.add(queryCacheKey) + + if (!pollingUpdateTimer) { + pollingUpdateTimer = setTimeout(() => { + // Process all pending updates in a single batch + for (const key of pendingPollingUpdates) { + updatePollingInterval({ queryCacheKey: key as any }, api) + } + pendingPollingUpdates.clear() + pollingUpdateTimer = null + }, 0) } } @@ -59,7 +81,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) return @@ -73,7 +95,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) return @@ -82,7 +104,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ findLowestPollingInterval(subscriptions) if (!Number.isFinite(lowestPollingInterval)) return - const currentPoll = currentPolls[queryCacheKey] + const currentPoll = currentPolls.get(queryCacheKey) if (currentPoll?.timeout) { clearTimeout(currentPoll.timeout) @@ -91,7 +113,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ const nextPollTimestamp = Date.now() + lowestPollingInterval - currentPolls[queryCacheKey] = { + currentPolls.set(queryCacheKey, { nextPollTimestamp, pollingInterval: lowestPollingInterval, timeout: setTimeout(() => { @@ -100,7 +122,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ } startNextPoll({ queryCacheKey }, api) }, lowestPollingInterval), - } + }) } function updatePollingInterval( @@ -109,7 +131,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) { return @@ -117,12 +139,21 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ const { lowestPollingInterval } = findLowestPollingInterval(subscriptions) + // HACK add extra data to track how many times this has been called in tests + // yes we're mutating a nonexistent field on a Map here + if (process.env.NODE_ENV === 'test') { + const updateCounters = ((currentPolls as any).pollUpdateCounters ??= {}) + updateCounters[queryCacheKey] ??= 0 + updateCounters[queryCacheKey]++ + } + if (!Number.isFinite(lowestPollingInterval)) { cleanupPollForKey(queryCacheKey) return } - const currentPoll = currentPolls[queryCacheKey] + const currentPoll = currentPolls.get(queryCacheKey) + const nextPollTimestamp = Date.now() + lowestPollingInterval if (!currentPoll || nextPollTimestamp < currentPoll.nextPollTimestamp) { @@ -131,30 +162,33 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ } function cleanupPollForKey(key: string) { - const existingPoll = currentPolls[key] + const existingPoll = currentPolls.get(key) if (existingPoll?.timeout) { clearTimeout(existingPoll.timeout) } - delete currentPolls[key] + currentPolls.delete(key) } function clearPolls() { - for (const key of Object.keys(currentPolls)) { + for (const key of currentPolls.keys()) { cleanupPollForKey(key) } } - function findLowestPollingInterval(subscribers: Subscribers = {}) { + function findLowestPollingInterval( + subscribers: SubscribersInternal = new Map(), + ) { let skipPollingIfUnfocused: boolean | undefined = false let lowestPollingInterval = Number.POSITIVE_INFINITY - for (let key in subscribers) { - if (!!subscribers[key].pollingInterval) { + + for (const entry of subscribers.values()) { + if (!!entry.pollingInterval) { lowestPollingInterval = Math.min( - subscribers[key].pollingInterval!, + entry.pollingInterval!, lowestPollingInterval, ) skipPollingIfUnfocused = - subscribers[key].skipPollingIfUnfocused || skipPollingIfUnfocused + entry.skipPollingIfUnfocused || skipPollingIfUnfocused } } diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index a95acd2bf9..e3ffe3340e 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -1,6 +1,7 @@ import type { Action, AsyncThunkAction, + Dispatch, Middleware, MiddlewareAPI, ThunkAction, @@ -16,6 +17,7 @@ import type { QueryStatus, QuerySubState, RootState, + SubscriptionInternalState, SubscriptionState, } from '../apiState' import type { @@ -25,18 +27,42 @@ import type { QueryThunkArg, ThunkResult, } from '../buildThunks' -import type { QueryActionCreatorResult } from '../buildInitiate' +import type { + InfiniteQueryActionCreatorResult, + MutationActionCreatorResult, + QueryActionCreatorResult, +} from '../buildInitiate' import type { AllSelectors } from '../buildSelectors' export type QueryStateMeta = Record export type TimeoutId = ReturnType +type QueryPollState = { + nextPollTimestamp: number + timeout?: TimeoutId + pollingInterval: number +} + export interface InternalMiddlewareState { - currentSubscriptions: SubscriptionState + currentSubscriptions: SubscriptionInternalState + currentPolls: Map + runningQueries: Map< + Dispatch, + Record< + string, + | QueryActionCreatorResult + | InfiniteQueryActionCreatorResult + | undefined + > + > + runningMutations: Map< + Dispatch, + Record | undefined> + > } export interface SubscriptionSelectors { - getSubscriptions: () => SubscriptionState + getSubscriptions: () => SubscriptionInternalState getSubscriptionCount: (queryCacheKey: string) => number isRequestSubscribed: (queryCacheKey: string, requestId: string) => boolean } @@ -54,6 +80,11 @@ export interface BuildMiddlewareInput< api: Api assertTagType: AssertTagTypes selectors: AllSelectors + getRunningQueryThunk: ( + endpointName: string, + queryArgs: any, + ) => (dispatch: Dispatch) => QueryActionCreatorResult | undefined + internalState: InternalMiddlewareState } export type SubMiddlewareApi = MiddlewareAPI< @@ -72,6 +103,10 @@ export interface BuildSubMiddlewareInput ): ThunkAction, any, any, UnknownAction> isThisApiSliceAction: (action: Action) => boolean selectors: AllSelectors + mwApi: MiddlewareAPI< + ThunkDispatch, + RootState + > } export type SubMiddlewareBuilder = ( diff --git a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts index e3c17b6513..53b5c87230 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts @@ -35,23 +35,19 @@ export const buildWindowEventHandler: InternalHandlerBuilder = ({ const subscriptions = internalState.currentSubscriptions context.batch(() => { - for (const queryCacheKey of Object.keys(subscriptions)) { + for (const queryCacheKey of subscriptions.keys()) { const querySubState = queries[queryCacheKey] - const subscriptionSubState = subscriptions[queryCacheKey] + const subscriptionSubState = subscriptions.get(queryCacheKey) if (!subscriptionSubState || !querySubState) continue + const values = [...subscriptionSubState.values()] const shouldRefetch = - Object.values(subscriptionSubState).some( - (sub) => sub[type] === true, - ) || - (Object.values(subscriptionSubState).every( - (sub) => sub[type] === undefined, - ) && - state.config[type]) + values.some((sub) => sub[type] === true) || + (values.every((sub) => sub[type] === undefined) && state.config[type]) if (shouldRefetch) { - if (countObjectKeys(subscriptionSubState) === 0) { + if (subscriptionSubState.size === 0) { api.dispatch( removeQueryResult({ queryCacheKey: queryCacheKey as QueryCacheKey, diff --git a/packages/toolkit/src/query/core/buildSlice.ts b/packages/toolkit/src/query/core/buildSlice.ts index b0f09693ef..d9f720c33a 100644 --- a/packages/toolkit/src/query/core/buildSlice.ts +++ b/packages/toolkit/src/query/core/buildSlice.ts @@ -519,13 +519,12 @@ export function buildSlice({ providedTags as FullTagDescription[] } }, - prepare: - prepareAutoBatched< - Array<{ - queryCacheKey: QueryCacheKey - providedTags: readonly FullTagDescription[] - }> - >(), + prepare: prepareAutoBatched< + Array<{ + queryCacheKey: QueryCacheKey + providedTags: readonly FullTagDescription[] + }> + >(), }, }, extraReducers(builder) { @@ -538,7 +537,9 @@ export function buildSlice({ ) .addMatcher(hasRehydrationInfo, (draft, action) => { const { provided } = extractRehydrationInfo(action)! - for (const [type, incomingTags] of Object.entries(provided)) { + for (const [type, incomingTags] of Object.entries( + provided.tags ?? {}, + )) { for (const [id, cacheKeys] of Object.entries(incomingTags)) { const subscribedQueries = ((draft.tags[type] ??= {})[ id || '__internal_without_id' @@ -549,6 +550,7 @@ export function buildSlice({ if (!alreadySubscribed) { subscribedQueries.push(queryCacheKey) } + draft.keys[queryCacheKey] = provided.keys[queryCacheKey] } } } diff --git a/packages/toolkit/src/query/core/buildThunks.ts b/packages/toolkit/src/query/core/buildThunks.ts index 048c98b9f3..f8e3b33a54 100644 --- a/packages/toolkit/src/query/core/buildThunks.ts +++ b/packages/toolkit/src/query/core/buildThunks.ts @@ -31,6 +31,7 @@ import type { SchemaFailureConverter, SchemaFailureHandler, SchemaFailureInfo, + SchemaType, } from '../endpointDefinitions' import { calculateProvidedBy, @@ -68,7 +69,11 @@ import { isRejectedWithValue, SHOULD_AUTOBATCH, } from './rtkImports' -import { parseWithSchema, NamedSchemaError } from '../standardSchema' +import { + parseWithSchema, + NamedSchemaError, + shouldSkip, +} from '../standardSchema' export type BuildThunksApiEndpointQuery< Definition extends QueryDefinition, @@ -346,7 +351,7 @@ export function buildThunks< selectors: AllSelectors onSchemaFailure: SchemaFailureHandler | undefined catchSchemaFailure: SchemaFailureConverter | undefined - skipSchemaValidation: boolean | undefined + skipSchemaValidation: boolean | SchemaType[] | undefined }) { type State = RootState @@ -505,10 +510,7 @@ export function buildThunks< endpointDefinition try { - let transformResponse = getTransformCallbackForEndpoint( - endpointDefinition, - 'transformResponse', - ) + let transformResponse: TransformCallback = defaultTransformResponse const baseQueryApi = { signal, @@ -569,7 +571,7 @@ export function buildThunks< const { extraOptions, argSchema, rawResponseSchema, responseSchema } = endpointDefinition - if (argSchema && !skipSchemaValidation) { + if (argSchema && !shouldSkip(skipSchemaValidation, 'arg')) { finalQueryArg = await parseWithSchema( argSchema, finalQueryArg, @@ -582,6 +584,13 @@ export function buildThunks< // upsertQueryData relies on this to pass in the user-provided value result = forceQueryFn() } else if (endpointDefinition.query) { + // We should only run `transformResponse` when the endpoint has a `query` method, + // and we're not doing an `upsertQueryData`. + transformResponse = getTransformCallbackForEndpoint( + endpointDefinition, + 'transformResponse', + ) + result = await baseQuery( endpointDefinition.query(finalQueryArg as any), baseQueryApi, @@ -633,7 +642,10 @@ export function buildThunks< let { data } = result - if (rawResponseSchema && !skipSchemaValidation) { + if ( + rawResponseSchema && + !shouldSkip(skipSchemaValidation, 'rawResponse') + ) { data = await parseWithSchema( rawResponseSchema, result.data, @@ -648,7 +660,7 @@ export function buildThunks< finalQueryArg, ) - if (responseSchema && !skipSchemaValidation) { + if (responseSchema && !shouldSkip(skipSchemaValidation, 'response')) { transformedResponse = await parseWithSchema( responseSchema, transformedResponse, @@ -751,7 +763,11 @@ export function buildThunks< finalQueryReturnValue = await executeRequest(arg.originalArgs) } - if (metaSchema && !skipSchemaValidation && finalQueryReturnValue.meta) { + if ( + metaSchema && + !shouldSkip(skipSchemaValidation, 'meta') && + finalQueryReturnValue.meta + ) { finalQueryReturnValue.meta = await parseWithSchema( metaSchema, finalQueryReturnValue.meta, @@ -781,7 +797,10 @@ export function buildThunks< let { value, meta } = caughtError try { - if (rawErrorResponseSchema && !skipSchemaValidation) { + if ( + rawErrorResponseSchema && + !shouldSkip(skipSchemaValidation, 'rawErrorResponse') + ) { value = await parseWithSchema( rawErrorResponseSchema, value, @@ -790,7 +809,7 @@ export function buildThunks< ) } - if (metaSchema && !skipSchemaValidation) { + if (metaSchema && !shouldSkip(skipSchemaValidation, 'meta')) { meta = await parseWithSchema(metaSchema, meta, 'metaSchema', meta) } let transformedErrorResponse = await transformErrorResponse( @@ -798,7 +817,10 @@ export function buildThunks< meta, arg.originalArgs, ) - if (errorResponseSchema && !skipSchemaValidation) { + if ( + errorResponseSchema && + !shouldSkip(skipSchemaValidation, 'errorResponse') + ) { transformedErrorResponse = await parseWithSchema( errorResponseSchema, transformedErrorResponse, diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 41dafa1e7b..e9bfa08b76 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -71,6 +71,7 @@ import type { import { buildThunks } from './buildThunks' import { createSelector as _createSelector } from './rtkImports' import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners' +import type { InternalMiddlewareState } from './buildMiddleware/types' /** * `ifOlderThan` - (default: `false` | `number`) - _number is value in seconds_ @@ -618,19 +619,12 @@ export const coreModule = ({ }) safeAssign(api.internalActions, sliceActions) - const { middleware, actions: middlewareActions } = buildMiddleware({ - reducerPath, - context, - queryThunk, - mutationThunk, - infiniteQueryThunk, - api, - assertTagType, - selectors, - }) - safeAssign(api.util, middlewareActions) - - safeAssign(api, { reducer: reducer as any, middleware }) + const internalState: InternalMiddlewareState = { + currentSubscriptions: new Map(), + currentPolls: new Map(), + runningQueries: new Map(), + runningMutations: new Map(), + } const { buildInitiateQuery, @@ -647,6 +641,7 @@ export const coreModule = ({ api, serializeQueryArgs: serializeQueryArgs as any, context, + internalState, }) safeAssign(api.util, { @@ -656,6 +651,22 @@ export const coreModule = ({ getRunningQueriesThunk, }) + const { middleware, actions: middlewareActions } = buildMiddleware({ + reducerPath, + context, + queryThunk, + mutationThunk, + infiniteQueryThunk, + api, + assertTagType, + selectors, + getRunningQueryThunk, + internalState, + }) + safeAssign(api.util, middlewareActions) + + safeAssign(api, { reducer: reducer as any, middleware }) + return { name: coreModuleName, injectEndpoint(endpointName, definition) { diff --git a/packages/toolkit/src/query/createApi.ts b/packages/toolkit/src/query/createApi.ts index 0bf4bb4a6d..e650c64f43 100644 --- a/packages/toolkit/src/query/createApi.ts +++ b/packages/toolkit/src/query/createApi.ts @@ -8,6 +8,7 @@ import type { EndpointDefinitions, SchemaFailureConverter, SchemaFailureHandler, + SchemaType, } from './endpointDefinitions' import { DefinitionType, @@ -108,9 +109,9 @@ export interface CreateApiOptions< /** * Defaults to `60` _(this value is in seconds)_. This is how long RTK Query will keep your data cached for **after** the last component unsubscribes. For example, if you query an endpoint, then unmount the component, then mount another component that makes the same request within the given time frame, the most recent value will be served from the cache. * + * @example * ```ts * // codeblock-meta title="keepUnusedDataFor example" - * * import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react' * interface Post { * id: number @@ -122,12 +123,12 @@ export interface CreateApiOptions< * baseQuery: fetchBaseQuery({ baseUrl: '/' }), * endpoints: (build) => ({ * getPosts: build.query({ - * query: () => 'posts', - * // highlight-start - * keepUnusedDataFor: 5 - * // highlight-end + * query: () => 'posts' * }) - * }) + * }), + * // highlight-start + * keepUnusedDataFor: 5 + * // highlight-end * }) * ``` */ @@ -280,6 +281,8 @@ export interface CreateApiOptions< * * If set to `true`, will skip schema validation for all endpoints, unless overridden by the endpoint. * + * Can be overridden for specific schemas by passing an array of schema types to skip. + * * @example * ```ts * // codeblock-meta no-transpile @@ -288,7 +291,7 @@ export interface CreateApiOptions< * * const api = createApi({ * baseQuery: fetchBaseQuery({ baseUrl: '/' }), - * skipSchemaValidation: process.env.NODE_ENV === "test", // skip schema validation in tests, since we'll be mocking the response + * skipSchemaValidation: process.env.NODE_ENV === "test" ? ["response"] : false, // skip schema validation for response in tests, since we'll be mocking the response * endpoints: (build) => ({ * getPost: build.query({ * query: ({ id }) => `/post/${id}`, @@ -298,7 +301,7 @@ export interface CreateApiOptions< * }) * ``` */ - skipSchemaValidation?: boolean + skipSchemaValidation?: boolean | SchemaType[] } export type CreateApi = { diff --git a/packages/toolkit/src/query/endpointDefinitions.ts b/packages/toolkit/src/query/endpointDefinitions.ts index 3e54bb2f42..9b092c9a5a 100644 --- a/packages/toolkit/src/query/endpointDefinitions.ts +++ b/packages/toolkit/src/query/endpointDefinitions.ts @@ -237,12 +237,26 @@ export type EndpointDefinitionWithQueryFn< rawErrorResponseSchema?: never } -type BaseEndpointTypes = { +type BaseEndpointTypes< + QueryArg, + BaseQuery extends BaseQueryFn, + ResultType, + RawResultType, +> = { QueryArg: QueryArg BaseQuery: BaseQuery ResultType: ResultType + RawResultType: RawResultType } +export type SchemaType = + | 'arg' + | 'rawResponse' + | 'response' + | 'rawErrorResponse' + | 'errorResponse' + | 'meta' + interface CommonEndpointDefinition< QueryArg, BaseQuery extends BaseQueryFn, @@ -421,6 +435,8 @@ interface CommonEndpointDefinition< * If set to `true`, will skip schema validation for this endpoint. * Overrides the global setting. * + * Can be overridden for specific schemas by passing an array of schema types to skip. + * * @example * ```ts * // codeblock-meta no-transpile @@ -433,13 +449,13 @@ interface CommonEndpointDefinition< * getPost: build.query({ * query: ({ id }) => `/post/${id}`, * responseSchema: v.object({ id: v.number(), name: v.string() }), - * skipSchemaValidation: process.env.NODE_ENV === "test", // skip schema validation in tests, since we'll be mocking the response + * skipSchemaValidation: process.env.NODE_ENV === "test" ? ["response"] : false, // skip schema validation for response in tests, since we'll be mocking the response * }), * }) * }) * ``` */ - skipSchemaValidation?: boolean + skipSchemaValidation?: boolean | SchemaType[] } export type BaseEndpointDefinition< @@ -519,7 +535,8 @@ type QueryTypes< TagTypes extends string, ResultType, ReducerPath extends string = string, -> = BaseEndpointTypes & { + RawResultType extends BaseQueryResult = BaseQueryResult, +> = BaseEndpointTypes & { /** * The endpoint definition type. To be used with some internal generic types. * @example @@ -547,6 +564,7 @@ export interface QueryExtraOptions< QueryArg, BaseQuery extends BaseQueryFn, ReducerPath extends string = string, + RawResultType extends BaseQueryResult = BaseQueryResult, > extends CacheLifecycleQueryExtraOptions< ResultType, QueryArg, @@ -791,7 +809,14 @@ export interface QueryExtraOptions< /** * All of these are `undefined` at runtime, purely to be used in TypeScript declarations! */ - Types?: QueryTypes + Types?: QueryTypes< + QueryArg, + BaseQuery, + TagTypes, + ResultType, + ReducerPath, + RawResultType + > } export type QueryDefinition< @@ -802,7 +827,14 @@ export type QueryDefinition< ReducerPath extends string = string, RawResultType extends BaseQueryResult = BaseQueryResult, > = BaseEndpointDefinition & - QueryExtraOptions + QueryExtraOptions< + TagTypes, + ResultType, + QueryArg, + BaseQuery, + ReducerPath, + RawResultType + > export type InfiniteQueryTypes< QueryArg, @@ -811,7 +843,8 @@ export type InfiniteQueryTypes< TagTypes extends string, ResultType, ReducerPath extends string = string, -> = BaseEndpointTypes & { + RawResultType extends BaseQueryResult = BaseQueryResult, +> = BaseEndpointTypes & { /** * The endpoint definition type. To be used with some internal generic types. * @example @@ -838,6 +871,7 @@ export interface InfiniteQueryExtraOptions< PageParam, BaseQuery extends BaseQueryFn, ReducerPath extends string = string, + RawResultType extends BaseQueryResult = BaseQueryResult, > extends CacheLifecycleInfiniteQueryExtraOptions< InfiniteData, QueryArg, @@ -987,7 +1021,8 @@ export interface InfiniteQueryExtraOptions< BaseQuery, TagTypes, ResultType, - ReducerPath + ReducerPath, + RawResultType > } @@ -1013,7 +1048,8 @@ export type InfiniteQueryDefinition< QueryArg, PageParam, BaseQuery, - ReducerPath + ReducerPath, + RawResultType > type MutationTypes< @@ -1022,7 +1058,8 @@ type MutationTypes< TagTypes extends string, ResultType, ReducerPath extends string = string, -> = BaseEndpointTypes & { + RawResultType extends BaseQueryResult = BaseQueryResult, +> = BaseEndpointTypes & { /** * The endpoint definition type. To be used with some internal generic types. * @example @@ -1050,6 +1087,7 @@ export interface MutationExtraOptions< QueryArg, BaseQuery extends BaseQueryFn, ReducerPath extends string = string, + RawResultType extends BaseQueryResult = BaseQueryResult, > extends CacheLifecycleMutationExtraOptions< ResultType, QueryArg, @@ -1124,7 +1162,14 @@ export interface MutationExtraOptions< /** * All of these are `undefined` at runtime, purely to be used in TypeScript declarations! */ - Types?: MutationTypes + Types?: MutationTypes< + QueryArg, + BaseQuery, + TagTypes, + ResultType, + ReducerPath, + RawResultType + > } export type MutationDefinition< @@ -1135,7 +1180,14 @@ export type MutationDefinition< ReducerPath extends string = string, RawResultType extends BaseQueryResult = BaseQueryResult, > = BaseEndpointDefinition & - MutationExtraOptions + MutationExtraOptions< + TagTypes, + ResultType, + QueryArg, + BaseQuery, + ReducerPath, + RawResultType + > export type EndpointDefinition< QueryArg, diff --git a/packages/toolkit/src/query/index.ts b/packages/toolkit/src/query/index.ts index 96a06166bf..d76c95efbf 100644 --- a/packages/toolkit/src/query/index.ts +++ b/packages/toolkit/src/query/index.ts @@ -49,6 +49,7 @@ export type { SchemaFailureHandler, SchemaFailureConverter, SchemaFailureInfo, + SchemaType, } from './endpointDefinitions' export { fetchBaseQuery } from './fetchBaseQuery' export type { diff --git a/packages/toolkit/src/query/react/buildHooks.ts b/packages/toolkit/src/query/react/buildHooks.ts index 249df6b071..37b081454d 100644 --- a/packages/toolkit/src/query/react/buildHooks.ts +++ b/packages/toolkit/src/query/react/buildHooks.ts @@ -39,11 +39,7 @@ import type { TSHelpersNoInfer, TSHelpersOverride, } from '@reduxjs/toolkit/query' -import { - defaultSerializeQueryArgs, - QueryStatus, - skipToken, -} from '@reduxjs/toolkit/query' +import { QueryStatus, skipToken } from '@reduxjs/toolkit/query' import type { DependencyList } from 'react' import { useCallback, @@ -1238,10 +1234,10 @@ type UseInfiniteQueryStateBaseResult< * Query is currently in "error" state. */ isError: false - hasNextPage: false - hasPreviousPage: false - isFetchingNextPage: false - isFetchingPreviousPage: false + hasNextPage: boolean + hasPreviousPage: boolean + isFetchingNextPage: boolean + isFetchingPreviousPage: boolean } type UseInfiniteQueryStateDefaultResult< @@ -1644,17 +1640,7 @@ export function buildHooks({ subscriptionSelectorsRef.current = returnedValue as unknown as SubscriptionSelectors } - const stableArg = useStableQueryArgs( - skip ? skipToken : arg, - // Even if the user provided a per-endpoint `serializeQueryArgs` with - // a consistent return value, _here_ we want to use the default behavior - // so we can tell if _anything_ actually changed. Otherwise, we can end up - // with a case where the query args did change but the serialization doesn't, - // and then we never try to initiate a refetch. - defaultSerializeQueryArgs, - context.endpointDefinitions[endpointName], - endpointName, - ) + const stableArg = useStableQueryArgs(skip ? skipToken : arg) const stableSubscriptionOptions = useShallowStableValue({ refetchOnReconnect, refetchOnFocus, @@ -1764,12 +1750,7 @@ export function buildHooks({ QueryDefinition, Definitions > - const stableArg = useStableQueryArgs( - skip ? skipToken : arg, - serializeQueryArgs, - context.endpointDefinitions[endpointName], - endpointName, - ) + const stableArg = useStableQueryArgs(skip ? skipToken : arg) type ApiRootState = Parameters>[0] @@ -2053,17 +2034,7 @@ export function buildHooks({ usePromiseRefUnsubscribeOnUnmount(promiseRef) - const stableArg = useStableQueryArgs( - options.skip ? skipToken : arg, - // Even if the user provided a per-endpoint `serializeQueryArgs` with - // a consistent return value, _here_ we want to use the default behavior - // so we can tell if _anything_ actually changed. Otherwise, we can end up - // with a case where the query args did change but the serialization doesn't, - // and then we never try to initiate a refetch. - defaultSerializeQueryArgs, - context.endpointDefinitions[endpointName], - endpointName, - ) + const stableArg = useStableQueryArgs(options.skip ? skipToken : arg) const refetch = useCallback( () => refetchOrErrorIfUnmounted(promiseRef), diff --git a/packages/toolkit/src/query/react/useSerializedStableValue.ts b/packages/toolkit/src/query/react/useSerializedStableValue.ts index 95bc62af78..5a5d9f936f 100644 --- a/packages/toolkit/src/query/react/useSerializedStableValue.ts +++ b/packages/toolkit/src/query/react/useSerializedStableValue.ts @@ -1,31 +1,17 @@ import { useEffect, useRef, useMemo } from 'react' -import type { SerializeQueryArgs } from '@reduxjs/toolkit/query' -import type { EndpointDefinition } from '@reduxjs/toolkit/query' +import { copyWithStructuralSharing } from '@reduxjs/toolkit/query' -export function useStableQueryArgs( - queryArgs: T, - serialize: SerializeQueryArgs, - endpointDefinition: EndpointDefinition, - endpointName: string, -) { - const incoming = useMemo( - () => ({ - queryArgs, - serialized: - typeof queryArgs == 'object' - ? serialize({ queryArgs, endpointDefinition, endpointName }) - : queryArgs, - }), - [queryArgs, serialize, endpointDefinition, endpointName], +export function useStableQueryArgs(queryArgs: T) { + const cache = useRef(queryArgs) + const copy = useMemo( + () => copyWithStructuralSharing(cache.current, queryArgs), + [queryArgs], ) - const cache = useRef(incoming) useEffect(() => { - if (cache.current.serialized !== incoming.serialized) { - cache.current = incoming + if (cache.current !== copy) { + cache.current = copy } - }, [incoming]) + }, [copy]) - return cache.current.serialized === incoming.serialized - ? cache.current.queryArgs - : queryArgs + return copy } diff --git a/packages/toolkit/src/query/standardSchema.ts b/packages/toolkit/src/query/standardSchema.ts index 75dd22f0d5..2c42a2b44e 100644 --- a/packages/toolkit/src/query/standardSchema.ts +++ b/packages/toolkit/src/query/standardSchema.ts @@ -1,21 +1,30 @@ import type { StandardSchemaV1 } from '@standard-schema/spec' import { SchemaError } from '@standard-schema/utils' +import type { SchemaType } from './endpointDefinitions' export class NamedSchemaError extends SchemaError { constructor( issues: readonly StandardSchemaV1.Issue[], public readonly value: any, - public readonly schemaName: string, + public readonly schemaName: `${SchemaType}Schema`, public readonly _bqMeta: any, ) { super(issues) } } +export const shouldSkip = ( + skipSchemaValidation: boolean | SchemaType[] | undefined, + schemaName: SchemaType, +) => + Array.isArray(skipSchemaValidation) + ? skipSchemaValidation.includes(schemaName) + : !!skipSchemaValidation + export async function parseWithSchema( schema: Schema, data: unknown, - schemaName: string, + schemaName: `${SchemaType}Schema`, bqMeta: any, ): Promise> { const result = await schema['~standard'].validate(data) diff --git a/packages/toolkit/src/query/tests/buildHooks.test.tsx b/packages/toolkit/src/query/tests/buildHooks.test.tsx index c371e5efaa..115a475a9f 100644 --- a/packages/toolkit/src/query/tests/buildHooks.test.tsx +++ b/packages/toolkit/src/query/tests/buildHooks.test.tsx @@ -1068,7 +1068,7 @@ describe('hooks tests', () => { const checkNumSubscriptions = (arg: string, count: number) => { const subscriptions = getSubscriptions() - const cacheKeyEntry = subscriptions[arg] + const cacheKeyEntry = subscriptions.get(arg) if (cacheKeyEntry) { const subscriptionCount = Object.keys(cacheKeyEntry) //getSubscriptionCount(arg) @@ -1190,6 +1190,87 @@ describe('hooks tests', () => { ).toBe(-1) }) + test('query thunk should be aborted when component unmounts and cache entry is removed', async () => { + let abortSignalFromQueryFn: AbortSignal | undefined + + const pokemonApi = createApi({ + baseQuery: fetchBaseQuery({ baseUrl: 'https://pokeapi.co/api/v2/' }), + endpoints: (builder) => ({ + getTest: builder.query({ + async queryFn(arg, { signal }) { + abortSignalFromQueryFn = signal + + // Simulate a long-running request that should be aborted + await new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, 5000) + + signal.addEventListener('abort', () => { + clearTimeout(timeout) + reject(new Error('Aborted')) + }) + }) + + return { data: 'data!' } + }, + keepUnusedDataFor: 0.01, // Very short timeout (10ms) + }), + }), + }) + + const storeRef = setupApiStore(pokemonApi, undefined, { + withoutTestLifecycles: true, + }) + + function TestComponent() { + const { data, isFetching } = pokemonApi.endpoints.getTest.useQuery(1) + + return ( +
+
{String(isFetching)}
+
{data || 'no data'}
+
+ ) + } + + function App() { + const [showComponent, setShowComponent] = useState(true) + + return ( +
+ {showComponent && } + +
+ ) + } + + render(, { wrapper: storeRef.wrapper }) + + // Wait for the query to start + await waitFor(() => + expect(screen.getByTestId('isFetching').textContent).toBe('true'), + ) + + // Verify we have an abort signal + expect(abortSignalFromQueryFn).toBeDefined() + expect(abortSignalFromQueryFn!.aborted).toBe(false) + + // Unmount the component + fireEvent.click(screen.getByTestId('unmount')) + + // Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms) + await act(async () => { + await delay(100) + }) + + // The abort signal should now be aborted + expect(abortSignalFromQueryFn!.aborted).toBe(true) + }) + describe('Hook middleware requirements', () => { const consoleErrorSpy = vi .spyOn(console, 'error') @@ -1898,7 +1979,6 @@ describe('hooks tests', () => { const checkNumQueries = (count: number) => { const cacheEntries = Object.keys(storeRef.store.getState().api.queries) const queries = cacheEntries.length - //console.log('queries', queries, storeRef.store.getState().api.queries) expect(queries).toBe(count) } @@ -2282,40 +2362,51 @@ describe('hooks tests', () => { expect(numRequests).toBe(1) }) - test('useInfiniteQuery hook does not fetch when the skip token is set', async () => { - function Pokemon() { - const [value, setValue] = useState(0) - - const { isFetching } = pokemonApi.useGetInfinitePokemonInfiniteQuery( - 'fire', - { - skip: value < 1, - }, - ) - getRenderCount = useRenderCounter() + test.each([ + ['skip token', true], + ['skip option', false], + ])( + 'useInfiniteQuery hook does not fetch when skipped via %s', + async (_, useSkipToken) => { + function Pokemon() { + const [value, setValue] = useState(0) + + const shouldFetch = value > 0 + + const arg = shouldFetch || !useSkipToken ? 'fire' : skipToken + const skip = useSkipToken ? undefined : shouldFetch ? undefined : true + + const { isFetching } = pokemonApi.useGetInfinitePokemonInfiniteQuery( + arg, + { + skip, + }, + ) + getRenderCount = useRenderCounter() - return ( -
-
{String(isFetching)}
- -
- ) - } + return ( +
+
{String(isFetching)}
+ +
+ ) + } - render(, { wrapper: storeRef.wrapper }) - expect(getRenderCount()).toBe(1) + render(, { wrapper: storeRef.wrapper }) + expect(getRenderCount()).toBe(1) - await waitFor(() => - expect(screen.getByTestId('isFetching').textContent).toBe('false'), - ) - fireEvent.click(screen.getByText('Increment value')) - await waitFor(() => - expect(screen.getByTestId('isFetching').textContent).toBe('true'), - ) - expect(getRenderCount()).toBe(2) - }) + await waitFor(() => + expect(screen.getByTestId('isFetching').textContent).toBe('false'), + ) + fireEvent.click(screen.getByText('Increment value')) + await waitFor(() => + expect(screen.getByTestId('isFetching').textContent).toBe('true'), + ) + expect(getRenderCount()).toBe(2) + }, + ) }) describe('useMutation', () => { @@ -3689,7 +3780,7 @@ describe('skip behavior', () => { expect(getSubscriptionCount('getUser(1)')).toBe(0) // also no subscription on `getUser(skipToken)` or similar: - expect(getSubscriptions()).toEqual({}) + expect(getSubscriptions().size).toBe(0) rerender([1]) @@ -3700,7 +3791,7 @@ describe('skip behavior', () => { expect(result.current).toMatchObject({ status: QueryStatus.fulfilled }) await waitMs(1) expect(getSubscriptionCount('getUser(1)')).toBe(1) - expect(getSubscriptions()).not.toEqual({}) + expect(getSubscriptions().size).toBe(1) rerender([skipToken]) @@ -3730,7 +3821,7 @@ describe('skip behavior', () => { expect(getSubscriptionCount('nestedValue')).toBe(0) // also no subscription on `getUser(skipToken)` or similar: - expect(getSubscriptions()).toEqual({}) + expect(getSubscriptions().size).toBe(0) rerender([{ param: { nested: 'nestedValue' } }]) @@ -3742,7 +3833,7 @@ describe('skip behavior', () => { await waitMs(1) expect(getSubscriptionCount('nestedValue')).toBe(1) - expect(getSubscriptions()).not.toEqual({}) + expect(getSubscriptions().size).toBe(1) rerender([skipToken]) diff --git a/packages/toolkit/src/query/tests/buildInitiate.test.tsx b/packages/toolkit/src/query/tests/buildInitiate.test.tsx index 08be5a3558..74d1d61e51 100644 --- a/packages/toolkit/src/query/tests/buildInitiate.test.tsx +++ b/packages/toolkit/src/query/tests/buildInitiate.test.tsx @@ -79,23 +79,11 @@ describe('calling initiate without a cache entry, with subscribe: false still re expect(isRequestSubscribed('increment(undefined)', promise.requestId)).toBe( false, ) - expect( - isRequestSubscribed( - 'increment(undefined)', - `${promise.requestId}_running`, - ), - ).toBe(true) await expect(promise).resolves.toMatchObject({ data: 0, status: 'fulfilled', }) - expect( - isRequestSubscribed( - 'increment(undefined)', - `${promise.requestId}_running`, - ), - ).toBe(false) }) test('rejected query', async () => { @@ -107,16 +95,10 @@ describe('calling initiate without a cache entry, with subscribe: false still re expect(isRequestSubscribed('failing(undefined)', promise.requestId)).toBe( false, ) - expect( - isRequestSubscribed('failing(undefined)', `${promise.requestId}_running`), - ).toBe(true) await expect(promise).resolves.toMatchObject({ status: 'rejected', }) - expect( - isRequestSubscribed('failing(undefined)', `${promise.requestId}_running`), - ).toBe(false) }) }) @@ -173,3 +155,130 @@ describe('calling initiate should have resulting queryCacheKey match baseQuery q ) }) }) + +describe('getRunningQueryThunk with multiple stores', () => { + test('should isolate running queries between different store instances using the same API', async () => { + // Create a shared API instance + const sharedApi = createApi({ + baseQuery: fakeBaseQuery(), + endpoints: (build) => ({ + testQuery: build.query({ + async queryFn(arg) { + // Add delay to ensure queries are running when we check + await new Promise((resolve) => setTimeout(resolve, 50)) + return { data: `result-${arg}` } + }, + }), + }), + }) + + // Create two separate stores using the same API instance + const store1 = setupApiStore(sharedApi, undefined, { + withoutTestLifecycles: true, + }).store + const store2 = setupApiStore(sharedApi, undefined, { + withoutTestLifecycles: true, + }).store + + // Start queries on both stores + const query1Promise = store1.dispatch( + sharedApi.endpoints.testQuery.initiate('arg1'), + ) + const query2Promise = store2.dispatch( + sharedApi.endpoints.testQuery.initiate('arg2'), + ) + + // Verify that getRunningQueryThunk returns the correct query for each store + const runningQuery1 = store1.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg1'), + ) + const runningQuery2 = store2.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg2'), + ) + + // Each store should only see its own running query + expect(runningQuery1).toBeDefined() + expect(runningQuery2).toBeDefined() + expect(runningQuery1?.requestId).toBe(query1Promise.requestId) + expect(runningQuery2?.requestId).toBe(query2Promise.requestId) + + // Cross-store queries should not be visible + const crossQuery1 = store1.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg2'), + ) + const crossQuery2 = store2.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg1'), + ) + + expect(crossQuery1).toBeUndefined() + expect(crossQuery2).toBeUndefined() + + // Wait for queries to complete + await Promise.all([query1Promise, query2Promise]) + + // After completion, getRunningQueryThunk should return undefined for both stores + const completedQuery1 = store1.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg1'), + ) + const completedQuery2 = store2.dispatch( + sharedApi.util.getRunningQueryThunk('testQuery', 'arg2'), + ) + + expect(completedQuery1).toBeUndefined() + expect(completedQuery2).toBeUndefined() + }) + + test('should handle same query args on different stores independently', async () => { + // Create a shared API instance + const sharedApi = createApi({ + baseQuery: fakeBaseQuery(), + endpoints: (build) => ({ + sameArgQuery: build.query({ + async queryFn(arg) { + await new Promise((resolve) => setTimeout(resolve, 50)) + return { data: `result-${arg}-${Math.random()}` } + }, + }), + }), + }) + + // Create two separate stores + const store1 = setupApiStore(sharedApi, undefined, { + withoutTestLifecycles: true, + }).store + const store2 = setupApiStore(sharedApi, undefined, { + withoutTestLifecycles: true, + }).store + + // Start the same query on both stores + const sameArg = 'shared-arg' + const query1Promise = store1.dispatch( + sharedApi.endpoints.sameArgQuery.initiate(sameArg), + ) + const query2Promise = store2.dispatch( + sharedApi.endpoints.sameArgQuery.initiate(sameArg), + ) + + // Both stores should see their own running query with the same cache key + const runningQuery1 = store1.dispatch( + sharedApi.util.getRunningQueryThunk('sameArgQuery', sameArg), + ) + const runningQuery2 = store2.dispatch( + sharedApi.util.getRunningQueryThunk('sameArgQuery', sameArg), + ) + + expect(runningQuery1).toBeDefined() + expect(runningQuery2).toBeDefined() + expect(runningQuery1?.requestId).toBe(query1Promise.requestId) + expect(runningQuery2?.requestId).toBe(query2Promise.requestId) + + // The request IDs should be different even though the cache key is the same + expect(runningQuery1?.requestId).not.toBe(runningQuery2?.requestId) + + // But the cache keys should be the same + expect(runningQuery1?.queryCacheKey).toBe(runningQuery2?.queryCacheKey) + + // Wait for completion + await Promise.all([query1Promise, query2Promise]) + }) +}) diff --git a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx index 0a03396302..2abac64498 100644 --- a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx +++ b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx @@ -26,11 +26,13 @@ const api = createApi({ providesTags: ['Bread'], }), invalidateFruit: build.mutation({ - query: (fruit?: 'Banana' | 'Bread' | null) => ({ url: `invalidate/fruit/${fruit || ''}` }), + query: (fruit?: 'Banana' | 'Bread' | null) => ({ + url: `invalidate/fruit/${fruit || ''}`, + }), invalidatesTags(result, error, arg) { return [arg] - } - }) + }, + }), }), }) const { getBanana, getBread, invalidateFruit } = api.endpoints @@ -77,9 +79,11 @@ it('invalidates the specified tags', async () => { ) }) -it('invalidates tags correctly when null or undefined are provided as tags', async() =>{ +it('invalidates tags correctly when null or undefined are provided as tags', async () => { await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(api.util.invalidateTags([undefined, null, 'Banana'])) + await storeRef.store.dispatch( + api.util.invalidateTags([undefined, null, 'Banana']), + ) // Slight pause to let the middleware run and such await delay(20) @@ -96,41 +100,116 @@ it('invalidates tags correctly when null or undefined are provided as tags', asy expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) }) - it.each([ - { tags: [undefined, null, 'Bread'] as Parameters['0'] }, - { tags: [undefined, null], }, { tags: [] }] -)('does not invalidate with tags=$tags if no query matches', async ({ tags }) => { - await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(api.util.invalidateTags(tags)) + { + tags: [undefined, null, 'Bread'] as Parameters< + typeof api.util.invalidateTags + >['0'], + }, + { tags: [undefined, null] }, + { tags: [] }, +])( + 'does not invalidate with tags=$tags if no query matches', + async ({ tags }) => { + await storeRef.store.dispatch(getBanana.initiate(1)) + await storeRef.store.dispatch(api.util.invalidateTags(tags)) + + // Slight pause to let the middleware run and such + await delay(20) + + const apiActions = [ + api.internalActions.middlewareRegistered.match, + getBanana.matchPending, + getBanana.matchFulfilled, + api.util.invalidateTags.match, + ] + + expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + }, +) - // Slight pause to let the middleware run and such - await delay(20) +it.each([ + { mutationArg: 'Bread' as 'Bread' | null | undefined }, + { mutationArg: undefined }, + { mutationArg: null }, +])( + 'does not invalidate queries when a mutation with tags=[$mutationArg] runs and does not match anything', + async ({ mutationArg }) => { + await storeRef.store.dispatch(getBanana.initiate(1)) + await storeRef.store.dispatch(invalidateFruit.initiate(mutationArg)) + + // Slight pause to let the middleware run and such + await delay(20) + + const apiActions = [ + api.internalActions.middlewareRegistered.match, + getBanana.matchPending, + getBanana.matchFulfilled, + invalidateFruit.matchPending, + invalidateFruit.matchFulfilled, + ] + + expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + }, +) + +it('correctly stringifies subscription state and dispatches subscriptionsUpdated', async () => { + // Create a fresh store for this test to avoid interference + const testStoreRef = setupApiStore( + api, + { + ...actionsReducer, + }, + { withoutListeners: true }, + ) - const apiActions = [ - api.internalActions.middlewareRegistered.match, - getBanana.matchPending, - getBanana.matchFulfilled, - api.util.invalidateTags.match, - ] + // Start multiple subscriptions + const subscription1 = testStoreRef.store.dispatch( + getBanana.initiate(1, { + subscriptionOptions: { pollingInterval: 1000 }, + }), + ) + const subscription2 = testStoreRef.store.dispatch( + getBanana.initiate(2, { + subscriptionOptions: { refetchOnFocus: true }, + }), + ) + const subscription3 = testStoreRef.store.dispatch( + api.endpoints.getBananas.initiate(), + ) - expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + // Wait for the subscriptions to be established + await Promise.all([subscription1, subscription2, subscription3]) + + // Wait for the subscription sync timer (500ms + buffer) + await delay(600) + + // Check the final subscription state in the store + const finalState = testStoreRef.store.getState() + const subscriptionState = finalState[api.reducerPath].subscriptions + + // Should have subscriptions for getBanana(1), getBanana(2), and getBananas() + expect(subscriptionState).toMatchObject({ + 'getBanana(1)': { + [subscription1.requestId]: { pollingInterval: 1000 }, + }, + 'getBanana(2)': { + [subscription2.requestId]: { refetchOnFocus: true }, + }, + 'getBananas(undefined)': { + [subscription3.requestId]: {}, + }, + }) + + // Verify the subscription entries have the expected structure + expect(Object.keys(subscriptionState)).toHaveLength(3) + expect(subscriptionState['getBanana(1)']?.[subscription1.requestId]).toEqual({ + pollingInterval: 1000, + }) + expect(subscriptionState['getBanana(2)']?.[subscription2.requestId]).toEqual({ + refetchOnFocus: true, + }) + expect( + subscriptionState['getBananas(undefined)']?.[subscription3.requestId], + ).toEqual({}) }) - -it.each([{ mutationArg: 'Bread' as "Bread" | null | undefined }, { mutationArg: undefined }, { mutationArg: null }])('does not invalidate queries when a mutation with tags=[$mutationArg] runs and does not match anything', async ({ mutationArg }) => { - await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(invalidateFruit.initiate(mutationArg)) - - // Slight pause to let the middleware run and such - await delay(20) - - const apiActions = [ - api.internalActions.middlewareRegistered.match, - getBanana.matchPending, - getBanana.matchFulfilled, - invalidateFruit.matchPending, - invalidateFruit.matchFulfilled, - ] - - expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) -}) \ No newline at end of file diff --git a/packages/toolkit/src/query/tests/buildSlice.test.ts b/packages/toolkit/src/query/tests/buildSlice.test.ts index 72ac45b7a2..25456d159c 100644 --- a/packages/toolkit/src/query/tests/buildSlice.test.ts +++ b/packages/toolkit/src/query/tests/buildSlice.test.ts @@ -1,10 +1,15 @@ -import { createSlice } from '@reduxjs/toolkit' +import { createSlice, createAction } from '@reduxjs/toolkit' +import type { CombinedState } from '@reduxjs/toolkit/query' import { createApi } from '@reduxjs/toolkit/query' import { delay } from 'msw' import { setupApiStore } from '../../tests/utils/helpers' let shouldApiResponseSuccess = true +const rehydrateAction = createAction<{ api: CombinedState }>( + 'persist/REHYDRATE', +) + const baseQuery = (args?: any) => ({ data: args }) const api = createApi({ baseQuery, @@ -17,6 +22,12 @@ const api = createApi({ providesTags: (result) => (result?.success ? ['SUCCEED'] : ['FAILED']), }), }), + extractRehydrationInfo(action, { reducerPath }) { + if (rehydrateAction.match(action)) { + return action.payload?.[reducerPath] + } + return undefined + }, }) const { getUser } = api.endpoints @@ -114,6 +125,20 @@ describe('buildSlice', () => { api.util.selectInvalidatedBy(storeRef.store.getState(), ['FAILED']), ).toHaveLength(1) }) + + it('handles extractRehydrationInfo correctly', async () => { + await storeRef.store.dispatch(getUser.initiate(1)) + await storeRef.store.dispatch(getUser.initiate(2)) + + const stateWithUser = storeRef.store.getState() + + storeRef.store.dispatch(api.util.resetApiState()) + + storeRef.store.dispatch(rehydrateAction({ api: stateWithUser.api })) + + const rehydratedState = storeRef.store.getState() + expect(rehydratedState).toEqual(stateWithUser) + }) }) describe('`merge` callback', () => { diff --git a/packages/toolkit/src/query/tests/buildThunks.test.tsx b/packages/toolkit/src/query/tests/buildThunks.test.tsx index 786bce700a..f197afcc90 100644 --- a/packages/toolkit/src/query/tests/buildThunks.test.tsx +++ b/packages/toolkit/src/query/tests/buildThunks.test.tsx @@ -4,71 +4,115 @@ import { renderHook, waitFor } from '@testing-library/react' import { actionsReducer, withProvider } from '../../tests/utils/helpers' import type { BaseQueryApi } from '../baseQueryTypes' -test('handles a non-async baseQuery without error', async () => { - const baseQuery = (args?: any) => ({ data: args }) - const api = createApi({ - baseQuery, - endpoints: (build) => ({ - getUser: build.query({ - query(id) { - return { url: `user/${id}` } - }, +describe('baseline thunk behavior', () => { + test('handles a non-async baseQuery without error', async () => { + const baseQuery = (args?: any) => ({ data: args }) + const api = createApi({ + baseQuery, + endpoints: (build) => ({ + getUser: build.query({ + query(id) { + return { url: `user/${id}` } + }, + }), }), - }), - }) - const { getUser } = api.endpoints - const store = configureStore({ - reducer: { - [api.reducerPath]: api.reducer, - }, - middleware: (gDM) => gDM().concat(api.middleware), - }) - - const promise = store.dispatch(getUser.initiate(1)) - const { data } = await promise + }) + const { getUser } = api.endpoints + const store = configureStore({ + reducer: { + [api.reducerPath]: api.reducer, + }, + middleware: (gDM) => gDM().concat(api.middleware), + }) - expect(data).toEqual({ - url: 'user/1', - }) + const promise = store.dispatch(getUser.initiate(1)) + const { data } = await promise - const storeResult = getUser.select(1)(store.getState()) - expect(storeResult).toEqual({ - data: { + expect(data).toEqual({ url: 'user/1', - }, - endpointName: 'getUser', - isError: false, - isLoading: false, - isSuccess: true, - isUninitialized: false, - originalArgs: 1, - requestId: expect.any(String), - status: 'fulfilled', - startedTimeStamp: expect.any(Number), - fulfilledTimeStamp: expect.any(Number), + }) + + const storeResult = getUser.select(1)(store.getState()) + expect(storeResult).toEqual({ + data: { + url: 'user/1', + }, + endpointName: 'getUser', + isError: false, + isLoading: false, + isSuccess: true, + isUninitialized: false, + originalArgs: 1, + requestId: expect.any(String), + status: 'fulfilled', + startedTimeStamp: expect.any(Number), + fulfilledTimeStamp: expect.any(Number), + }) }) -}) -test('passes the extraArgument property to the baseQueryApi', async () => { - const baseQuery = (_args: any, api: BaseQueryApi) => ({ data: api.extra }) - const api = createApi({ - baseQuery, - endpoints: (build) => ({ - getUser: build.query({ - query: () => '', + test('passes the extraArgument property to the baseQueryApi', async () => { + const baseQuery = (_args: any, api: BaseQueryApi) => ({ data: api.extra }) + const api = createApi({ + baseQuery, + endpoints: (build) => ({ + getUser: build.query({ + query: () => '', + }), }), - }), + }) + const store = configureStore({ + reducer: { + [api.reducerPath]: api.reducer, + }, + middleware: (gDM) => + gDM({ thunk: { extraArgument: 'cakes' } }).concat(api.middleware), + }) + const { getUser } = api.endpoints + const { data } = await store.dispatch(getUser.initiate()) + expect(data).toBe('cakes') }) - const store = configureStore({ - reducer: { - [api.reducerPath]: api.reducer, - }, - middleware: (gDM) => - gDM({ thunk: { extraArgument: 'cakes' } }).concat(api.middleware), + + test('only triggers transformResponse when a query method is actually used', async () => { + const baseQuery = (args?: any) => ({ data: args }) + const transformResponse = vi.fn((response: any) => response) + const api = createApi({ + baseQuery, + endpoints: (build) => ({ + hasQuery: build.query({ + query: (arg) => 'test', + transformResponse, + }), + hasQueryFn: build.query( + // @ts-expect-error + { + queryFn: () => ({ data: 'test' }), + transformResponse, + }, + ), + }), + }) + + const store = configureStore({ + reducer: { + [api.reducerPath]: api.reducer, + }, + middleware: (gDM) => + gDM({ thunk: { extraArgument: 'cakes' } }).concat(api.middleware), + }) + + await store.dispatch(api.util.upsertQueryData('hasQuery', 'a', 'test')) + expect(transformResponse).not.toHaveBeenCalled() + + transformResponse.mockReset() + + await store.dispatch(api.endpoints.hasQuery.initiate('b')) + expect(transformResponse).toHaveBeenCalledTimes(1) + + transformResponse.mockReset() + + await store.dispatch(api.endpoints.hasQueryFn.initiate()) + expect(transformResponse).not.toHaveBeenCalled() }) - const { getUser } = api.endpoints - const { data } = await store.dispatch(getUser.initiate()) - expect(data).toBe('cakes') }) describe('re-triggering behavior on arg change', () => { diff --git a/packages/toolkit/src/query/tests/createApi.test-d.ts b/packages/toolkit/src/query/tests/createApi.test-d.ts index 96bcef5946..ad86da1473 100644 --- a/packages/toolkit/src/query/tests/createApi.test-d.ts +++ b/packages/toolkit/src/query/tests/createApi.test-d.ts @@ -498,6 +498,7 @@ describe('type tests', () => { id: number }>() expectTypeOf(api.endpoints.query.Types.ResultType).toEqualTypeOf() + expectTypeOf(api.endpoints.query.Types.RawResultType).toBeAny() expectTypeOf(api.endpoints.query2.Types.QueryArg).toEqualTypeOf<{ id: number @@ -505,11 +506,15 @@ describe('type tests', () => { expectTypeOf( api.endpoints.query2.Types.ResultType, ).toEqualTypeOf() + expectTypeOf(api.endpoints.query2.Types.RawResultType).toBeAny() expectTypeOf(api.endpoints.query3.Types.QueryArg).toEqualTypeOf() expectTypeOf(api.endpoints.query3.Types.ResultType).toEqualTypeOf< EntityState >() + expectTypeOf(api.endpoints.query3.Types.RawResultType).toEqualTypeOf< + Post[] + >() }) }) }) diff --git a/packages/toolkit/src/query/tests/createApi.test.ts b/packages/toolkit/src/query/tests/createApi.test.ts index 824977a3b7..4ad3a9efe4 100644 --- a/packages/toolkit/src/query/tests/createApi.test.ts +++ b/packages/toolkit/src/query/tests/createApi.test.ts @@ -12,6 +12,7 @@ import type { FetchBaseQueryMeta, OverrideResultType, SchemaFailureConverter, + SchemaType, SerializeQueryArgs, TagTypesFromApi, } from '@reduxjs/toolkit/query' @@ -1227,7 +1228,7 @@ describe('endpoint schemas', () => { value, arg, }: { - schemaName: string + schemaName: `${SchemaType}Schema` value: unknown arg: unknown }) { @@ -1244,30 +1245,49 @@ describe('endpoint schemas', () => { } } + interface SkipApiOptions { + globalSkip?: boolean + endpointSkip?: boolean + useArray?: boolean + globalCatch?: boolean + endpointCatch?: boolean + } + + const apiOptions = ( + type: SchemaType, + { useArray, globalSkip, globalCatch }: SkipApiOptions = {}, + ) => ({ + onSchemaFailure: onSchemaFailureGlobal, + skipSchemaValidation: useArray ? globalSkip && [type] : globalSkip, + catchSchemaFailure: globalCatch ? schemaConverter : undefined, + }) + + const endpointOptions = ( + type: SchemaType, + { useArray, endpointSkip, endpointCatch }: SkipApiOptions = {}, + ) => ({ + onSchemaFailure: onSchemaFailureEndpoint, + skipSchemaValidation: useArray ? endpointSkip && [type] : endpointSkip, + catchSchemaFailure: endpointCatch ? schemaConverter : undefined, + }) + + const skipCases: [string, SkipApiOptions][] = [ + ['globally', { globalSkip: true }], + ['on the endpoint', { endpointSkip: true }], + ['globally (array)', { globalSkip: true, useArray: true }], + ['on the endpoint (array)', { endpointSkip: true, useArray: true }], + ] + describe('argSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - skipSchemaValidation: globalSkip, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, + ...apiOptions('arg', opts), endpoints: (build) => ({ query: build.query({ query: ({ id }) => `/post/${id}`, argSchema: v.object({ id: v.number() }), - onSchemaFailure: onSchemaFailureEndpoint, - skipSchemaValidation: endpointSkip, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, + ...endpointOptions('arg', opts), }), }), }) @@ -1296,22 +1316,9 @@ describe('endpoint schemas', () => { arg: { id: '1' }, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) - - const storeRef = setupApiStore(api, undefined, { - withoutTestLifecycles: true, - }) - - const result = await storeRef.store.dispatch( - // @ts-expect-error - api.endpoints.query.initiate({ id: '1' }), - ) - expect(result?.error).toBeUndefined() - }) - test('can be skipped on the endpoint', async () => { - const api = makeApi({ endpointSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, @@ -1387,29 +1394,15 @@ describe('endpoint schemas', () => { }) }) describe('rawResponseSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, - skipSchemaValidation: globalSkip, + ...apiOptions('rawResponse', opts), endpoints: (build) => ({ query: build.query<{ success: boolean }, void>({ query: () => '/success', rawResponseSchema: v.object({ value: v.literal('success!') }), - onSchemaFailure: onSchemaFailureEndpoint, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, - skipSchemaValidation: endpointSkip, + ...endpointOptions('rawResponse', opts), }), }), }) @@ -1428,8 +1421,8 @@ describe('endpoint schemas', () => { arg: undefined, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, }) @@ -1488,30 +1481,16 @@ describe('endpoint schemas', () => { }) }) describe('responseSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, - skipSchemaValidation: globalSkip, + ...apiOptions('response', opts), endpoints: (build) => ({ query: build.query<{ success: boolean }, void>({ query: () => '/success', transformResponse: () => ({ success: false }), responseSchema: v.object({ success: v.literal(true) }), - onSchemaFailure: onSchemaFailureEndpoint, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, - skipSchemaValidation: endpointSkip, + ...endpointOptions('response', opts), }), }), }) @@ -1531,18 +1510,8 @@ describe('endpoint schemas', () => { arg: undefined, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) - const storeRef = setupApiStore(api, undefined, { - withoutTestLifecycles: true, - }) - const result = await storeRef.store.dispatch( - api.endpoints.query.initiate(), - ) - expect(result?.error).toBeUndefined() - }) - test('can be skipped on the endpoint', async () => { - const api = makeApi({ endpointSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, }) @@ -1591,22 +1560,10 @@ describe('endpoint schemas', () => { }) }) describe('rawErrorResponseSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, - skipSchemaValidation: globalSkip, + ...apiOptions('rawErrorResponse', opts), endpoints: (build) => ({ query: build.query<{ success: boolean }, void>({ query: () => '/error', @@ -1614,9 +1571,7 @@ describe('endpoint schemas', () => { status: v.pipe(v.number(), v.minValue(400), v.maxValue(499)), data: v.unknown(), }), - onSchemaFailure: onSchemaFailureEndpoint, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, - skipSchemaValidation: endpointSkip, + ...endpointOptions('rawErrorResponse', opts), }), }), }) @@ -1635,18 +1590,8 @@ describe('endpoint schemas', () => { arg: undefined, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) - const storeRef = setupApiStore(api, undefined, { - withoutTestLifecycles: true, - }) - const result = await storeRef.store.dispatch( - api.endpoints.query.initiate(), - ) - expect(result?.error).not.toEqual(serializedSchemaError) - }) - test('can be skipped on the endpoint', async () => { - const api = makeApi({ endpointSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, }) @@ -1695,22 +1640,10 @@ describe('endpoint schemas', () => { }) }) describe('errorResponseSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, - skipSchemaValidation: globalSkip, + ...apiOptions('errorResponse', opts), endpoints: (build) => ({ query: build.query<{ success: boolean }, void>({ query: () => '/error', @@ -1724,9 +1657,7 @@ describe('endpoint schemas', () => { error: v.literal('oh no'), data: v.unknown(), }), - onSchemaFailure: onSchemaFailureEndpoint, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, - skipSchemaValidation: endpointSkip, + ...endpointOptions('errorResponse', opts), }), }), }) @@ -1749,18 +1680,8 @@ describe('endpoint schemas', () => { arg: undefined, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) - const storeRef = setupApiStore(api, undefined, { - withoutTestLifecycles: true, - }) - const result = await storeRef.store.dispatch( - api.endpoints.query.initiate(), - ) - expect(result?.error).not.toEqual(serializedSchemaError) - }) - test('can be skipped on the endpoint', async () => { - const api = makeApi({ endpointSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, }) @@ -1817,22 +1738,10 @@ describe('endpoint schemas', () => { }) }) describe('metaSchema', () => { - const makeApi = ({ - globalSkip, - endpointSkip, - globalCatch, - endpointCatch, - }: { - globalSkip?: boolean - endpointSkip?: boolean - globalCatch?: boolean - endpointCatch?: boolean - } = {}) => + const makeApi = (opts?: SkipApiOptions) => createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'https://example.com' }), - onSchemaFailure: onSchemaFailureGlobal, - catchSchemaFailure: globalCatch ? schemaConverter : undefined, - skipSchemaValidation: globalSkip, + ...apiOptions('meta', opts), endpoints: (build) => ({ query: build.query<{ success: boolean }, void>({ query: () => '/success', @@ -1841,9 +1750,7 @@ describe('endpoint schemas', () => { response: v.instance(Response), timestamp: v.number(), }), - onSchemaFailure: onSchemaFailureEndpoint, - catchSchemaFailure: endpointCatch ? schemaConverter : undefined, - skipSchemaValidation: endpointSkip, + ...endpointOptions('meta', opts), }), }), }) @@ -1865,18 +1772,8 @@ describe('endpoint schemas', () => { arg: undefined, }) }) - test('can be skipped globally', async () => { - const api = makeApi({ globalSkip: true }) - const storeRef = setupApiStore(api, undefined, { - withoutTestLifecycles: true, - }) - const result = await storeRef.store.dispatch( - api.endpoints.query.initiate(), - ) - expect(result?.error).toBeUndefined() - }) - test('can be skipped on the endpoint', async () => { - const api = makeApi({ endpointSkip: true }) + test.each(skipCases)('can be skipped %s', async (_, arg) => { + const api = makeApi(arg) const storeRef = setupApiStore(api, undefined, { withoutTestLifecycles: true, }) diff --git a/packages/toolkit/src/query/tests/infiniteQueries.test.ts b/packages/toolkit/src/query/tests/infiniteQueries.test.ts index 1325d135af..da5e10ae70 100644 --- a/packages/toolkit/src/query/tests/infiniteQueries.test.ts +++ b/packages/toolkit/src/query/tests/infiniteQueries.test.ts @@ -16,6 +16,7 @@ describe('Infinite queries', () => { name: string } + type HitCounter = { page: number; hitCounter: number } let counters: Record = {} let queryCounter = 0 @@ -88,39 +89,41 @@ describe('Infinite queries', () => { }), }) - let hitCounter = 0 - - type HitCounter = { page: number; hitCounter: number } + function createCountersApi() { + let hitCounter = 0 - const countersApi = createApi({ - baseQuery: fakeBaseQuery(), - tagTypes: ['Counter'], - endpoints: (build) => ({ - counters: build.infiniteQuery({ - queryFn({ pageParam }) { - hitCounter++ + const countersApi = createApi({ + baseQuery: fakeBaseQuery(), + tagTypes: ['Counter'], + endpoints: (build) => ({ + counters: build.infiniteQuery({ + queryFn({ pageParam }) { + hitCounter++ - return { data: { page: pageParam, hitCounter } } - }, - infiniteQueryOptions: { - initialPageParam: 0, - getNextPageParam: ( - lastPage, - allPages, - lastPageParam, - allPageParams, - ) => lastPageParam + 1, - }, - providesTags: ['Counter'], - }), - mutation: build.mutation({ - queryFn: async () => { - return { data: null } - }, - invalidatesTags: ['Counter'], + return { data: { page: pageParam, hitCounter } } + }, + infiniteQueryOptions: { + initialPageParam: 0, + getNextPageParam: ( + lastPage, + allPages, + lastPageParam, + allPageParams, + ) => lastPageParam + 1, + }, + providesTags: ['Counter'], + }), + mutation: build.mutation({ + queryFn: async () => { + return { data: null } + }, + invalidatesTags: ['Counter'], + }), }), - }), - }) + }) + + return countersApi + } let storeRef = setupApiStore( pokemonApi, @@ -155,7 +158,6 @@ describe('Infinite queries', () => { counters = {} - hitCounter = 0 queryCounter = 0 }) @@ -411,6 +413,8 @@ describe('Infinite queries', () => { } } + const countersApi = createCountersApi() + const storeRef = setupApiStore( countersApi, { ...actionsReducer }, @@ -465,6 +469,8 @@ describe('Infinite queries', () => { } } + const countersApi = createCountersApi() + const storeRef = setupApiStore( countersApi, { ...actionsReducer }, @@ -528,6 +534,7 @@ describe('Infinite queries', () => { }) test('Refetches on polling', async () => { + const countersApi = createCountersApi() const checkResultData = ( result: InfiniteQueryResult, expectedValues: HitCounter[], diff --git a/packages/toolkit/src/query/tests/optimisticUpserts.test.tsx b/packages/toolkit/src/query/tests/optimisticUpserts.test.tsx index 6d64e3554b..17ee73273b 100644 --- a/packages/toolkit/src/query/tests/optimisticUpserts.test.tsx +++ b/packages/toolkit/src/query/tests/optimisticUpserts.test.tsx @@ -86,7 +86,7 @@ const api = createApi({ // and leave a side effect we can check in the test api.dispatch(postAddedAction(res.data.id)) }, - keepUnusedDataFor: 0.01, + keepUnusedDataFor: 0.1, }), getFolder: build.query({ queryFn: async (args) => { @@ -430,12 +430,12 @@ describe('upsertQueryEntries', () => { expect(selectedData).toBe(posts[0]) }, - { timeout: 50, interval: 5 }, + { timeout: 150, interval: 5 }, ) // The cache data should be removed after the keepUnusedDataFor time, // so wait longer than that - await delay(100) + await delay(300) const stateAfter = storeRef.store.getState() diff --git a/packages/toolkit/src/query/tests/polling.test.tsx b/packages/toolkit/src/query/tests/polling.test.tsx index 425a0bf804..cff9ab7449 100644 --- a/packages/toolkit/src/query/tests/polling.test.tsx +++ b/packages/toolkit/src/query/tests/polling.test.tsx @@ -1,4 +1,5 @@ import { createApi } from '@reduxjs/toolkit/query' +import type { QueryActionCreatorResult } from '@reduxjs/toolkit/query' import { delay } from 'msw' import { setupApiStore } from '../../tests/utils/helpers' import type { SubscriptionSelectors } from '../core/buildMiddleware/types' @@ -29,10 +30,15 @@ beforeEach(() => { ;({ getSubscriptions } = storeRef.store.dispatch( api.internalActions.internal_getRTKQSubscriptions(), ) as unknown as SubscriptionSelectors) + + const currentPolls = storeRef.store.dispatch({ + type: `${api.reducerPath}/getPolling`, + }) as any + ;(currentPolls as any).pollUpdateCounters = {} }) const getSubscribersForQueryCacheKey = (queryCacheKey: string) => - getSubscriptions()[queryCacheKey] || {} + getSubscriptions().get(queryCacheKey) ?? new Map() const createSubscriptionGetter = (queryCacheKey: string) => () => getSubscribersForQueryCacheKey(queryCacheKey) @@ -66,14 +72,14 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(queryCacheKey) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].pollingInterval).toBe(10) + expect(getSubs().size).toBe(1) + expect(getSubs()?.get(requestId)?.pollingInterval).toBe(10) subscription.updateSubscriptionOptions({ pollingInterval: 20 }) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].pollingInterval).toBe(20) + expect(getSubs().size).toBe(1) + expect(getSubs()?.get(requestId)?.pollingInterval).toBe(20) }) it(`doesn't replace the interval when removing a shared query instance with a poll `, async () => { @@ -95,12 +101,12 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(subscriptionOne.queryCacheKey) - expect(Object.keys(getSubs())).toHaveLength(2) + expect(getSubs().size).toBe(2) subscriptionOne.unsubscribe() await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) + expect(getSubs().size).toBe(1) }) it('uses lowest specified interval when two components are mounted', async () => { @@ -155,7 +161,7 @@ describe('polling tests', () => { const callsWithoutSkip = mockBaseQuery.mock.calls.length expect(callsWithSkip).toBe(1) - expect(callsWithoutSkip).toBeGreaterThan(2) + expect(callsWithoutSkip).toBeGreaterThanOrEqual(2) storeRef.store.dispatch(api.util.resetApiState()) }) @@ -218,8 +224,8 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(queryCacheKey) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].skipPollingIfUnfocused).toBe(false) + expect(getSubs().size).toBe(1) + expect(getSubs().get(requestId)?.skipPollingIfUnfocused).toBe(false) subscription.updateSubscriptionOptions({ pollingInterval: 20, @@ -227,7 +233,54 @@ describe('polling tests', () => { }) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].skipPollingIfUnfocused).toBe(true) + expect(getSubs().size).toBe(1) + expect(getSubs().get(requestId)?.skipPollingIfUnfocused).toBe(true) + }) + + it('should minimize polling recalculations when adding multiple subscribers', async () => { + // Reset any existing state + const storeRef = setupApiStore(api, undefined, { + withoutTestLifecycles: true, + }) + + const SUBSCRIBER_COUNT = 10 + const subscriptions: QueryActionCreatorResult[] = [] + + // Add 10 subscribers to the same endpoint with polling enabled + for (let i = 0; i < SUBSCRIBER_COUNT; i++) { + const subscription = storeRef.store.dispatch( + getPosts.initiate(1, { + subscriptionOptions: { pollingInterval: 1000 }, + subscribe: true, + }), + ) + subscriptions.push(subscription) + } + + // Wait a bit for all subscriptions to be processed + await Promise.all(subscriptions) + + // Wait for the poll update timer + await delay(25) + + // Get the polling state using the secret "getPolling" action + const currentPolls = storeRef.store.dispatch({ + type: `${api.reducerPath}/getPolling`, + }) as any + + // Get the query cache key for our endpoint + const queryCacheKey = subscriptions[0].queryCacheKey + + // Check the poll update counters + const pollUpdateCounters = currentPolls.pollUpdateCounters || {} + const updateCount = pollUpdateCounters[queryCacheKey] || 0 + + // With batching optimization, this should be much lower than SUBSCRIBER_COUNT + // Ideally 1, but could be slightly higher due to timing + expect(updateCount).toBeGreaterThanOrEqual(1) + expect(updateCount).toBeLessThanOrEqual(2) + + // Clean up subscriptions + subscriptions.forEach((sub) => sub.unsubscribe()) }) }) diff --git a/packages/toolkit/src/query/tests/retry.test.ts b/packages/toolkit/src/query/tests/retry.test.ts index 03365fd0af..b2233021b0 100644 --- a/packages/toolkit/src/query/tests/retry.test.ts +++ b/packages/toolkit/src/query/tests/retry.test.ts @@ -465,4 +465,87 @@ describe('configuration', () => { expect(baseBaseQuery).toHaveBeenCalledOnce() }) + + test('retryCondition receives abort signal and stops retrying when cache entry is removed', async () => { + let capturedSignal: AbortSignal | undefined + let retryAttempts = 0 + + const baseBaseQuery = vi.fn< + Parameters, + ReturnType + >() + + // Always return an error to trigger retries + baseBaseQuery.mockResolvedValue({ error: 'network error' }) + + let retryConditionCalled = false + + const baseQuery = retry(baseBaseQuery, { + retryCondition: (error, args, { attempt, baseQueryApi }) => { + retryConditionCalled = true + retryAttempts = attempt + capturedSignal = baseQueryApi.signal + + // Stop retrying if the signal is aborted + if (baseQueryApi.signal.aborted) { + return false + } + + // Otherwise, retry up to 10 times + return attempt <= 10 + }, + backoff: async () => { + // Short backoff for faster test + await new Promise((resolve) => setTimeout(resolve, 10)) + }, + }) + + const api = createApi({ + baseQuery, + endpoints: (build) => ({ + getTest: build.query({ + query: (id) => ({ url: `test/${id}` }), + keepUnusedDataFor: 0.01, // Very short timeout (10ms) + }), + }), + }) + + const storeRef = setupApiStore(api, undefined, { + withoutTestLifecycles: true, + }) + + // Start the query + const queryPromise = storeRef.store.dispatch( + api.endpoints.getTest.initiate(1), + ) + + // Wait for the first retry to happen so we capture the signal + await loopTimers(2) + + // Verify the retry condition was called and we have a signal + expect(retryConditionCalled).toBe(true) + expect(capturedSignal).toBeDefined() + expect(capturedSignal!.aborted).toBe(false) + + // Unsubscribe to trigger cache removal + queryPromise.unsubscribe() + + // Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms) + await vi.advanceTimersByTimeAsync(50) + + // Allow some time for more retries to potentially happen + await loopTimers(3) + + // The signal should now be aborted + expect(capturedSignal!.aborted).toBe(true) + + // We should have stopped retrying early due to the abort signal + // If abort signal wasn't working, we'd see many more retry attempts + expect(retryAttempts).toBeLessThan(10) + + // The base query should have been called at least once (initial attempt) + // but not the full 10+ times it would without abort signal + expect(baseBaseQuery).toHaveBeenCalled() + expect(baseBaseQuery.mock.calls.length).toBeLessThan(10) + }) }) diff --git a/packages/toolkit/src/query/utils/getOrInsert.ts b/packages/toolkit/src/query/utils/getOrInsert.ts index 124da032ea..8ae351e02f 100644 --- a/packages/toolkit/src/query/utils/getOrInsert.ts +++ b/packages/toolkit/src/query/utils/getOrInsert.ts @@ -1,3 +1,7 @@ +// Duplicate some of the utils in `/src/utils` to ensure +// we don't end up dragging in larger chunks of the RTK core +// into the RTKQ bundle + export function getOrInsert( map: WeakMap, key: K, @@ -13,3 +17,25 @@ export function getOrInsert( return map.set(key, value).get(key) as V } + +export function getOrInsertComputed( + map: WeakMap, + key: K, + compute: (key: K) => V, +): V +export function getOrInsertComputed( + map: Map, + key: K, + compute: (key: K) => V, +): V +export function getOrInsertComputed( + map: Map | WeakMap, + key: K, + compute: (key: K) => V, +): V { + if (map.has(key)) return map.get(key) as V + + return map.set(key, compute(key)).get(key) as V +} + +export const createNewMap = () => new Map() diff --git a/packages/toolkit/src/tests/createReducer.test.ts b/packages/toolkit/src/tests/createReducer.test.ts index e57bd47da9..114096c069 100644 --- a/packages/toolkit/src/tests/createReducer.test.ts +++ b/packages/toolkit/src/tests/createReducer.test.ts @@ -7,10 +7,12 @@ import type { } from '@reduxjs/toolkit' import { createAction, + createAsyncThunk, createNextState, createReducer, isPlainObject, } from '@reduxjs/toolkit' +import { waitMs } from './utils/helpers' interface Todo { text: string @@ -39,6 +41,8 @@ type ToggleTodoReducer = CaseReducer< type CreateReducer = typeof createReducer +const addTodoThunk = createAsyncThunk('todos/add', (todo: Todo) => todo) + describe('createReducer', () => { describe('given impure reducers with immer', () => { const addTodo: AddTodoReducer = (state, action) => { @@ -341,24 +345,24 @@ describe('createReducer', () => { expect(reducer(5, decrement(5))).toBe(0) }) test('will throw if the same type is used twice', () => { - expect(() => - createReducer(0, (builder) => + expect(() => { + createReducer(0, (builder) => { builder .addCase(increment, (state, action) => state + action.payload) .addCase(increment, (state, action) => state + action.payload) - .addCase(decrement, (state, action) => state - action.payload), - ), - ).toThrowErrorMatchingInlineSnapshot( + .addCase(decrement, (state, action) => state - action.payload) + }) + }).toThrowErrorMatchingInlineSnapshot( `[Error: \`builder.addCase\` cannot be called with two reducers for the same action type 'increment']`, ) - expect(() => - createReducer(0, (builder) => + expect(() => { + createReducer(0, (builder) => { builder .addCase(increment, (state, action) => state + action.payload) .addCase('increment', (state) => state + 1) - .addCase(decrement, (state, action) => state - action.payload), - ), - ).toThrowErrorMatchingInlineSnapshot( + .addCase(decrement, (state, action) => state - action.payload) + }) + }).toThrowErrorMatchingInlineSnapshot( `[Error: \`builder.addCase\` cannot be called with two reducers for the same action type 'increment']`, ) }) @@ -369,14 +373,14 @@ describe('createReducer', () => { payload, }) customActionCreator.type = '' - expect(() => - createReducer(0, (builder) => + expect(() => { + createReducer(0, (builder) => { builder.addCase( customActionCreator, (state, action) => state + action.payload, - ), - ), - ).toThrowErrorMatchingInlineSnapshot( + ) + }) + }).toThrowErrorMatchingInlineSnapshot( `[Error: \`builder.addCase\` cannot be called with an empty action type]`, ) }) @@ -529,6 +533,56 @@ describe('createReducer', () => { ) }) }) + describe('builder "addAsyncThunk" method', () => { + const initialState = { todos: [] as Todo[], loading: false, errored: false } + test('uses the matching reducer for each action type', () => { + const reducer = createReducer(initialState, (builder) => + builder.addAsyncThunk(addTodoThunk, { + pending(state) { + state.loading = true + }, + fulfilled(state, action) { + state.todos.push(action.payload) + }, + rejected(state) { + state.errored = true + }, + settled(state) { + state.loading = false + }, + }), + ) + const todo: Todo = { text: 'test' } + expect(reducer(undefined, addTodoThunk.pending('test', todo))).toEqual({ + todos: [], + loading: true, + errored: false, + }) + expect( + reducer(undefined, addTodoThunk.fulfilled(todo, 'test', todo)), + ).toEqual({ + todos: [todo], + loading: false, + errored: false, + }) + expect( + reducer(undefined, addTodoThunk.rejected(new Error(), 'test', todo)), + ).toEqual({ + todos: [], + loading: false, + errored: true, + }) + }) + test('calling addAsyncThunk after addDefaultCase should result in an error in development mode', () => { + expect(() => + createReducer(initialState, (builder: any) => + builder.addDefaultCase(() => {}).addAsyncThunk(addTodoThunk, {}), + ), + ).toThrowErrorMatchingInlineSnapshot( + `[Error: \`builder.addAsyncThunk\` should only be called before calling \`builder.addDefaultCase\`]`, + ) + }) + }) }) function behavesLikeReducer(todosReducer: TodosReducer) { diff --git a/packages/toolkit/src/tests/createSlice.test.ts b/packages/toolkit/src/tests/createSlice.test.ts index 248e7c71fe..09f726c113 100644 --- a/packages/toolkit/src/tests/createSlice.test.ts +++ b/packages/toolkit/src/tests/createSlice.test.ts @@ -6,6 +6,7 @@ import { combineSlices, configureStore, createAction, + createAsyncThunk, createSlice, } from '@reduxjs/toolkit' @@ -265,6 +266,58 @@ describe('createSlice', () => { ) }) + test('can be used with addAsyncThunk and async thunks', () => { + const asyncThunk = createAsyncThunk('test', (n: number) => n) + const slice = createSlice({ + name: 'counter', + initialState: { + loading: false, + errored: false, + value: 0, + }, + reducers: {}, + extraReducers: (builder) => + builder.addAsyncThunk(asyncThunk, { + pending(state) { + state.loading = true + }, + fulfilled(state, action) { + state.value = action.payload + }, + rejected(state) { + state.errored = true + }, + settled(state) { + state.loading = false + }, + }), + }) + expect( + slice.reducer(undefined, asyncThunk.pending('requestId', 5)), + ).toEqual({ + loading: true, + errored: false, + value: 0, + }) + expect( + slice.reducer(undefined, asyncThunk.fulfilled(5, 'requestId', 5)), + ).toEqual({ + loading: false, + errored: false, + value: 5, + }) + expect( + slice.reducer( + undefined, + asyncThunk.rejected(new Error(), 'requestId', 5), + ), + ).toEqual({ + loading: false, + errored: true, + value: 0, + }) + }) + test('can be used with addMatcher and type guard functions', () => { const slice = createSlice({ name: 'counter', diff --git a/packages/toolkit/src/tests/mapBuilders.test-d.ts b/packages/toolkit/src/tests/mapBuilders.test-d.ts index 63aece1c18..d2ec028ccf 100644 --- a/packages/toolkit/src/tests/mapBuilders.test-d.ts +++ b/packages/toolkit/src/tests/mapBuilders.test-d.ts @@ -128,21 +128,42 @@ describe('type tests', () => { expectTypeOf(action).toMatchTypeOf() }) - test('addMatcher() should prevent further calls to addCase()', () => { + test('addAsyncThunk() should prevent further calls to addCase() ', () => { + const asyncThunk = createAsyncThunk('test', () => {}) + const b = builder.addAsyncThunk(asyncThunk, { + pending: () => {}, + rejected: () => {}, + fulfilled: () => {}, + settled: () => {}, + }) + + expectTypeOf(b).not.toHaveProperty('addCase') + + expectTypeOf(b.addAsyncThunk).toBeFunction() + + expectTypeOf(b.addMatcher).toBeCallableWith(increment.match, () => {}) + + expectTypeOf(b.addDefaultCase).toBeCallableWith(() => {}) + }) + + test('addMatcher() should prevent further calls to addCase() and addAsyncThunk()', () => { const b = builder.addMatcher(increment.match, () => {}) expectTypeOf(b).not.toHaveProperty('addCase') + expectTypeOf(b).not.toHaveProperty('addAsyncThunk') expectTypeOf(b.addMatcher).toBeCallableWith(increment.match, () => {}) expectTypeOf(b.addDefaultCase).toBeCallableWith(() => {}) }) - test('addDefaultCase() should prevent further calls to addCase(), addMatcher() and addDefaultCase', () => { + test('addDefaultCase() should prevent further calls to addCase(), addAsyncThunk(), addMatcher() and addDefaultCase', () => { const b = builder.addDefaultCase(() => {}) expectTypeOf(b).not.toHaveProperty('addCase') + expectTypeOf(b).not.toHaveProperty('addAsyncThunk') + expectTypeOf(b).not.toHaveProperty('addMatcher') expectTypeOf(b).not.toHaveProperty('addDefaultCase') @@ -188,79 +209,206 @@ describe('type tests', () => { } }>() }) + + builder.addAsyncThunk(thunk, { + pending(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: undefined + meta: { + arg: void + requestId: string + requestStatus: 'pending' + } + }>() + }, + rejected(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: unknown + error: SerializedError + meta: { + arg: void + requestId: string + requestStatus: 'rejected' + aborted: boolean + condition: boolean + rejectedWithValue: boolean + } + }>() + }, + fulfilled(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: 'ret' + meta: { + arg: void + requestId: string + requestStatus: 'fulfilled' + } + }>() + }, + settled(_, action) { + expectTypeOf(action).toMatchTypeOf< + | { + payload: 'ret' + meta: { + arg: void + requestId: string + requestStatus: 'fulfilled' + } + } + | { + payload: unknown + error: SerializedError + meta: { + arg: void + requestId: string + requestStatus: 'rejected' + aborted: boolean + condition: boolean + rejectedWithValue: boolean + } + } + >() + }, + }) }) - }) - test('case 2: `createAsyncThunk` with `meta`', () => { - const thunk = createAsyncThunk< - 'ret', - void, - { - pendingMeta: { startedTimeStamp: number } - fulfilledMeta: { - fulfilledTimeStamp: number - baseQueryMeta: 'meta!' - } - rejectedMeta: { - baseQueryMeta: 'meta!' + test('case 2: `createAsyncThunk` with `meta`', () => { + const thunk = createAsyncThunk< + 'ret', + void, + { + pendingMeta: { startedTimeStamp: number } + fulfilledMeta: { + fulfilledTimeStamp: number + baseQueryMeta: 'meta!' + } + rejectedMeta: { + baseQueryMeta: 'meta!' + } } - } - >( - 'test', - (_, api) => { - return api.fulfillWithValue('ret' as const, { - fulfilledTimeStamp: 5, - baseQueryMeta: 'meta!', - }) - }, - { - getPendingMeta() { - return { startedTimeStamp: 0 } + >( + 'test', + (_, api) => { + return api.fulfillWithValue('ret' as const, { + fulfilledTimeStamp: 5, + baseQueryMeta: 'meta!', + }) }, - }, - ) + { + getPendingMeta() { + return { startedTimeStamp: 0 } + }, + }, + ) - builder.addCase(thunk.pending, (_, action) => { - expectTypeOf(action).toMatchTypeOf<{ - payload: undefined - meta: { - arg: void - requestId: string - requestStatus: 'pending' - startedTimeStamp: number - } - }>() - }) + builder.addCase(thunk.pending, (_, action) => { + expectTypeOf(action).toMatchTypeOf<{ + payload: undefined + meta: { + arg: void + requestId: string + requestStatus: 'pending' + startedTimeStamp: number + } + }>() + }) - builder.addCase(thunk.rejected, (_, action) => { - expectTypeOf(action).toMatchTypeOf<{ - payload: unknown - error: SerializedError - meta: { - arg: void - requestId: string - requestStatus: 'rejected' - aborted: boolean - condition: boolean - rejectedWithValue: boolean - baseQueryMeta?: 'meta!' - } - }>() + builder.addCase(thunk.rejected, (_, action) => { + expectTypeOf(action).toMatchTypeOf<{ + payload: unknown + error: SerializedError + meta: { + arg: void + requestId: string + requestStatus: 'rejected' + aborted: boolean + condition: boolean + rejectedWithValue: boolean + baseQueryMeta?: 'meta!' + } + }>() - if (action.meta.rejectedWithValue) { - expectTypeOf(action.meta.baseQueryMeta).toEqualTypeOf<'meta!'>() - } - }) - builder.addCase(thunk.fulfilled, (_, action) => { - expectTypeOf(action).toMatchTypeOf<{ - payload: 'ret' - meta: { - arg: void - requestId: string - requestStatus: 'fulfilled' - baseQueryMeta: 'meta!' + if (action.meta.rejectedWithValue) { + expectTypeOf(action.meta.baseQueryMeta).toEqualTypeOf<'meta!'>() } - }>() + }) + builder.addCase(thunk.fulfilled, (_, action) => { + expectTypeOf(action).toMatchTypeOf<{ + payload: 'ret' + meta: { + arg: void + requestId: string + requestStatus: 'fulfilled' + baseQueryMeta: 'meta!' + } + }>() + }) + + builder.addAsyncThunk(thunk, { + pending(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: undefined + meta: { + arg: void + requestId: string + requestStatus: 'pending' + startedTimeStamp: number + } + }>() + }, + rejected(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: unknown + error: SerializedError + meta: { + arg: void + requestId: string + requestStatus: 'rejected' + aborted: boolean + condition: boolean + rejectedWithValue: boolean + baseQueryMeta?: 'meta!' + } + }>() + }, + fulfilled(_, action) { + expectTypeOf(action).toMatchTypeOf<{ + payload: 'ret' + meta: { + arg: void + requestId: string + requestStatus: 'fulfilled' + baseQueryMeta: 'meta!' + } + }>() + }, + settled(_, action) { + expectTypeOf(action).toMatchTypeOf< + | { + payload: 'ret' + meta: { + arg: void + requestId: string + requestStatus: 'fulfilled' + baseQueryMeta: 'meta!' + } + } + | { + payload: unknown + error: SerializedError + meta: { + arg: void + requestId: string + requestStatus: 'rejected' + aborted: boolean + condition: boolean + rejectedWithValue: boolean + baseQueryMeta?: 'meta!' + } + } + >() + }, + }) }) }) }) diff --git a/website/docusaurus.config.ts b/website/docusaurus.config.ts index e704d4f79a..bd3911bbe3 100644 --- a/website/docusaurus.config.ts +++ b/website/docusaurus.config.ts @@ -36,6 +36,7 @@ const config: Config = { 'query/endpointDefinitions.ts', 'query/react/index.ts', 'query/react/ApiProvider.tsx', + 'query/core/buildMiddleware/cacheCollection.ts', ], }, }, diff --git a/website/src/css/custom.css b/website/src/css/custom.css index 1120329d0d..a66bd54a94 100644 --- a/website/src/css/custom.css +++ b/website/src/css/custom.css @@ -144,7 +144,7 @@ a.contents__link > code { color: var(--ifm-color-primary-lightest); } -a:visited { +a:visited:not(.menu__link, .table-of-contents__link, .navbar__link:not([target="_blank"])) { color: var(--ifm-color-primary); } .navbar .navbar__inner { @@ -283,4 +283,4 @@ div[class*='announcementBar_'] { /* Intentionally override the theme behavior, so that the course banner image is effectively cropped*/ z-index: calc(var(--ifm-z-index-fixed) -1) !important; -} \ No newline at end of file +} diff --git a/yarn.lock b/yarn.lock index e92552b700..9200516b2f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8120,7 +8120,7 @@ __metadata: lodash.camelcase: "npm:^4.3.0" msw: "npm:^2.1.5" node-fetch: "npm:^3.3.2" - oazapfts: "npm:^6.1.0" + oazapfts: "npm:^6.3.0" openapi-types: "npm:^9.1.0" prettier: "npm:^3.2.5" pretty-quick: "npm:^4.0.0" @@ -24075,21 +24075,21 @@ __metadata: languageName: node linkType: hard -"oazapfts@npm:^6.1.0": - version: 6.2.2 - resolution: "oazapfts@npm:6.2.2" +"oazapfts@npm:^6.3.0": + version: 6.3.0 + resolution: "oazapfts@npm:6.3.0" dependencies: "@apidevtools/swagger-parser": "npm:^10.1.1" lodash: "npm:^4.17.21" minimist: "npm:^1.2.8" swagger2openapi: "npm:^7.0.8" - tapable: "npm:^2.2.1" - typescript: "npm:^5.8.2" + tapable: "npm:^2.2.2" + typescript: "npm:^5.8.3" peerDependencies: "@oazapfts/runtime": "*" bin: oazapfts: cli.js - checksum: 10/20402b38e657679a04a122217fe8728318e967f7684e9f968cf4c690e094d0e3c17ac248448bb16b4111b30db4f936d7db52ad115c06a6a228b68ef46c82ccfe + checksum: 10/a8f0d04123aca7578093eac91e7206936c2a1d0515eda34bde70b172065f36af6bbda3d15a38c8bcf687ffd2accf0a88117e115c455af2be3f4c4d5fa6c70c94 languageName: node linkType: hard @@ -30897,6 +30897,13 @@ __metadata: languageName: node linkType: hard +"tapable@npm:^2.2.2": + version: 2.2.2 + resolution: "tapable@npm:2.2.2" + checksum: 10/065a0dc44aba1b32020faa1c27c719e8f76e5345347515d8494bf158524f36e9f22ad9eaa5b5494f9d5d67bf0640afdd5698505948c46d720b6b7e69d19349a6 + languageName: node + linkType: hard + "tar@npm:^7.4.3": version: 7.4.3 resolution: "tar@npm:7.4.3" @@ -31761,7 +31768,7 @@ __metadata: languageName: node linkType: hard -"typescript@npm:^5.8.2": +"typescript@npm:^5.8.2, typescript@npm:^5.8.3": version: 5.8.3 resolution: "typescript@npm:5.8.3" bin: @@ -31801,7 +31808,7 @@ __metadata: languageName: node linkType: hard -"typescript@patch:typescript@npm%3A^5.8.2#optional!builtin": +"typescript@patch:typescript@npm%3A^5.8.2#optional!builtin, typescript@patch:typescript@npm%3A^5.8.3#optional!builtin": version: 5.8.3 resolution: "typescript@patch:typescript@npm%3A5.8.3#optional!builtin::version=5.8.3&hash=8c6c40" bin: