6
6
import functools
7
7
from operator import getitem
8
8
from typing import TYPE_CHECKING
9
+ from typing import Callable
9
10
from typing import ContextManager
10
11
from typing import NamedTuple
11
12
@@ -148,21 +149,33 @@ def convert_arg(arg: Node) -> TensorBox:
148
149
# pyre-ignore[6]
149
150
* map_arg ((node .args , node .kwargs ), convert_arg ),
150
151
)
151
- result .realize ()
152
- if not isinstance (result , TensorBox ) or not isinstance (result .data , StorageBox ):
153
- raise InductorLoweringError (
154
- f"Lowering { node .target } returned type(result), expected TensorBox(StorageBox(...)): { result } "
155
- )
156
- if not isinstance (buffer := result .data .data , ComputedBuffer ):
157
- raise InductorLoweringError (
158
- f"Lowering { node .target } returned buffer type { type (buffer )} , expected ComputedBuffer: { buffer } "
159
- )
152
+ if not isinstance (result , tuple ):
153
+ result = (result ,)
154
+ buffer_name_to_output_index = {}
155
+ for i , r in enumerate (result ):
156
+ r .realize ()
157
+ if not isinstance (r , TensorBox ) or not isinstance (r .data , StorageBox ):
158
+ raise InductorLoweringError (
159
+ f"Lowering { node .target } returned { type (r )} , expected TensorBox(StorageBox(...)): { r } "
160
+ )
161
+ if not isinstance (buffer := r .data .data , ComputedBuffer ):
162
+ raise InductorLoweringError (
163
+ f"Lowering { node .target } returned buffer type { type (buffer )} , expected ComputedBuffer: { buffer } "
164
+ )
165
+ buffer_name_to_output_index [buffer .get_name ()] = i
160
166
161
167
new_buffers = graph_lowering .buffers [prior_buffers :]
162
- assert new_buffers [ - 1 ] is buffer
168
+ assert buffer in new_buffers # pyre-ignore[61]
163
169
nodes = []
164
170
extra_input_names = []
165
171
new_node : torch .fx .Node
172
+
173
+ # Explicitly track the mapping from node to Inductor buffer name.
174
+ # First, map the original input nodes to their names.
175
+ node_to_buf_name_mapping : dict [torch .fx .Node , str ] = dict (
176
+ zip (node ._input_nodes , input_names , strict = True )
177
+ )
178
+
166
179
for i , buffer in enumerate (new_buffers ):
167
180
if not isinstance (buffer , ComputedBuffer ) or not isinstance (
168
181
buffer .data , (Pointwise , Reduction )
@@ -176,29 +189,49 @@ def convert_arg(arg: Node) -> TensorBox:
176
189
new_node .kwargs = {** new_node .kwargs , "_extra_args" : [* nodes ]}
177
190
else :
178
191
new_node = create_extra_node (node , buffer , [* node ._input_nodes , * nodes ])
192
+
193
+ # Store output index if this buffer corresponds to an output
194
+ if buffer .get_name () in buffer_name_to_output_index :
195
+ new_node .meta ["output_index" ] = buffer_name_to_output_index [
196
+ buffer .get_name ()
197
+ ]
198
+
179
199
lowering_cls = (
180
200
PointwiseLowering
181
201
if isinstance (buffer .data , Pointwise )
182
202
else ReductionLowering
183
203
)
184
204
buffer .freeze_layout ()
205
+
206
+ current_input_nodes = new_node ._input_nodes
207
+ current_input_names = []
208
+ for inp_node in current_input_nodes :
209
+ current_input_names .append (node_to_buf_name_mapping [inp_node ])
210
+
185
211
used_input_names = strip_unused_inputs (
186
212
new_node ,
187
213
buffer .get_read_names (),
188
- dict (
189
- zip (
190
- node .all_input_nodes ,
191
- [* input_names , * extra_input_names ],
192
- strict = True ,
193
- )
194
- ),
214
+ dict (zip (current_input_nodes , current_input_names , strict = True )),
195
215
)
196
216
new_node .meta ["lowering" ] = lowering = lowering_cls (buffer , used_input_names )
217
+ new_node .meta ["orig_node" ] = node
197
218
if isinstance (lowering , ReductionLowering ):
198
219
lowering .add_input_mask (new_node )
199
220
nodes .append (new_node )
200
221
extra_input_names .append (buffer .get_name ())
201
222
223
+ # Add this node to our mapping for future nodes to reference
224
+ node_to_buf_name_mapping [new_node ] = buffer .get_name ()
225
+
226
+ # After all nodes are created, build the output_nodes mapping for multi-output operations
227
+ if len (result ) > 1 and nodes :
228
+ last_node = nodes [- 1 ] # The last node is the main node
229
+ output_nodes = {}
230
+ for n in nodes :
231
+ if "output_index" in n .meta :
232
+ output_nodes [n .meta ["output_index" ]] = n .name
233
+ last_node .meta ["output_nodes" ] = output_nodes
234
+
202
235
203
236
def strip_unused_inputs (
204
237
node : torch .fx .Node ,
@@ -447,14 +480,23 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
447
480
strategy = BlockReductionStrategy (state , self .block_index )
448
481
449
482
inputs = self .input_fake_tensors (node )
450
- if len (inputs ) != 1 :
451
- # TODO(jansel): combine multiple inputs into a single fake value
452
- raise NotImplementedError ("reductions with >1 input" )
483
+
484
+ repr_input = None
485
+ if len (inputs ) == 1 :
486
+ repr_input = inputs [0 ]
487
+ else :
488
+ if node .meta ["orig_node" ].target == torch .ops .aten .var_mean .correction :
489
+ assert len (inputs ) == 2
490
+ # `inputs[0]` is the original input tensor to var_mean
491
+ repr_input = inputs [0 ]
492
+ else :
493
+ # TODO(jansel): combine multiple inputs into a single fake value
494
+ raise NotImplementedError ("reductions with >1 input" )
453
495
454
496
# TODO(jansel): find a better way to get dim
455
497
(dim ,) = [
456
498
i
457
- for i , v in enumerate (inputs [ 0 ] .shape )
499
+ for i , v in enumerate (repr_input .shape )
458
500
if TileStrategy .get_block_index (v ) == self .block_index
459
501
]
460
502
@@ -463,7 +505,7 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
463
505
output_name ,
464
506
reduction .reduction_type ,
465
507
dim ,
466
- inputs [ 0 ] ,
508
+ repr_input ,
467
509
node .meta ["val" ],
468
510
)
469
511
@@ -806,6 +848,14 @@ def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> str:
806
848
name = self .cg .lift (
807
849
expr_from_string (self .cg .device_function .user_sympy_expr (expr ))
808
850
).id
851
+
852
+ # If the lifted symbol refers to a `tl.constexpr` kernel
853
+ # argument (for example a tile/block size constant such as
854
+ # `_BLOCK_SIZE_1`) the resulting Triton value is not a tensor
855
+ # and therefore does not expose a `.to` method.
856
+ if name in self .cg .device_function ._constexpr_args :
857
+ return name
858
+
809
859
return f"{ name } .to({ triton_type (dtype )} )"
810
860
811
861
@@ -821,11 +871,57 @@ def __init__(self, graph: torch.fx.Graph, cg: GenerateAST) -> None:
821
871
super ().__init__ (_LazyGraphModule ({}, graph ), garbage_collect_values = False )
822
872
self .cg = cg
823
873
874
+ def _collect_multi_outputs (
875
+ self , node : Node , last_node_result : object
876
+ ) -> tuple [object , ...]:
877
+ """
878
+ Collect outputs for multi-output operations using metadata.
879
+ """
880
+ # Check if this operation has multiple outputs using the new metadata
881
+ assert "output_nodes" in node .meta
882
+ output_nodes = node .meta ["output_nodes" ]
883
+ outputs = [None ] * len (output_nodes )
884
+ all_nodes = {n .name : n for n in self .module .graph .nodes } # pyre-ignore[16]
885
+
886
+ for idx , node_name in output_nodes .items ():
887
+ if node_name == node .name :
888
+ # This is the last node
889
+ outputs [idx ] = last_node_result # pyre-ignore[6]
890
+ else :
891
+ # This is an extra node - get its result from env
892
+ if node_name in all_nodes :
893
+ extra_node = all_nodes [node_name ]
894
+ if extra_node in self .env :
895
+ outputs [idx ] = self .env [extra_node ]
896
+
897
+ # Ensure all outputs are found and are ast.Name nodes
898
+ final_outputs = []
899
+ for i , result in enumerate (outputs ):
900
+ assert result is not None
901
+ if not isinstance (result , ast .Name ):
902
+ var_name = self .cg .device_function .new_var (f"{ node .name } _output{ i } " )
903
+ self .cg .add_statement (
904
+ statement_from_string (f"{ var_name } = result" , result = result )
905
+ )
906
+ result = create (ast .Name , id = var_name , ctx = ast .Load ())
907
+ final_outputs .append (result )
908
+
909
+ return tuple (final_outputs )
910
+
824
911
def run_node (self , n : Node ) -> object :
825
912
if n .op == "call_function" :
826
913
with self ._set_current_node (n ), n .meta ["location" ]:
827
914
lowering : Lowering = n .meta ["lowering" ]
828
915
result = lowering .codegen (self , n )
916
+ n .meta ["codegen" ] = result
917
+
918
+ # Generic handling for operations with multiple outputs
919
+ if n .kwargs .get ("_extra_args" ):
920
+ # Check if this node has getitem users, indicating multiple outputs
921
+ getitem_users = [user for user in n .users if user .target == getitem ]
922
+ if len (getitem_users ) > 0 :
923
+ return self ._collect_multi_outputs (n , result )
924
+
829
925
if result is None :
830
926
return None
831
927
if not isinstance (result , ast .AST ):
0 commit comments