Skip to content

Commit c2d4373

Browse files
yueshengysjax authors
authored andcommitted
Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage). PiperOrigin-RevId: 626091210
1 parent 9c9e805 commit c2d4373

File tree

8 files changed

+123
-89
lines changed

8 files changed

+123
-89
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ Remember to align the itemized text with the first line of an item within a list
2424
to non-parallel computations, as we already do async dispatch for parallel
2525
computations. You can recover the old behavior by setting
2626
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
27+
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
28+
be created and threaded in and out of computations to build up dependency.
29+
The singleton object `core.token` has been removed, users now should create
30+
and use fresh `core.Token` objects instead.
2731

2832
* Deprecations & Removals
2933
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

jax/_src/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,8 @@ def _infer_src_sharding(src, x) -> Sharding | None:
24542454
def _check_sharding(x, s):
24552455
if isinstance(s, Sharding):
24562456
aval = shaped_abstractify(x)
2457+
if isinstance(aval, core.AbstractToken):
2458+
aval = core.token_shaped_array
24572459
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
24582460
pjit.pjit_check_aval_sharding(
24592461
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)

jax/_src/array.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,13 @@ def _array_shard_arg(x, sharding):
952952
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
953953

954954

955+
def _token_shard_arg(x, sharding):
956+
return _array_shard_arg(x._buf, sharding)
957+
958+
959+
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
960+
961+
955962
def _array_global_result_handler(global_aval, out_sharding, committed):
956963
if global_aval.dtype == dtypes.float0:
957964
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
@@ -963,7 +970,21 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
963970
)
964971
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
965972
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
966-
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token
973+
974+
975+
def _token_global_result_handler(global_aval, out_sharding, committed):
976+
array_handler = _array_global_result_handler(
977+
core.token_shaped_array, out_sharding, committed
978+
)
979+
980+
def wrapper(*args, **kwargs):
981+
out_buf = array_handler(*args, **kwargs)
982+
return core.Token(out_buf)
983+
984+
return wrapper
985+
986+
987+
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
967988

968989

969990
# Only used for Arrays that come out of pmap.

jax/_src/core.py

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,70 @@ def shape(self):
16351635
"UnshapedArray instances to ever be produced.")
16361636
raise TypeError(msg)
16371637

1638+
def _canonicalize_dimension(dim: DimSize) -> DimSize:
1639+
# Dimensions are most commonly integral (by far), so we check that first.
1640+
try:
1641+
return operator.index(dim)
1642+
except TypeError as e:
1643+
type_error = e
1644+
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
1645+
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
1646+
or isinstance(dim.dtype, bint))):
1647+
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
1648+
return dim
1649+
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
1650+
type(dim._aval.dtype) is bint and not dim._aval.shape):
1651+
return dim
1652+
elif is_dim(dim):
1653+
return dim
1654+
else:
1655+
raise type_error
1656+
1657+
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
1658+
"""Canonicalizes and checks for errors in a user-provided shape value.
1659+
1660+
Args:
1661+
shape: a Python value that represents a shape.
1662+
1663+
Returns:
1664+
A tuple of canonical dimension values.
1665+
"""
1666+
try:
1667+
return tuple(unsafe_map(_canonicalize_dimension, shape))
1668+
except TypeError:
1669+
pass
1670+
raise _invalid_shape_error(shape, context)
1671+
1672+
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
1673+
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
1674+
1675+
Args:
1676+
f: a Python value that represents a dimension.
1677+
1678+
Returns:
1679+
A canonical dimension value.
1680+
"""
1681+
return canonicalize_shape((d,), context)[0]
1682+
1683+
def _invalid_shape_error(shape: Shape, context: str=""):
1684+
if config.dynamic_shapes.value:
1685+
msg = ("Shapes must be 1D sequences of integer scalars, "
1686+
f"got {shape}")
1687+
else:
1688+
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
1689+
f"got {shape}.")
1690+
if context:
1691+
msg += f" {context}."
1692+
if not config.dynamic_shapes.value and any(
1693+
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
1694+
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
1695+
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
1696+
"smaller subfunctions.")
1697+
for x in shape:
1698+
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
1699+
msg += x._origin_msg()
1700+
1701+
return TypeError(msg)
16381702

16391703
class ShapedArray(UnshapedArray):
16401704
__slots__ = ['shape', 'named_shape']
@@ -1960,9 +2024,18 @@ def str_short(self, short_dtypes=False): return 'Tok'
19602024
def at_least_vspace(self): return self
19612025
abstract_token: AbstractToken = AbstractToken()
19622026

2027+
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
2028+
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
2029+
19632030
# Concrete token object
1964-
class Token: pass
1965-
token: Token = Token()
2031+
class Token:
2032+
# The underlying data wrapped by the token, could be used to threaded in and
2033+
# out of computations to build up data dependency.
2034+
_buf: Array
2035+
def __init__(self, buf):
2036+
self._buf = buf
2037+
def block_until_ready(self):
2038+
self._buf.block_until_ready()
19662039
pytype_aval_mappings[Token] = lambda _: abstract_token
19672040

19682041

@@ -2121,71 +2194,6 @@ def dimension_as_value(d: DimSize):
21212194
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
21222195
return operator.index(d)
21232196

2124-
def _canonicalize_dimension(dim: DimSize) -> DimSize:
2125-
# Dimensions are most commonly integral (by far), so we check that first.
2126-
try:
2127-
return operator.index(dim)
2128-
except TypeError as e:
2129-
type_error = e
2130-
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
2131-
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
2132-
or isinstance(dim.dtype, bint))):
2133-
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
2134-
return dim
2135-
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
2136-
type(dim._aval.dtype) is bint and not dim._aval.shape):
2137-
return dim
2138-
elif is_dim(dim):
2139-
return dim
2140-
else:
2141-
raise type_error
2142-
2143-
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
2144-
"""Canonicalizes and checks for errors in a user-provided shape value.
2145-
2146-
Args:
2147-
shape: a Python value that represents a shape.
2148-
2149-
Returns:
2150-
A tuple of canonical dimension values.
2151-
"""
2152-
try:
2153-
return tuple(unsafe_map(_canonicalize_dimension, shape))
2154-
except TypeError:
2155-
pass
2156-
raise _invalid_shape_error(shape, context)
2157-
2158-
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
2159-
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
2160-
2161-
Args:
2162-
f: a Python value that represents a dimension.
2163-
2164-
Returns:
2165-
A canonical dimension value.
2166-
"""
2167-
return canonicalize_shape((d,), context)[0]
2168-
2169-
def _invalid_shape_error(shape: Shape, context: str=""):
2170-
if config.dynamic_shapes.value:
2171-
msg = ("Shapes must be 1D sequences of integer scalars, "
2172-
f"got {shape}")
2173-
else:
2174-
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
2175-
f"got {shape}.")
2176-
if context:
2177-
msg += f" {context}."
2178-
if not config.dynamic_shapes.value and any(
2179-
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
2180-
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
2181-
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
2182-
"smaller subfunctions.")
2183-
for x in shape:
2184-
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
2185-
msg += x._origin_msg()
2186-
2187-
return TypeError(msg)
2188-
21892197
class SomeTracer:
21902198
__slots__ = ()
21912199
def __repr__(self): return "[dynamic]"

jax/_src/dispatch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class RuntimeTokenSet(threading.local):
107107

108108
# For each ordered effect, the token returned by the last dispatched
109109
# computation, sharded over the devices in that computation.
110-
current_tokens: dict[core.Effect, jax.Array]
110+
current_tokens: dict[core.Effect, core.Token]
111111

112112
# For each device, the runtime token returned by the last dispatched
113113
# computation on that device.
@@ -117,11 +117,12 @@ def __init__(self):
117117
self.current_tokens = {}
118118
self.output_runtime_tokens = {}
119119

120-
def get_token_input(self, eff: core.Effect,
121-
devices: list[Device]) -> jax.Array:
120+
def get_token_input(
121+
self, eff: core.Effect, devices: list[Device]
122+
) -> core.Token:
122123
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
123124

124-
if isinstance(tok, jax.Array):
125+
if isinstance(tok, core.Token):
125126
# The order of devices may change, so we need to reshard if necessary.
126127
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
127128
# scenario. Revise the logic later. A distributed shutdown barrier inside
@@ -131,11 +132,11 @@ def get_token_input(self, eff: core.Effect,
131132
# We only use replicated sharding for the first time when the token for the
132133
# order effect hasn't been created.
133134
s = jax.sharding.GSPMDSharding.get_replicated(devices)
134-
sharded_tok = pxla.shard_args([s], [tok])[0]
135+
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
135136
self.current_tokens[eff] = sharded_tok
136137
return sharded_tok
137138

138-
def set_token_result(self, eff: core.Effect, token: jax.Array):
139+
def set_token_result(self, eff: core.Effect, token: core.Token):
139140
self.current_tokens[eff] = token
140141

141142
def set_output_runtime_token(self, device: Device, token: RuntimeToken):

jax/_src/interpreters/pxla.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,6 @@ def get_addressable_devices_for_shard_arg(
131131
def _get_replicated_slices(num_addressable_devices: int):
132132
return ((slice(None),),) * num_addressable_devices
133133

134-
def _shard_token(x, sharding):
135-
devices = get_addressable_devices_for_shard_arg(sharding)
136-
indices = _get_replicated_slices(len(devices))
137-
zeros = np.zeros((), dtype=np.dtype(np.bool_))
138-
aval = api_util.shaped_abstractify(zeros)
139-
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
140-
shard_arg_handlers[core.Token] = _shard_token
141134

142135
def _masked_array_error(x, sharding):
143136
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
@@ -1148,8 +1141,9 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
11481141
def _add_tokens_to_inputs(self, input_bufs):
11491142
if self.ordered_effects:
11501143
tokens = [
1151-
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
1152-
for eff in self.ordered_effects]
1144+
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf
1145+
for eff in self.ordered_effects
1146+
]
11531147
input_bufs = [*tokens, *input_bufs]
11541148
return input_bufs
11551149

@@ -1163,7 +1157,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
11631157
for eff, token_buf in zip(self.ordered_effects, token_bufs):
11641158
assert len(token_buf) > 0
11651159
if len(token_buf) == 1:
1166-
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
1160+
dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0]))
11671161
else:
11681162
token_devices = []
11691163
for token in token_buf:
@@ -1173,7 +1167,9 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
11731167
global_token_array = jax.make_array_from_single_device_arrays(
11741168
(0,), s, token_buf
11751169
)
1176-
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
1170+
dispatch.runtime_tokens.set_token_result(
1171+
eff, core.Token(global_token_array)
1172+
)
11771173

11781174
@profiler.annotate_function
11791175
def __call__(self, *args):

jax/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@
148148
subst_axis_names_var as subst_axis_names_var,
149149
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
150150
thread_local_state as thread_local_state,
151-
token as token,
152151
trace_state_clean as trace_state_clean,
153152
traverse_jaxpr_params as traverse_jaxpr_params,
154153
typecheck as typecheck,

tests/api_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,12 @@ def noop(arr, token):
673673

674674
arr = jnp.ones(10)
675675
token = jax.lax.create_token()
676+
_, out_token = noop(arr, token)
676677

677-
self.assertEqual(token, noop(arr, token)[1])
678+
self.assertIsInstance(token, core.Token)
679+
self.assertIsInstance(out_token, core.Token)
680+
# Different token objects.
681+
self.assertIsNot(token, out_token)
678682

679683
def test_jit_bad_input(self):
680684
def f(x):
@@ -1226,7 +1230,6 @@ def f(x, y, *args, **kwargs):
12261230
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
12271231
self.assertNotIn(s, hlo_str)
12281232

1229-
12301233
@parameterized.parameters([0, 2, [(0, 2)]])
12311234
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
12321235
def f(x, y, *args, **kwargs):
@@ -3732,7 +3735,7 @@ def test_jit_returning_token(self):
37323735
self.assertIsInstance(x, core.Token)
37333736

37343737
def test_jit_capturing_token(self):
3735-
tok = core.token
3738+
tok = jax.lax.create_token()
37363739
_, y = jax.jit(lambda x: (x + 2, tok))(7)
37373740
self.assertIsInstance(y, core.Token)
37383741

0 commit comments

Comments
 (0)