@@ -567,14 +567,12 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt
567
567
// CHECK-LABEL: @test_matmulinteger
568
568
func.func @test_matmulinteger (%arg0: !torch.vtensor <[4 ,3 ],ui8 >, %arg1: !torch.vtensor <[3 ,2 ],ui8 >, %arg2: !torch.vtensor <[1 ],ui8 >, %arg3: !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[4 ,2 ],si32 > attributes {torch.onnx_meta.ir_version = 5 : si64 , torch.onnx_meta.opset_version = 10 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
569
569
%0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[4 ,3 ],ui8 >, !torch.vtensor <[3 ,2 ],ui8 >, !torch.vtensor <[1 ],ui8 >, !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[4 ,2 ],si32 >
570
- // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
571
- // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
572
- // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
573
- // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574
- // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
575
- // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
576
- // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
577
- // CHECK: return %[[MM]]
570
+ // CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[4,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,3],si32>
571
+ // CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[3,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],si32>
572
+ // CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[4,3],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[4,3],si32>
573
+ // CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[3,2],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[3,2],si32>
574
+ // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[4,3],si32>, !torch.vtensor<[3,2],si32> -> !torch.vtensor<[4,2],si32>
575
+ // CHECK: return %[[MM]] : !torch.vtensor<[4,2],si32>
578
576
return %0 : !torch.vtensor <[4 ,2 ],si32 >
579
577
}
580
578
@@ -583,57 +581,39 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt
583
581
// CHECK-LABEL: @test_matmulinteger_batched
584
582
func.func @test_matmulinteger_batched (%arg0: !torch.vtensor <[7 ,4 ,3 ],ui8 >, %arg1: !torch.vtensor <[3 ,2 ],ui8 >, %arg2: !torch.vtensor <[1 ],ui8 >, %arg3: !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[7 ,4 ,2 ],si32 > attributes {torch.onnx_meta.ir_version = 5 : si64 , torch.onnx_meta.opset_version = 10 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
585
583
%0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[7 ,4 ,3 ],ui8 >, !torch.vtensor <[3 ,2 ],ui8 >, !torch.vtensor <[1 ],ui8 >, !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[7 ,4 ,2 ],si32 >
586
- // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
587
- // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
588
- // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
589
- // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
590
- // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
591
- // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
592
- // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
593
- // CHECK: return %[[MM]]
584
+ // CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[7,4,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[7,4,3],si32>
585
+ // CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[3,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],si32>
586
+ // CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[7,4,3],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[7,4,3],si32>
587
+ // CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[3,2],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[3,2],si32>
588
+ // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[7,4,3],si32>, !torch.vtensor<[3,2],si32> -> !torch.vtensor<[7,4,2],si32>
589
+ // CHECK: return %[[MM]] : !torch.vtensor<[7,4,2],si32>
594
590
return %0 : !torch.vtensor <[7 ,4 ,2 ],si32 >
595
591
}
596
592
597
593
// -----
598
594
599
595
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_lhsZp(
600
- // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[16,2],ui8>,
601
- // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
602
- // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16],ui8>,
603
- // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
604
- func.func @test_matmulinteger_non_scalar_lhsZp (%arg0: !torch.vtensor <[16 , 2 ],ui8 >, %arg1: !torch.vtensor <[2 ,768 ],si8 >, %arg2: !torch.vtensor <[16 ],ui8 >, %arg3: !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 21 : si64 , torch.onnx_meta.producer_name = " pytorch" , torch.onnx_meta.producer_version = " 0.1.0" } {
605
- // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],si8> -> !torch.int
606
- // CHECK: %[[VAL_5:.*]] = torch.constant.int 6
607
- // CHECK: %[[VAL_6:.*]] = torch.constant.none
608
- // CHECK: %[[VAL_7:.*]] = torch.constant.int 0
609
- // CHECK: %[[VAL_8:.*]] = torch.aten.ones_like %[[VAL_2]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[16],ui8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[16],f32>
610
- // CHECK: %[[VAL_9:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_0]], %[[VAL_8]], %[[VAL_2]], %[[VAL_7]] : !torch.vtensor<[16,2],ui8>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],ui8>, !torch.int -> !torch.vtensor<[16,2],!torch.quint8>
611
- // CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e+00
612
- // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_4]] : !torch.vtensor<[2,768],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
613
- // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_9]], %[[VAL_11]] : !torch.vtensor<[16,2],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[16,768],si32>
614
- // CHECK: return %[[VAL_12]] : !torch.vtensor<[16,768],si32>
615
- %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[16 ,2 ],ui8 >, !torch.vtensor <[2 ,768 ],si8 >, !torch.vtensor <[16 ],ui8 >, !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 >
596
+ func.func @test_matmulinteger_non_scalar_lhsZp (%arg0: !torch.vtensor <[16 , 2 ],ui8 >, %arg1: !torch.vtensor <[2 ,768 ],si8 >, %arg2: !torch.vtensor <[16 ,1 ],ui8 >, %arg3: !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 21 : si64 , torch.onnx_meta.producer_name = " pytorch" , torch.onnx_meta.producer_version = " 0.1.0" } {
597
+ // CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[16,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[16,2],si32>
598
+ // CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[2,768],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,768],si32>
599
+ // CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[16,2],si32>, !torch.vtensor<[16,1],ui8>, !torch.int -> !torch.vtensor<[16,2],si32>
600
+ // CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[2,768],si32>, !torch.vtensor<[],si8>, !torch.int -> !torch.vtensor<[2,768],si32>
601
+ // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[16,2],si32>, !torch.vtensor<[2,768],si32> -> !torch.vtensor<[16,768],si32>
602
+ // CHECK: return %[[MM]] : !torch.vtensor<[16,768],si32>
603
+ %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[16 ,2 ],ui8 >, !torch.vtensor <[2 ,768 ],si8 >, !torch.vtensor <[16 ,1 ],ui8 >, !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 >
616
604
return %0 : !torch.vtensor <[16 ,768 ],si32 >
617
605
}
618
606
619
607
// -----
620
608
621
609
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_rhsZp(
622
- // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],ui8>,
623
- // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
624
- // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],ui8>,
625
- // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_met
626
610
func.func @test_matmulinteger_non_scalar_rhsZp (%arg0: !torch.vtensor <[?,?],ui8 >, %arg1: !torch.vtensor <[2 ,768 ],si8 >, %arg2: !torch.vtensor <[],ui8 >, %arg3: !torch.vtensor <[768 ],si8 >) -> !torch.vtensor <[?,768 ],si32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 21 : si64 , torch.onnx_meta.producer_name = " pytorch" , torch.onnx_meta.producer_version = " 0.1.0" } {
627
- // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],ui8> -> !torch.int
628
- // CHECK: %[[VAL_5:.*]] = torch.constant.int 6
629
- // CHECK: %[[VAL_6:.*]] = torch.constant.none
630
- // CHECK: %[[VAL_7:.*]] = torch.constant.float 1.000000e+00
631
- // CHECK: %[[VAL_8:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_4]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
632
- // CHECK: %[[VAL_9:.*]] = torch.constant.int 1
633
- // CHECK: %[[VAL_10:.*]] = torch.aten.ones_like %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[768],si8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[768],f32>
634
- // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_3]], %[[VAL_9]] : !torch.vtensor<[2,768],si8>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
635
- // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_8]], %[[VAL_11]] : !torch.vtensor<[?,?],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[?,768],si32>
636
- // CHECK: return %[[VAL_12]] : !torch.vtensor<[?,768],si32>
611
+ // CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
612
+ // CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[2,768],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,768],si32>
613
+ // CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[?,?],si32>, !torch.vtensor<[],ui8>, !torch.int -> !torch.vtensor<[?,?],si32>
614
+ // CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[2,768],si32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],si32>
615
+ // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[?,?],si32>, !torch.vtensor<[2,768],si32> -> !torch.vtensor<[?,768],si32>
616
+ // CHECK: return %[[MM]] : !torch.vtensor<[?,768],si32>
637
617
%0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[?,?],ui8 >, !torch.vtensor <[2 ,768 ],si8 >, !torch.vtensor <[],ui8 >, !torch.vtensor <[768 ],si8 >) -> !torch.vtensor <[?,768 ],si32 >
638
618
return %0 : !torch.vtensor <[?,768 ],si32 >
639
619
}
0 commit comments