Skip to content

Commit 6ec6f1d

Browse files
authored
Add error for using a host tensor directly (#306)
1 parent 25b3ab9 commit 6ec6f1d

File tree

8 files changed

+63
-7
lines changed

8 files changed

+63
-7
lines changed

docs/api/exceptions.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ These exceptions occur when Helion language functions are used incorrectly with
135135
136136
Raised for invalid types in tensor subscripts.
137137
138+
.. autoclass:: HostTensorDirectUsage
139+
140+
Raised when host tensors are used directly in device code without proper indexing.
138141
```
139142

140143
## Assignment and Variable Errors

helion/_compiler/device_ir.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
920920
for graph in device_ir.graphs:
921921
prepare_graph_lowerings(graph.graph)
922922
for graph in device_ir.graphs:
923+
validate_host_tensor_usage(graph.graph)
923924
remove_unnecessary_tile_index(graph.graph)
924925
remove_unnecessary_masking(graph.graph)
925926
device_ir.build_rolled_reductions()
@@ -949,6 +950,32 @@ def codegen(self, state: CodegenState) -> list[object]:
949950
return codegen_helper_function_graph_info(self, state)
950951

951952

953+
def validate_host_tensor_usage(graph: torch.fx.Graph) -> None:
954+
"""
955+
Validate that scalar _host_tensor ops only flow into allowed operations.
956+
This replaces the AST visitor context detection with cleaner FX graph validation.
957+
Only checks 0-dimensional tensors (scalars), not regular tensors.
958+
Uses decorator metadata to determine which operations allow host tensors.
959+
"""
960+
from ..language._decorators import is_api_func
961+
from ..language._tracing_ops import _host_tensor
962+
963+
for node in graph.find_nodes(op="call_function", target=_host_tensor):
964+
scalar_tensor_name = node.args[0]
965+
assert isinstance(scalar_tensor_name, str), scalar_tensor_name
966+
967+
# Check all users of this scalar _host_tensor node
968+
for user in node.users:
969+
if user.op == "call_function":
970+
# Check if this operation allows host tensors via decorator metadata
971+
if not (
972+
is_api_func(user.target)
973+
and getattr(user.target, "_allow_host_tensor", False)
974+
):
975+
op_name = getattr(user.target, "__name__", str(user.target))
976+
raise exc.HostTensorDirectUsage(scalar_tensor_name, op_name)
977+
978+
952979
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
953980
"""
954981
Remove unnecessary tile_index nodes from the graph.

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ class NotAllowedOnDevice(BaseError):
115115
message = "The statement {} is not allowed inside the `hl.tile` or `hl.grid` loop."
116116

117117

118+
class HostTensorDirectUsage(BaseError):
119+
message = (
120+
"Direct use of host tensor '{0}' in op '{1}' not allowed inside the `hl.tile` or `hl.grid` loop. "
121+
"First load it using {0}[...] or hl.load({0}, ...)."
122+
)
123+
124+
118125
class ShapeSpecializingCall(BaseError):
119126
message = "Call would force shape specialization, try `hl.specialize(x)` or `hl.constexpr`."
120127

helion/language/_decorators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class APIFunc(Protocol):
7777
_prepare_args: Callable[[tuple[object, ...]], tuple[object, ...]]
7878
_get_masked_value: Callable[[torch.fx.Node], float | bool | None] | None
7979
_to_device_ir: Callable[..., object] | None
80+
_allow_host_tensor: bool
8081
_signature: inspect.Signature
8182

8283
def __call__(self, *args: object, **kwargs: object) -> object: ...
@@ -126,6 +127,7 @@ def api(
126127
is_device_only: bool = True,
127128
tiles_as_sizes: bool = False,
128129
cache_type: bool = False,
130+
allow_host_tensor: bool = False,
129131
signature: inspect.Signature | None = None,
130132
) -> _Decorator:
131133
def _impl(fn: _C) -> _C:
@@ -181,6 +183,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
181183
api._fake_fn = None
182184
api._get_masked_value = None
183185
api._to_device_ir = None
186+
api._allow_host_tensor = allow_host_tensor
184187
api._signature = signature or inspect.signature(
185188
cast("Callable[..., object]", fn)
186189
)

helion/language/_tracing_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _(state: CodegenState) -> None:
9494

9595
# Note we can't DCE phi nodes because there may be a loop carry dependency not captured in the outer graph
9696
@has_side_effect
97-
@_decorators.api()
97+
@_decorators.api(allow_host_tensor=True)
9898
def _phi(lhs: object, rhs: object) -> object:
9999
"""Combine values from different branches of a control flow."""
100100
raise AssertionError("this should never be called")
@@ -291,7 +291,7 @@ def _(node: torch.fx.Node) -> float | bool:
291291
return value
292292

293293

294-
@_decorators.api()
294+
@_decorators.api(allow_host_tensor=True)
295295
def _new_var(value: _T, /) -> _T:
296296
"""
297297
Create a shallow copy of a value that is assigned a fresh variable in codegen.

helion/language/memory_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
@has_side_effect
22-
@_decorators.api(tiles_as_sizes=True)
22+
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
2323
def store(
2424
tensor: torch.Tensor,
2525
index: list[object],
@@ -84,7 +84,7 @@ def _(state: CodegenState) -> ast.AST:
8484
)
8585

8686

87-
@_decorators.api(tiles_as_sizes=True)
87+
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
8888
def load(
8989
tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None
9090
) -> torch.Tensor:
@@ -130,7 +130,7 @@ def _(node: torch.fx.Node) -> int:
130130

131131

132132
@has_side_effect
133-
@_decorators.api()
133+
@_decorators.api(allow_host_tensor=True)
134134
def atomic_add(
135135
target: torch.Tensor,
136136
index: list[object],

helion/language/signal_wait.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@has_side_effect
21-
@_decorators.api(tiles_as_sizes=True)
21+
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
2222
def wait(
2323
signal_pad: torch.Tensor,
2424
index: list[object],
@@ -158,7 +158,7 @@ def _(state: CodegenState) -> ast.AST:
158158

159159

160160
@has_side_effect
161-
@_decorators.api(tiles_as_sizes=True)
161+
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
162162
def signal(
163163
signal_pad: torch.Tensor,
164164
index: list[object],

test/test_errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def closure_fn():
198198
with self.assertRaises(helion.exc.StatementNotSupported):
199199
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
200200

201+
def test_direct_scalar_tensor_in_device_context(self):
202+
"""Test that direct scalar tensor usage gives clear error in device code."""
203+
204+
@helion.kernel()
205+
def bad_fn(x: torch.Tensor, scalar_tensor: torch.Tensor) -> torch.Tensor:
206+
result = torch.empty_like(x)
207+
for tile in hl.tile(x.shape):
208+
result[tile] = x[tile] + scalar_tensor # Error: direct scalar usage
209+
return result
210+
211+
with self.assertRaises(helion.exc.HostTensorDirectUsage):
212+
code_and_output(
213+
bad_fn,
214+
(torch.randn(4, 4, device=DEVICE), torch.tensor(3.0, device=DEVICE)),
215+
)
216+
201217

202218
if __name__ == "__main__":
203219
unittest.main()

0 commit comments

Comments
 (0)