Skip to content

Updated setState to support updateFn to work with both updater function and state #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/reference/classes/store.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ Defined in: [store.ts:28](https://github.com/TanStack/store/blob/main/packages/s

• **TState**

• **TUpdater** *extends* `AnyUpdater` = (`cb`) => `TState`

## Constructors

### new Store()
Expand Down
2 changes: 0 additions & 2 deletions docs/reference/interfaces/storeoptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ Defined in: [store.ts:5](https://github.com/TanStack/store/blob/main/packages/st

• **TState**

• **TUpdater** *extends* `AnyUpdater` = (`cb`) => `TState`

## Properties

### onSubscribe()?
Expand Down
10 changes: 10 additions & 0 deletions examples/react/simple/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export const store = new Store({
cats: 0,
})

const countStore = new Store(0);

interface DisplayProps {
animal: 'dogs' | 'cats'
}
Expand All @@ -18,6 +20,11 @@ const Display = ({ animal }: DisplayProps) => {
return <div>{`${animal}: ${count}`}</div>
}

const DisplayCount = () => {
const count = useStore(countStore);
return <div>{`count: ${count}`}</div>
}

const updateState = (animal: 'dogs' | 'cats') => {
store.setState((state) => {
return {
Expand Down Expand Up @@ -47,6 +54,9 @@ function App() {
<Display animal="dogs" />
<Increment animal="cats" />
<Display animal="cats" />
<button onClick={() => countStore.setState((v) => v + 1)}>Update count</button>
<button onClick={() => countStore.setState(0)}>Reset count</button>
<DisplayCount />
</div>
)
}
Expand Down
4 changes: 2 additions & 2 deletions packages/angular-store/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export * from '@tanstack/store'
type NoInfer<T> = [T][T extends any ? 0 : never]

export function injectStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any>,
store: Store<TState>,
selector?: (state: NoInfer<TState>) => TSelected,
options?: CreateSignalOptions<TSelected> & { injector?: Injector },
): Signal<TSelected>
Expand All @@ -27,7 +27,7 @@ export function injectStore<TState, TSelected = NoInfer<TState>>(
options?: CreateSignalOptions<TSelected> & { injector?: Injector },
): Signal<TSelected>
export function injectStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any> | Derived<TState, any>,
store: Store<TState> | Derived<TState, any>,
selector: (state: NoInfer<TState>) => TSelected = (d) => d as TSelected,
options: CreateSignalOptions<TSelected> & { injector?: Injector } = {
equal: shallow,
Expand Down
4 changes: 2 additions & 2 deletions packages/react-store/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ export * from '@tanstack/store'
export type NoInfer<T> = [T][T extends any ? 0 : never]

export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any>,
store: Store<TState>,
selector?: (state: NoInfer<TState>) => TSelected,
): TSelected
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Derived<TState, any>,
selector?: (state: NoInfer<TState>) => TSelected,
): TSelected
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any> | Derived<TState, any>,
store: Store<TState> | Derived<TState, any>,
selector: (state: NoInfer<TState>) => TSelected = (d) => d as any,
): TSelected {
const slice = useSyncExternalStoreWithSelector(
Expand Down
4 changes: 2 additions & 2 deletions packages/solid-store/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ export * from '@tanstack/store'
export type NoInfer<T> = [T][T extends any ? 0 : never]

export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any>,
store: Store<TState>,
selector?: (state: NoInfer<TState>) => TSelected,
): Accessor<TSelected>
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Derived<TState, any>,
selector?: (state: NoInfer<TState>) => TSelected,
): Accessor<TSelected>
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any> | Derived<TState, any>,
store: Store<TState> | Derived<TState, any>,
selector: (state: NoInfer<TState>) => TSelected = (d) => d as any,
): Accessor<TSelected> {
const [slice, setSlice] = createStore({
Expand Down
39 changes: 13 additions & 26 deletions packages/store/src/store.ts
Original file line number Diff line number Diff line change
@@ -1,71 +1,58 @@
import { __flush } from './scheduler'
import { isUpdaterFunction } from './types'
import type { AnyUpdater, Listener, Updater } from './types'
import type { Listener } from './types'

export interface StoreOptions<
TState,
TUpdater extends AnyUpdater = (cb: TState) => TState,
> {
export interface StoreOptions< TState> {
/**
* Replace the default update function with a custom one.
*/
updateFn?: (previous: TState) => (updater: TUpdater) => TState
updateFn?: (previous: TState) => (updater: (prev: TState) => TState) => TState
/**
* Called when a listener subscribes to the store.
*
* @return a function to unsubscribe the listener
*/
onSubscribe?: (
listener: Listener<TState>,
store: Store<TState, TUpdater>,
store: Store<TState>,
) => () => void
/**
* Called after the state has been updated, used to derive other state.
*/
onUpdate?: () => void
}

export class Store<
TState,
TUpdater extends AnyUpdater = (cb: TState) => TState,
> {
export class Store<TState> {
listeners = new Set<Listener<TState>>()
state: TState
prevState: TState
options?: StoreOptions<TState, TUpdater>
options?: StoreOptions<TState>

constructor(initialState: TState, options?: StoreOptions<TState, TUpdater>) {
constructor(initialState: TState, options?: StoreOptions<TState>) {
this.prevState = initialState
this.state = initialState
this.options = options
}

subscribe = (listener: Listener<TState>) => {
this.listeners.add(listener)
const unsub = this.options?.onSubscribe?.(listener, this)
const unsubscribe = this.options?.onSubscribe?.(listener, this)
return () => {
this.listeners.delete(listener)
unsub?.()
unsubscribe?.()
}
}

/**
* Update the store state safely with improved type checking
*/
setState(updater: (prevState: TState) => TState): void
setState(updater: TState): void
setState(updater: TUpdater): void
setState(updater: Updater<TState> | TUpdater): void {
setState(updater: TState | ((prevState: TState) => TState)): void {
this.prevState = this.state

if (this.options?.updateFn) {
this.state = this.options.updateFn(this.prevState)(updater as TUpdater)
if (isUpdaterFunction(updater)) {
this.state = this.options?.updateFn ? this.options.updateFn(this.prevState)(updater) : updater(this.prevState)
} else {
if (isUpdaterFunction(updater)) {
this.state = updater(this.prevState)
} else {
this.state = updater as TState
}
this.state = this.options?.updateFn ? this.options.updateFn(this.prevState)(() => updater) : updater
Copy link
Contributor

@jiji-hoon96 jiji-hoon96 Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When updater is a function, updater: (prev: TState) => TState

this.options.updateFn(this.prevState)(() => updater)

The type of () => updater is () => (prev: TState) => TState, but updateFn expects a parameter of type (prev: TState) => TState.

When updater is a value, updater: TState

this.options.updateFn(this.prevState)(() => updater)

The type of () => updater is () => TState, but updateFn expects a parameter of type (prev: TState) => TState.

It seems like this might cause a type error. Am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setState(updater: TState | ((prevState: TState) => TState)): void {
  this.prevState = this.state
  
  if (this.options?.updateFn) {
    // When updateFn is defined: always convert the updater to a function.
    const functionUpdater = isUpdaterFunction(updater) 
      ? updater 
      : (prev: TState) => updater
    this.state = this.options.updateFn(this.prevState)(functionUpdater)
  } else {
    // When updateFn is not defined: handle the updater directly.
    this.state = isUpdaterFunction(updater) 
      ? updater(this.prevState) 
      : updater
  }
}

If I were to make changes, I think this could be a possible improvement. What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we need to check both whether updater is a function or not and updateFn is defined or not. Either way is fine: we can check updater’s function type first and then updateFn, or vice versa.

}

// Always run onUpdate, regardless of batching
Expand Down
5 changes: 0 additions & 5 deletions packages/store/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
/**
* @private
*/
export type AnyUpdater = (prev: any) => any

/**
* Type-safe updater that can be either a function or direct value
*/
Expand Down
8 changes: 4 additions & 4 deletions packages/store/tests/scheduler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {

describe('Scheduler logic', () => {
test('Should build a graph properly', () => {
const count = new Store<any, any>(10)
const count = new Store<any>(10)

const halfCount = new Derived<any, any>({
deps: [count],
Expand All @@ -36,7 +36,7 @@ describe('Scheduler logic', () => {
})

test('should unbuild a graph properly', () => {
const count = new Store<any, any>(10)
const count = new Store<any>(10)

const halfCount = new Derived<any, any>({
deps: [count],
Expand Down Expand Up @@ -99,7 +99,7 @@ describe('Scheduler logic', () => {
})

test('should register graph items in the wrong order properly', () => {
const count = new Store<any, any>(12)
const count = new Store<any>(12)

const double = new Derived<any, any>({
deps: [count],
Expand All @@ -124,7 +124,7 @@ describe('Scheduler logic', () => {
})

test('should register graph items in the right direction order', () => {
const count = new Store<any, any>(12)
const count = new Store<any>(12)

const double = new Derived<any, any>({
deps: [count],
Expand Down
4 changes: 2 additions & 2 deletions packages/svelte-store/src/index.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ export * from '@tanstack/store'
export type NoInfer<T> = [T][T extends any ? 0 : never]

export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any>,
store: Store<TState>,
selector?: (state: NoInfer<TState>) => TSelected,
): { readonly current: TSelected }
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Derived<TState, any>,
selector?: (state: NoInfer<TState>) => TSelected,
): { readonly current: TSelected }
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any> | Derived<TState, any>,
store: Store<TState> | Derived<TState, any>,
selector: (state: NoInfer<TState>) => TSelected = (d) => d as any,
): { readonly current: TSelected } {
let slice = $state(selector(store.state))
Expand Down
4 changes: 2 additions & 2 deletions packages/vue-store/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ export * from '@tanstack/store'
export type NoInfer<T> = [T][T extends any ? 0 : never]

export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any>,
store: Store<TState>,
selector?: (state: NoInfer<TState>) => TSelected,
): Readonly<Ref<TSelected>>
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Derived<TState, any>,
selector?: (state: NoInfer<TState>) => TSelected,
): Readonly<Ref<TSelected>>
export function useStore<TState, TSelected = NoInfer<TState>>(
store: Store<TState, any> | Derived<TState, any>,
store: Store<TState> | Derived<TState, any>,
selector: (state: NoInfer<TState>) => TSelected = (d) => d as any,
): Readonly<Ref<TSelected>> {
const slice = ref(selector(store.state)) as Ref<TSelected>
Expand Down