Skip to content

Commit eaf87f3

Browse files
[ONNX] Add support for Onnx.QLinearAveragePool op (#4142)
This commit adds the Onnx->Torch lowering for [Onnx.QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) op. --------- Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 4042304 commit eaf87f3

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,4 +928,85 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
928928
y);
929929
return success();
930930
});
931+
patterns.onOp(
932+
"QLinearAveragePool", 1,
933+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
934+
Location loc = binder.getLoc();
935+
Torch::ValueTensorType resultType;
936+
llvm::SmallVector<Value> operands;
937+
int64_t channelsLast;
938+
if (binder.tensorOperandsList(operands) ||
939+
binder.tensorResultType(resultType) ||
940+
binder.s64IntegerAttr(channelsLast, "channels_last"))
941+
return failure();
942+
943+
// TODO: Add support for channels_last attribute.
944+
if (channelsLast)
945+
return rewriter.notifyMatchFailure(
946+
binder.op,
947+
"Unimplemented: support not present for channels_last attribute");
948+
949+
Value x = operands[0];
950+
Value xScale, xZp, yScale, yZp;
951+
952+
if (failed(extractPerTensorQuantizationArguments(
953+
rewriter, loc, /*scale=*/operands[1],
954+
/*zero_point=*/operands[2], xScale, xZp)))
955+
return rewriter.notifyMatchFailure(
956+
binder.op, "Incompatible arguments for per-tensor quantization");
957+
958+
if (failed(extractPerTensorQuantizationArguments(
959+
rewriter, loc, /*scale=*/operands[3],
960+
/*zero_point=*/operands[4], yScale, yZp)))
961+
return rewriter.notifyMatchFailure(
962+
binder.op, "Incompatible arguments for per-tensor quantization");
963+
964+
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
965+
if (!xTy || !xTy.hasSizes())
966+
return rewriter.notifyMatchFailure(
967+
binder.op, "Expected input argument `x` to have sizes");
968+
969+
xTy = getQTorchTypeFromTorchIntType(xTy);
970+
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
971+
loc, xTy, x, xScale, xZp);
972+
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
973+
rewriter.getF32Type());
974+
// Dequantizing the input tensor `x`.
975+
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
976+
977+
// Creating Onnx.AveragePool op.
978+
llvm::SmallVector<Value> newOperands = {x};
979+
llvm::SmallVector<NamedAttribute> newAttributes;
980+
newAttributes.push_back(rewriter.getNamedAttr(
981+
"name", rewriter.getStringAttr("onnx.AveragePool")));
982+
for (auto namedAttr : binder.op->getAttrDictionary()) {
983+
if (namedAttr.getName().getValue().compare("name") == 0)
984+
continue;
985+
newAttributes.push_back(namedAttr);
986+
}
987+
988+
auto yTy = rewriter.getType<Torch::ValueTensorType>(
989+
resultType.getOptionalSizes(), rewriter.getF32Type());
990+
Value averagePool =
991+
rewriter
992+
.create<Torch::OperatorOp>(binder.getLoc(), yTy, newOperands,
993+
newAttributes,
994+
binder.op->getRegions().size())
995+
.getResult(0);
996+
997+
// Quantizing the result of AveragePool op.
998+
yTy = dyn_cast<Torch::ValueTensorType>(
999+
getQTorchTypeFromTorchIntType(resultType));
1000+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
1001+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1002+
rewriter.getIntegerAttr(
1003+
rewriter.getIntegerType(64),
1004+
static_cast<int64_t>(
1005+
Torch::getScalarTypeForType(yTy.getDtype()))));
1006+
averagePool = rewriter.create<Torch::AtenQuantizePerTensorOp>(
1007+
loc, yTy, averagePool, yScale, yZp, dtyVal);
1008+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
1009+
averagePool);
1010+
return success();
1011+
});
9311012
}

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
454454
return success();
455455
});
456456
patterns.onOp(
457-
"AveragePool", 11,
457+
"AveragePool", 1,
458458
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
459459
std::string autoPad;
460460
SmallVector<int64_t> dilations;

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,3 +3823,29 @@ func.func @test_qlinear_sigmoid(%arg0: !torch.vtensor<[?,?],ui8>, %arg1: !torch.
38233823
// CHECK: return %[[OUT]]
38243824
return %0 : !torch.vtensor<[?,?],ui8>
38253825
}
3826+
3827+
// -----
3828+
3829+
// CHECK-LABEL: @test_qlinearAveragePool(
3830+
// CHECK-SAME: %[[X:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,128,56,56],ui8>,
3831+
// CHECK-SAME: %[[X_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3832+
// CHECK-SAME: %[[X_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3833+
// CHECK-SAME: %[[Y_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3834+
// CHECK-SAME: %[[Y_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,128,28,28],ui8>
3835+
func.func @test_qlinearAveragePool(%arg0: !torch.vtensor<[1,128,56,56],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,128,28,28],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64} {
3836+
%0 = torch.operator "onnx.QLinearAveragePool"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.auto_pad = "NOTSET", torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,128,56,56],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,128,28,28],ui8>
3837+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3838+
// CHECK-DAG: %[[XSCALE:.+]] = torch.aten.item %[[X_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3839+
// CHECK-DAG: %[[XZP:.+]] = torch.aten.item %[[X_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3840+
// CHECK-DAG: %[[EMPTY_0:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3841+
// CHECK-DAG: %[[YSCALE:.+]] = torch.aten.item %[[Y_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3842+
// CHECK-DAG: %[[YZP:.+]] = torch.aten.item %[[Y_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3843+
// CHECK-DAG: %[[X_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[X]], %[[XSCALE]], %[[XZP]] : !torch.vtensor<[1,128,56,56],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,56,56],!torch.quint8>
3844+
// CHECK: %[[X_F32:.+]] = torch.aten.dequantize.self %[[X_QUANT]] : !torch.vtensor<[1,128,56,56],!torch.quint8> -> !torch.vtensor<[1,128,56,56],f32>
3845+
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[X_F32]], %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : !torch.vtensor<[1,128,56,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128,28,28],f32>
3846+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3847+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[AVGPOOL]], %[[YSCALE]], %[[YZP]], %[[DTY]] : !torch.vtensor<[1,128,28,28],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,128,28,28],!torch.quint8>
3848+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,128,28,28],!torch.quint8> -> !torch.vtensor<[1,128,28,28],ui8>
3849+
// CHECK: return %[[OUT]]
3850+
return %0 : !torch.vtensor<[1,128,28,28],ui8>
3851+
}

0 commit comments

Comments
 (0)