Skip to content

Commit 563331c

Browse files
[ONNX] Add support for Onnx.QLinearSigmoid op (#4140)
This commit adds the Onnx->Torch lowering for Onnx.[QLinearSigmoid](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqlinearsigmoid) op. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent e0e4c0c commit 563331c

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,4 +870,62 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
870870
avgpool);
871871
return success();
872872
});
873+
patterns.onOp(
874+
"QLinearSigmoid", 1,
875+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
876+
Location loc = binder.getLoc();
877+
Torch::ValueTensorType resultType;
878+
llvm::SmallVector<Value> operands;
879+
if (binder.tensorOperandsList(operands) ||
880+
binder.tensorResultType(resultType))
881+
return failure();
882+
883+
Value x = operands[0];
884+
Value xScale, xZp, yScale, yZp;
885+
886+
if (failed(extractPerTensorQuantizationArguments(
887+
rewriter, loc, /*scale=*/operands[1],
888+
/*zero_point=*/operands[2], xScale, xZp)))
889+
return rewriter.notifyMatchFailure(
890+
binder.op, "Incompatible arguments for per-tensor quantization");
891+
892+
if (failed(extractPerTensorQuantizationArguments(
893+
rewriter, loc, /*scale=*/operands[3],
894+
/*zero_point=*/operands[4], yScale, yZp)))
895+
return rewriter.notifyMatchFailure(
896+
binder.op, "Incompatible arguments for per-tensor quantization");
897+
898+
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
899+
if (!xTy || !xTy.hasSizes())
900+
return rewriter.notifyMatchFailure(
901+
binder.op, "Expected input argument `x` to have sizes");
902+
903+
xTy = getQTorchTypeFromTorchIntType(xTy);
904+
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
905+
loc, xTy, x, xScale, xZp);
906+
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
907+
rewriter.getF32Type());
908+
// Dequantizing the input tensor `x`.
909+
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
910+
911+
// Computing the Sigmoid result.
912+
auto yTy = rewriter.getType<Torch::ValueTensorType>(
913+
resultType.getOptionalSizes(), rewriter.getF32Type());
914+
Value y = rewriter.create<Torch::AtenSigmoidOp>(loc, yTy, x);
915+
916+
// Quantizing the result of Sigmoid op.
917+
yTy = dyn_cast<Torch::ValueTensorType>(
918+
getQTorchTypeFromTorchIntType(resultType));
919+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
920+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
921+
rewriter.getIntegerAttr(
922+
rewriter.getIntegerType(64),
923+
static_cast<int64_t>(
924+
Torch::getScalarTypeForType(yTy.getDtype()))));
925+
y = rewriter.create<Torch::AtenQuantizePerTensorOp>(loc, yTy, y, yScale,
926+
yZp, dtyVal);
927+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
928+
y);
929+
return success();
930+
});
873931
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3797,3 +3797,29 @@ func.func @test_qlinearglobalavgpool(%arg0: !torch.vtensor<[1,1000,13,13],ui8>,
37973797
// CHECK: return %[[OUT]]
37983798
return %0 : !torch.vtensor<[1,1000,1,1],ui8>
37993799
}
3800+
3801+
// -----
3802+
3803+
// CHECK-LABEL: @test_qlinear_sigmoid(
3804+
// CHECK-SAME: %[[X:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,?],ui8>,
3805+
// CHECK-SAME: %[[X_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3806+
// CHECK-SAME: %[[X_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3807+
// CHECK-SAME: %[[Y_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3808+
// CHECK-SAME: %[[Y_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,?],ui8>
3809+
func.func @test_qlinear_sigmoid(%arg0: !torch.vtensor<[?,?],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,?],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64} {
3810+
%0 = torch.operator "onnx.QLinearSigmoid"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[?,?],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,?],ui8>
3811+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3812+
// CHECK-DAG: %[[XSCALE:.+]] = torch.aten.item %[[X_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3813+
// CHECK-DAG: %[[XZP:.+]] = torch.aten.item %[[X_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3814+
// CHECK-DAG: %[[EMPTY_0:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3815+
// CHECK-DAG: %[[YSCALE:.+]] = torch.aten.item %[[Y_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3816+
// CHECK-DAG: %[[YZP:.+]] = torch.aten.item %[[Y_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3817+
// CHECK-DAG: %[[X_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[X]], %[[XSCALE]], %[[XZP]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
3818+
// CHECK: %[[X_F32:.+]] = torch.aten.dequantize.self %[[X_QUANT]] : !torch.vtensor<[?,?],!torch.quint8> -> !torch.vtensor<[?,?],f32>
3819+
// CHECK: %[[SIGMOID:.*]] = torch.aten.sigmoid %[[X_F32]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
3820+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3821+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[SIGMOID]], %[[YSCALE]], %[[YZP]], %[[DTY]] : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
3822+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[?,?],!torch.quint8> -> !torch.vtensor<[?,?],ui8>
3823+
// CHECK: return %[[OUT]]
3824+
return %0 : !torch.vtensor<[?,?],ui8>
3825+
}

0 commit comments

Comments
 (0)