From 093da29f4a33c9baca81c03cb7482d34a5283f69 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 14 Nov 2024 06:55:21 +0100 Subject: [PATCH 1/4] [mlir][Parser][NFC] Make `parseFloatFromIntegerLiteral` a standalone function --- mlir/lib/AsmParser/AsmParserImpl.h | 13 +++--- mlir/lib/AsmParser/AttributeParser.cpp | 24 +++++----- mlir/lib/AsmParser/Parser.cpp | 63 +++++++++++++------------- mlir/lib/AsmParser/Parser.h | 12 ++--- 4 files changed, 57 insertions(+), 55 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 04250f63dcd25..1e6cbc0ec51be 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -287,13 +287,13 @@ class AsmParserImpl : public BaseT { APFloat &result) override { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); - SMLoc loc = curTok.getLoc(); + auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); }; // Check for a floating point value. if (curTok.is(Token::floatliteral)) { auto val = curTok.getFloatingPointValue(); if (!val) - return emitError(loc, "floating point value too large"); + return emitErrorAtTok() << "floating point value too large"; parser.consumeToken(Token::floatliteral); result = APFloat(isNegative ? -*val : *val); bool losesInfo; @@ -303,10 +303,9 @@ class AsmParserImpl : public BaseT { // Check for a hexadecimal float value. if (curTok.is(Token::integer)) { - std::optional apResult; - if (failed(parser.parseFloatFromIntegerLiteral( - apResult, curTok, isNegative, semantics, - APFloat::semanticsSizeInBits(semantics)))) + FailureOr apResult = parseFloatFromIntegerLiteral( + emitErrorAtTok, curTok, isNegative, semantics); + if (failed(apResult)) return failure(); result = *apResult; @@ -314,7 +313,7 @@ class AsmParserImpl : public BaseT { return success(); } - return emitError(loc, "expected floating point literal"); + return emitErrorAtTok() << "expected floating point literal"; } /// Parse a floating point value from the stream. diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e49abc33..ba9be3b030453 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -422,10 +422,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } if (auto floatType = dyn_cast(type)) { - std::optional result; - if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth()))) + auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); }; + FailureOr result = parseFloatFromIntegerLiteral( + emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics()); + if (failed(result)) return Attribute(); return FloatAttr::get(floatType, *result); } @@ -661,10 +661,10 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, // Handle hexadecimal float literals. if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) { - std::optional result; - if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, - eltTy.getFloatSemantics(), - eltTy.getWidth()))) + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr result = parseFloatFromIntegerLiteral( + emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); + if (failed(result)) return failure(); floatValues.push_back(*result); @@ -911,10 +911,12 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { auto floatType = cast(type); if (p.consumeIf(Token::integer)) { // Parse an integer literal as a float. - if (p.parseFloatFromIntegerLiteral(result, token, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth())) + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr fromIntLit = parseFloatFromIntegerLiteral( + emitErrorAtTok, token, isNegative, floatType.getFloatSemantics()); + if (failed(fromIntLit)) return failure(); + result = *fromIntLit; } else if (p.consumeIf(Token::floatliteral)) { // Parse a floating point literal. std::optional val = token.getFloatingPointValue(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8f19487d80fa3..ac7eec931b125 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -67,6 +67,38 @@ using namespace mlir; using namespace mlir::detail; +/// Parse a floating point value from an integer literal token. +FailureOr detail::parseFloatFromIntegerLiteral( + function_ref emitError, const Token &tok, + bool isNegative, const llvm::fltSemantics &semantics) { + StringRef spelling = tok.getSpelling(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (!isHex) { + auto error = emitError(); + error << "unexpected decimal integer literal for a " + "floating point value"; + error.attachNote() << "add a trailing dot to make the literal a float"; + return failure(); + } + if (isNegative) { + emitError() << "hexadecimal float literal should not have a " + "leading minus"; + return failure(); + } + + APInt intValue; + tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); + auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); + if (intValue.getActiveBits() > typeSizeInBits) { + return emitError() << "hexadecimal float constant out of range for type"; + return failure(); + } + + APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), + intValue.getRawData()); + return APFloat(semantics, truncatedValue); +} + //===----------------------------------------------------------------------===// // CodeComplete //===----------------------------------------------------------------------===// @@ -347,37 +379,6 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { return success(); } -/// Parse a floating point value from an integer literal token. -ParseResult Parser::parseFloatFromIntegerLiteral( - std::optional &result, const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, size_t typeSizeInBits) { - SMLoc loc = tok.getLoc(); - StringRef spelling = tok.getSpelling(); - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - if (!isHex) { - return emitError(loc, "unexpected decimal integer literal for a " - "floating point value") - .attachNote() - << "add a trailing dot to make the literal a float"; - } - if (isNegative) { - return emitError(loc, "hexadecimal float literal should not have a " - "leading minus"); - } - - APInt intValue; - tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); - if (intValue.getActiveBits() > typeSizeInBits) - return emitError(loc, "hexadecimal float constant out of range for type"); - - APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), - intValue.getRawData()); - - result.emplace(semantics, truncatedValue); - - return success(); -} - ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { // Check that the current token is a keyword. if (!isCurrentTokenAKeyword()) diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index bf91831798056..fa29264ffe506 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -16,6 +16,12 @@ namespace mlir { namespace detail { +/// Parse a floating point value from an integer literal token. +FailureOr +parseFloatFromIntegerLiteral(function_ref emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// @@ -151,12 +157,6 @@ class Parser { /// Parse an optional integer value only in decimal format from the stream. OptionalParseResult parseOptionalDecimalInteger(APInt &result); - /// Parse a floating point value from an integer literal token. - ParseResult parseFloatFromIntegerLiteral(std::optional &result, - const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, - size_t typeSizeInBits); - /// Returns true if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { return getToken().isAny(Token::bare_identifier, Token::inttype) || From 94bcf6f9dd075c314dfc6d2bb5dbde00366f52e0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 14 Nov 2024 07:43:08 +0100 Subject: [PATCH 2/4] [mlir][Parser] Deduplicate fp parsing functionality --- mlir/lib/AsmParser/AsmParserImpl.h | 33 ++------- mlir/lib/AsmParser/AttributeParser.cpp | 71 ++++---------------- mlir/lib/AsmParser/Parser.cpp | 23 +++++++ mlir/lib/AsmParser/Parser.h | 6 ++ mlir/test/IR/invalid-builtin-attributes.mlir | 10 +-- 5 files changed, 56 insertions(+), 87 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 1e6cbc0ec51be..bbd70d5980f8f 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); }; - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val) - return emitErrorAtTok() << "floating point value too large"; - parser.consumeToken(Token::floatliteral); - result = APFloat(isNegative ? -*val : *val); - bool losesInfo; - result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo); - return success(); - } - - // Check for a hexadecimal float value. - if (curTok.is(Token::integer)) { - FailureOr apResult = parseFloatFromIntegerLiteral( - emitErrorAtTok, curTok, isNegative, semantics); - if (failed(apResult)) - return failure(); - - result = *apResult; - parser.consumeToken(Token::integer); - return success(); - } - - return emitErrorAtTok() << "expected floating point literal"; + FailureOr apResult = + parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics); + if (failed(apResult)) + return failure(); + parser.consumeToken(); + result = *apResult; + return success(); } /// Parse a floating point value from the stream. diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index ba9be3b030453..9ebada076cd04 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; - - // Handle hexadecimal float literals. - if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) { - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr result = parseFloatFromIntegerLiteral( - emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); - if (failed(result)) - return failure(); - - floatValues.push_back(*result); - continue; - } - - // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) - return p.emitError() - << "expected floating-point elements, but parsed integer"; - - // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val) - return p.emitError("floating point value too large for attribute"); - - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr result = parseFloatFromLiteral( + emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); + if (failed(result)) + return failure(); + floatValues.push_back(*result); } return success(); } @@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { bool isNegative = p.consumeIf(Token::minus); - Token token = p.getToken(); - std::optional result; - auto floatType = cast(type); - if (p.consumeIf(Token::integer)) { - // Parse an integer literal as a float. - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr fromIntLit = parseFloatFromIntegerLiteral( - emitErrorAtTok, token, isNegative, floatType.getFloatSemantics()); - if (failed(fromIntLit)) - return failure(); - result = *fromIntLit; - } else if (p.consumeIf(Token::floatliteral)) { - // Parse a floating point literal. - std::optional val = token.getFloatingPointValue(); - if (!val) - return failure(); - result = APFloat(isNegative ? -*val : *val); - if (!type.isF64()) { - bool unused; - result->convert(floatType.getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - } - } else { - return p.emitError("expected integer or floating point literal"); - } - - append(result->bitcastToAPInt()); + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr fromIntLit = + parseFloatFromLiteral(emitErrorAtTok, token, isNegative, + cast(type).getFloatSemantics()); + if (failed(fromIntLit)) + return failure(); + p.consumeToken(); + append(fromIntLit->bitcastToAPInt()); return success(); } diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index ac7eec931b125..15f3dd7a66c35 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -99,6 +99,29 @@ FailureOr detail::parseFloatFromIntegerLiteral( return APFloat(semantics, truncatedValue); } +FailureOr +detail::parseFloatFromLiteral(function_ref emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { + // Check for a floating point value. + if (tok.is(Token::floatliteral)) { + auto val = tok.getFloatingPointValue(); + if (!val) + return emitError() << "floating point value too large"; + + APFloat result(isNegative ? -*val : *val); + bool unused; + result.convert(semantics, APFloat::rmNearestTiesToEven, &unused); + return result; + } + + // Check for a hexadecimal float value. + if (tok.is(Token::integer)) + return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics); + + return emitError() << "expected floating point literal"; +} + //===----------------------------------------------------------------------===// // CodeComplete //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index fa29264ffe506..ab445476a9192 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref emitError, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics); +/// Parse a floating point value from a literal. +FailureOr +parseFloatFromLiteral(function_ref emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir index 431c7b12b8f5f..5098fe751fd01 100644 --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () { // ----- func.func @elementsattr_floattype2() -> () { - // expected-error@+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> () } @@ -138,21 +139,22 @@ func.func @float_in_int_tensor() { // ----- func.func @float_in_bool_tensor() { - // expected-error @+1 {{expected integer elements, but parsed floating-point}} + // expected-error@below {{expected integer elements, but parsed floating-point}} "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> () } // ----- func.func @decimal_int_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> () } // ----- func.func @bool_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error @+1 {{expected floating point literal}} "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> () } From 6a59007385b335a54739d4b060103384b6c54d20 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 18 Nov 2024 09:06:22 +0100 Subject: [PATCH 3/4] address comments --- mlir/lib/AsmParser/AsmParserImpl.h | 3 +- mlir/lib/AsmParser/AttributeParser.cpp | 14 ++-- mlir/lib/AsmParser/Parser.cpp | 110 ++++++++++++------------- mlir/lib/AsmParser/Parser.h | 20 ++--- 4 files changed, 70 insertions(+), 77 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index bbd70d5980f8f..d9d49d53a407d 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -287,9 +287,8 @@ class AsmParserImpl : public BaseT { APFloat &result) override { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); - auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); }; FailureOr apResult = - parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics); + parser.parseFloatFromLiteral(curTok, isNegative, semantics); if (failed(apResult)) return failure(); parser.consumeToken(); diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 9ebada076cd04..0df3d492f411e 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -422,9 +422,8 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } if (auto floatType = dyn_cast(type)) { - auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); }; FailureOr result = parseFloatFromIntegerLiteral( - emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics()); + tok, isNegative, floatType.getFloatSemantics()); if (failed(result)) return Attribute(); return FloatAttr::get(floatType, *result); @@ -658,9 +657,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr result = parseFloatFromLiteral( - emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); + FailureOr result = + p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics()); if (failed(result)) return failure(); floatValues.push_back(*result); @@ -882,10 +880,8 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { bool isNegative = p.consumeIf(Token::minus); Token token = p.getToken(); - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr fromIntLit = - parseFloatFromLiteral(emitErrorAtTok, token, isNegative, - cast(type).getFloatSemantics()); + FailureOr fromIntLit = p.parseFloatFromLiteral( + token, isNegative, cast(type).getFloatSemantics()); if (failed(fromIntLit)) return failure(); p.consumeToken(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 15f3dd7a66c35..e10b87ad43b13 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -67,61 +67,6 @@ using namespace mlir; using namespace mlir::detail; -/// Parse a floating point value from an integer literal token. -FailureOr detail::parseFloatFromIntegerLiteral( - function_ref emitError, const Token &tok, - bool isNegative, const llvm::fltSemantics &semantics) { - StringRef spelling = tok.getSpelling(); - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - if (!isHex) { - auto error = emitError(); - error << "unexpected decimal integer literal for a " - "floating point value"; - error.attachNote() << "add a trailing dot to make the literal a float"; - return failure(); - } - if (isNegative) { - emitError() << "hexadecimal float literal should not have a " - "leading minus"; - return failure(); - } - - APInt intValue; - tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); - auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); - if (intValue.getActiveBits() > typeSizeInBits) { - return emitError() << "hexadecimal float constant out of range for type"; - return failure(); - } - - APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), - intValue.getRawData()); - return APFloat(semantics, truncatedValue); -} - -FailureOr -detail::parseFloatFromLiteral(function_ref emitError, - const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics) { - // Check for a floating point value. - if (tok.is(Token::floatliteral)) { - auto val = tok.getFloatingPointValue(); - if (!val) - return emitError() << "floating point value too large"; - - APFloat result(isNegative ? -*val : *val); - bool unused; - result.convert(semantics, APFloat::rmNearestTiesToEven, &unused); - return result; - } - - // Check for a hexadecimal float value. - if (tok.is(Token::integer)) - return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics); - - return emitError() << "expected floating point literal"; -} - //===----------------------------------------------------------------------===// // CodeComplete //===----------------------------------------------------------------------===// @@ -402,6 +347,61 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { return success(); } +FailureOr +Parser::parseFloatFromLiteral(const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { + // Check for a floating point value. + if (tok.is(Token::floatliteral)) { + auto val = tok.getFloatingPointValue(); + if (!val) + return emitError(tok.getLoc()) << "floating point value too large"; + + APFloat result(isNegative ? -*val : *val); + bool unused; + result.convert(semantics, APFloat::rmNearestTiesToEven, &unused); + return result; + } + + // Check for a hexadecimal float value. + if (tok.is(Token::integer)) + return parseFloatFromIntegerLiteral(tok, isNegative, semantics); + + return emitError(tok.getLoc()) << "expected floating point literal"; +} + +/// Parse a floating point value from an integer literal token. +FailureOr +Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { + StringRef spelling = tok.getSpelling(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (!isHex) { + auto error = emitError(tok.getLoc()); + error << "unexpected decimal integer literal for a " + "floating point value"; + error.attachNote() << "add a trailing dot to make the literal a float"; + return failure(); + } + if (isNegative) { + emitError(tok.getLoc()) << "hexadecimal float literal should not have a " + "leading minus"; + return failure(); + } + + APInt intValue; + tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); + auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); + if (intValue.getActiveBits() > typeSizeInBits) { + return emitError(tok.getLoc()) + << "hexadecimal float constant out of range for type"; + return failure(); + } + + APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), + intValue.getRawData()); + return APFloat(semantics, truncatedValue); +} + ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { // Check that the current token is a keyword. if (!isCurrentTokenAKeyword()) diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index ab445476a9192..15c4990de5834 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -16,17 +16,6 @@ namespace mlir { namespace detail { -/// Parse a floating point value from an integer literal token. -FailureOr -parseFloatFromIntegerLiteral(function_ref emitError, - const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics); - -/// Parse a floating point value from a literal. -FailureOr -parseFloatFromLiteral(function_ref emitError, - const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics); //===----------------------------------------------------------------------===// // Parser @@ -163,6 +152,15 @@ class Parser { /// Parse an optional integer value only in decimal format from the stream. OptionalParseResult parseOptionalDecimalInteger(APInt &result); + /// Parse a floating point value from a literal. + FailureOr parseFloatFromLiteral(const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + + /// Parse a floating point value from an integer literal token. + FailureOr + parseFloatFromIntegerLiteral(const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + /// Returns true if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { return getToken().isAny(Token::bare_identifier, Token::inttype) || From 1f81180b91597f0813d46cb2f5b422d7678a2a3c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 18 Nov 2024 09:16:19 +0100 Subject: [PATCH 4/4] address comments 2 --- mlir/lib/AsmParser/AsmParserImpl.h | 6 ++-- mlir/lib/AsmParser/AttributeParser.cpp | 19 ++++++------ mlir/lib/AsmParser/Parser.cpp | 42 +++++++++++++------------- mlir/lib/AsmParser/Parser.h | 11 ++++--- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index d9d49d53a407d..d5b72d63813a4 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -287,9 +287,9 @@ class AsmParserImpl : public BaseT { APFloat &result) override { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); - FailureOr apResult = - parser.parseFloatFromLiteral(curTok, isNegative, semantics); - if (failed(apResult)) + std::optional apResult; + if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative, + semantics))) return failure(); parser.consumeToken(); result = *apResult; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 0df3d492f411e..ff616dac9625b 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -422,9 +422,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } if (auto floatType = dyn_cast(type)) { - FailureOr result = parseFloatFromIntegerLiteral( - tok, isNegative, floatType.getFloatSemantics()); - if (failed(result)) + std::optional result; + if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, + floatType.getFloatSemantics()))) return Attribute(); return FloatAttr::get(floatType, *result); } @@ -657,9 +657,9 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; - FailureOr result = - p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics()); - if (failed(result)) + std::optional result; + if (failed(p.parseFloatFromLiteral(result, token, isNegative, + eltTy.getFloatSemantics()))) return failure(); floatValues.push_back(*result); } @@ -880,9 +880,10 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { bool isNegative = p.consumeIf(Token::minus); Token token = p.getToken(); - FailureOr fromIntLit = p.parseFloatFromLiteral( - token, isNegative, cast(type).getFloatSemantics()); - if (failed(fromIntLit)) + std::optional fromIntLit; + if (failed( + p.parseFloatFromLiteral(fromIntLit, token, isNegative, + cast(type).getFloatSemantics()))) return failure(); p.consumeToken(); append(fromIntLit->bitcastToAPInt()); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index e10b87ad43b13..e3db248164672 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -347,59 +347,59 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { return success(); } -FailureOr -Parser::parseFloatFromLiteral(const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics) { +ParseResult Parser::parseFloatFromLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { // Check for a floating point value. if (tok.is(Token::floatliteral)) { auto val = tok.getFloatingPointValue(); if (!val) return emitError(tok.getLoc()) << "floating point value too large"; - APFloat result(isNegative ? -*val : *val); + result.emplace(isNegative ? -*val : *val); bool unused; - result.convert(semantics, APFloat::rmNearestTiesToEven, &unused); - return result; + result->convert(semantics, APFloat::rmNearestTiesToEven, &unused); + return success(); } // Check for a hexadecimal float value. if (tok.is(Token::integer)) - return parseFloatFromIntegerLiteral(tok, isNegative, semantics); + return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics); return emitError(tok.getLoc()) << "expected floating point literal"; } /// Parse a floating point value from an integer literal token. -FailureOr -Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative, +ParseResult +Parser::parseFloatFromIntegerLiteral(std::optional &result, + const Token &tok, bool isNegative, const llvm::fltSemantics &semantics) { StringRef spelling = tok.getSpelling(); bool isHex = spelling.size() > 1 && spelling[1] == 'x'; if (!isHex) { - auto error = emitError(tok.getLoc()); - error << "unexpected decimal integer literal for a " - "floating point value"; - error.attachNote() << "add a trailing dot to make the literal a float"; - return failure(); + return emitError(tok.getLoc(), "unexpected decimal integer literal for a " + "floating point value") + .attachNote() + << "add a trailing dot to make the literal a float"; } if (isNegative) { - emitError(tok.getLoc()) << "hexadecimal float literal should not have a " - "leading minus"; - return failure(); + return emitError(tok.getLoc(), + "hexadecimal float literal should not have a " + "leading minus"); } APInt intValue; tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); if (intValue.getActiveBits() > typeSizeInBits) { - return emitError(tok.getLoc()) - << "hexadecimal float constant out of range for type"; - return failure(); + return emitError(tok.getLoc(), + "hexadecimal float constant out of range for type"); } APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), intValue.getRawData()); - return APFloat(semantics, truncatedValue); + result.emplace(semantics, truncatedValue); + return success(); } ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index 15c4990de5834..4979cfc6e69e4 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -153,13 +153,14 @@ class Parser { OptionalParseResult parseOptionalDecimalInteger(APInt &result); /// Parse a floating point value from a literal. - FailureOr parseFloatFromLiteral(const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics); + ParseResult parseFloatFromLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); /// Parse a floating point value from an integer literal token. - FailureOr - parseFloatFromIntegerLiteral(const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics); + ParseResult parseFloatFromIntegerLiteral(std::optional &result, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); /// Returns true if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const {