@@ -157,60 +157,37 @@ def jit(
157
157
"""Sets up ``fun`` for just-in-time compilation with XLA.
158
158
159
159
Args:
160
- fun: Function to be jitted. ``fun`` should be a pure function, as
161
- side-effects may only be executed once.
162
-
163
- The arguments and return value of ``fun`` should be arrays,
164
- scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
165
- Positional arguments indicated by ``static_argnums`` can be anything at
166
- all, provided they are hashable and have an equality operation defined.
167
- Static arguments are included as part of a compilation cache key, which is
168
- why hash and equality operators must be defined.
169
-
170
- JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
171
- so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
172
- objects will already satisfy this requirement.
173
- in_shardings: Pytree of structure matching that of arguments to ``fun``,
174
- with all actual arguments replaced by resource assignment specifications.
175
- It is also valid to specify a pytree prefix (e.g. one value in place of a
176
- whole subtree), in which case the leaves get broadcast to all values in
177
- that subtree.
178
-
179
- The ``in_shardings`` argument is optional. JAX will infer the shardings
180
- from the input :py:class:`jax.Array`'s and defaults to replicating the input
181
- if the sharding cannot be inferred.
182
-
183
- The valid resource assignment specifications are:
184
- - :py:class:`XLACompatibleSharding`, which will decide how the value
185
- will be partitioned. With this, using a mesh context manager is not
186
- required.
187
- - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
188
- it wants.
189
- For in_shardings, JAX will mark is as replicated but this behavior
190
- can change in the future.
191
- For out_shardings, we will rely on the XLA GSPMD partitioner to
192
- determine the output shardings.
193
-
194
- The size of every dimension has to be a multiple of the total number of
195
- resources assigned to it. This is similar to pjit's in_shardings.
196
- out_shardings: Like ``in_shardings``, but specifies resource
197
- assignment for function outputs. This is similar to pjit's
198
- out_shardings.
199
-
200
- The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
201
- will use GSPMD's sharding propagation to figure out what the sharding of the
202
- output(s) should be.
203
- static_argnums: An optional int or collection of ints that specify which
204
- positional arguments to treat as static (compile-time constant).
205
- Operations that only depend on static arguments will be constant-folded in
206
- Python (during tracing), and so the corresponding argument values can be
207
- any Python object.
160
+ fun: Function to be jitted. ``fun`` should be a pure function.
161
+
162
+ The arguments and return value of ``fun`` should be arrays, scalar, or
163
+ (nested) standard Python containers (tuple/list/dict) thereof. Positional
164
+ arguments indicated by ``static_argnums`` can be any hashable type. Static
165
+ arguments are included as part of a compilation cache key, which is why
166
+ hash and equality operators must be defined. JAX keeps a weak reference to
167
+ ``fun`` for use as a compilation cache key, so the object ``fun`` must be
168
+ weakly-referenceable.
169
+ in_shardings: optional, a :py:class:`Sharding` or pytree with
170
+ :py:class:`Sharding` leaves and structure that is a tree prefix of the
171
+ positional arguments tuple to ``fun``. If provided, the positional
172
+ arguments passed to ``fun`` must have shardings that are compatible with
173
+ ``in_shardings`` or an error is raised, and the compiled computation has
174
+ input shardings corresponding to ``in_shardings``. If not provided, the
175
+ compiled computation's input shardings are inferred from argument
176
+ sharings.
177
+ out_shardings: optional, a :py:class:`Sharding` or pytree with
178
+ :py:class:`Sharding` leaves and structure that is a tree prefix of the
179
+ output of ``fun``. If provided, it has the same effect as applying
180
+ corresponding :py:func:`jax.lax.with_sharding_constraint`s to the output
181
+ of ``fun``.
182
+ static_argnums: optional, an int or collection of ints that specify which
183
+ positional arguments to treat as static (trace- and compile-time
184
+ constant).
208
185
209
186
Static arguments should be hashable, meaning both ``__hash__`` and
210
- ``__eq__`` are implemented, and immutable. Calling the jitted function
211
- with different values for these constants will trigger recompilation.
212
- Arguments that are not arrays or containers thereof must be marked as
213
- static.
187
+ ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary
188
+ Python objects. Calling the jitted function with different values for
189
+ these constants will trigger recompilation. Arguments that are not
190
+ array-like or containers thereof must be marked as static.
214
191
215
192
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
216
193
arguments are treated as static. If ``static_argnums`` is not provided but
@@ -221,17 +198,18 @@ def jit(
221
198
provided, ``inspect.signature`` is not used, and only actual
222
199
parameters listed in either ``static_argnums`` or ``static_argnames`` will
223
200
be treated as static.
224
- static_argnames: An optional string or collection of strings specifying
201
+ static_argnames: optional, a string or collection of strings specifying
225
202
which named arguments to treat as static (compile-time constant). See the
226
203
comment on ``static_argnums`` for details. If not
227
204
provided but ``static_argnums`` is set, the default is based on calling
228
205
``inspect.signature(fun)`` to find corresponding named arguments.
229
- donate_argnums: Specify which positional argument buffers are "donated" to
230
- the computation. It is safe to donate argument buffers if you no longer
231
- need them once the computation has finished. In some cases XLA can make
232
- use of donated buffers to reduce the amount of memory needed to perform a
206
+ donate_argnums: optional, collection of integers to specify which positional
207
+ argument buffers can be overwritten by the computation and marked deleted
208
+ in the caller. It is safe to donate argument buffers if you no longer need
209
+ them once the computation has started. In some cases XLA can make use of
210
+ donated buffers to reduce the amount of memory needed to perform a
233
211
computation, for example recycling one of your input buffers to store a
234
- result. You should not reuse buffers that you donate to a computation, JAX
212
+ result. You should not reuse buffers that you donate to a computation; JAX
235
213
will raise an error if you try to. By default, no argument buffers are
236
214
donated.
237
215
@@ -247,15 +225,16 @@ def jit(
247
225
248
226
For more details on buffer donation see the
249
227
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
250
- donate_argnames: An optional string or collection of strings specifying
228
+ donate_argnames: optional, a string or collection of strings specifying
251
229
which named arguments are donated to the computation. See the
252
230
comment on ``donate_argnums`` for details. If not
253
231
provided but ``donate_argnums`` is set, the default is based on calling
254
232
``inspect.signature(fun)`` to find corresponding named arguments.
255
- keep_unused: If `False` (the default), arguments that JAX determines to be
256
- unused by `fun` *may* be dropped from resulting compiled XLA executables.
257
- Such arguments will not be transferred to the device nor provided to the
258
- underlying executable. If `True`, unused arguments will not be pruned.
233
+ keep_unused: optional boolean. If `False` (the default), arguments that JAX
234
+ determines to be unused by `fun` *may* be dropped from resulting compiled
235
+ XLA executables. Such arguments will not be transferred to the device nor
236
+ provided to the underlying executable. If `True`, unused arguments will
237
+ not be pruned.
259
238
device: This is an experimental feature and the API is likely to change.
260
239
Optional, the Device the jitted function will run on. (Available devices
261
240
can be retrieved via :py:func:`jax.devices`.) The default is inherited
@@ -264,9 +243,8 @@ def jit(
264
243
backend: This is an experimental feature and the API is likely to change.
265
244
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
266
245
``'tpu'``.
267
- inline: Specify whether this function should be inlined into enclosing
268
- jaxprs (rather than being represented as an application of the xla_call
269
- primitive with its own subjaxpr). Default False.
246
+ inline: Optional boolean. Specify whether this function should be inlined
247
+ into enclosing jaxprs. Default False.
270
248
271
249
Returns:
272
250
A wrapped version of ``fun``, set up for just-in-time compilation.
@@ -287,8 +265,8 @@ def jit(
287
265
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
288
266
-0.85743 -0.78232 0.76827 0.59566 ]
289
267
290
- To pass arguments such as ``static_argnames`` when decorating a function, a common
291
- pattern is to use :func:`functools.partial`:
268
+ To pass arguments such as ``static_argnames`` when decorating a function, a
269
+ common pattern is to use :func:`functools.partial`:
292
270
293
271
>>> from functools import partial
294
272
>>>
@@ -2470,10 +2448,10 @@ def device_put(
2470
2448
2471
2449
Args:
2472
2450
x: An array, scalar, or (nested) standard Python container thereof.
2473
- device: The (optional) :py:class:`Device`, `Sharding`, or a (nested)
2474
- `Sharding` in standard Python container (must be a tree prefix of ``x``),
2475
- representing the device(s) to which ``x`` should be transferred. If
2476
- given, then the result is committed to the device(s).
2451
+ device: The (optional) :py:class:`Device`, :py:class: `Sharding`, or a
2452
+ (nested) :py:class: `Sharding` in standard Python container (must be a tree
2453
+ prefix of ``x``), representing the device(s) to which ``x`` should be
2454
+ transferred. If given, then the result is committed to the device(s).
2477
2455
2478
2456
Returns:
2479
2457
A copy of ``x`` that resides on ``device``.
0 commit comments