Skip to content

Commit 0996500

Browse files
authored
[NNPA] Support bias none in ONNX.Gemm (#1466)
* Support for bias none in Gemm Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
1 parent 862bca6 commit 0996500

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ bool isValidElementType(Value val) {
4141
/// detect whether the shapes are exactly the same or not. Hence, return false.
4242
/// Also, check the ranks of two tensors, they must be in range of (0, 4].
4343
bool haveSameStaticShape(Value value1, Value value2) {
44-
auto valueType1 = value1.getType().cast<ShapedType>();
45-
auto valueType2 = value2.getType().cast<ShapedType>();
44+
ShapedType valueType1 = value1.getType().cast<ShapedType>();
45+
ShapedType valueType2 = value2.getType().cast<ShapedType>();
46+
if (!valueType1.hasRank() || !valueType2.hasRank())
47+
return false;
4648
// Different rank, return false.
4749
if (valueType1.getRank() != valueType2.getRank())
4850
return false;
@@ -360,48 +362,54 @@ template <>
360362
bool isSuitableForZDNN<ONNXSoftmaxOp>(ONNXSoftmaxOp op) {
361363
if (!isValidElementType(op.input()))
362364
return false;
363-
return ((op.axis() == 1 || op.axis() == -1) &&
364-
(op.input().getType().cast<ShapedType>().getRank() == 2));
365+
ShapedType inputType = op.getType().cast<ShapedType>();
366+
return (op.axis() == 1 || op.axis() == -1) && inputType.hasRank() &&
367+
(inputType.getRank() == 2);
365368
}
366369

367370
/// Check legality for ONNXRelu.
368371
template <>
369372
bool isSuitableForZDNN<ONNXReluOp>(ONNXReluOp op) {
370373
if (!isValidElementType(op.X()))
371374
return false;
372-
return (op.X().getType().cast<ShapedType>().getRank() <= 4);
375+
ShapedType xType = op.X().getType().cast<ShapedType>();
376+
return xType.hasRank() && (xType.getRank() <= 4);
373377
}
374378

375379
/// Check legality for ONNXTanh.
376380
template <>
377381
bool isSuitableForZDNN<ONNXTanhOp>(ONNXTanhOp op) {
378382
if (!isValidElementType(op.input()))
379383
return false;
380-
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
384+
ShapedType inputType = op.getType().cast<ShapedType>();
385+
return inputType.hasRank() && (inputType.getRank() <= 4);
381386
}
382387

383388
/// Check legality for ONNXSigmoid.
384389
template <>
385390
bool isSuitableForZDNN<ONNXSigmoidOp>(ONNXSigmoidOp op) {
386391
if (!isValidElementType(op.X()))
387392
return false;
388-
return (op.X().getType().cast<ShapedType>().getRank() <= 4);
393+
ShapedType xType = op.X().getType().cast<ShapedType>();
394+
return xType.hasRank() && (xType.getRank() <= 4);
389395
}
390396

391397
/// Check legality for ONNXLog.
392398
template <>
393399
bool isSuitableForZDNN<ONNXLogOp>(ONNXLogOp op) {
394400
if (!isValidElementType(op.input()))
395401
return false;
396-
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
402+
ShapedType inputType = op.input().getType().cast<ShapedType>();
403+
return inputType.hasRank() && (inputType.getRank() <= 4);
397404
}
398405

399406
/// Check legality for ONNXExp.
400407
template <>
401408
bool isSuitableForZDNN<ONNXExpOp>(ONNXExpOp op) {
402409
if (!isValidElementType(op.input()))
403410
return false;
404-
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
411+
ShapedType inputType = op.input().getType().cast<ShapedType>();
412+
return inputType.hasRank() && (inputType.getRank() <= 4);
405413
}
406414

407415
/// Check legality for ONNXMatMul.

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;
3333
def IsNotNoneType : Constraint<CPred<"(!($_self).getType().isa<NoneType>())">>;
3434

3535
class HasRankOf<int rank> : Constraint<
36-
CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>
36+
CPred<"$0.getType().isa<ShapedType>() && "
37+
"$0.getType().cast<ShapedType>().hasRank() && "
38+
"$0.getType().cast<ShapedType>().getRank() == " # rank>
39+
>;
40+
41+
def IsBiasNoneOr1D : Constraint<
42+
CPred<"$_self.getType().isa<NoneType>() || "
43+
" ($_self.getType().isa<ShapedType>() && "
44+
" $_self.getType().cast<ShapedType>().hasRank() && "
45+
" $_self.getType().cast<ShapedType>().getRank() == 1)">
3746
>;
3847

3948
class VariadicSizeIs<int N> : Constraint<
@@ -536,14 +545,15 @@ def normalizeONNXGemmTransBPattern : Pat<
536545
(addBenefit 1)
537546
>;
538547

539-
def replaceONNXGemmBias1DPattern : Pat<
548+
549+
def replaceONNXGemmBiasNoneOr1DPattern : Pat<
540550
(ONNXGemmOp $a, $b, $c, $_, $_, $_, $_),
541551
(ZHighUnstickOp
542552
(ZHighMatMulOp
543553
(ZHighStickOp $a, (_2DLayoutAttr)),
544554
(ZHighStickOp $b, (_2DLayoutAttr)),
545555
(ZHighStickOp $c, (_1DLayoutAttr)))),
546-
[(HasRankOf<1> $c)],
556+
[(IsBiasNoneOr1D:$c)],
547557
(addBenefit 0)
548558
>;
549559

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s
22

3+
func @test_gemm_bias_none(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
4+
%bias = "onnx.NoValue"() {value} : () -> none
5+
%0 ="onnx.Gemm"(%arg0, %arg1, %bias) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, none) -> tensor<*xf32>
6+
"func.return"(%0) : (tensor<*xf32>) -> ()
7+
8+
// CHECK-LABEL: func @test_gemm_bias_none
9+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x5xf32>, [[PARAM_1_:%.+]]: tensor<5x10xf32>) -> tensor<10x10xf32> {
10+
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>>
11+
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>>
12+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none
13+
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>>, tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>>, none) -> tensor<*xf32>
14+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf32>) -> tensor<10x10xf32>
15+
// CHECK: return [[VAR_4_]] : tensor<10x10xf32>
16+
// CHECK: }
17+
}
18+
19+
// -----
20+
321
func @test_gemm_bias_1d(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> {
422
%0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32>
523
"func.return"(%0) : (tensor<*xf32>) -> ()

0 commit comments

Comments
 (0)