Skip to content

[MLIR] Testing arith-to-emitc conversions using opaque types #137936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ bool isIntegerIndexOrOpaqueType(Type type);
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);

/// Determines whether \p type is a valid floating-point or opaque type in
/// EmitC.
bool isFloatOrOpaqueType(mlir::Type type);

/// Determines whether \p type is a valid integer or opaque type in
/// EmitC.
bool isIntegerOrOpaqueType(mlir::Type type);

/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

Expand Down
84 changes: 54 additions & 30 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {

/// Insert a cast operation to type \p ty if \p val does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
assert(emitc::isSupportedEmitCType(val.getType()));
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

Expand Down Expand Up @@ -273,7 +274,8 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = adaptor.getLhs().getType();
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand Down Expand Up @@ -328,7 +330,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
if (!opReturnType || !(emitc::isIntegerOrOpaqueType(opReturnType) ||
emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
Expand All @@ -339,7 +341,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}

Type operandType = adaptor.getIn().getType();
if (!operandType || !(isa<IntegerType>(operandType) ||
if (!operandType || !(emitc::isIntegerOrOpaqueType(operandType) ||
emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
Expand Down Expand Up @@ -433,16 +435,17 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
if (!newRetTy)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
if (!isa<IntegerType>(newRetTy)) {

if (!emitc::isIntegerOrOpaqueType(newRetTy)) {
return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
}
Type unsignedType =
adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
if (!unsignedType)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
Value lhsAdapted = adaptValueType(adaptor.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(adaptor.getRhs(), rewriter, unsignedType);

auto newDivOp =
rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
Expand All @@ -463,7 +466,8 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand Down Expand Up @@ -506,7 +510,7 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(type)) {
if (!type || !emitc::isIntegerOrOpaqueType(type)) {
return rewriter.notifyMatchFailure(
op,
"expected integer type, vector/tensor support not yet implemented");
Expand Down Expand Up @@ -546,7 +550,9 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
bool retIsOpaque = isa_and_nonnull<emitc::OpaqueType>(type);
if (!type || (!retIsOpaque && !(isa<IntegerType>(type) ||
emitc::isPointerWideType(type)))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand All @@ -572,21 +578,33 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
sizeOfCall.getResult(0));
} else {
} else if (!retIsOpaque) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could invert the condition and swap the branches.

width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
} else {
width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
emitc::OpaqueAttr::get(rhsType.getContext(),
"opaque_shift_bitwidth"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does opaque_shift_bitwidth come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If opaque types are used, the bitwidth, which is needed for the shiftOp, can't be determined. So the opaque attribute serves as a reference point for where to enter the bitwidth of the type later on.

}

Value excessCheck = rewriter.create<emitc::CmpOp>(
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);

// Any concrete value is a valid refinement of poison.
Value poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
Value poison;
if (retIsOpaque) {
poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
emitc::OpaqueAttr::get(rhsType.getContext(), "opaque_shift_poison"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, where is this defined and why is it needed?

} else {
poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
}

emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Expand Down Expand Up @@ -663,19 +681,23 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

Type actualResultType = dstType;

// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
// truncated to 0, whereas a boolean conversion would return true.
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualResultType = dstType;
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
bool dstIsOpaque = isa<emitc::OpaqueType>(dstType);
if (!dstIsOpaque) {
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
}

Value result = rewriter.create<emitc::CastOp>(
Expand All @@ -702,7 +724,9 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedIntegerType(operandType))
bool opIsOpaque = isa<emitc::OpaqueType>(operandType);

if (!(opIsOpaque || emitc::isSupportedIntegerType(operandType)))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

Expand All @@ -717,7 +741,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualOperandType = operandType;
if (isa<arith::UIToFPOp>(castOp)) {
if (!opIsOpaque && isa<arith::UIToFPOp>(castOp)) {
actualOperandType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
Expand Down Expand Up @@ -745,7 +769,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported.
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
if (!emitc::isFloatOrOpaqueType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
if (auto roundingModeOp =
Expand All @@ -759,7 +783,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedFloatType(dstType))
if (!emitc::isFloatOrOpaqueType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

Expand Down
35 changes: 34 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,42 @@ namespace {
struct ConvertArithToEmitC
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
void runOnOperation() override;

/// Applies conversion to opaque types for f80 and i80 types, both unsupported
/// in emitc. Used to test the pass with opaque types.
void populateOpaqueTypeConversions(TypeConverter &converter);
};
} // namespace

void ConvertArithToEmitC::populateOpaqueTypeConversions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why these types should be unconditionally legalized, and why only for bitwidth 80?

TypeConverter &converter) {
converter.addConversion([](Type type) -> std::optional<Type> {
if (type.isF80())
return emitc::OpaqueType::get(type.getContext(), "f80");
if (type.isInteger() && type.getIntOrFloatBitWidth() == 80)
return emitc::OpaqueType::get(type.getContext(), "i80");
return type;
});

converter.addTypeAttributeConversion(
[](Type type,
Attribute attrToConvert) -> TypeConverter::AttributeConversionResult {
if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attrToConvert)) {
if (floatAttr.getType().isF80()) {
return emitc::OpaqueAttr::get(type.getContext(), "f80");
}
return {};
}
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrToConvert)) {
if (intAttr.getType().isInteger() &&
intAttr.getType().getIntOrFloatBitWidth() == 80) {
return emitc::OpaqueAttr::get(type.getContext(), "i80");
}
}
return {};
});
}

void ConvertArithToEmitC::runOnOperation() {
ConversionTarget target(getContext());

Expand All @@ -42,8 +75,8 @@ void ConvertArithToEmitC::runOnOperation() {
RewritePatternSet patterns(&getContext());

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

populateOpaqueTypeConversions(typeConverter);
populateArithToEmitCPatterns(typeConverter, patterns);

if (failed(
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ bool mlir::emitc::isSupportedFloatType(Type type) {
return false;
}

bool mlir::emitc::isIntegerOrOpaqueType(Type type) {
return isa<emitc::OpaqueType>(type) || isSupportedIntegerType(type);
}

bool mlir::emitc::isFloatOrOpaqueType(Type type) {
return isa<emitc::OpaqueType>(type) || isSupportedFloatType(type);
}

bool mlir::emitc::isPointerWideType(Type type) {
return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
type);
Expand Down
16 changes: 0 additions & 16 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
return %t: vector<5xi32>
}

// -----
func.func @arith_cast_f80(%arg0: f80) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f80 to i32
return %t: i32
}

// -----

func.func @arith_cast_f128(%arg0: f128) -> i32 {
Expand All @@ -29,15 +22,6 @@ func.func @arith_cast_f128(%arg0: f128) -> i32 {
return %t: i32
}


// -----

func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to f80
return %t: f80
}

// -----

func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,74 @@ func.func @arith_truncf(%arg0: f64) -> f16 {

return %truncd1 : f16
}

// -----

func.func @float_opaque_conversion(%arg0: f80, %arg1: f80) {
// CHECK-LABEL: float_opaque_conversion
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f80, %[[Arg1:[^ ]*]]: f80)

// CHECK-DAG: [[arg1_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg1]] : f80 to !emitc.opaque<"f80">
// CHECK-DAG: [[arg0_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg0]] : f80 to !emitc.opaque<"f80">
// CHECK: "emitc.constant"() <{value = #emitc.opaque<"f80">}> : () -> !emitc.opaque<"f80">
%10 = arith.constant 0.0 : f80
// CHECK: emitc.add [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> !emitc.opaque<"f80">
%2 = arith.addf %arg0, %arg1 : f80
// CHECK: [[EQ:[^ ]*]] = emitc.cmp eq, [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
// CHECK: [[NotNaNArg0:[^ ]*]] = emitc.cmp eq, [[arg0_cast]], [[arg0_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
// CHECK: [[NotNaNArg1:[^ ]*]] = emitc.cmp eq, [[arg1_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
// CHECK: [[Ordered:[^ ]*]] = emitc.logical_and [[NotNaNArg0]], [[NotNaNArg1]] : i1, i1
// CHECK: emitc.logical_and [[Ordered]], [[EQ]] : i1, i1
%11 = arith.cmpf oeq, %arg0, %arg1 : f80
// CHECK: emitc.unary_minus [[arg0_cast]] : (!emitc.opaque<"f80">) -> !emitc.opaque<"f80">
%12 = arith.negf %arg0 : f80
// CHECK: [[V0:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"f80"> to ui32
// CHECK: [[V1:[^ ]*]] = emitc.cast [[V0]] : ui32 to i32
%7 = arith.fptoui %arg0 : f80 to i32
// CHECK: emitc.cast [[V1]] : i32 to !emitc.opaque<"f80">
%8 = arith.sitofp %7 : i32 to f80
// CHECK: [[trunc:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"f80"> to f32
%13 = arith.truncf %arg0 : f80 to f32
// CHECK: emitc.cast [[trunc]] : f32 to !emitc.opaque<"f80">
%15 = arith.extf %13 : f32 to f80
return
}

// -----

func.func @int_opaque_conversion(%arg0: i80, %arg1: i80, %arg2: i1) {
// CHECK-LABEL: int_opaque_conversion
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i80, %[[Arg1:[^ ]*]]: i80, %[[Arg2:[^ ]*]]: i1)

// CHECK-DAG: [[arg1_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg1]] : i80 to !emitc.opaque<"i80">
// CHECK-DAG: [[arg0_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg0]] : i80 to !emitc.opaque<"i80">
// CHECK: "emitc.constant"() <{value = #emitc.opaque<"i80">}> : () -> !emitc.opaque<"i80">
%10 = arith.constant 0 : i80
// CHECK: emitc.div [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
%3 = arith.divui %arg0, %arg1 : i80
// CHECK: emitc.add [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
%2 = arith.addi %arg0, %arg1 : i80
// CHECK: emitc.bitwise_and [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
%14 = arith.andi %arg0, %arg1 : i80
// CHECK: [[Bitwidth:[^ ]*]] = "emitc.constant"() <{value = #emitc.opaque<"opaque_shift_bitwidth">}> : () -> !emitc.opaque<"i80">
// CHECK: [[LT:[^ ]*]] = emitc.cmp lt, [[arg1_cast]], [[Bitwidth]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1
// CHECK: [[Poison:[^ ]*]] = "emitc.constant"() <{value = #emitc.opaque<"opaque_shift_poison">}> : () -> !emitc.opaque<"i80">
// CHECK: [[Exp:[^ ]*]] = emitc.expression : !emitc.opaque<"i80"> {
// CHECK: [[LShift:[^ ]*]] = emitc.bitwise_left_shift [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
// CHECK: emitc.conditional [[LT]], [[LShift]], [[Poison]] : !emitc.opaque<"i80">
// CHECK: emitc.yield {{.*}} : !emitc.opaque<"i80">
// CHECK: }
%12 = arith.shli %arg0, %arg1 : i80
// CHECK: emitc.cmp eq, [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1
%11 = arith.cmpi eq, %arg0, %arg1 : i80
// CHECK: emitc.conditional %[[Arg2]], [[arg0_cast]], [[arg1_cast]] : !emitc.opaque<"i80">
%13 = arith.select %arg2, %arg0, %arg1 : i80
// CHECK: [[V0:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"i80"> to ui8
// CHECK: emitc.cast [[V0]] : ui8 to i8
%15 = arith.trunci %arg0 : i80 to i8
// CHECK: [[V1:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"i80"> to f32
%9 = arith.uitofp %arg0 : i80 to f32
// CHECK: emitc.cast [[V1]] : f32 to !emitc.opaque<"i80">
%6 = arith.fptosi %9 : f32 to i80
return
}