diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 04250f63dcd25..d5b72d63813a4 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -287,34 +287,13 @@ class AsmParserImpl : public BaseT { APFloat &result) override { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); - SMLoc loc = curTok.getLoc(); - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val) - return emitError(loc, "floating point value too large"); - parser.consumeToken(Token::floatliteral); - result = APFloat(isNegative ? -*val : *val); - bool losesInfo; - result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo); - return success(); - } - - // Check for a hexadecimal float value. - if (curTok.is(Token::integer)) { - std::optional apResult; - if (failed(parser.parseFloatFromIntegerLiteral( - apResult, curTok, isNegative, semantics, - APFloat::semanticsSizeInBits(semantics)))) - return failure(); - - result = *apResult; - parser.consumeToken(Token::integer); - return success(); - } - - return emitError(loc, "expected floating point literal"); + std::optional apResult; + if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative, + semantics))) + return failure(); + parser.consumeToken(); + result = *apResult; + return success(); } /// Parse a floating point value from the stream. diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e49abc33..ff616dac9625b 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -424,8 +424,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { if (auto floatType = dyn_cast(type)) { std::optional result; if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth()))) + floatType.getFloatSemantics()))) return Attribute(); return FloatAttr::get(floatType, *result); } @@ -658,36 +657,11 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; - - // Handle hexadecimal float literals. - if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) { - std::optional result; - if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, - eltTy.getFloatSemantics(), - eltTy.getWidth()))) - return failure(); - - floatValues.push_back(*result); - continue; - } - - // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) - return p.emitError() - << "expected floating-point elements, but parsed integer"; - - // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val) - return p.emitError("floating point value too large for attribute"); - - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); + std::optional result; + if (failed(p.parseFloatFromLiteral(result, token, isNegative, + eltTy.getFloatSemantics()))) + return failure(); + floatValues.push_back(*result); } return success(); } @@ -905,32 +879,14 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { bool isNegative = p.consumeIf(Token::minus); - Token token = p.getToken(); - std::optional result; - auto floatType = cast(type); - if (p.consumeIf(Token::integer)) { - // Parse an integer literal as a float. - if (p.parseFloatFromIntegerLiteral(result, token, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth())) - return failure(); - } else if (p.consumeIf(Token::floatliteral)) { - // Parse a floating point literal. - std::optional val = token.getFloatingPointValue(); - if (!val) - return failure(); - result = APFloat(isNegative ? -*val : *val); - if (!type.isF64()) { - bool unused; - result->convert(floatType.getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - } - } else { - return p.emitError("expected integer or floating point literal"); - } - - append(result->bitcastToAPInt()); + std::optional fromIntLit; + if (failed( + p.parseFloatFromLiteral(fromIntLit, token, isNegative, + cast(type).getFloatSemantics()))) + return failure(); + p.consumeToken(); + append(fromIntLit->bitcastToAPInt()); return success(); } diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8f19487d80fa3..e3db248164672 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -347,34 +347,58 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { return success(); } +ParseResult Parser::parseFloatFromLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { + // Check for a floating point value. + if (tok.is(Token::floatliteral)) { + auto val = tok.getFloatingPointValue(); + if (!val) + return emitError(tok.getLoc()) << "floating point value too large"; + + result.emplace(isNegative ? -*val : *val); + bool unused; + result->convert(semantics, APFloat::rmNearestTiesToEven, &unused); + return success(); + } + + // Check for a hexadecimal float value. + if (tok.is(Token::integer)) + return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics); + + return emitError(tok.getLoc()) << "expected floating point literal"; +} + /// Parse a floating point value from an integer literal token. -ParseResult Parser::parseFloatFromIntegerLiteral( - std::optional &result, const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, size_t typeSizeInBits) { - SMLoc loc = tok.getLoc(); +ParseResult +Parser::parseFloatFromIntegerLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { StringRef spelling = tok.getSpelling(); bool isHex = spelling.size() > 1 && spelling[1] == 'x'; if (!isHex) { - return emitError(loc, "unexpected decimal integer literal for a " - "floating point value") + return emitError(tok.getLoc(), "unexpected decimal integer literal for a " + "floating point value") .attachNote() << "add a trailing dot to make the literal a float"; } if (isNegative) { - return emitError(loc, "hexadecimal float literal should not have a " - "leading minus"); + return emitError(tok.getLoc(), + "hexadecimal float literal should not have a " + "leading minus"); } APInt intValue; tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); - if (intValue.getActiveBits() > typeSizeInBits) - return emitError(loc, "hexadecimal float constant out of range for type"); + auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); + if (intValue.getActiveBits() > typeSizeInBits) { + return emitError(tok.getLoc(), + "hexadecimal float constant out of range for type"); + } APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), intValue.getRawData()); - result.emplace(semantics, truncatedValue); - return success(); } diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index bf91831798056..4979cfc6e69e4 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -16,6 +16,7 @@ namespace mlir { namespace detail { + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// @@ -151,11 +152,15 @@ class Parser { /// Parse an optional integer value only in decimal format from the stream. OptionalParseResult parseOptionalDecimalInteger(APInt &result); + /// Parse a floating point value from a literal. + ParseResult parseFloatFromLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + /// Parse a floating point value from an integer literal token. ParseResult parseFloatFromIntegerLiteral(std::optional &result, const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, - size_t typeSizeInBits); + const llvm::fltSemantics &semantics); /// Returns true if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir index 431c7b12b8f5f..5098fe751fd01 100644 --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () { // ----- func.func @elementsattr_floattype2() -> () { - // expected-error@+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> () } @@ -138,21 +139,22 @@ func.func @float_in_int_tensor() { // ----- func.func @float_in_bool_tensor() { - // expected-error @+1 {{expected integer elements, but parsed floating-point}} + // expected-error@below {{expected integer elements, but parsed floating-point}} "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> () } // ----- func.func @decimal_int_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> () } // ----- func.func @bool_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error @+1 {{expected floating point literal}} "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> () }