@@ -488,96 +488,105 @@ def all_ast_nodes(x, y):
488
488
join_var1 = x + y
489
489
# Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:115>)
490
490
join_var2 = 1
491
+ # Dict: DictType({'x': LiteralType(0)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
492
+ # Constant: LiteralType('x') SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
493
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:116>)
494
+ join_var3 = {'x': 0}
491
495
else:
492
496
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
493
497
join_var0 = y
494
- # Constant: LiteralType(None) SourceOrigin(location=<SourceLocation all_ast_nodes.py:118 >)
498
+ # Constant: LiteralType(None) SourceOrigin(location=<SourceLocation all_ast_nodes.py:119 >)
495
499
join_var1 = None
496
- # Constant: LiteralType(2) SourceOrigin(location=<SourceLocation all_ast_nodes.py:119 >)
500
+ # Constant: LiteralType(2) SourceOrigin(location=<SourceLocation all_ast_nodes.py:120 >)
497
501
join_var2 = 2
498
- # List: SequenceType([TensorType([y_size0, x_size1], torch.int32), UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)"), SymIntType(u8)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:120>)
502
+ # Dict: DictType({'x': LiteralType(1)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
503
+ # Constant: LiteralType('x') SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
504
+ # Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
505
+ join_var3 = {'x': 1}
506
+ # List: SequenceType([TensorType([y_size0, x_size1], torch.int32), UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)"), SymIntType(u8), DictType({'x': SymIntType(u9)})]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
499
507
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
500
- # Name: UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:118 >)
508
+ # Name: UnknownType("Can't combine types from control flow: TensorType([y_size0, x_size1], torch.int32) and LiteralType(None)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:119 >)
501
509
# Name: SymIntType(u8) SourceOrigin(location=<SourceLocation all_ast_nodes.py:115>)
502
- combined = [join_var0, join_var1, join_var2]
503
- # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122>)
510
+ # Name: DictType({'x': SymIntType(u9)}) SourceOrigin(location=<SourceLocation all_ast_nodes.py:121>)
511
+ combined = [join_var0, join_var1, join_var2, join_var3]
512
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124>)
504
513
i = 0
505
- # Compare: SymBoolType(Eq(u11 , 1)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:123 >)
506
- # Name: SymIntType(u9 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122 >)
507
- # Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:123 >)
514
+ # Compare: SymBoolType(Eq(u12 , 1)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:125 >)
515
+ # Name: SymIntType(u10 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124 >)
516
+ # Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:125 >)
508
517
while i < 3:
509
- # BinOp: SymIntType(u13 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124 >)
510
- # Name: SymIntType(u12 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:122 >)
511
- # Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124 >)
518
+ # BinOp: SymIntType(u14 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:126 >)
519
+ # Name: SymIntType(u13 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:124 >)
520
+ # Constant: LiteralType(1) SourceOrigin(location=<SourceLocation all_ast_nodes.py:126 >)
512
521
i = i + 1
513
522
continue
514
523
else:
515
- # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:127 >)
524
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:129 >)
516
525
t = 0
517
526
with contextlib.nullcontext():
518
- # Global: UnknownType('ast.Global is not supported') SourceOrigin(location=<SourceLocation all_ast_nodes.py:132 >)
527
+ # Global: UnknownType('ast.Global is not supported') SourceOrigin(location=<SourceLocation all_ast_nodes.py:134 >)
519
528
e3 = 1
520
529
global global0
521
- # Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134 >)
530
+ # Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136 >)
522
531
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
523
532
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
524
533
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
525
534
out = torch.empty_like(x)
526
- # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:135 >)
535
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
527
536
v = 0
528
- # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136 >)
537
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:138 >)
529
538
# For: loop_type=GRID
530
539
z = 0
531
- # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
540
+ # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
532
541
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
533
542
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
534
- # Call: SequenceType((SymIntType(s17), SymIntType(s27))) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
535
- # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=<SourceLocation all_ast_nodes.py:134 >), key='size')
536
- # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134 >)
543
+ # Call: SequenceType((SymIntType(s17), SymIntType(s27))) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
544
+ # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=<SourceLocation all_ast_nodes.py:136 >), key='size')
545
+ # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136 >)
537
546
for tile in hl.tile(out.size()):
538
- # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138 >)
539
- # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:134 >)
540
- # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
541
- # BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138 >)
542
- # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138 >)
547
+ # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
548
+ # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:136 >)
549
+ # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
550
+ # BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
551
+ # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
543
552
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
544
- # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
545
- # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:138 >)
553
+ # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
554
+ # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
546
555
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y')
547
- # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
556
+ # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
548
557
# For: loop_type=HOST
549
558
out[tile] = x[tile] + y[tile]
550
- # Call: LiteralType(range(0, 3)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
559
+ # Call: LiteralType(range(0, 3)) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
551
560
# Name: CallableType(range) BuiltinOrigin(name='range')
552
- # Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
561
+ # Constant: LiteralType(3) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
553
562
for i in range(3):
554
- # BinOp: SymIntType(u21 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
555
- # Name: SymIntType(u20 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:135 >)
556
- # Name: SymIntType(u19 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:139 >)
563
+ # BinOp: SymIntType(u22 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:142 >)
564
+ # Name: SymIntType(u21 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:137 >)
565
+ # Name: SymIntType(u20 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
557
566
v = v + i
558
- # BinOp: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
559
- # Name: UnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
567
+ # BinOp: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143 >)
568
+ # Name: UnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143 >)
560
569
# Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
561
570
z = z + x
562
571
break
563
572
else:
564
- # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:144 >)
573
+ # Constant: LiteralType(0) SourceOrigin(location=<SourceLocation all_ast_nodes.py:146 >)
565
574
t = 0
566
- # List: SequenceType([SymIntType(u22 ), ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)")]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:145 >)
567
- # Name: SymIntType(u22 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:140 >)
568
- # Name: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:141 >)
575
+ # List: SequenceType([SymIntType(u23 ), ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)")]) SourceOrigin(location=<SourceLocation all_ast_nodes.py:147 >)
576
+ # Name: SymIntType(u23 ) SourceOrigin(location=<SourceLocation all_ast_nodes.py:142 >)
577
+ # Name: ChainedUnknownType("Can't combine types from control flow: LiteralType(0) and TensorType([y_size0, x_size1], torch.int32)") SourceOrigin(location=<SourceLocation all_ast_nodes.py:143 >)
569
578
combined = [v, z]
570
579
return out
571
580
572
581
def root_graph_0():
573
- # File: .../all_ast_nodes.py:138 in all_ast_nodes, code: out[tile] = x[tile] + y[tile]
582
+ # File: .../all_ast_nodes.py:140 in all_ast_nodes, code: out[tile] = x[tile] + y[tile]
574
583
x: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('x')
575
- block_size_0: "Sym(u16 )" = helion_language__tracing_ops__get_symnode('block_size_0')
576
- block_size_1: "Sym(u17 )" = helion_language__tracing_ops__get_symnode('block_size_1')
577
- load: "i32[u16, u17 ]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1], None); x = None
584
+ block_size_0: "Sym(u17 )" = helion_language__tracing_ops__get_symnode('block_size_0')
585
+ block_size_1: "Sym(u18 )" = helion_language__tracing_ops__get_symnode('block_size_1')
586
+ load: "i32[u17, u18 ]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1], None); x = None
578
587
y: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('y')
579
- load_1: "i32[u16, u17 ]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1], None); y = None
580
- add: "i32[u16, u17 ]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None
588
+ load_1: "i32[u17, u18 ]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1], None); y = None
589
+ add: "i32[u17, u18 ]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None
581
590
out: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('out')
582
591
store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], add, None); out = block_size_0 = block_size_1 = add = store = None
583
592
return None""" ,
@@ -795,32 +804,32 @@ def fn(x):
795
804
output ,
796
805
"""\
797
806
def fn(x):
798
- # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:785 >)
807
+ # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:794 >)
799
808
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
800
809
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
801
810
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
802
811
# For: loop_type=GRID
803
812
out = torch.empty_like(x)
804
- # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:786 >)
813
+ # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:795 >)
805
814
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
806
815
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
807
- # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:786 >)
816
+ # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:795 >)
808
817
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
809
818
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
810
819
for tile in hl.tile(x.size()):
811
- # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787 >)
812
- # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:785 >)
813
- # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:786 >)
814
- # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787 >)
815
- # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:787 >), key='sin')
816
- # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:787 >)
820
+ # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796 >)
821
+ # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:794 >)
822
+ # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:795 >)
823
+ # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796 >)
824
+ # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:796 >), key='sin')
825
+ # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:796 >)
817
826
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
818
- # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:786 >)
827
+ # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:795 >)
819
828
out[tile] = x[tile].sin()
820
829
return out
821
830
822
831
def root_graph_0():
823
- # File: .../test_type_propagation.py:787 in fn, code: out[tile] = x[tile].sin()
832
+ # File: .../test_type_propagation.py:796 in fn, code: out[tile] = x[tile].sin()
824
833
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
825
834
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
826
835
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')
0 commit comments