Skip to content

Commit 148398a

Browse files
authored
Add helion.exc.DeviceTensorSubscriptAssignmentNotAllowed (#292)
1 parent adac8a7 commit 148398a

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

docs/api/exceptions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ These exceptions occur when Helion language functions are used incorrectly with
175175
.. autoclass:: CannotReadDeviceVariableOnHost
176176
177177
Raised when attempting to read variables defined inside device loops from host context.
178+
179+
.. autoclass:: DeviceTensorSubscriptAssignmentNotAllowed
180+
181+
Raised when attempting to assign to subscript of device tensor.
178182
```
179183

180184
## Type and Inference Errors

helion/_compiler/device_ir.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,14 @@ def visit_Assign(self, node: ast.Assign) -> None:
758758
raise exc.NonTensorSubscriptAssign(lhs_type, rhs_type)
759759
assert isinstance(target.value, ExtendedAST)
760760
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
761-
assert target_origin.is_host()
761+
if not target_origin.is_host():
762+
# Get the variable name for the error message
763+
var_name = (
764+
target.value.id
765+
if isinstance(target.value, ast.Name)
766+
else str(target.value)
767+
)
768+
raise exc.DeviceTensorSubscriptAssignmentNotAllowed(var_name)
762769
val = self.visit(node.value)
763770
self._assign_subscript(target, val)
764771

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,7 @@ class CannotModifyHostVariableOnDevice(BaseError):
312312

313313
class CannotReadDeviceVariableOnHost(BaseError):
314314
message = "Cannot read variable '{0}' defined inside `hl.tile` or `hl.grid` loop from host code."
315+
316+
317+
class DeviceTensorSubscriptAssignmentNotAllowed(BaseError):
318+
message = "Cannot assign to subscript of device tensor '{0}'."

test/test_errors.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,35 @@ def bad_fn(x: torch.Tensor) -> torch.Tensor:
169169
with self.assertRaises(helion.exc.CannotReadDeviceVariableOnHost):
170170
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
171171

172+
def test_device_tensor_subscript(self):
173+
@helion.kernel()
174+
def bad_fn(x: torch.Tensor) -> torch.Tensor:
175+
batch = x.size(0)
176+
result = torch.empty_like(x)
177+
for i in hl.tile(batch):
178+
tmp = x[i] * 2
179+
tmp[0] = 1 # This should not be allowed
180+
result[i] = tmp
181+
return result
182+
183+
with self.assertRaises(helion.exc.DeviceTensorSubscriptAssignmentNotAllowed):
184+
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
185+
186+
def test_closure_fn(self):
187+
@helion.kernel()
188+
def bad_fn(x: torch.Tensor) -> torch.Tensor:
189+
def closure_fn():
190+
pass
191+
192+
batch = x.size(0)
193+
result = torch.empty_like(x)
194+
for i in hl.tile(batch):
195+
result[i] = x[i] * 2
196+
return result
197+
198+
with self.assertRaises(helion.exc.StatementNotSupported):
199+
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))
200+
172201

173202
if __name__ == "__main__":
174203
unittest.main()

0 commit comments

Comments
 (0)