Skip to content

Commit 26d2a09

Browse files
[ONNX] Simplify Onnx.MatMulInteger op lowering (llvm#4163)
This commit modifies the Onnx.MatMulInteger op lowering to fix the Torch->Linalg lowering path of the op. Fixes nod-ai/SHARK-ModelDev#906. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 60379d7 commit 26d2a09

File tree

2 files changed

+59
-134
lines changed

2 files changed

+59
-134
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 33 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -551,16 +551,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
551551
patterns.onOp(
552552
"MatMulInteger", 10,
553553
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
554+
Location loc = binder.getLoc();
554555
Torch::ValueTensorType resultType;
555556
Value lhs, rhs, lhsZp, rhsZp;
556557
if (binder.tensorOperandAtIndex(lhs, 0) ||
557558
binder.tensorOperandAtIndex(rhs, 1) ||
558559
binder.tensorResultType(resultType))
559560
return failure();
560561

561-
auto lhsTy = dyn_cast<Torch::ValueTensorType>(lhs.getType());
562-
auto rhsTy = dyn_cast<Torch::ValueTensorType>(rhs.getType());
563-
564562
if (binder.tensorOperandAtIndex(lhsZp, 2)) {
565563
lhsZp = rewriter.create<Torch::ConstantIntOp>(
566564
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@@ -573,92 +571,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
573571
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
574572
}
575573

576-
bool isChannelQuantizationForLhs = false;
577-
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(lhsZp.getType())) {
578-
auto lhsZpSize = zpTy.getSizes();
579-
if (lhsZpSize.size() == 0 ||
580-
llvm::all_of(lhsZpSize, [](int64_t d) { return d == 1; })) {
581-
lhsZp = rewriter.create<Torch::AtenItemOp>(
582-
binder.getLoc(), rewriter.getType<Torch::IntType>(), lhsZp);
583-
} else if (lhsZpSize.size() == 1) {
584-
auto lhsSize = lhsTy.getSizes();
585-
if (lhsSize.size() != 2 || lhsSize[0] != lhsZpSize[0])
586-
return failure();
587-
isChannelQuantizationForLhs = true;
588-
} else {
589-
return failure();
590-
}
591-
}
592-
593-
bool isChannelQuantizationForRhs = false;
594-
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(rhsZp.getType())) {
595-
auto rhsZpSize = zpTy.getSizes();
596-
if (rhsZpSize.size() == 0 ||
597-
llvm::all_of(rhsZpSize, [](int64_t d) { return d == 1; })) {
598-
rhsZp = rewriter.create<Torch::AtenItemOp>(
599-
binder.getLoc(), rewriter.getType<Torch::IntType>(), rhsZp);
600-
} else if (rhsZpSize.size() == 1) {
601-
auto rhsSize = rhsTy.getSizes();
602-
if (rhsSize.size() != 2 || rhsSize[1] != rhsZpSize[0])
603-
return failure();
604-
isChannelQuantizationForRhs = true;
605-
} else {
606-
return failure();
607-
}
608-
}
609-
610-
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
611-
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);
612-
613-
if (!lhsQTy || !rhsQTy)
614-
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");
615-
616-
Value f32Ty = rewriter.create<Torch::ConstantIntOp>(
617-
binder.getLoc(), rewriter.getI64IntegerAttr(
618-
(int64_t)torch_upstream::ScalarType::Float));
619-
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
620-
621-
if (isChannelQuantizationForLhs) {
622-
Value axis = rewriter.create<Torch::ConstantIntOp>(
623-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
624-
rewriter.getI64IntegerAttr(0));
625-
Torch::ValueTensorType lhsZpTy =
626-
dyn_cast<Torch::ValueTensorType>(lhsZp.getType());
627-
Type scaleTy = lhsZpTy.getWithSizesAndDtype(lhsZpTy.getSizes(),
628-
rewriter.getF32Type());
629-
Value scale = rewriter.create<Torch::AtenOnesLikeOp>(
630-
binder.getLoc(), scaleTy, /*self=*/lhsZp, f32Ty, /*layout=*/none,
631-
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
632-
lhs = rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
633-
binder.getLoc(), lhsQTy, lhs, scale, lhsZp, axis);
634-
} else {
635-
Value scale = rewriter.create<Torch::ConstantFloatOp>(
636-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
637-
rewriter.getF64FloatAttr(1.0));
638-
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
639-
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
640-
}
574+
// This op is lowered as follows:
575+
// lhs = lhs.to(dtype=torch.int32)
576+
// rhs = rhs.to(dtype=torch.int32)
577+
// lhs = lhs - lhsZp
578+
// rhs = rhs - rhsZp
579+
// res = torch.mm(lhs, rhs)
580+
581+
// Converting lhs and rhs tensor to `si32` type.
582+
lhs = Torch::convertTensorToDtype(
583+
rewriter, loc, lhs,
584+
mlir::IntegerType::get(binder.op->getContext(), 32,
585+
mlir::IntegerType::Signed));
586+
rhs = Torch::convertTensorToDtype(
587+
rewriter, loc, rhs,
588+
mlir::IntegerType::get(binder.op->getContext(), 32,
589+
mlir::IntegerType::Signed));
590+
591+
// Subtracting the zero_point values from lhs and rhs.
592+
Value alpha = rewriter.create<Torch::ConstantIntOp>(
593+
loc, rewriter.getI64IntegerAttr(1));
594+
if (auto lhsZpTy = dyn_cast<Torch::ValueTensorType>(lhsZp.getType()))
595+
lhs = rewriter.create<Torch::AtenSubTensorOp>(loc, lhs.getType(), lhs,
596+
lhsZp, alpha);
597+
else
598+
lhs = rewriter.create<Torch::AtenSubScalarOp>(loc, lhs.getType(), lhs,
599+
lhsZp, alpha);
641600

642-
if (isChannelQuantizationForRhs) {
643-
Value axis = rewriter.create<Torch::ConstantIntOp>(
644-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
645-
rewriter.getI64IntegerAttr(1));
646-
Torch::ValueTensorType rhsZpTy =
647-
dyn_cast<Torch::ValueTensorType>(rhsZp.getType());
648-
Type scaleTy = rhsZpTy.getWithSizesAndDtype(rhsZpTy.getSizes(),
649-
rewriter.getF32Type());
650-
Value scale = rewriter.create<Torch::AtenOnesLikeOp>(
651-
binder.getLoc(), scaleTy, /*self=*/rhsZp, f32Ty, /*layout=*/none,
652-
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
653-
rhs = rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
654-
binder.getLoc(), rhsQTy, rhs, scale, rhsZp, axis);
655-
} else {
656-
Value scale = rewriter.create<Torch::ConstantFloatOp>(
657-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
658-
rewriter.getF64FloatAttr(1.0));
659-
rhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
660-
binder.getLoc(), rhsQTy, rhs, scale, rhsZp);
661-
}
601+
if (auto rhsZpTy = dyn_cast<Torch::ValueTensorType>(rhsZp.getType()))
602+
rhs = rewriter.create<Torch::AtenSubTensorOp>(loc, rhs.getType(), rhs,
603+
rhsZp, alpha);
604+
else
605+
rhs = rewriter.create<Torch::AtenSubScalarOp>(loc, rhs.getType(), rhs,
606+
rhsZp, alpha);
662607

663608
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(binder.op, resultType,
664609
lhs, rhs);

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -567,14 +567,12 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt
567567
// CHECK-LABEL: @test_matmulinteger
568568
func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vtensor<[3,2],ui8>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
569569
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32>
570-
// CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
571-
// CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
572-
// CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
573-
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574-
// CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
575-
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
576-
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
577-
// CHECK: return %[[MM]]
570+
// CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[4,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,3],si32>
571+
// CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[3,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],si32>
572+
// CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[4,3],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[4,3],si32>
573+
// CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[3,2],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[3,2],si32>
574+
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[4,3],si32>, !torch.vtensor<[3,2],si32> -> !torch.vtensor<[4,2],si32>
575+
// CHECK: return %[[MM]] : !torch.vtensor<[4,2],si32>
578576
return %0 : !torch.vtensor<[4,2],si32>
579577
}
580578

@@ -583,57 +581,39 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt
583581
// CHECK-LABEL: @test_matmulinteger_batched
584582
func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1: !torch.vtensor<[3,2],ui8>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
585583
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[7,4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32>
586-
// CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
587-
// CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
588-
// CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
589-
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
590-
// CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
591-
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
592-
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
593-
// CHECK: return %[[MM]]
584+
// CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[7,4,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[7,4,3],si32>
585+
// CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[3,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],si32>
586+
// CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[7,4,3],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[7,4,3],si32>
587+
// CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[3,2],si32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[3,2],si32>
588+
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[7,4,3],si32>, !torch.vtensor<[3,2],si32> -> !torch.vtensor<[7,4,2],si32>
589+
// CHECK: return %[[MM]] : !torch.vtensor<[7,4,2],si32>
594590
return %0 : !torch.vtensor<[7,4,2],si32>
595591
}
596592

597593
// -----
598594

599595
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_lhsZp(
600-
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[16,2],ui8>,
601-
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
602-
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16],ui8>,
603-
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
604-
func.func @test_matmulinteger_non_scalar_lhsZp(%arg0: !torch.vtensor<[16, 2],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[16],ui8>, %arg3: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
605-
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],si8> -> !torch.int
606-
// CHECK: %[[VAL_5:.*]] = torch.constant.int 6
607-
// CHECK: %[[VAL_6:.*]] = torch.constant.none
608-
// CHECK: %[[VAL_7:.*]] = torch.constant.int 0
609-
// CHECK: %[[VAL_8:.*]] = torch.aten.ones_like %[[VAL_2]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[16],ui8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[16],f32>
610-
// CHECK: %[[VAL_9:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_0]], %[[VAL_8]], %[[VAL_2]], %[[VAL_7]] : !torch.vtensor<[16,2],ui8>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],ui8>, !torch.int -> !torch.vtensor<[16,2],!torch.quint8>
611-
// CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e+00
612-
// CHECK: %[[VAL_11:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_4]] : !torch.vtensor<[2,768],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
613-
// CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_9]], %[[VAL_11]] : !torch.vtensor<[16,2],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[16,768],si32>
614-
// CHECK: return %[[VAL_12]] : !torch.vtensor<[16,768],si32>
615-
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[16,2],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[16],ui8>, !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32>
596+
func.func @test_matmulinteger_non_scalar_lhsZp(%arg0: !torch.vtensor<[16, 2],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[16,1],ui8>, %arg3: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
597+
// CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[16,2],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[16,2],si32>
598+
// CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[2,768],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,768],si32>
599+
// CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[16,2],si32>, !torch.vtensor<[16,1],ui8>, !torch.int -> !torch.vtensor<[16,2],si32>
600+
// CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[2,768],si32>, !torch.vtensor<[],si8>, !torch.int -> !torch.vtensor<[2,768],si32>
601+
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[16,2],si32>, !torch.vtensor<[2,768],si32> -> !torch.vtensor<[16,768],si32>
602+
// CHECK: return %[[MM]] : !torch.vtensor<[16,768],si32>
603+
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[16,2],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[16,1],ui8>, !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32>
616604
return %0 : !torch.vtensor<[16,768],si32>
617605
}
618606

619607
// -----
620608

621609
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_rhsZp(
622-
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],ui8>,
623-
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
624-
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],ui8>,
625-
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_met
626610
func.func @test_matmulinteger_non_scalar_rhsZp(%arg0: !torch.vtensor<[?,?],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
627-
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],ui8> -> !torch.int
628-
// CHECK: %[[VAL_5:.*]] = torch.constant.int 6
629-
// CHECK: %[[VAL_6:.*]] = torch.constant.none
630-
// CHECK: %[[VAL_7:.*]] = torch.constant.float 1.000000e+00
631-
// CHECK: %[[VAL_8:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_4]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
632-
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
633-
// CHECK: %[[VAL_10:.*]] = torch.aten.ones_like %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[768],si8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[768],f32>
634-
// CHECK: %[[VAL_11:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_3]], %[[VAL_9]] : !torch.vtensor<[2,768],si8>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
635-
// CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_8]], %[[VAL_11]] : !torch.vtensor<[?,?],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[?,768],si32>
636-
// CHECK: return %[[VAL_12]] : !torch.vtensor<[?,768],si32>
611+
// CHECK: %[[LHS:.*]] = torch.aten.to.dtype %arg0, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],si32>
612+
// CHECK: %[[RHS:.*]] = torch.aten.to.dtype %arg1, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[2,768],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,768],si32>
613+
// CHECK: %[[LHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[LHS]], %arg2, %{{.+}} : !torch.vtensor<[?,?],si32>, !torch.vtensor<[],ui8>, !torch.int -> !torch.vtensor<[?,?],si32>
614+
// CHECK: %[[RHS_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[RHS]], %arg3, %{{.+}} : !torch.vtensor<[2,768],si32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],si32>
615+
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LHS_MINUS_ZP]], %[[RHS_MINUS_ZP]] : !torch.vtensor<[?,?],si32>, !torch.vtensor<[2,768],si32> -> !torch.vtensor<[?,768],si32>
616+
// CHECK: return %[[MM]] : !torch.vtensor<[?,768],si32>
637617
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[?,?],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[],ui8>, !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32>
638618
return %0 : !torch.vtensor<[?,768],si32>
639619
}

0 commit comments

Comments
 (0)