diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index d660292478b19..59e2666ebd95e 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -24,21 +24,26 @@ class Complex_Op traits = []> // one result, all of which must be complex numbers of the same type. class ComplexArithmeticOp traits = []> : Complex_Op]> { - let arguments = (ins Complex:$lhs, Complex:$rhs, DefaultValuedAttr< - Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); + Elementwise, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let arguments = (ins + Complex:$lhs, + Complex:$rhs, + DefaultValuedAttr:$fastmath, + DefaultValuedAttr:$overflowFlags); let results = (outs Complex:$result); - let assemblyFormat = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result)"; + let assemblyFormat = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? (`overflow` `` $overflowFlags^)? attr-dict `:` type($result)"; } // Base class for standard unary operations on complex numbers with a // floating-point element type. These operations take one operand and return // one result; the operand must be a complex number. class ComplexUnaryOp traits = []> : - Complex_Op]> { - let arguments = (ins Complex:$complex, DefaultValuedAttr< - Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); - let assemblyFormat = "$complex (`fastmath` `` $fastmath^)? attr-dict `:` type($complex)"; + Complex_Op, DeclareOpInterfaceMethods]> { + let arguments = (ins Complex:$complex, + DefaultValuedAttr:$fastmath, + DefaultValuedAttr:$overflowFlags); + let assemblyFormat = "$complex (`fastmath` `` $fastmath^)? (`overflow` `` $overflowFlags^)? attr-dict `:` type($complex)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir index 96f17b2898c83..597bd00feb801 100644 --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -89,5 +89,11 @@ func.func @ops(%f: f32) { // CHECK: complex.bitcast %[[C]] %i64 = complex.bitcast %complex : complex to i64 + // CHECK: complex.add %[[C]], %[[C]] overflow : complex + %add_intflags = complex.add %complex, %complex overflow : complex + + // CHECK: complex.sub %[[C]], %[[C]] overflow : complex + %sub_intflags = complex.sub %complex, %complex overflow : complex + return }