Skip to content

Commit 55759ff

Browse files
authored
Add helion.exc.CannotModifyHostVariableOnDevice and helion.exc.CannotReadDeviceVariableOnHost (#290)
1 parent 47878bf commit 55759ff

File tree

8 files changed

+92
-7
lines changed

8 files changed

+92
-7
lines changed

docs/api/exceptions.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ These exceptions occur when Helion language functions are used incorrectly with
167167
.. autoclass:: UndefinedVariable
168168
169169
Raised when referencing undefined variables.
170+
171+
.. autoclass:: CannotModifyHostVariableOnDevice
172+
173+
Raised when modifying host variables inside device loops without subscript assignment.
174+
175+
.. autoclass:: CannotReadDeviceVariableOnHost
176+
177+
Raised when attempting to read variables defined inside device loops from host context.
170178
```
171179

172180
## Type and Inference Errors

helion/_compiler/type_propagation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def get(self, name: str) -> TypeInfo:
123123
return self.variables[name]
124124
return self.parent.get(name)
125125

126+
def maybe_get(self, name: str) -> TypeInfo | None:
127+
try:
128+
return self.get(name)
129+
except exc.UndefinedVariable:
130+
return None
131+
126132
def set(self, name: str, type_info: TypeInfo) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
127133
self.variables[name] = type_info
128134

@@ -1545,6 +1551,13 @@ def _compare(self, op: ast.cmpop, left: TypeInfo, right: TypeInfo) -> TypeInfo:
15451551

15461552
def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
15471553
if isinstance(lhs, ast.Name):
1554+
# Check if we're trying to modify a host variable inside a device loop
1555+
if (
1556+
(existing_type := self.scope.maybe_get(lhs.id)) is not None
1557+
and existing_type.origin.is_host()
1558+
and rhs.origin.is_device()
1559+
):
1560+
raise exc.CannotModifyHostVariableOnDevice(lhs.id) from None
15481561
return self.scope.set(lhs.id, rhs)
15491562
if isinstance(lhs, ast.Starred):
15501563
try:
@@ -1679,7 +1692,10 @@ def visit_Dict(self, node: ast.Dict) -> TypeInfo:
16791692
return DictType(element_types=element_types, origin=self.origin())
16801693

16811694
def visit_Name(self, node: ast.Name) -> TypeInfo:
1682-
return self.scope.get(node.id)
1695+
result = self.scope.get(node.id)
1696+
if self.device_loop_depth == 0 and result.origin.is_device():
1697+
raise exc.CannotReadDeviceVariableOnHost(node.id)
1698+
return result
16831699

16841700
visit_Starred: _VisitMethod = generic_visit # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
16851701

@@ -1996,6 +2012,8 @@ def visit_Try(self, node: ast.Try) -> TypeInfo:
19962012
def _not_on_device_statement(self, node: ast.AST) -> TypeInfo:
19972013
if self.device_loop_depth:
19982014
raise exc.NotAllowedOnDevice(type(node).__name__)
2015+
for child_node in ast.iter_child_nodes(node):
2016+
self.visit(child_node)
19992017
return NoType(origin=self.origin())
20002018

20012019
visit_ExceptHandler: _VisitMethod = _not_on_device_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]

helion/exc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,11 @@ class TypeInferenceError(BaseError):
304304

305305
class NotAllowedInHelperFunction(BaseError):
306306
message = "This operation is not allowed inside helper functions. It requires kernel context."
307+
308+
309+
class CannotModifyHostVariableOnDevice(BaseError):
310+
message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' instead."
311+
312+
313+
class CannotReadDeviceVariableOnHost(BaseError):
314+
message = "Cannot read variable '{0}' defined inside `hl.tile` or `hl.grid` loop from host code."

helion/language/loops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def tile_info_example(x: torch.Tensor) -> torch.Tensor:
193193
* tile(begin, end, block_size) iterates begin to end-1, fixed block_size
194194
* tile(end, block_size=block_size) iterates 0 to end-1, fixed block_size
195195
196-
Block sizes can be registered for autotuning explicitly with :func:`~helion.language.register_block_size`.
197-
And passed in to as ``block_size`` argument if one needs two loops to use the same block size. Passing
196+
Block sizes can be registered for autotuning explicitly with :func:`~helion.language.register_block_size`
197+
and passed as the ``block_size`` argument if one needs two loops to use the same block size. Passing
198198
``block_size=None`` is equivalent to calling register_block_size.
199199
200200
Use ``tile`` in most cases. Use ``grid`` when you need explicit control over the launch grid.

test/data/all_ast_nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def all_ast_nodes(x, y):
135135
i = 3
136136
t = 0
137137

138-
with contextlib.nullcontext():
139-
e3 = 1
138+
# with contextlib.nullcontext():
139+
# e3 = 1
140140

141141
global global0
142142

test/test_errors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,50 @@ def fn(x: torch.Tensor) -> torch.Tensor:
125125
with self.assertRaises(helion.exc.InvalidDeviceForLoop):
126126
code_and_output(fn, (torch.randn(8, device=DEVICE),))
127127

128+
def test_return_inside_grid_loop(self):
129+
"""Test that return statement inside hl.grid loop raises proper error."""
130+
131+
@helion.kernel()
132+
def fn(x: torch.Tensor) -> torch.Tensor:
133+
batch = x.size(0)
134+
out = x.new_empty(batch)
135+
for tile_batch in hl.grid(batch):
136+
if x[tile_batch] > 0:
137+
return out # This should not be allowed
138+
out[tile_batch] = x[tile_batch] * 2
139+
return out
140+
141+
with self.assertRaises(helion.exc.NotAllowedOnDevice):
142+
code_and_output(fn, (torch.randn(8, device=DEVICE),))
143+
144+
def test_assign_without_subscript1(self):
145+
"""Test that modifying host variables inside device loops raises proper error."""
146+
147+
@helion.kernel()
148+
def bad_fn(x: torch.Tensor) -> torch.Tensor:
149+
batch = x.size(0)
150+
result = torch.empty_like(x)
151+
for tile_batch in hl.tile(batch):
152+
# shouldn't be able to modify host variables on device
153+
result = x[tile_batch] * 2
154+
return result
155+
156+
with self.assertRaises(helion.exc.CannotModifyHostVariableOnDevice):
157+
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
158+
159+
def test_assign_without_subscript2(self):
160+
"""Test that reading device variables from host context raises proper error."""
161+
162+
@helion.kernel()
163+
def bad_fn(x: torch.Tensor) -> torch.Tensor:
164+
batch = x.size(0)
165+
for tile_batch in hl.tile(batch):
166+
result = x[tile_batch] * 2
167+
return result # shouldn't be able to read device variable here
168+
169+
with self.assertRaises(helion.exc.CannotReadDeviceVariableOnHost):
170+
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
171+
128172

129173
if __name__ == "__main__":
130174
unittest.main()

test/test_reductions.expected

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
108108
# UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
109109
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
110110
out[tile_n] = fn(x[tile_n, :], dim=-1)
111+
# Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:49>)
111112
return out
112113

113114
def root_graph_0():

test/test_type_propagation.expected

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def add(x, y):
3333
# Name: TensorType([y_size0, x_size1], torch.int32) GetItemOrigin(value=SourceOrigin(location=<SourceLocation basic_kernels.py:10>), key=1)
3434
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation basic_kernels.py:12>)
3535
out[tile] = x[tile] + y[tile]
36+
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:11>)
3637
return out
3738

3839
def root_graph_0():
@@ -383,8 +384,6 @@ def all_ast_nodes(x, y):
383384
i = 3
384385
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>)
385386
t = 0
386-
with contextlib.nullcontext():
387-
e3 = 1
388387
global global0
389388
# Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
390389
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
@@ -437,6 +436,7 @@ def all_ast_nodes(x, y):
437436
# Name: SymIntType(u16) SourceOrigin(location=<SourceLocation all_ast_nodes.py:149>)
438437
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:150>)
439438
combined = [v, z]
439+
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
440440
return out
441441

442442
def root_graph_0():
@@ -488,6 +488,7 @@ def hl_full_usage(x: torch.Tensor):
488488
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation basic_kernels.py:39>)
489489
# Name: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation basic_kernels.py:42>)
490490
out[tile] = tmp
491+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:38>)
491492
return out
492493

493494
def root_graph_0():
@@ -545,6 +546,7 @@ def hl_zeros_usage(x: torch.Tensor):
545546
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation basic_kernels.py:28>)
546547
# Name: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation basic_kernels.py:31>)
547548
out[tile] = tmp
549+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:27>)
548550
return out
549551

550552
def root_graph_0():
@@ -640,6 +642,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor):
640642
# Name: TileIndexType(1) SourceOrigin(location=<SourceLocation matmul.py:19>)
641643
# Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation matmul.py:22>)
642644
out[tile_m, tile_n] = acc
645+
# Name: TensorType([512, 512], torch.float32) SourceOrigin(location=<SourceLocation matmul.py:16>)
643646
return out
644647

645648
def for_loop_0(arg0_1: "f32[u0, u1]"):
@@ -697,6 +700,7 @@ def fn(x):
697700
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
698701
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:77>)
699702
out[tile] = x[tile].sin()
703+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:76>)
700704
return out
701705

702706
def root_graph_0():
@@ -747,6 +751,7 @@ def pointwise_device_loop(x: torch.Tensor):
747751
# Name: TileIndexType(1) DeviceOrigin(location=<SourceLocation basic_kernels.py:52>)
748752
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation basic_kernels.py:53>)
749753
out[tile_n, tile_m] = torch.sigmoid(x[tile_n, tile_m] + 1)
754+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:49>)
750755
return out
751756

752757
def for_loop_0():
@@ -805,6 +810,7 @@ def torch_ops_pointwise(x, y):
805810
# Name: TensorType([y_size0], torch.int32) ArgumentOrigin(name='y')
806811
# Name: SequenceType([TileIndexType(0)]) SourceOrigin(location=<SourceLocation basic_kernels.py:20>)
807812
out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile])))
813+
# Name: TensorType([x_size0], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:19>)
808814
return out
809815

810816
def root_graph_0():

0 commit comments

Comments
 (0)