Skip to content

Commit 1a45fbc

Browse files
[ONNX] Add required checks for com.microsoft.QLinear* ops (#4156)
Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 7590515 commit 1a45fbc

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
464464
if (binder.tensorOperandsList(operands) ||
465465
binder.tensorResultType(resultType))
466466
return failure();
467+
468+
if (operands.size() != 8)
469+
return rewriter.notifyMatchFailure(
470+
binder.op, "Unimplemented: expected 8 input operands");
471+
467472
Value a = operands[0];
468473
Value aScale = operands[1];
469474
Value aZp = operands[2];
470475
Value b = operands[3];
471476
Value bScale = operands[4];
472477
Value bZp = operands[5];
473478
Value cScale = operands[6];
479+
Value cZp = operands[7];
474480

475481
auto check = [](Value v) {
476482
auto vTy = cast<Torch::ValueTensorType>(v.getType());
@@ -480,7 +486,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
480486
return true;
481487
};
482488
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
483-
!check(cScale))
489+
!check(cScale) || !check(cZp))
484490
return rewriter.notifyMatchFailure(
485491
binder.op, "Unsupported per-tensor quantization");
486492

@@ -508,19 +514,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
508514

509515
aZp = extract(aZp);
510516
bZp = extract(bZp);
511-
512-
Value cZp;
513-
if (operands.size() == 8) {
514-
cZp = operands[7];
515-
if (!check(cZp))
516-
return rewriter.notifyMatchFailure(
517-
binder.op,
518-
"Unsupported c_zero_point for per-tensor quantization");
519-
cZp = extract(cZp);
520-
} else {
521-
cZp = rewriter.create<Torch::ConstantIntOp>(
522-
loc, rewriter.getI64IntegerAttr(0));
523-
}
517+
cZp = extract(cZp);
524518

525519
aScale = extract(aScale);
526520
bScale = extract(bScale);
@@ -590,6 +584,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
590584
binder.f32FloatAttr(alpha, "alpha"))
591585
return failure();
592586

587+
if (operands.size() != 5)
588+
return rewriter.notifyMatchFailure(
589+
binder.op, "Unimplemented: expected 5 input operands");
590+
593591
Value x = operands[0];
594592
Value xScale = operands[1];
595593
Value xZp = operands[2];
@@ -760,6 +758,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
760758
binder.s64IntegerAttr(channelsLast, "channels_last"))
761759
return failure();
762760

761+
// TODO: Add support for channels_last attribute.
762+
if (channelsLast)
763+
return rewriter.notifyMatchFailure(
764+
binder.op,
765+
"Unimplemented: support not present for channels_last attribute");
766+
763767
Value x = operands[0];
764768
Value xScale, xZp, yScale, yZp;
765769

@@ -880,6 +884,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
880884
binder.tensorResultType(resultType))
881885
return failure();
882886

887+
if (operands.size() != 5)
888+
return rewriter.notifyMatchFailure(
889+
binder.op, "Unimplemented: expected 5 input operands");
890+
883891
Value x = operands[0];
884892
Value xScale, xZp, yScale, yZp;
885893

@@ -946,6 +954,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
946954
binder.op,
947955
"Unimplemented: support not present for channels_last attribute");
948956

957+
if (operands.size() != 5)
958+
return rewriter.notifyMatchFailure(
959+
binder.op, "Unimplemented: expected 5 input operands");
960+
949961
Value x = operands[0];
950962
Value xScale, xZp, yScale, yZp;
951963

0 commit comments

Comments
 (0)