@@ -683,9 +683,9 @@ def visit_Name(self, node: ast.Name) -> object:
683
683
return self .scope [node .id ]
684
684
assert isinstance (node , ExtendedAST )
685
685
type_info = node ._type_info
686
- assert type_info .origin .is_host () # pyright: ignore[reportOptionalMemberAccess]
686
+ assert type_info is not None and type_info .origin .is_host ()
687
687
try :
688
- return type_info .proxy () # pyright: ignore[reportOptionalMemberAccess]
688
+ return type_info .proxy ()
689
689
except NotImplementedError :
690
690
raise exc .CantReadOnDevice (type_info ) from None
691
691
@@ -757,6 +757,7 @@ def visit_Assign(self, node: ast.Assign) -> None:
757
757
):
758
758
raise exc .NonTensorSubscriptAssign (lhs_type , rhs_type )
759
759
assert isinstance (target .value , ExtendedAST )
760
+ assert target .value ._type_info is not None
760
761
target_origin = target .value ._type_info .origin # pyright: ignore[reportOptionalMemberAccess]
761
762
if not target_origin .is_host ():
762
763
# Get the variable name for the error message
@@ -781,7 +782,8 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
781
782
raise exc .NonTensorSubscriptAssign (lhs_type , type (val ))
782
783
783
784
assert isinstance (target .value , ExtendedAST )
784
- target_origin = target .value ._type_info .origin # pyright: ignore[reportOptionalMemberAccess]
785
+ assert target .value ._type_info is not None
786
+ target_origin = target .value ._type_info .origin
785
787
assert target_origin .is_host ()
786
788
787
789
return hl .store (
@@ -811,7 +813,7 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
811
813
value = node .value
812
814
assert isinstance (value , ExtendedAST )
813
815
type_info = value ._type_info
814
- if type_info .origin .is_host (): # pyright: ignore[reportOptionalMemberAccess]
816
+ if type_info is not None and type_info .origin .is_host ():
815
817
return hl .load (self .visit (value ), self ._subscript_slice_proxy (node .slice )) # pyright: ignore[reportArgumentType]
816
818
return hl .subscript (self .visit (value ), self ._subscript_slice_proxy (node .slice )) # pyright: ignore[reportArgumentType]
817
819
0 commit comments