Skip to content

Commit 4f3a60b

Browse files
authored
Add ConstantArgument support to fx_import (llvm#4244)
This PR will fix the following issue: [fx_importer NotImplementedError: MultiheadAttention layer with NeedWeight = false](llvm#4158) The following error was raised before this fix: Python Error: NotImplementedError: OutputKind.USER_OUTPUT for <class 'torch.export.graph_signature.ConstantArgument'>: ConstantArgument(name='', value=None) This occurs for an exported MultiheadAttention layer with "NeedWeight = false" which means weights are not going to be returned by the layer. So, the second output attn_output_weights will be None in this case.
1 parent 68011ea commit 4f3a60b

File tree

2 files changed

+284
-17
lines changed

2 files changed

+284
-17
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,7 @@ def import_program(
630630
OutputKind,
631631
TensorArgument,
632632
SymIntArgument,
633+
ConstantArgument,
633634
)
634635

635636
sig = prog.graph_signature
@@ -650,24 +651,35 @@ def import_program(
650651
constant_tensors: Dict[Node, torch.Tensor] = {}
651652
parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
652653
buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
654+
constant_output_values: Dict[int, Any] = {}
655+
constant_input_values: Dict[Node, Any] = {}
653656

654657
# Derive user outputs that we preserve. These will be nodes of the
655658
# producer for the output.
656-
user_outputs: List[Node] = []
659+
user_outputs: List[Optional[Node]] = []
657660
user_output_types: List[IrType] = []
658-
for output_spec in sig.output_specs:
661+
for i, output_spec in enumerate(sig.output_specs):
659662
kind = output_spec.kind
660663
arg = output_spec.arg
661664
if kind == OutputKind.USER_OUTPUT:
662-
if not isinstance(arg, (TensorArgument, SymIntArgument)):
665+
if not isinstance(
666+
arg, (TensorArgument, SymIntArgument, ConstantArgument)
667+
):
663668
raise NotImplementedError(
664669
f"OutputKind.USER_OUTPUT for {type(arg)}: {arg}"
665670
)
666-
output_producer_node = all_producer_nodes[arg.name]
667-
user_outputs.append(output_producer_node)
668-
user_output_types.append(
669-
self._cc.node_val_to_type(output_producer_node)
670-
)
671+
if isinstance(arg, (TensorArgument, SymIntArgument)):
672+
output_producer_node = all_producer_nodes[arg.name]
673+
user_outputs.append(output_producer_node)
674+
user_output_types.append(
675+
self._cc.node_val_to_type(output_producer_node)
676+
)
677+
elif isinstance(arg, ConstantArgument):
678+
# Constant Outputs don't have a node so we will only store their values
679+
constant_output_values[i] = arg.value
680+
# Placeholder for constant outputs in the node list
681+
user_outputs.append(None)
682+
user_output_types.append(self._cc.value_info_to_type(arg.value))
671683
elif kind == OutputKind.BUFFER_MUTATION and isinstance(arg, TensorArgument):
672684
mutable_buffer_target_producers[output_spec.target] = arg.name
673685

@@ -678,16 +690,22 @@ def import_program(
678690
arg = input_spec.arg
679691
if input_spec.kind == InputKind.USER_INPUT:
680692
# Set up user input.
681-
if not isinstance(arg, (TensorArgument, SymIntArgument)):
693+
if not isinstance(
694+
arg, (TensorArgument, SymIntArgument, ConstantArgument)
695+
):
682696
raise NotImplementedError(
683697
f"InputKind.USER_INPUT for {type(arg)}: {arg}"
684698
)
685699
placeholder_node = placeholder_nodes[arg.name]
686-
mutable = placeholder_node.name in mutated_user_inputs
687-
user_inputs.append(placeholder_node)
688-
user_input_types.append(
689-
self._cc.node_val_to_type(placeholder_node, mutable=mutable)
690-
)
700+
if isinstance(arg, (TensorArgument, SymIntArgument)):
701+
mutable = placeholder_node.name in mutated_user_inputs
702+
user_inputs.append(placeholder_node)
703+
user_input_types.append(
704+
self._cc.node_val_to_type(placeholder_node, mutable=mutable)
705+
)
706+
elif isinstance(arg, ConstantArgument):
707+
# Constant argument will be handled separately, they are not mutable and do not need function parameters
708+
constant_input_values[placeholder_node] = arg.value
691709
elif input_spec.kind == InputKind.CONSTANT_TENSOR and isinstance(
692710
arg, TensorArgument
693711
):
@@ -778,6 +796,9 @@ def import_program(
778796
for constant_node, constant_tensor in constant_tensors.items():
779797
node_importer.import_constant(loc, constant_node, constant_tensor)
780798

799+
for constant_node, constant_value in constant_input_values.items():
800+
node_importer.import_constant(loc, constant_node, constant_value)
801+
781802
# Bind user inputs to IR values.
782803
for user_input_node, block_arg_value in zip(user_inputs, entry_block.arguments):
783804
if user_input_node.name in mutated_user_inputs:
@@ -804,7 +825,10 @@ def import_program(
804825
skip_placeholders_outputs=True,
805826
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
806827
)
807-
node_importer.return_node_values(loc, user_outputs)
828+
829+
# Call the return function that handles both nodes and constant values
830+
node_importer.return_node_values(loc, user_outputs, constant_output_values)
831+
808832
self.symbol_table.insert(func_op)
809833
return func_op
810834

@@ -1419,9 +1443,17 @@ def on_produced(value: Value):
14191443

14201444
self._on_node_produced[info.store_producer_node] = on_produced
14211445

1422-
def return_node_values(self, loc, nodes: List[Node]):
1446+
def return_node_values(self, loc, nodes: List[Node], constants: Dict[int, Any]):
1447+
# This function returns both node values and constant values
14231448
with loc, InsertionPoint(self._b):
1424-
operands = [self.resolve_node_value(n) for n in nodes]
1449+
operands = [
1450+
(
1451+
self.resolve_node_value(n)
1452+
if isinstance(n, Node)
1453+
else self._import_literal(constants[index])
1454+
)
1455+
for index, n in enumerate(nodes)
1456+
]
14251457
func_dialect.ReturnOp(operands, loc=loc)
14261458

14271459
def import_nodes(

test/python/fx_importer/v2.3/mutation_import.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,238 @@ def forward(self, x):
171171
)
172172
print(m)
173173
m.operation.verify()
174+
175+
176+
@run
177+
# CHECK-LABEL: test_single_input_const_argument
178+
# CHECK: %[[int2:.+]] = torch.constant.int 2
179+
# CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[int2]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
180+
# CHECK: return %[[buffer]] : !torch.vtensor<[3,4],f32>
181+
def test_single_input_const_argument():
182+
class SingleConstantInputModule(torch.nn.Module):
183+
def __init__(self):
184+
super().__init__()
185+
186+
def forward(self, x, y=2): # Single constant input
187+
return x * y
188+
189+
m = fx.export_and_import(
190+
SingleConstantInputModule(),
191+
torch.randn(3, 4),
192+
experimental_support_mutation=True,
193+
)
194+
print(m)
195+
m.operation.verify()
196+
197+
198+
@run
199+
# CHECK-LABEL: test_single_output_const_argument
200+
# CHECK: %[[float1:.+]] = torch.constant.float 5.000000e-01
201+
# CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float1]]
202+
# CHECK: %[[float2:.+]] = torch.constant.float 5.000000e-01
203+
# CHECK: return %[[buffer]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float
204+
def test_single_output_const_argument():
205+
class SingleConstantOutputModule(torch.nn.Module):
206+
def __init__(self):
207+
super().__init__()
208+
self.scale = 0.5 # Single constant output
209+
210+
def forward(self, x):
211+
scaled = x * self.scale
212+
return scaled, self.scale # Return tensor + constant
213+
214+
m = fx.export_and_import(
215+
SingleConstantOutputModule(),
216+
torch.randn(3, 4),
217+
experimental_support_mutation=True,
218+
)
219+
print(m)
220+
m.operation.verify()
221+
222+
223+
@run
224+
# CHECK-LABEL: test_multiple_input_const_argument
225+
# CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
226+
# CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
227+
# CHECK: %[[float3:.+]] = torch.constant.float 3.000000e+00
228+
# CHECK: %[[int1:.+]] = torch.constant.int 1
229+
# CHECK: %[[buffer1:.+]] = torch.aten.add.Scalar %[[buffer0]], %[[float3]], %[[int1]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,4],f32>
230+
# CHECK: return %[[buffer1]] : !torch.vtensor<[3,4],f32>
231+
def test_multiple_input_const_argument():
232+
class MultipleConstantInputModule(torch.nn.Module):
233+
def __init__(self):
234+
super().__init__()
235+
236+
def forward(
237+
self, x, scale=2.0, offset=1.0, multiplier=3
238+
): # Multiple constant inputs
239+
return x * scale + offset * multiplier
240+
241+
m = fx.export_and_import(
242+
MultipleConstantInputModule(),
243+
torch.randn(3, 4),
244+
experimental_support_mutation=True,
245+
)
246+
print(m)
247+
m.operation.verify()
248+
249+
250+
@run
251+
# CHECK-LABEL: test_multiple_output_const_argument
252+
# CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
253+
# CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
254+
# CHECK: %[[str:.+]] = torch.constant.str "model"
255+
# CHECK: %[[int42:.+]] = torch.constant.int 42
256+
# CHECK: %[[true:.+]] = torch.constant.bool true
257+
# CHECK: %[[none:.+]] = torch.constant.none
258+
# CHECK: return %[[buffer]], %[[float5]]
259+
# CHECK-SAME: %[[str]], %[[int42]], %[[true]], %[[none]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.str, !torch.int, !torch.bool, !torch.none
260+
def test_multiple_output_const_argument():
261+
class MultipleConstantOutputModule(torch.nn.Module):
262+
def __init__(self):
263+
super().__init__()
264+
self.scale = 0.5
265+
self.name = "model"
266+
self.version = 42
267+
268+
def forward(self, x):
269+
result = x * self.scale
270+
# Return tensor + multiple constants
271+
return result, self.scale, self.name, self.version, True, None
272+
273+
m = fx.export_and_import(
274+
MultipleConstantOutputModule(),
275+
torch.randn(3, 4),
276+
experimental_support_mutation=True,
277+
)
278+
print(m)
279+
m.operation.verify()
280+
281+
282+
@run
283+
# CHECK-LABEL: test_input_output_const_argument
284+
# CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
285+
# CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]]
286+
# CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
287+
# CHECK: %[[buffer1:.+]] = torch.aten.mul.Scalar %[[buffer0]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
288+
# CHECK: %[[float1:.+]] = torch.constant.float 1.000000e+00
289+
# CHECK: %[[int1:.+]] = torch.constant.int 1
290+
# CHECK: %[[buffer2:.+]] = torch.aten.add.Scalar %[[buffer1]], %[[float1]], %[[int1]]
291+
# CHECK: %[[str:.+]] = torch.constant.str "combined_model"
292+
# CHECK: %[[true:.+]] = torch.constant.bool true
293+
# CHECK: %[[none:.+]] = torch.constant.none
294+
# CHECK: return %[[buffer2]], %[[float5]]
295+
# CHECK-SAME: %[[str]]
296+
def test_input_output_const_argument():
297+
class CombinedConstantModule(torch.nn.Module):
298+
def __init__(self):
299+
super().__init__()
300+
self.base_scale = 0.5
301+
self.model_name = "combined_model"
302+
303+
def forward(self, x, user_scale=2.0, add_bias=True, bias_value=1.0):
304+
if add_bias:
305+
result = (x * self.base_scale * user_scale) + bias_value
306+
else:
307+
result = x * self.base_scale * user_scale
308+
309+
# Return mix of tensors and constants (both output and input)
310+
return (
311+
result, # tensor
312+
self.base_scale, # constantArgument output
313+
self.model_name, # constantArgument output
314+
user_scale, # constantArgument input
315+
add_bias, # constantArgument input
316+
bias_value, # constantArgument input
317+
None, # constantArgument literal (output)
318+
)
319+
320+
m = fx.export_and_import(
321+
CombinedConstantModule(), torch.randn(3, 4), experimental_support_mutation=True
322+
)
323+
print(m)
324+
m.operation.verify()
325+
326+
327+
@run
328+
# CHECK-LABEL: test_const_argument_edge_cases
329+
# CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>) ->
330+
# CHECK-SAME: (!torch.vtensor<[3,4],f32>, !torch.float, !torch.int, !torch.str, !torch.bool, !torch.none, !torch.none, !torch.str, !torch.int, !torch.bool)
331+
# CHECK: %[[float314:.+]] = torch.constant.float 3.140000e+00
332+
# CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float314]]
333+
# CHECK: %[[int42:.+]] = torch.constant.int 42
334+
# CHECK: %[[string1:.+]] = torch.constant.str "test"
335+
# CHECK: %[[true:.+]] = torch.constant.bool true
336+
# CHECK: %[[none:.+]] = torch.constant.none
337+
# CHECK: %[[string2:.+]] = torch.constant.str "default"
338+
# CHECK: %[[int0:.+]] = torch.constant.int 0
339+
# CHECK: %[[false:.+]] = torch.constant.bool false
340+
# CHECK: return %[[buffer]], %[[float314]]
341+
# CHECK-SAME: %[[int42]], %[[string1]], %[[true]], %[[none]], %[[none]]
342+
# CHECK-SAME: %[[string2]], %[[int0]], %[[false]]
343+
def test_const_argument_edge_cases():
344+
class EdgeCaseConstantModule(torch.nn.Module):
345+
def __init__(self):
346+
super().__init__()
347+
self.float_val = 3.14
348+
self.int_val = 42
349+
self.str_val = "test"
350+
self.bool_val = True
351+
self.none_val = None
352+
353+
def forward(self, x, input_none=None, input_str="default"):
354+
result = x * self.float_val
355+
356+
# Return all different ConstantArgument types
357+
return (
358+
result, # tensor
359+
self.float_val, # float output constantArgument
360+
self.int_val, # int output constantArgument
361+
self.str_val, # string output constantArgument
362+
self.bool_val, # bool output constantArgument
363+
self.none_val, # None output constantArgument
364+
input_none, # None input constantArgument
365+
input_str, # string input constantArgument
366+
0, # literal int
367+
False, # literal bool
368+
)
369+
370+
m = fx.export_and_import(
371+
EdgeCaseConstantModule(), torch.randn(3, 4), experimental_support_mutation=True
372+
)
373+
print(m)
374+
m.operation.verify()
375+
376+
377+
@run
378+
# CHECK-LABEL: test_const_argument_from_multiheadattention_layer
379+
# CHECK: func.func @main(%arg0: !torch.vtensor<[1,10,64],f32>, %arg1: !torch.vtensor<[1,10,64],f32>, %arg2: !torch.vtensor<[1,10,64],f32>) ->
380+
# CHECK-SAME: (!torch.vtensor<[1,10,64],f32>, !torch.none)
381+
# CHECK: %[[int1:.+]] = torch.constant.int 1
382+
# CHECK: %[[int0:.+]] = torch.constant.int 0
383+
# CHECK-DAG: %[[buffer:.+]] = torch.aten.transpose.int %arg0, %[[int1]], %[[int0]] : !torch.vtensor<[1,10,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,1,64],f32>
384+
def test_const_argument_from_multiheadattention_layer():
385+
"""
386+
Test case using actual MultiheadAttention where a constantArgument appears automatically
387+
due to returning the attention layer without the weights (need_weights=False)
388+
"""
389+
390+
class AttentionLikeConstantModule(torch.nn.Module):
391+
def __init__(self):
392+
super().__init__()
393+
self.attn = torch.nn.MultiheadAttention(
394+
embed_dim=64, num_heads=1, dropout=0.1, batch_first=True
395+
)
396+
397+
def forward(self, query, key, value, need_weights=False):
398+
return self.attn(query, key, value, need_weights=need_weights)
399+
400+
m = fx.export_and_import(
401+
AttentionLikeConstantModule(),
402+
torch.randn(1, 10, 64), # query
403+
torch.randn(1, 10, 64), # key
404+
torch.randn(1, 10, 64), # value
405+
experimental_support_mutation=True,
406+
)
407+
print(m)
408+
m.operation.verify()

0 commit comments

Comments
 (0)