Skip to content

Commit 808e1d6

Browse files
authored
Rename TileIndexProxy to hl.Tile (#171)
This makes the docs more readable when they are auto-generated from docstrings.
1 parent b460e5f commit 808e1d6

File tree

9 files changed

+78
-62
lines changed

9 files changed

+78
-62
lines changed

helion/_compiler/device_ir.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@
4444
from .node_masking import remove_unnecessary_masking
4545
from .roll_reduction import ReductionRoller
4646
from .source_location import current_location
47-
from .tile_index_proxy import CheckForIndexCalls
48-
from .tile_index_proxy import TileIndexProxy
4947
from .type_propagation import CallableType
5048
from .type_propagation import GridIndexType
5149
from .type_propagation import IterType
@@ -58,6 +56,8 @@
5856
from .type_propagation import _eval_binary
5957
from .type_propagation import _eval_compare
6058
from .type_propagation import _eval_unary
59+
from helion.language.tile_proxy import Tile
60+
from helion.language.tile_proxy import _CheckForIndexCalls
6161

6262
if TYPE_CHECKING:
6363
from collections.abc import Callable
@@ -83,7 +83,7 @@ def _get_proxy_slot(
8383
default: object = proxy_tensor.no_default,
8484
transform: Callable[[object], object] = lambda x: x,
8585
) -> object:
86-
if isinstance(obj, torch.Tensor) and not isinstance(obj, TileIndexProxy):
86+
if isinstance(obj, torch.Tensor) and not isinstance(obj, Tile):
8787
tracker = tracer.tensor_tracker
8888
if obj not in tracker:
8989
origin = HostFunction.current().tensor_to_origin[obj]
@@ -473,7 +473,7 @@ def disable_tracing() -> Iterator[proxy_tensor.PythonKeyTracer]:
473473

474474
@staticmethod
475475
def should_become_arg(value: object) -> bool:
476-
if isinstance(value, (TileIndexProxy, torch.SymInt)):
476+
if isinstance(value, (Tile, torch.SymInt)):
477477
return False
478478
if isinstance(value, torch.Tensor):
479479
if (
@@ -584,7 +584,7 @@ def run_subgraph(*args: object) -> list[object]:
584584
tracer=tracer,
585585
)
586586
for name, value in outputs.unflatten().items():
587-
if isinstance(value, TileIndexProxy):
587+
if isinstance(value, Tile):
588588
continue
589589
if name in self.scope:
590590
try:
@@ -803,7 +803,7 @@ def visit_Call(self, node: ast.Call) -> object:
803803
func = self.visit(node.func)
804804

805805
# pyre-ignore[6]
806-
return CheckForIndexCalls.retry_call(func, args, kwargs)
806+
return _CheckForIndexCalls.retry_call(func, args, kwargs)
807807

808808
def visit_Attribute(self, node: ast.Attribute) -> object:
809809
return getattr(self.visit(node.value), node.attr)
@@ -825,7 +825,7 @@ def __init__(self, values: dict[str, object]) -> None:
825825
self.tensor_indices = [
826826
i
827827
for i, v in enumerate(self.flat_values)
828-
if isinstance(v, torch.Tensor) and not isinstance(v, TileIndexProxy)
828+
if isinstance(v, torch.Tensor) and not isinstance(v, Tile)
829829
]
830830

831831
def unflatten(self) -> dict[str, object]:

helion/_compiler/type_propagation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from .output_header import library_imports
4040
from .source_location import SourceLocation
4141
from .source_location import current_location
42-
from .tile_index_proxy import CheckForIndexCalls
43-
from .tile_index_proxy import TileIndexProxy
4442
from .variable_origin import ArgumentOrigin
4543
from .variable_origin import AttributeOrigin
4644
from .variable_origin import BuiltinOrigin
@@ -51,6 +49,8 @@
5149
from .variable_origin import SourceOrigin
5250
from .variable_origin import TensorSizeOrigin
5351
import helion
52+
from helion.language.tile_proxy import Tile
53+
from helion.language.tile_proxy import _CheckForIndexCalls
5454

5555
# pyre-ignore-all-errors[8,15,58]: visit_* overrides
5656
if TYPE_CHECKING:
@@ -623,7 +623,7 @@ def propagate_call(
623623
try:
624624
fn = getattr(self.tensor.fake_value, attr)
625625
output_type = TypeInfo.from_example(
626-
CheckForIndexCalls.retry_call(fn, proxy_args, proxy_kwargs), origin
626+
_CheckForIndexCalls.retry_call(fn, proxy_args, proxy_kwargs), origin
627627
)
628628
except exc.Base:
629629
raise
@@ -775,7 +775,9 @@ def to_proxy(arg: TypeInfo) -> object:
775775
try:
776776
with patch.object(torch.SymInt, "__index__", _raise_shape_specializing):
777777
output_type = TypeInfo.from_example(
778-
CheckForIndexCalls.retry_call(self.value, proxy_args, proxy_kwargs),
778+
_CheckForIndexCalls.retry_call(
779+
self.value, proxy_args, proxy_kwargs
780+
),
779781
origin,
780782
)
781783
output_type.tree_map(warn_wrong_device)
@@ -1006,7 +1008,7 @@ def proxy(self) -> object:
10061008
torch._C._TorchDispatchModeKey.FAKE
10071009
)
10081010
try:
1009-
return TileIndexProxy(self.block_id)
1011+
return Tile(self.block_id)
10101012
finally:
10111013
assert fake_mode is not None
10121014
torch._C._set_dispatch_mode(fake_mode)
@@ -1048,7 +1050,7 @@ def merge(self, other: TypeInfo) -> TypeInfo:
10481050
return super().merge(other)
10491051

10501052
def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
1051-
if isinstance(getattr(TileIndexProxy, attr, None), property):
1053+
if isinstance(getattr(Tile, attr, None), property):
10521054
return TypeInfo.from_example(getattr(self.proxy(), attr), origin)
10531055
return super().propagate_attribute(attr, origin)
10541056

helion/language/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from .creation_ops import full as full
66
from .creation_ops import zeros as zeros
77
from .device_print import device_print as device_print
8-
from .loops import Tile as Tile
98
from .loops import grid as grid
109
from .loops import tile as tile
1110
from .memory_ops import atomic_add as atomic_add
1211
from .memory_ops import load as load
1312
from .memory_ops import store as store
14-
from .tiles import tile_begin as tile_begin
15-
from .tiles import tile_block_size as tile_block_size
16-
from .tiles import tile_end as tile_end
17-
from .tiles import tile_index as tile_index
13+
from .tile_ops import tile_begin as tile_begin
14+
from .tile_ops import tile_block_size as tile_block_size
15+
from .tile_ops import tile_end as tile_end
16+
from .tile_ops import tile_index as tile_index
17+
from .tile_proxy import Tile as Tile
1818
from .tunable_ops import register_block_size as register_block_size
1919
from .tunable_ops import register_reduction_dim as register_reduction_dim
2020
from .tunable_ops import register_tunable as register_tunable

helion/language/_decorators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def unpack(x: object) -> object:
111111

112112

113113
def tiles_as_sizes_prepare_args(*args: object) -> tuple[object, ...]:
114-
from helion._compiler.tile_index_proxy import TileIndexProxy
114+
from helion.language.tile_proxy import Tile
115115

116-
return TileIndexProxy.tiles_to_sizes(args)
116+
return Tile._tiles_to_sizes(args)
117117

118118

119119
def no_op_prepare_args(*args: object) -> tuple[object, ...]:
@@ -273,12 +273,12 @@ def _default_type_function(
273273
def type_prop_with_fake_fn(
274274
*args: object, origin: Origin, **kwargs: object
275275
) -> TypeInfo:
276-
from .._compiler.tile_index_proxy import TileIndexProxy
277276
from .._compiler.type_propagation import TypeInfo
277+
from helion.language.tile_proxy import Tile
278278

279279
args, kwargs = tree_map_only(TypeInfo, _to_proxy, (args, kwargs))
280280
if tiles_as_sizes:
281-
args, kwargs = TileIndexProxy.tiles_to_sizes((args, kwargs))
281+
args, kwargs = Tile._tiles_to_sizes((args, kwargs))
282282
return TypeInfo.from_example(fake_fn(*args, **kwargs), origin)
283283

284284
return type_prop_with_fake_fn

helion/language/_tracing_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from .._compiler.ast_extension import expr_from_string
1313
from .._compiler.compile_environment import CompileEnvironment
1414
from .._compiler.host_function import HostFunction
15-
from .._compiler.tile_index_proxy import TileIndexProxy
1615
from ..exc import NotInsideKernel
1716
from . import _decorators
17+
from helion.language.tile_proxy import Tile
1818

1919
if TYPE_CHECKING:
2020
from .._compiler.inductor_lowering import CodegenState
@@ -96,9 +96,9 @@ def _phi(lhs: object, rhs: object) -> object:
9696

9797
@_decorators.register_fake(_phi)
9898
def _(lhs: object, rhs: object) -> object:
99-
if isinstance(lhs, TileIndexProxy):
100-
assert isinstance(rhs, TileIndexProxy)
101-
assert lhs.block_size_index == rhs.block_size_index
99+
if isinstance(lhs, Tile):
100+
assert isinstance(rhs, Tile)
101+
assert lhs.block_id == rhs.block_id
102102
return lhs
103103
assert isinstance(lhs, torch.Tensor), lhs
104104
assert isinstance(rhs, torch.Tensor), rhs

helion/language/loops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .._compiler.ast_extension import LoopType
1616
from .._compiler.ast_extension import expr_from_string
1717
from .._compiler.compile_environment import CompileEnvironment
18-
from .._compiler.tile_index_proxy import TileIndexProxy
1918
from .._compiler.type_propagation import GridIndexType
2019
from .._compiler.type_propagation import IterType
2120
from .._compiler.type_propagation import Origin
@@ -28,15 +27,15 @@
2827
from ..autotuner.config_spec import L2GroupingSpec
2928
from ..autotuner.config_spec import LoopOrderSpec
3029
from . import _decorators
30+
from helion.language.tile_proxy import Tile
3131

3232
if TYPE_CHECKING:
3333
from collections.abc import Sequence
3434

3535
from .._compiler.inductor_lowering import CodegenState
3636

3737

38-
__all__ = ["Tile", "grid", "tile"]
39-
Tile = TileIndexProxy
38+
__all__ = ["grid", "tile"]
4039

4140

4241
@overload
@@ -186,7 +185,7 @@ def _(
186185
proxy_end = _to_proxy(end)
187186
_check_matching(proxy_begin, proxy_end)
188187
if _not_none(block_size):
189-
proxy_block_size = TileIndexProxy.tiles_to_sizes(_to_proxy(block_size))
188+
proxy_block_size = Tile._tiles_to_sizes(_to_proxy(block_size))
190189
_check_matching(proxy_end, proxy_block_size)
191190
else:
192191
proxy_block_size = begin.tree_map(lambda n: None)

helion/language/memory_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def _(
4444
value: torch.Tensor | torch.SymInt | float,
4545
extra_mask: torch.Tensor | None = None,
4646
) -> tuple[torch.Tensor, list[object], torch.Tensor | torch.SymInt | int | float]:
47-
from helion._compiler.tile_index_proxy import TileIndexProxy
47+
from helion.language.tile_proxy import Tile
4848

4949
if hasattr(value, "dtype") and value.dtype != tensor.dtype:
5050
value = value.to(tensor.dtype)
51-
index = TileIndexProxy.tiles_to_sizes(index)
51+
index = Tile._tiles_to_sizes(index)
5252
return (tensor, index, value, extra_mask)
5353

5454

@@ -147,16 +147,16 @@ def _(
147147
value: torch.Tensor | float,
148148
sem: str = "relaxed",
149149
) -> tuple[torch.Tensor, object, torch.Tensor | float | int, str]:
150-
from helion._compiler.tile_index_proxy import TileIndexProxy
150+
from helion.language.tile_proxy import Tile
151151

152152
valid_sems = {"relaxed", "acquire", "release", "acq_rel"}
153153
if sem not in valid_sems:
154154
raise ValueError(
155155
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
156156
)
157157

158-
index = TileIndexProxy.prepare_index(index)
159-
index = TileIndexProxy.tiles_to_sizes(index)
158+
index = Tile._prepare_index(index)
159+
index = Tile._tiles_to_sizes(index)
160160

161161
return (target, index, value, sem)
162162

File renamed without changes.

0 commit comments

Comments
 (0)