@@ -464,13 +464,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
464
464
if (binder.tensorOperandsList (operands) ||
465
465
binder.tensorResultType (resultType))
466
466
return failure ();
467
+
468
+ if (operands.size () != 8 )
469
+ return rewriter.notifyMatchFailure (
470
+ binder.op , " Unimplemented: expected 8 input operands" );
471
+
467
472
Value a = operands[0 ];
468
473
Value aScale = operands[1 ];
469
474
Value aZp = operands[2 ];
470
475
Value b = operands[3 ];
471
476
Value bScale = operands[4 ];
472
477
Value bZp = operands[5 ];
473
478
Value cScale = operands[6 ];
479
+ Value cZp = operands[7 ];
474
480
475
481
auto check = [](Value v) {
476
482
auto vTy = cast<Torch::ValueTensorType>(v.getType ());
@@ -480,7 +486,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
480
486
return true ;
481
487
};
482
488
if (!check (aScale) || !check (aZp) || !check (bScale) || !check (bZp) ||
483
- !check (cScale))
489
+ !check (cScale) || ! check (cZp) )
484
490
return rewriter.notifyMatchFailure (
485
491
binder.op , " Unsupported per-tensor quantization" );
486
492
@@ -508,19 +514,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
508
514
509
515
aZp = extract (aZp);
510
516
bZp = extract (bZp);
511
-
512
- Value cZp;
513
- if (operands.size () == 8 ) {
514
- cZp = operands[7 ];
515
- if (!check (cZp))
516
- return rewriter.notifyMatchFailure (
517
- binder.op ,
518
- " Unsupported c_zero_point for per-tensor quantization" );
519
- cZp = extract (cZp);
520
- } else {
521
- cZp = rewriter.create <Torch::ConstantIntOp>(
522
- loc, rewriter.getI64IntegerAttr (0 ));
523
- }
517
+ cZp = extract (cZp);
524
518
525
519
aScale = extract (aScale);
526
520
bScale = extract (bScale);
@@ -590,6 +584,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
590
584
binder.f32FloatAttr (alpha, " alpha" ))
591
585
return failure ();
592
586
587
+ if (operands.size () != 5 )
588
+ return rewriter.notifyMatchFailure (
589
+ binder.op , " Unimplemented: expected 5 input operands" );
590
+
593
591
Value x = operands[0 ];
594
592
Value xScale = operands[1 ];
595
593
Value xZp = operands[2 ];
@@ -760,6 +758,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
760
758
binder.s64IntegerAttr (channelsLast, " channels_last" ))
761
759
return failure ();
762
760
761
+ // TODO: Add support for channels_last attribute.
762
+ if (channelsLast)
763
+ return rewriter.notifyMatchFailure (
764
+ binder.op ,
765
+ " Unimplemented: support not present for channels_last attribute" );
766
+
763
767
Value x = operands[0 ];
764
768
Value xScale, xZp, yScale, yZp;
765
769
@@ -880,6 +884,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
880
884
binder.tensorResultType (resultType))
881
885
return failure ();
882
886
887
+ if (operands.size () != 5 )
888
+ return rewriter.notifyMatchFailure (
889
+ binder.op , " Unimplemented: expected 5 input operands" );
890
+
883
891
Value x = operands[0 ];
884
892
Value xScale, xZp, yScale, yZp;
885
893
@@ -946,6 +954,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
946
954
binder.op ,
947
955
" Unimplemented: support not present for channels_last attribute" );
948
956
957
+ if (operands.size () != 5 )
958
+ return rewriter.notifyMatchFailure (
959
+ binder.op , " Unimplemented: expected 5 input operands" );
960
+
949
961
Value x = operands[0 ];
950
962
Value xScale, xZp, yScale, yZp;
951
963
0 commit comments