Skip to content

Commit 366a3b3

Browse files
authored
Make imports relative (#310)
1 parent 8f5068c commit 366a3b3

18 files changed

+76
-77
lines changed

examples/matmul_layernorm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import torch
4+
import torch.nn.functional as F
45

56
import helion
67
from helion._testing import run_example
@@ -37,8 +38,6 @@ def matmul_layernorm(
3738
def matmul_layernorm_pytorch(
3839
x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
3940
) -> torch.Tensor:
40-
import torch.nn.functional as F
41-
4241
matmul_out = torch.matmul(x, y)
4342

4443
ln_out = F.layer_norm(

helion/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from .runtime import Kernel
1212
from .runtime import kernel
1313
from .runtime import kernel as jit # alias
14-
from helion.runtime.settings import Settings
15-
from helion.runtime.settings import set_default_settings
14+
from .runtime.settings import Settings
15+
from .runtime.settings import set_default_settings
1616

1717
__all__ = [
1818
"Config",

helion/_compiler/device_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from ..language._decorators import args_to_proxies
3333
from ..language._decorators import get_device_func_replacement
3434
from ..language._tracing_ops import _new_var
35+
from ..language.tile_proxy import Tile
36+
from ..language.tile_proxy import _CheckForIndexCalls
3537
from .ast_extension import ExtendedAST
3638
from .ast_extension import LoopType
3739
from .ast_extension import NodeVisitor
@@ -58,8 +60,6 @@
5860
from .type_propagation import _eval_binary
5961
from .type_propagation import _eval_compare
6062
from .type_propagation import _eval_unary
61-
from helion.language.tile_proxy import Tile
62-
from helion.language.tile_proxy import _CheckForIndexCalls
6363

6464
if TYPE_CHECKING:
6565
from collections.abc import Callable

helion/_compiler/lift_closures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from torch._dynamo.utils import make_cell
77

8-
from helion import exc
9-
from helion._compiler.host_function import HostFunction
10-
from helion._compiler.variable_origin import ClosureOrigin
11-
from helion._compiler.variable_origin import Origin
8+
from .. import exc
9+
from .host_function import HostFunction
10+
from .variable_origin import ClosureOrigin
11+
from .variable_origin import Origin
1212

1313

1414
class CaptureGlobals(dict[str, object]):

helion/_compiler/node_masking.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from torch.fx.experimental import proxy_tensor
1717
from torch.utils._sympy.value_ranges import ValueRanges
1818

19-
from helion.language._tracing_ops import _for_loop
20-
from helion.language._tracing_ops import _if
21-
from helion.language._tracing_ops import _mask_to
22-
from helion.language._tracing_ops import _phi
19+
from ..language._tracing_ops import _for_loop
20+
from ..language._tracing_ops import _if
21+
from ..language._tracing_ops import _mask_to
22+
from ..language._tracing_ops import _phi
2323

2424
if TYPE_CHECKING:
25-
from helion._compiler.inductor_lowering import InductorLowering
25+
from .inductor_lowering import InductorLowering
2626

2727
ValueRangesAny = ValueRanges[Any]
2828

@@ -49,7 +49,7 @@ def apply_masking(
4949
if user.args[1] == other:
5050
assert user.args[0] is node
5151
return user # reuse existing mask_to node
52-
from helion._compiler.inductor_lowering import APIFuncLowering
52+
from .inductor_lowering import APIFuncLowering
5353

5454
# If we reach here, we need to create a new mask_to node
5555
with node.graph.inserting_before(base_node):
@@ -79,9 +79,9 @@ def cached_masked_value(
7979
return node.meta["masked_value"]
8080

8181
if node.op == "placeholder":
82-
from helion._compiler.device_ir import DeviceIR
83-
from helion._compiler.device_ir import ForLoopGraphInfo
84-
from helion._compiler.device_ir import NodeArgsGraphInfo
82+
from .device_ir import DeviceIR
83+
from .device_ir import ForLoopGraphInfo
84+
from .device_ir import NodeArgsGraphInfo
8585

8686
"""
8787
We are inside a for loop or an if statement, which is represented as a subgraph.
@@ -128,7 +128,7 @@ def getitem_masked_value(
128128
Retrieve the masked value for a node that is a getitem operation.
129129
This handles loop outputs, since the `_for` node has multiple outputs.
130130
"""
131-
from helion._compiler.device_ir import DeviceIR
131+
from .device_ir import DeviceIR
132132

133133
assert not getitem_node.kwargs, "getitem kwargs not supported"
134134
node, index = getitem_node.args

helion/_compiler/roll_reduction.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@
66
import torch
77
from torch.fx import map_arg
88

9-
from helion._compiler.compile_environment import CompileEnvironment
10-
from helion._compiler.inductor_lowering import APIFuncLowering
11-
from helion._compiler.inductor_lowering import ReductionLowering
12-
from helion._compiler.inductor_lowering import aten_lowering_dispatch
13-
from helion.language._tracing_ops import _for_loop
14-
from helion.language._tracing_ops import _get_symnode
15-
from helion.language._tracing_ops import _host_tensor
16-
from helion.language._tracing_ops import _if
17-
from helion.language.memory_ops import store
18-
from helion.language.reduce_ops import _reduce
9+
from ..language._tracing_ops import _for_loop
10+
from ..language._tracing_ops import _get_symnode
11+
from ..language._tracing_ops import _host_tensor
12+
from ..language._tracing_ops import _if
13+
from ..language.memory_ops import store
14+
from ..language.reduce_ops import _reduce
15+
from .compile_environment import CompileEnvironment
16+
from .inductor_lowering import APIFuncLowering
17+
from .inductor_lowering import ReductionLowering
18+
from .inductor_lowering import aten_lowering_dispatch
1919

2020
if TYPE_CHECKING:
21-
from helion._compiler.compile_environment import BlockSizeInfo
22-
from helion._compiler.device_ir import DeviceIR
23-
from helion._compiler.device_ir import RolledReductionInfo
21+
from .compile_environment import BlockSizeInfo
22+
from .device_ir import DeviceIR
23+
from .device_ir import RolledReductionInfo
2424

2525
_duplicate_ops: tuple[object, ...] = (
2626
_host_tensor,

helion/_compiler/tile_dispatch.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@
44
import operator
55
from typing import TYPE_CHECKING
66

7-
from helion._compiler.compile_environment import CompileEnvironment
8-
from helion._compiler.device_function import DeviceFunction
9-
from helion._compiler.device_ir import ForLoopGraphInfo
10-
from helion._compiler.device_ir import ReductionLoopGraphInfo
11-
from helion._compiler.host_function import HostFunction
12-
from helion._compiler.reduction_strategy import LoopedReductionStrategy
13-
from helion._compiler.reduction_strategy import PersistentReductionStrategy
14-
from helion._compiler.reduction_strategy import ReductionStrategy
15-
from helion._compiler.tile_strategy import CompactedShape
16-
from helion._compiler.tile_strategy import DeviceLoopState
17-
from helion._compiler.tile_strategy import FlattenedTileStrategy
18-
from helion._compiler.tile_strategy import NDTileStrategy
19-
from helion._compiler.tile_strategy import TileStrategy
7+
from .compile_environment import CompileEnvironment
8+
from .device_function import DeviceFunction
9+
from .device_ir import ForLoopGraphInfo
10+
from .device_ir import ReductionLoopGraphInfo
11+
from .host_function import HostFunction
12+
from .reduction_strategy import LoopedReductionStrategy
13+
from .reduction_strategy import PersistentReductionStrategy
14+
from .reduction_strategy import ReductionStrategy
15+
from .tile_strategy import CompactedShape
16+
from .tile_strategy import DeviceLoopState
17+
from .tile_strategy import FlattenedTileStrategy
18+
from .tile_strategy import NDTileStrategy
19+
from .tile_strategy import TileStrategy
2020

2121
if TYPE_CHECKING:
2222
from collections.abc import Sequence
2323

2424
import sympy
2525
import torch
2626

27-
from helion import Config
28-
from helion._compiler.inductor_lowering import CodegenState
27+
from .. import Config
28+
from .inductor_lowering import CodegenState
2929

3030
SymIntLike = torch.SymInt | int
3131
ShapeLike = Sequence[SymIntLike]

helion/_compiler/type_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from ..autotuner.config_spec import BlockSizeSpec
2828
from ..language._decorators import get_device_func_replacement
2929
from ..language._decorators import is_api_func
30+
from ..language.tile_proxy import Tile
31+
from ..language.tile_proxy import _CheckForIndexCalls
3032
from .ast_extension import ExtendedAST
3133
from .ast_extension import LoopType
3234
from .ast_extension import create
@@ -50,8 +52,6 @@
5052
from .variable_origin import SourceOrigin
5153
from .variable_origin import TensorSizeOrigin
5254
import helion
53-
from helion.language.tile_proxy import Tile
54-
from helion.language.tile_proxy import _CheckForIndexCalls
5555

5656
if TYPE_CHECKING:
5757
from collections.abc import Callable

helion/autotuner/config_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from typing import TYPE_CHECKING
88
from typing import cast
99

10-
from helion._compat import warps_to_threads
11-
from helion.autotuner.config_fragment import Category
12-
from helion.autotuner.config_fragment import ConfigSpecFragment
13-
from helion.autotuner.config_fragment import PowerOfTwoFragment
10+
from .._compat import warps_to_threads
11+
from .config_fragment import Category
12+
from .config_fragment import ConfigSpecFragment
13+
from .config_fragment import PowerOfTwoFragment
1414

1515
if TYPE_CHECKING:
16-
from helion import Config
17-
from helion.autotuner import ConfigSpec
16+
from .. import Config
17+
from . import ConfigSpec
1818

1919
FlatConfig = list[object]
2020

helion/autotuner/config_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from collections.abc import Callable
2929
from collections.abc import Sequence
3030

31-
from helion.runtime.config import IndexingLiteral
32-
from helion.runtime.config import PidTypeLiteral
31+
from ..runtime.config import IndexingLiteral
32+
from ..runtime.config import PidTypeLiteral
3333

3434
DEFAULT_NUM_WARPS = 4
3535
DEFAULT_NUM_STAGES = 3

0 commit comments

Comments
 (0)