Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 432fe37

Browse files
Improve cross referencing
1 parent d100c52 commit 432fe37

File tree

3 files changed

+55
-27
lines changed

3 files changed

+55
-27
lines changed

docs/source/framework/pytorch_integration/python_api.rst

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@ Comprehensions.
1313

1414
.. autofunction:: make_autograd
1515

16+
The :func:`define` function provides an implicit compilation caching
17+
functionality which alleviates the need to implement a caching mechanism at
18+
the user-facing level. The question still remains which :class:`~tclib.MappingOptions`
19+
to use to compile. Since this is still an open problem, we provide support
20+
for user-defined functions to specify this behavior. We require a user
21+
of the :func:`define` function to provide a :class:`~tclib.MappingOptions` generator
22+
function whose sole purpose is to determine the options with which to compile
23+
a particular TC def for particular input sizes.
24+
25+
To facilitate usage we provide the following generators:
26+
27+
.. autofunction:: make_naive_options_factory
28+
29+
.. autofunction:: make_load_from_cache_options_factory
30+
31+
.. autofunction:: make_autotuned_options_factory
32+
33+
Custom behavior to select :class:`~tclib.MappingOptions` may be implemented
34+
in addition to the provided defaults. The signature of custom generators must
35+
match:
36+
37+
.. code-block:: python
38+
39+
def some_generator(tc: str, entry_point: str, *inputs: torch.Tensor)
40+
-> MappingOptions:
41+
...
1642
1743
Low-level API
1844
-------------
@@ -31,7 +57,7 @@ generally useful for benchmarking.
3157

3258
.. autofunction:: autotune_and_compile
3359

34-
Additionally the :code:`assert_almost_equal` helper function is useful in
60+
Additionally the :func:`assert_almost_equal` helper function is useful in
3561
performing numerical checks.
3662

3763
.. autofunction:: assert_almost_equal

docs/source/framework/pytorch_integration/writing_layers.rst

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
Writing TC operations
22
=====================
33

4+
.. automodule:: tensor_comprehensions
5+
46
This document focuses on writing TC operations using the high-level API.
57
For examples of using the low-level API, see the Python API documentation.
68

79
To create a CUDA kernel implementing an operation backed by TC, one should:
810

911
1. Create a callable TC object by calling :func:`define`
1012
2. Create input PyTorch Tensors
11-
3. Call the helper object with the input PyTorch Tensors
13+
3. Call the TC object with the input PyTorch Tensors
1214

1315
When running, the backend ensures the TC is compiled and memoized for the
14-
given input tensor sizes (see the documentation for :func:`define` for more detals).
16+
given input tensor sizes (see the documentation for :func:`define` for more details).
1517
Calling the object returned by :func:`define` executes the
1618
corresponding operation and returns a list of outputs.
1719
If the operation has already been compiled, in the following runs, the TC
@@ -23,11 +25,11 @@ Example
2325

2426
The following example demonstrates the steps above.
2527
We use the :func:`make_naive_options_factory` builder function to provide
26-
naive :class:`MappingOptions`. Naive options result in poor performance.
27-
At this time, there is no notion of a default :class:`MappingOptions`.
28+
naive :class:`~tclib.MappingOptions`. Naive options result in poor performance.
29+
At this time, there is no notion of a default :class:`~tclib.MappingOptions`.
2830
Instead one should use the autotuner to perform an evolutionary search
29-
starting from an initial :class:`MappingOptions` object and return a better
30-
:class:`MappingOptions` object for a given TC function and sizes (more on this
31+
starting from an initial :class:`~tclib.MappingOptions` object and return a better
32+
:class:`~tclib.MappingOptions` object for a given TC function and sizes (more on this
3133
below).
3234

3335
.. code-block:: python
@@ -50,19 +52,19 @@ below).
5052
Specifying MappingOptions
5153
-------------------------
5254

53-
There are three ways to construct :class:`MappingOptions` when defining a TC:
55+
There are three ways to construct :class:`~tclib.MappingOptions` when defining a TC:
5456

5557
* **Naive MappingOptions**:
5658

5759
* :code:`naive`: this is provided to create a basic GPU mapping strategy with
5860
3-D tiling by 32x32x32, mapping to 256x256 blocks 32x8 threads. This
5961
should by no means be considered a good baseline but just a point to
6062
get started using TC. Once a correct TC is written, we recommend either
61-
using options loaded from a :class:`MappingOptionsCache` or resulting from
62-
a tuning run. One can also modify a :class:`MappingOptions` object
63+
using options loaded from a :class:`~tclib.MappingOptionsCache` or resulting from
64+
a tuning run. One can also modify a :class:`~tclib.MappingOptions` object
6365
programmatically (see the API documentation).
6466

65-
* **Loading from MappingOptionsCache**: a :class:`MappingOptionsCache` provides
67+
* **Loading from MappingOptionsCache**: a :class:`~tclib.MappingOptionsCache` provides
6668
a simple interface to load the best options from a previous tuning run.
6769

6870
* **Autotuning**: A kernel can be autotuned for fixed input tensor sizes.
@@ -73,7 +75,7 @@ There are three ways to construct :class:`MappingOptions` when defining a TC:
7375
Loading from cache
7476
------------------
7577

76-
Loading the best options from a previously serialized :class:`MappingOptionsCache`
78+
Loading the best options from a previously serialized :class:`~tclib.MappingOptionsCache`
7779
can be achieved by making a factory function with
7880
:func:`make_load_from_cache_options_factory` and passing it as an argument to the
7981
:func:`define` function:
@@ -91,7 +93,7 @@ can be achieved by making a factory function with
9193
torch.randn(G, D, device='cuda'))
9294
Sum, SumSq, O = T.group_normalization(I, gamma, beta)
9395
94-
One can also use the low-level :class:`MappingOptionsCache`.
96+
One can also use the low-level :class:`~tclib.MappingOptionsCache`.
9597

9698
Autotuning
9799
----------
@@ -121,10 +123,10 @@ Tuning can be achieved by making a factory function with
121123
that case, the compilation and evaluation jobs currently in flight will
122124
be flushed, but no new compilation job will be created. Once the jobs in
123125
flight are flushed, saving to cache occurs (if requested) and the best
124-
:class:`MappingOptions` found so far will be returned.
126+
:class:`~tclib.MappingOptions` found so far will be returned.
125127

126128
Tuning behavior can be modified by defining the TC with an optional
127-
:class:`TunerConfig` parameter constructed as such:
129+
:class:`~tclib.TunerConfig` parameter constructed as such:
128130
:code:`tuner_config=tc.TunerConfig().threads(5).generations(3).pop_size(5)`.
129131

130132
.. note::

tensor_comprehensions/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ def autotune(tc: str,
166166
:param inputs: PyTorch Tensors that TC should tune for. The inputs must be
167167
passed in the order they are also passed in the definition of
168168
the TC function.
169-
:param starting_options: :code:`MappingOptions` from which tuning should start.
170-
:param tuner_config: :code:`TunerConfig` to control the behavior of the autotuner.
171-
:param load_from_cache: Get the starting MappingOptions by loading from
169+
:param starting_options: :class:`~tclib.MappingOptions` from which tuning should start.
170+
:param tuner_config: :class:`~tclib.TunerConfig` to control the behavior of the autotuner.
171+
:param load_from_cache: Get the starting :class:`~tclib.MappingOptions` by loading from
172172
:code:`cache_filename`. If loading fails to recover an entry
173173
from the cache file for the given input sizes an assertion error
174174
will trigger.
@@ -256,13 +256,13 @@ def autotune_and_compile(
256256

257257
def make_naive_options_factory() -> (
258258
Callable[[str, str, Iterable[torch.Tensor]], MappingOptions]):
259-
r"""Return a factory that always generates naive :class:`MappingOptions`.
259+
r"""Return a factory that always generates naive :class:`~tclib.MappingOptions`.
260260
261261
For easily getting started with TC and debugging purposes only.
262262
263263
:rtype: a function that takes a string with multiple
264264
TC defs, an entry_point and input PyTorch Tensors and produces a
265-
:class:`MappingOptions`.
265+
:class:`~tclib.MappingOptions`.
266266
"""
267267
def generate(tc: str,
268268
entry_point: str,
@@ -273,12 +273,12 @@ def generate(tc: str,
273273

274274
def make_load_from_cache_options_factory(cache_filename: str) -> (
275275
Callable[[str, str, Iterable[torch.Tensor]], MappingOptions]):
276-
r"""Return a factory that loads :class:`MappingOptions` from a cache file.
276+
r"""Return a factory that loads :class:`~tclib.MappingOptions` from a cache file.
277277
278278
:param cache_filename: the filename
279279
:rtype: a function that takes a string with multiple
280280
TC defs, an entry_point and input PyTorch Tensors and produces a
281-
:class:`MappingOptions`.
281+
:class:`~tclib.MappingOptions`.
282282
"""
283283
def generate(tc: str,
284284
entry_point: str,
@@ -298,14 +298,14 @@ def make_autotuned_options_factory(
298298
load_from_cache: Optional[bool] = False,
299299
store_to_cache: Optional[bool] = False) -> (
300300
Callable[[str, str, Iterable[torch.Tensor]], MappingOptions]):
301-
r"""Return a factory that runs autotuning to determine the best :class:`MappingOptions`.
301+
r"""Return a factory that runs autotuning to determine the best :class:`~tclib.MappingOptions`.
302302
303303
The returned factory just calls the :func:`autotune` function, see
304-
it documentation for more information.
304+
its documentation for more information.
305305
306306
:rtype: a function that takes a string with multiple
307307
TC defs, an entry_point and input PyTorch Tensors and produces a
308-
:class:`MappingOptions`.
308+
:class:`~tclib.MappingOptions`.
309309
"""
310310
def generate(tc: str,
311311
entry_point: str,
@@ -408,7 +408,7 @@ def define(tc: str,
408408
with PyTorch Tensors of new sizes. The returned :class:`TC` helper class is
409409
backed by a compilation cache which memoizes the results of compilation and
410410
avoids spurious recompilations. In order to determine the
411-
:class:`MappingOptions`, used for JIT compiling a particular TC def on
411+
:class:`~tclib.MappingOptions`, used for JIT compiling a particular TC def on
412412
inputs of particular sizes, the :code:`mapping_options_factory`
413413
function is called. We provide the factory builder functions
414414
:func:`make_naive_options_factory`,
@@ -427,7 +427,7 @@ def define(tc: str,
427427
:param tc: a string containing one of more TC defs.
428428
:param mapping_options_factory: a function that takes a string with multiple
429429
TC defs, an entry_point and input PyTorch Tensors and produces a
430-
:class:`MappingOptions`.
430+
:class:`~tclib.MappingOptions`.
431431
:rtype: a Callable helper object with methods corresponding to the TC def
432432
names and backed by a compilation cache.
433433

0 commit comments

Comments
 (0)