Skip to content

Commit b8df23c

Browse files
committed
tweak jit docstring
1 parent 52f5f70 commit b8df23c

File tree

1 file changed

+50
-72
lines changed

1 file changed

+50
-72
lines changed

jax/_src/api.py

Lines changed: 50 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -157,60 +157,37 @@ def jit(
157157
"""Sets up ``fun`` for just-in-time compilation with XLA.
158158
159159
Args:
160-
fun: Function to be jitted. ``fun`` should be a pure function, as
161-
side-effects may only be executed once.
162-
163-
The arguments and return value of ``fun`` should be arrays,
164-
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
165-
Positional arguments indicated by ``static_argnums`` can be anything at
166-
all, provided they are hashable and have an equality operation defined.
167-
Static arguments are included as part of a compilation cache key, which is
168-
why hash and equality operators must be defined.
169-
170-
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
171-
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
172-
objects will already satisfy this requirement.
173-
in_shardings: Pytree of structure matching that of arguments to ``fun``,
174-
with all actual arguments replaced by resource assignment specifications.
175-
It is also valid to specify a pytree prefix (e.g. one value in place of a
176-
whole subtree), in which case the leaves get broadcast to all values in
177-
that subtree.
178-
179-
The ``in_shardings`` argument is optional. JAX will infer the shardings
180-
from the input :py:class:`jax.Array`'s and defaults to replicating the input
181-
if the sharding cannot be inferred.
182-
183-
The valid resource assignment specifications are:
184-
- :py:class:`XLACompatibleSharding`, which will decide how the value
185-
will be partitioned. With this, using a mesh context manager is not
186-
required.
187-
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
188-
it wants.
189-
For in_shardings, JAX will mark is as replicated but this behavior
190-
can change in the future.
191-
For out_shardings, we will rely on the XLA GSPMD partitioner to
192-
determine the output shardings.
193-
194-
The size of every dimension has to be a multiple of the total number of
195-
resources assigned to it. This is similar to pjit's in_shardings.
196-
out_shardings: Like ``in_shardings``, but specifies resource
197-
assignment for function outputs. This is similar to pjit's
198-
out_shardings.
199-
200-
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
201-
will use GSPMD's sharding propagation to figure out what the sharding of the
202-
output(s) should be.
203-
static_argnums: An optional int or collection of ints that specify which
204-
positional arguments to treat as static (compile-time constant).
205-
Operations that only depend on static arguments will be constant-folded in
206-
Python (during tracing), and so the corresponding argument values can be
207-
any Python object.
160+
fun: Function to be jitted. ``fun`` should be a pure function.
161+
162+
The arguments and return value of ``fun`` should be arrays, scalar, or
163+
(nested) standard Python containers (tuple/list/dict) thereof. Positional
164+
arguments indicated by ``static_argnums`` can be any hashable type. Static
165+
arguments are included as part of a compilation cache key, which is why
166+
hash and equality operators must be defined. JAX keeps a weak reference to
167+
``fun`` for use as a compilation cache key, so the object ``fun`` must be
168+
weakly-referenceable.
169+
in_shardings: optional, a :py:class:`Sharding` or pytree with
170+
:py:class:`Sharding` leaves and structure that is a tree prefix of the
171+
positional arguments tuple to ``fun``. If provided, the positional
172+
arguments passed to ``fun`` must have shardings that are compatible with
173+
``in_shardings`` or an error is raised, and the compiled computation has
174+
input shardings corresponding to ``in_shardings``. If not provided, the
175+
compiled computation's input shardings are inferred from argument
176+
sharings.
177+
out_shardings: optional, a :py:class:`Sharding` or pytree with
178+
:py:class:`Sharding` leaves and structure that is a tree prefix of the
179+
output of ``fun``. If provided, it has the same effect as applying
180+
corresponding :py:func:`jax.lax.with_sharding_constraint`s to the output
181+
of ``fun``.
182+
static_argnums: optional, an int or collection of ints that specify which
183+
positional arguments to treat as static (trace- and compile-time
184+
constant).
208185
209186
Static arguments should be hashable, meaning both ``__hash__`` and
210-
``__eq__`` are implemented, and immutable. Calling the jitted function
211-
with different values for these constants will trigger recompilation.
212-
Arguments that are not arrays or containers thereof must be marked as
213-
static.
187+
``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary
188+
Python objects. Calling the jitted function with different values for
189+
these constants will trigger recompilation. Arguments that are not
190+
array-like or containers thereof must be marked as static.
214191
215192
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
216193
arguments are treated as static. If ``static_argnums`` is not provided but
@@ -221,17 +198,18 @@ def jit(
221198
provided, ``inspect.signature`` is not used, and only actual
222199
parameters listed in either ``static_argnums`` or ``static_argnames`` will
223200
be treated as static.
224-
static_argnames: An optional string or collection of strings specifying
201+
static_argnames: optional, a string or collection of strings specifying
225202
which named arguments to treat as static (compile-time constant). See the
226203
comment on ``static_argnums`` for details. If not
227204
provided but ``static_argnums`` is set, the default is based on calling
228205
``inspect.signature(fun)`` to find corresponding named arguments.
229-
donate_argnums: Specify which positional argument buffers are "donated" to
230-
the computation. It is safe to donate argument buffers if you no longer
231-
need them once the computation has finished. In some cases XLA can make
232-
use of donated buffers to reduce the amount of memory needed to perform a
206+
donate_argnums: optional, collection of integers to specify which positional
207+
argument buffers can be overwritten by the computation and marked deleted
208+
in the caller. It is safe to donate argument buffers if you no longer need
209+
them once the computation has started. In some cases XLA can make use of
210+
donated buffers to reduce the amount of memory needed to perform a
233211
computation, for example recycling one of your input buffers to store a
234-
result. You should not reuse buffers that you donate to a computation, JAX
212+
result. You should not reuse buffers that you donate to a computation; JAX
235213
will raise an error if you try to. By default, no argument buffers are
236214
donated.
237215
@@ -247,15 +225,16 @@ def jit(
247225
248226
For more details on buffer donation see the
249227
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
250-
donate_argnames: An optional string or collection of strings specifying
228+
donate_argnames: optional, a string or collection of strings specifying
251229
which named arguments are donated to the computation. See the
252230
comment on ``donate_argnums`` for details. If not
253231
provided but ``donate_argnums`` is set, the default is based on calling
254232
``inspect.signature(fun)`` to find corresponding named arguments.
255-
keep_unused: If `False` (the default), arguments that JAX determines to be
256-
unused by `fun` *may* be dropped from resulting compiled XLA executables.
257-
Such arguments will not be transferred to the device nor provided to the
258-
underlying executable. If `True`, unused arguments will not be pruned.
233+
keep_unused: optional boolean. If `False` (the default), arguments that JAX
234+
determines to be unused by `fun` *may* be dropped from resulting compiled
235+
XLA executables. Such arguments will not be transferred to the device nor
236+
provided to the underlying executable. If `True`, unused arguments will
237+
not be pruned.
259238
device: This is an experimental feature and the API is likely to change.
260239
Optional, the Device the jitted function will run on. (Available devices
261240
can be retrieved via :py:func:`jax.devices`.) The default is inherited
@@ -264,9 +243,8 @@ def jit(
264243
backend: This is an experimental feature and the API is likely to change.
265244
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
266245
``'tpu'``.
267-
inline: Specify whether this function should be inlined into enclosing
268-
jaxprs (rather than being represented as an application of the xla_call
269-
primitive with its own subjaxpr). Default False.
246+
inline: Optional boolean. Specify whether this function should be inlined
247+
into enclosing jaxprs. Default False.
270248
271249
Returns:
272250
A wrapped version of ``fun``, set up for just-in-time compilation.
@@ -287,8 +265,8 @@ def jit(
287265
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
288266
-0.85743 -0.78232 0.76827 0.59566 ]
289267
290-
To pass arguments such as ``static_argnames`` when decorating a function, a common
291-
pattern is to use :func:`functools.partial`:
268+
To pass arguments such as ``static_argnames`` when decorating a function, a
269+
common pattern is to use :func:`functools.partial`:
292270
293271
>>> from functools import partial
294272
>>>
@@ -2470,10 +2448,10 @@ def device_put(
24702448
24712449
Args:
24722450
x: An array, scalar, or (nested) standard Python container thereof.
2473-
device: The (optional) :py:class:`Device`, `Sharding`, or a (nested)
2474-
`Sharding` in standard Python container (must be a tree prefix of ``x``),
2475-
representing the device(s) to which ``x`` should be transferred. If
2476-
given, then the result is committed to the device(s).
2451+
device: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a
2452+
(nested) :py:class:`Sharding` in standard Python container (must be a tree
2453+
prefix of ``x``), representing the device(s) to which ``x`` should be
2454+
transferred. If given, then the result is committed to the device(s).
24772455
24782456
Returns:
24792457
A copy of ``x`` that resides on ``device``.

0 commit comments

Comments
 (0)