@@ -171,3 +171,238 @@ def forward(self, x):
171
171
)
172
172
print (m )
173
173
m .operation .verify ()
174
+
175
+
176
+ @run
177
+ # CHECK-LABEL: test_single_input_const_argument
178
+ # CHECK: %[[int2:.+]] = torch.constant.int 2
179
+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[int2]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
180
+ # CHECK: return %[[buffer]] : !torch.vtensor<[3,4],f32>
181
+ def test_single_input_const_argument ():
182
+ class SingleConstantInputModule (torch .nn .Module ):
183
+ def __init__ (self ):
184
+ super ().__init__ ()
185
+
186
+ def forward (self , x , y = 2 ): # Single constant input
187
+ return x * y
188
+
189
+ m = fx .export_and_import (
190
+ SingleConstantInputModule (),
191
+ torch .randn (3 , 4 ),
192
+ experimental_support_mutation = True ,
193
+ )
194
+ print (m )
195
+ m .operation .verify ()
196
+
197
+
198
+ @run
199
+ # CHECK-LABEL: test_single_output_const_argument
200
+ # CHECK: %[[float1:.+]] = torch.constant.float 5.000000e-01
201
+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float1]]
202
+ # CHECK: %[[float2:.+]] = torch.constant.float 5.000000e-01
203
+ # CHECK: return %[[buffer]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float
204
+ def test_single_output_const_argument ():
205
+ class SingleConstantOutputModule (torch .nn .Module ):
206
+ def __init__ (self ):
207
+ super ().__init__ ()
208
+ self .scale = 0.5 # Single constant output
209
+
210
+ def forward (self , x ):
211
+ scaled = x * self .scale
212
+ return scaled , self .scale # Return tensor + constant
213
+
214
+ m = fx .export_and_import (
215
+ SingleConstantOutputModule (),
216
+ torch .randn (3 , 4 ),
217
+ experimental_support_mutation = True ,
218
+ )
219
+ print (m )
220
+ m .operation .verify ()
221
+
222
+
223
+ @run
224
+ # CHECK-LABEL: test_multiple_input_const_argument
225
+ # CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
226
+ # CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
227
+ # CHECK: %[[float3:.+]] = torch.constant.float 3.000000e+00
228
+ # CHECK: %[[int1:.+]] = torch.constant.int 1
229
+ # CHECK: %[[buffer1:.+]] = torch.aten.add.Scalar %[[buffer0]], %[[float3]], %[[int1]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,4],f32>
230
+ # CHECK: return %[[buffer1]] : !torch.vtensor<[3,4],f32>
231
+ def test_multiple_input_const_argument ():
232
+ class MultipleConstantInputModule (torch .nn .Module ):
233
+ def __init__ (self ):
234
+ super ().__init__ ()
235
+
236
+ def forward (
237
+ self , x , scale = 2.0 , offset = 1.0 , multiplier = 3
238
+ ): # Multiple constant inputs
239
+ return x * scale + offset * multiplier
240
+
241
+ m = fx .export_and_import (
242
+ MultipleConstantInputModule (),
243
+ torch .randn (3 , 4 ),
244
+ experimental_support_mutation = True ,
245
+ )
246
+ print (m )
247
+ m .operation .verify ()
248
+
249
+
250
+ @run
251
+ # CHECK-LABEL: test_multiple_output_const_argument
252
+ # CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
253
+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
254
+ # CHECK: %[[str:.+]] = torch.constant.str "model"
255
+ # CHECK: %[[int42:.+]] = torch.constant.int 42
256
+ # CHECK: %[[true:.+]] = torch.constant.bool true
257
+ # CHECK: %[[none:.+]] = torch.constant.none
258
+ # CHECK: return %[[buffer]], %[[float5]]
259
+ # CHECK-SAME: %[[str]], %[[int42]], %[[true]], %[[none]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.str, !torch.int, !torch.bool, !torch.none
260
+ def test_multiple_output_const_argument ():
261
+ class MultipleConstantOutputModule (torch .nn .Module ):
262
+ def __init__ (self ):
263
+ super ().__init__ ()
264
+ self .scale = 0.5
265
+ self .name = "model"
266
+ self .version = 42
267
+
268
+ def forward (self , x ):
269
+ result = x * self .scale
270
+ # Return tensor + multiple constants
271
+ return result , self .scale , self .name , self .version , True , None
272
+
273
+ m = fx .export_and_import (
274
+ MultipleConstantOutputModule (),
275
+ torch .randn (3 , 4 ),
276
+ experimental_support_mutation = True ,
277
+ )
278
+ print (m )
279
+ m .operation .verify ()
280
+
281
+
282
+ @run
283
+ # CHECK-LABEL: test_input_output_const_argument
284
+ # CHECK: %[[float5:.+]] = torch.constant.float 5.000000e-01
285
+ # CHECK: %[[buffer0:.+]] = torch.aten.mul.Scalar %arg0, %[[float5]]
286
+ # CHECK: %[[float2:.+]] = torch.constant.float 2.000000e+00
287
+ # CHECK: %[[buffer1:.+]] = torch.aten.mul.Scalar %[[buffer0]], %[[float2]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
288
+ # CHECK: %[[float1:.+]] = torch.constant.float 1.000000e+00
289
+ # CHECK: %[[int1:.+]] = torch.constant.int 1
290
+ # CHECK: %[[buffer2:.+]] = torch.aten.add.Scalar %[[buffer1]], %[[float1]], %[[int1]]
291
+ # CHECK: %[[str:.+]] = torch.constant.str "combined_model"
292
+ # CHECK: %[[true:.+]] = torch.constant.bool true
293
+ # CHECK: %[[none:.+]] = torch.constant.none
294
+ # CHECK: return %[[buffer2]], %[[float5]]
295
+ # CHECK-SAME: %[[str]]
296
+ def test_input_output_const_argument ():
297
+ class CombinedConstantModule (torch .nn .Module ):
298
+ def __init__ (self ):
299
+ super ().__init__ ()
300
+ self .base_scale = 0.5
301
+ self .model_name = "combined_model"
302
+
303
+ def forward (self , x , user_scale = 2.0 , add_bias = True , bias_value = 1.0 ):
304
+ if add_bias :
305
+ result = (x * self .base_scale * user_scale ) + bias_value
306
+ else :
307
+ result = x * self .base_scale * user_scale
308
+
309
+ # Return mix of tensors and constants (both output and input)
310
+ return (
311
+ result , # tensor
312
+ self .base_scale , # constantArgument output
313
+ self .model_name , # constantArgument output
314
+ user_scale , # constantArgument input
315
+ add_bias , # constantArgument input
316
+ bias_value , # constantArgument input
317
+ None , # constantArgument literal (output)
318
+ )
319
+
320
+ m = fx .export_and_import (
321
+ CombinedConstantModule (), torch .randn (3 , 4 ), experimental_support_mutation = True
322
+ )
323
+ print (m )
324
+ m .operation .verify ()
325
+
326
+
327
+ @run
328
+ # CHECK-LABEL: test_const_argument_edge_cases
329
+ # CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>) ->
330
+ # CHECK-SAME: (!torch.vtensor<[3,4],f32>, !torch.float, !torch.int, !torch.str, !torch.bool, !torch.none, !torch.none, !torch.str, !torch.int, !torch.bool)
331
+ # CHECK: %[[float314:.+]] = torch.constant.float 3.140000e+00
332
+ # CHECK: %[[buffer:.+]] = torch.aten.mul.Scalar %arg0, %[[float314]]
333
+ # CHECK: %[[int42:.+]] = torch.constant.int 42
334
+ # CHECK: %[[string1:.+]] = torch.constant.str "test"
335
+ # CHECK: %[[true:.+]] = torch.constant.bool true
336
+ # CHECK: %[[none:.+]] = torch.constant.none
337
+ # CHECK: %[[string2:.+]] = torch.constant.str "default"
338
+ # CHECK: %[[int0:.+]] = torch.constant.int 0
339
+ # CHECK: %[[false:.+]] = torch.constant.bool false
340
+ # CHECK: return %[[buffer]], %[[float314]]
341
+ # CHECK-SAME: %[[int42]], %[[string1]], %[[true]], %[[none]], %[[none]]
342
+ # CHECK-SAME: %[[string2]], %[[int0]], %[[false]]
343
+ def test_const_argument_edge_cases ():
344
+ class EdgeCaseConstantModule (torch .nn .Module ):
345
+ def __init__ (self ):
346
+ super ().__init__ ()
347
+ self .float_val = 3.14
348
+ self .int_val = 42
349
+ self .str_val = "test"
350
+ self .bool_val = True
351
+ self .none_val = None
352
+
353
+ def forward (self , x , input_none = None , input_str = "default" ):
354
+ result = x * self .float_val
355
+
356
+ # Return all different ConstantArgument types
357
+ return (
358
+ result , # tensor
359
+ self .float_val , # float output constantArgument
360
+ self .int_val , # int output constantArgument
361
+ self .str_val , # string output constantArgument
362
+ self .bool_val , # bool output constantArgument
363
+ self .none_val , # None output constantArgument
364
+ input_none , # None input constantArgument
365
+ input_str , # string input constantArgument
366
+ 0 , # literal int
367
+ False , # literal bool
368
+ )
369
+
370
+ m = fx .export_and_import (
371
+ EdgeCaseConstantModule (), torch .randn (3 , 4 ), experimental_support_mutation = True
372
+ )
373
+ print (m )
374
+ m .operation .verify ()
375
+
376
+
377
+ @run
378
+ # CHECK-LABEL: test_const_argument_from_multiheadattention_layer
379
+ # CHECK: func.func @main(%arg0: !torch.vtensor<[1,10,64],f32>, %arg1: !torch.vtensor<[1,10,64],f32>, %arg2: !torch.vtensor<[1,10,64],f32>) ->
380
+ # CHECK-SAME: (!torch.vtensor<[1,10,64],f32>, !torch.none)
381
+ # CHECK: %[[int1:.+]] = torch.constant.int 1
382
+ # CHECK: %[[int0:.+]] = torch.constant.int 0
383
+ # CHECK-DAG: %[[buffer:.+]] = torch.aten.transpose.int %arg0, %[[int1]], %[[int0]] : !torch.vtensor<[1,10,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,1,64],f32>
384
+ def test_const_argument_from_multiheadattention_layer ():
385
+ """
386
+ Test case using actual MultiheadAttention where a constantArgument appears automatically
387
+ due to returning the attention layer without the weights (need_weights=False)
388
+ """
389
+
390
+ class AttentionLikeConstantModule (torch .nn .Module ):
391
+ def __init__ (self ):
392
+ super ().__init__ ()
393
+ self .attn = torch .nn .MultiheadAttention (
394
+ embed_dim = 64 , num_heads = 1 , dropout = 0.1 , batch_first = True
395
+ )
396
+
397
+ def forward (self , query , key , value , need_weights = False ):
398
+ return self .attn (query , key , value , need_weights = need_weights )
399
+
400
+ m = fx .export_and_import (
401
+ AttentionLikeConstantModule (),
402
+ torch .randn (1 , 10 , 64 ), # query
403
+ torch .randn (1 , 10 , 64 ), # key
404
+ torch .randn (1 , 10 , 64 ), # value
405
+ experimental_support_mutation = True ,
406
+ )
407
+ print (m )
408
+ m .operation .verify ()
0 commit comments