Skip to content

Commit 38753a6

Browse files
authored
Fix type_info null errors (#294)
1 parent b9e76dc commit 38753a6

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

helion/_compiler/device_ir.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -683,9 +683,9 @@ def visit_Name(self, node: ast.Name) -> object:
683683
return self.scope[node.id]
684684
assert isinstance(node, ExtendedAST)
685685
type_info = node._type_info
686-
assert type_info.origin.is_host() # pyright: ignore[reportOptionalMemberAccess]
686+
assert type_info is not None and type_info.origin.is_host()
687687
try:
688-
return type_info.proxy() # pyright: ignore[reportOptionalMemberAccess]
688+
return type_info.proxy()
689689
except NotImplementedError:
690690
raise exc.CantReadOnDevice(type_info) from None
691691

@@ -757,6 +757,7 @@ def visit_Assign(self, node: ast.Assign) -> None:
757757
):
758758
raise exc.NonTensorSubscriptAssign(lhs_type, rhs_type)
759759
assert isinstance(target.value, ExtendedAST)
760+
assert target.value._type_info is not None
760761
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
761762
if not target_origin.is_host():
762763
# Get the variable name for the error message
@@ -781,7 +782,8 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
781782
raise exc.NonTensorSubscriptAssign(lhs_type, type(val))
782783

783784
assert isinstance(target.value, ExtendedAST)
784-
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
785+
assert target.value._type_info is not None
786+
target_origin = target.value._type_info.origin
785787
assert target_origin.is_host()
786788

787789
return hl.store(
@@ -811,7 +813,7 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
811813
value = node.value
812814
assert isinstance(value, ExtendedAST)
813815
type_info = value._type_info
814-
if type_info.origin.is_host(): # pyright: ignore[reportOptionalMemberAccess]
816+
if type_info is not None and type_info.origin.is_host():
815817
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
816818
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
817819

helion/_compiler/generate_ast.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,17 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
191191
for arg_node in iter_node.args:
192192
assert not isinstance(arg_node, ast.Starred)
193193
assert isinstance(arg_node, ExtendedAST)
194-
args.append(arg_node._type_info.proxy()) # pyright: ignore[reportOptionalMemberAccess]
194+
assert arg_node._type_info is not None
195+
args.append(arg_node._type_info.proxy())
195196
for kwarg_node in iter_node.keywords:
196197
assert kwarg_node.arg is not None
197198
assert isinstance(kwarg_node.value, ExtendedAST)
198-
kwargs[kwarg_node.arg] = kwarg_node.value._type_info.proxy() # pyright: ignore[reportOptionalMemberAccess]
199+
assert kwarg_node.value._type_info is not None
200+
kwargs[kwarg_node.arg] = kwarg_node.value._type_info.proxy()
199201
fn_node = iter_node.func
200202
assert isinstance(fn_node, ExtendedAST)
201-
fn = fn_node._type_info.proxy() # pyright: ignore[reportOptionalMemberAccess]
203+
assert fn_node._type_info is not None
204+
fn = fn_node._type_info.proxy()
202205
assert is_api_func(fn)
203206
assert fn._codegen is not None
204207
bound = fn._signature.bind(*args, **kwargs)

0 commit comments

Comments
 (0)