Skip to content

Commit 4e9c3ce

Browse files
authored
[mlir][tosa] Improve invalid operator data types error message (#140756)
The error message on invalid operator data types in the validation pass was not very clear. This commit improves the error message as follows: Current: ``` 'tosa.add' op illegal: operand/result data types not supported ``` Improved: ``` 'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification. ```
1 parent 758fea0 commit 4e9c3ce

File tree

4 files changed

+68
-7
lines changed

4 files changed

+68
-7
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ class TosaProfileCompliance {
164164
SmallVector<StringRef>
165165
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
166166

167+
static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);
168+
167169
private:
168170
template <typename T>
169171
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,52 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
485485
CheckCondition condition = CheckCondition::invalid;
486486
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
487487
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
488+
488489
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
489-
!maybeProfDef.value().size() && !maybeExtDef.value().size())
490+
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
491+
std::string message;
492+
llvm::raw_string_ostream os(message);
493+
os << "illegal: operation operand/result data types did not align with any "
494+
"profile or extension, got (";
495+
496+
ProfileInfoDepot depot(op);
497+
SmallVector<TypeInfo> current = depot.getInfo();
498+
for (const auto &typeInfo : llvm::drop_end(current))
499+
os << stringifyTypeInfo(typeInfo) << ",";
500+
os << stringifyTypeInfo(current.back()) << ")";
501+
502+
// avoid polluting the error message output by outputting only
503+
// the best match
504+
const std::string opName = op->getName().getStringRef().str();
505+
int maxMatches = -1;
506+
SmallVector<TypeInfo> bestTypeInfo;
507+
const auto searchBestMatch = [&](auto map) {
508+
for (const auto &complianceInfos : map[opName]) {
509+
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
510+
const int matches = llvm::count_if(
511+
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
512+
return isSameTypeInfo(std::get<0>(zipType),
513+
std::get<1>(zipType));
514+
});
515+
if (matches > maxMatches) {
516+
maxMatches = matches;
517+
bestTypeInfo = typeInfos;
518+
}
519+
}
520+
}
521+
};
522+
searchBestMatch(getProfileComplianceMap<Profile>());
523+
searchBestMatch(getProfileComplianceMap<Extension>());
524+
525+
os << ", did you mean (";
526+
for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
527+
os << stringifyTypeInfo(typeInfo) << ",";
528+
os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
529+
os << "Otherwise, please refer to the 'supported data types' for '"
530+
<< opName << "' in the specification.";
531+
op->emitOpError(message);
490532
return failure();
533+
}
491534

492535
return success();
493536
}
@@ -562,3 +605,21 @@ SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
562605

563606
return debugStrings;
564607
}
608+
609+
llvm::SmallString<7>
610+
TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
611+
if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
612+
return {"i" + llvm::utostr(typeInfo.bitWidth)};
613+
} else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
614+
return {"f16"};
615+
} else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
616+
return {"f32"};
617+
} else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
618+
return {"bf16"};
619+
} else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
620+
return {"fp8e4m3"};
621+
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
622+
return {"fp8e5m2"};
623+
}
624+
llvm_unreachable("unknown type");
625+
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,10 +1248,8 @@ void TosaValidation::runOnOperation() {
12481248
return signalPassFailure();
12491249

12501250
if (!allowInvalidOpDatatypeCombinations &&
1251-
failed(profileComp.checkInvalid(op))) {
1252-
op->emitOpError("illegal: operand/result data types not supported");
1251+
failed(profileComp.checkInvalid(op)))
12531252
return signalPassFailure();
1254-
}
12551253

12561254
// Some uses of TOSA rely on the constant operands of particular
12571255
// operations.

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2:
3535

3636
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
3737
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
38-
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
38+
// expected-error@+1 {{'tosa.conv2d' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i8,i8,i8,i32,i8), did you mean (i8,i8,i32,i8,i8,i32,i32)?}}
3939
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
4040
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
4141
return %0 : tensor<1x27x27x16xi8>
@@ -1888,7 +1888,7 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
18881888

18891889
// CHECK-LABEL: test_add_i1
18901890
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
1891-
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
1891+
// expected-error@+1 {{'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification.}}
18921892
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
18931893
return %0 : tensor<13x21x3xi1>
18941894
}
@@ -1897,7 +1897,7 @@ func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) ->
18971897

18981898
// CHECK-LABEL: test_mul_out_i16
18991899
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
1900-
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
1900+
// expected-error@+1 {{'tosa.mul' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i16), did you mean (i8,i8,i32)?}}
19011901
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
19021902
return %0 : tensor<13x21x3xi16>
19031903
}

0 commit comments

Comments
 (0)