@@ -41,8 +41,10 @@ bool isValidElementType(Value val) {
41
41
// / detect whether the shapes are exactly the same or not. Hence, return false.
42
42
// / Also, check the ranks of two tensors, they must be in range of (0, 4].
43
43
bool haveSameStaticShape (Value value1, Value value2) {
44
- auto valueType1 = value1.getType ().cast <ShapedType>();
45
- auto valueType2 = value2.getType ().cast <ShapedType>();
44
+ ShapedType valueType1 = value1.getType ().cast <ShapedType>();
45
+ ShapedType valueType2 = value2.getType ().cast <ShapedType>();
46
+ if (!valueType1.hasRank () || !valueType2.hasRank ())
47
+ return false ;
46
48
// Different rank, return false.
47
49
if (valueType1.getRank () != valueType2.getRank ())
48
50
return false ;
@@ -360,48 +362,54 @@ template <>
360
362
bool isSuitableForZDNN<ONNXSoftmaxOp>(ONNXSoftmaxOp op) {
361
363
if (!isValidElementType (op.input ()))
362
364
return false ;
363
- return ((op.axis () == 1 || op.axis () == -1 ) &&
364
- (op.input ().getType ().cast <ShapedType>().getRank () == 2 ));
365
+ ShapedType inputType = op.getType ().cast <ShapedType>();
366
+ return (op.axis () == 1 || op.axis () == -1 ) && inputType.hasRank () &&
367
+ (inputType.getRank () == 2 );
365
368
}
366
369
367
370
// / Check legality for ONNXRelu.
368
371
template <>
369
372
bool isSuitableForZDNN<ONNXReluOp>(ONNXReluOp op) {
370
373
if (!isValidElementType (op.X ()))
371
374
return false ;
372
- return (op.X ().getType ().cast <ShapedType>().getRank () <= 4 );
375
+ ShapedType xType = op.X ().getType ().cast <ShapedType>();
376
+ return xType.hasRank () && (xType.getRank () <= 4 );
373
377
}
374
378
375
379
// / Check legality for ONNXTanh.
376
380
template <>
377
381
bool isSuitableForZDNN<ONNXTanhOp>(ONNXTanhOp op) {
378
382
if (!isValidElementType (op.input ()))
379
383
return false ;
380
- return (op.input ().getType ().cast <ShapedType>().getRank () <= 4 );
384
+ ShapedType inputType = op.getType ().cast <ShapedType>();
385
+ return inputType.hasRank () && (inputType.getRank () <= 4 );
381
386
}
382
387
383
388
// / Check legality for ONNXSigmoid.
384
389
template <>
385
390
bool isSuitableForZDNN<ONNXSigmoidOp>(ONNXSigmoidOp op) {
386
391
if (!isValidElementType (op.X ()))
387
392
return false ;
388
- return (op.X ().getType ().cast <ShapedType>().getRank () <= 4 );
393
+ ShapedType xType = op.X ().getType ().cast <ShapedType>();
394
+ return xType.hasRank () && (xType.getRank () <= 4 );
389
395
}
390
396
391
397
// / Check legality for ONNXLog.
392
398
template <>
393
399
bool isSuitableForZDNN<ONNXLogOp>(ONNXLogOp op) {
394
400
if (!isValidElementType (op.input ()))
395
401
return false ;
396
- return (op.input ().getType ().cast <ShapedType>().getRank () <= 4 );
402
+ ShapedType inputType = op.input ().getType ().cast <ShapedType>();
403
+ return inputType.hasRank () && (inputType.getRank () <= 4 );
397
404
}
398
405
399
406
// / Check legality for ONNXExp.
400
407
template <>
401
408
bool isSuitableForZDNN<ONNXExpOp>(ONNXExpOp op) {
402
409
if (!isValidElementType (op.input ()))
403
410
return false ;
404
- return (op.input ().getType ().cast <ShapedType>().getRank () <= 4 );
411
+ ShapedType inputType = op.input ().getType ().cast <ShapedType>();
412
+ return inputType.hasRank () && (inputType.getRank () <= 4 );
405
413
}
406
414
407
415
// / Check legality for ONNXMatMul.
0 commit comments