Skip to content

Commit 8a95454

Browse files
authored
Fix bug in merging sequence types (#152)
1 parent c432607 commit 8a95454

File tree

3 files changed

+72
-61
lines changed

3 files changed

+72
-61
lines changed

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12061206
def merge(self, other: TypeInfo) -> TypeInfo:
12071207
if isinstance(other, SequenceType):
12081208
self_elements = self.element_types
1209-
other_elements = self.element_types
1209+
other_elements = other.element_types
12101210
if len(self_elements) == len(other_elements):
12111211
return SequenceType(
12121212
origin=other.origin,
@@ -1246,7 +1246,7 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12461246
def merge(self, other: TypeInfo) -> TypeInfo:
12471247
if isinstance(other, DictType):
12481248
self_elements = self.element_types
1249-
other_elements = self.element_types
1249+
other_elements = other.element_types
12501250
if set(self_elements.keys()) == set(other_elements.keys()):
12511251
return DictType(
12521252
origin=other.origin,
@@ -1298,7 +1298,7 @@ def unpack(self) -> list[TypeInfo]:
12981298
def merge(self, other: TypeInfo) -> TypeInfo:
12991299
if isinstance(other, SliceType):
13001300
self_elements = self.element_types
1301-
other_elements = self.element_types
1301+
other_elements = other.element_types
13021302
return SliceType(
13031303
origin=other.origin,
13041304
element_types=slice(

test/data/all_ast_nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,13 @@ def all_ast_nodes(x, y):
113113
join_var0 = x
114114
join_var1 = x + y
115115
join_var2 = 1
116+
join_var3 = {"x": 0}
116117
else:
117118
join_var0 = y
118119
join_var1 = None
119120
join_var2 = 2
120-
combined = [join_var0, join_var1, join_var2]
121+
join_var3 = {"x": 1}
122+
combined = [join_var0, join_var1, join_var2, join_var3]
121123

122124
i = 0
123125
while i < 3:

test/test_type_propagation.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -488,96 +488,105 @@ def all_ast_nodes(x, y):
488488
join_var1 = x + y
489489
# Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:115>)
490490
join_var2 = 1
491+
# Dict: DictType({'x': LiteralType(0)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
492+
# Constant: LiteralType('x') SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
493+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
494+
join_var3 = {'x': 0}
491495
else:
492496
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
493497
join_var0 = y
494-
# Constant: LiteralType(None) SourceOrigin(location=<SourceLocation all_ast_nodes.py:118>)
498+
# Constant: LiteralType(None) SourceOrigin(location=<SourceLocation all_ast_nodes.py:119>)
495499
join_var1 = None
496-
# Constant: LiteralType(2) SourceOrigin(location=<SourceLocation all_ast_nodes.py:119>)
500+
# Constant: LiteralType(2) SourceOrigin(location=<SourceLocation all_ast_nodes.py:120>)
497501
join_var2 = 2
498-
# List: SequenceType([TensorType([y_size0, x_size1], torch.int32), UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)"), SymIntType(u8)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:120>)
502+
# Dict: DictType({'x': LiteralType(1)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
503+
# Constant: LiteralType('x') SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
504+
# Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
505+
join_var3 = {'x': 1}
506+
# List: SequenceType([TensorType([y_size0, x_size1], torch.int32), UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)"), SymIntType(u8), DictType({'x': SymIntType(u9)})]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
499507
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
500-
# Name: UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:118>)
508+
# Name: UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:119>)
501509
# Name: SymIntType(u8) SourceOrigin(location=<SourceLocation all_ast_nodes.py:115>)
502-
combined = [join_var0, join_var1, join_var2]
503-
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
510+
# Name: DictType({'x': SymIntType(u9)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
511+
combined = [join_var0, join_var1, join_var2, join_var3]
512+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
504513
i = 0
505-
# Compare: SymBoolType(Eq(u11, 1)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:123>)
506-
# Name: SymIntType(u9) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
507-
# Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:123>)
514+
# Compare: SymBoolType(Eq(u12, 1)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:125>)
515+
# Name: SymIntType(u10) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
516+
# Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:125>)
508517
while i < 3:
509-
# BinOp: SymIntType(u13) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
510-
# Name: SymIntType(u12) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
511-
# Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
518+
# BinOp: SymIntType(u14) SourceOrigin(location=<SourceLocation all_ast_nodes.py:126>)
519+
# Name: SymIntType(u13) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
520+
# Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:126>)
512521
i = i + 1
513522
continue
514523
else:
515-
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:127>)
524+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:129>)
516525
t = 0
517526
with contextlib.nullcontext():
518-
# Global: UnknownType('ast.Global is not supported') SourceOrigin(location=<SourceLocation all_ast_nodes.py:132>)
527+
# Global: UnknownType('ast.Global is not supported') SourceOrigin(location=<SourceLocation all_ast_nodes.py:134>)
519528
e3 = 1
520529
global global0
521-
# Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134>)
530+
# Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>)
522531
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
523532
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
524533
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
525534
out = torch.empty_like(x)
526-
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:135>)
535+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
527536
v = 0
528-
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>)
537+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:138>)
529538
# For: loop_type=GRID
530539
z = 0
531-
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
540+
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
532541
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
533542
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
534-
# Call: SequenceType((SymIntType(s17), SymIntType(s27))) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
535-
# Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=<SourceLocation all_ast_nodes.py:134>), key='size')
536-
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134>)
543+
# Call: SequenceType((SymIntType(s17), SymIntType(s27))) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
544+
# Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>), key='size')
545+
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>)
537546
for tile in hl.tile(out.size()):
538-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138>)
539-
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134>)
540-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
541-
# BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138>)
542-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138>)
547+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
548+
# Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136>)
549+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
550+
# BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
551+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
543552
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
544-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
545-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138>)
553+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
554+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
546555
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
547-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
556+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
548557
# For: loop_type=HOST
549558
out[tile] = x[tile] + y[tile]
550-
# Call: LiteralType(range(0, 3)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
559+
# Call: LiteralType(range(0, 3)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
551560
# Name: CallableType(range) BuiltinOrigin(name='range')
552-
# Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
561+
# Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
553562
for i in range(3):
554-
# BinOp: SymIntType(u21) SourceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
555-
# Name: SymIntType(u20) SourceOrigin(location=<SourceLocation all_ast_nodes.py:135>)
556-
# Name: SymIntType(u19) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139>)
563+
# BinOp: SymIntType(u22) SourceOrigin(location=<SourceLocation all_ast_nodes.py:142>)
564+
# Name: SymIntType(u21) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137>)
565+
# Name: SymIntType(u20) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
557566
v = v + i
558-
# BinOp: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
559-
# Name: UnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
567+
# BinOp: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
568+
# Name: UnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
560569
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
561570
z = z + x
562571
break
563572
else:
564-
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:144>)
573+
# Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:146>)
565574
t = 0
566-
# List: SequenceType([SymIntType(u22), ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)")]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:145>)
567-
# Name: SymIntType(u22) SourceOrigin(location=<SourceLocation all_ast_nodes.py:140>)
568-
# Name: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141>)
575+
# List: SequenceType([SymIntType(u23), ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)")]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:147>)
576+
# Name: SymIntType(u23) SourceOrigin(location=<SourceLocation all_ast_nodes.py:142>)
577+
# Name: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
569578
combined = [v, z]
570579
return out
571580
572581
def root_graph_0():
573-
# File: .../all_ast_nodes.py:138 in all_ast_nodes, code: out[tile] = x[tile] + y[tile]
582+
# File: .../all_ast_nodes.py:140 in all_ast_nodes, code: out[tile] = x[tile] + y[tile]
574583
x: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('x')
575-
block_size_0: "Sym(u16)" = helion_language__tracing_ops__get_symnode('block_size_0')
576-
block_size_1: "Sym(u17)" = helion_language__tracing_ops__get_symnode('block_size_1')
577-
load: "i32[u16, u17]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1], None); x = None
584+
block_size_0: "Sym(u17)" = helion_language__tracing_ops__get_symnode('block_size_0')
585+
block_size_1: "Sym(u18)" = helion_language__tracing_ops__get_symnode('block_size_1')
586+
load: "i32[u17, u18]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1], None); x = None
578587
y: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('y')
579-
load_1: "i32[u16, u17]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1], None); y = None
580-
add: "i32[u16, u17]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None
588+
load_1: "i32[u17, u18]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1], None); y = None
589+
add: "i32[u17, u18]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None
581590
out: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('out')
582591
store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], add, None); out = block_size_0 = block_size_1 = add = store = None
583592
return None""",
@@ -795,32 +804,32 @@ def fn(x):
795804
output,
796805
"""\
797806
def fn(x):
798-
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:785>)
807+
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:794>)
799808
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
800809
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
801810
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
802811
# For: loop_type=GRID
803812
out = torch.empty_like(x)
804-
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:786>)
813+
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:795>)
805814
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
806815
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
807-
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:786>)
816+
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:795>)
808817
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
809818
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
810819
for tile in hl.tile(x.size()):
811-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787>)
812-
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:785>)
813-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:786>)
814-
# Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787>)
815-
# Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:787>), key='sin')
816-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787>)
820+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796>)
821+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:794>)
822+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:795>)
823+
# Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796>)
824+
# Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:796>), key='sin')
825+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796>)
817826
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
818-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:786>)
827+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:795>)
819828
out[tile] = x[tile].sin()
820829
return out
821830
822831
def root_graph_0():
823-
# File: .../test_type_propagation.py:787 in fn, code: out[tile] = x[tile].sin()
832+
# File: .../test_type_propagation.py:796 in fn, code: out[tile] = x[tile].sin()
824833
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
825834
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
826835
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')

0 commit comments

Comments
 (0)