From 07837e3ca122d1c1fa7448797addb3839e8b7a45 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 14 Nov 2024 06:55:21 +0100 Subject: [PATCH] [mlir][Parser][NFC] Make `parseFloatFromIntegerLiteral` a standalone function --- mlir/lib/AsmParser/AsmParserImpl.h | 13 +++--- mlir/lib/AsmParser/AttributeParser.cpp | 24 +++++----- mlir/lib/AsmParser/Parser.cpp | 63 +++++++++++++------------- mlir/lib/AsmParser/Parser.h | 12 ++--- 4 files changed, 57 insertions(+), 55 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 04250f63dcd25..1e6cbc0ec51be 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -287,13 +287,13 @@ class AsmParserImpl : public BaseT { APFloat &result) override { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); - SMLoc loc = curTok.getLoc(); + auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); }; // Check for a floating point value. if (curTok.is(Token::floatliteral)) { auto val = curTok.getFloatingPointValue(); if (!val) - return emitError(loc, "floating point value too large"); + return emitErrorAtTok() << "floating point value too large"; parser.consumeToken(Token::floatliteral); result = APFloat(isNegative ? -*val : *val); bool losesInfo; @@ -303,10 +303,9 @@ class AsmParserImpl : public BaseT { // Check for a hexadecimal float value. if (curTok.is(Token::integer)) { - std::optional apResult; - if (failed(parser.parseFloatFromIntegerLiteral( - apResult, curTok, isNegative, semantics, - APFloat::semanticsSizeInBits(semantics)))) + FailureOr apResult = parseFloatFromIntegerLiteral( + emitErrorAtTok, curTok, isNegative, semantics); + if (failed(apResult)) return failure(); result = *apResult; @@ -314,7 +313,7 @@ class AsmParserImpl : public BaseT { return success(); } - return emitError(loc, "expected floating point literal"); + return emitErrorAtTok() << "expected floating point literal"; } /// Parse a floating point value from the stream. diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e49abc33..ba9be3b030453 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -422,10 +422,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } if (auto floatType = dyn_cast(type)) { - std::optional result; - if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth()))) + auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); }; + FailureOr result = parseFloatFromIntegerLiteral( + emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics()); + if (failed(result)) return Attribute(); return FloatAttr::get(floatType, *result); } @@ -661,10 +661,10 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, // Handle hexadecimal float literals. if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) { - std::optional result; - if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, - eltTy.getFloatSemantics(), - eltTy.getWidth()))) + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr result = parseFloatFromIntegerLiteral( + emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); + if (failed(result)) return failure(); floatValues.push_back(*result); @@ -911,10 +911,12 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { auto floatType = cast(type); if (p.consumeIf(Token::integer)) { // Parse an integer literal as a float. - if (p.parseFloatFromIntegerLiteral(result, token, isNegative, - floatType.getFloatSemantics(), - floatType.getWidth())) + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr fromIntLit = parseFloatFromIntegerLiteral( + emitErrorAtTok, token, isNegative, floatType.getFloatSemantics()); + if (failed(fromIntLit)) return failure(); + result = *fromIntLit; } else if (p.consumeIf(Token::floatliteral)) { // Parse a floating point literal. std::optional val = token.getFloatingPointValue(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8f19487d80fa3..ac7eec931b125 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -67,6 +67,38 @@ using namespace mlir; using namespace mlir::detail; +/// Parse a floating point value from an integer literal token. +FailureOr detail::parseFloatFromIntegerLiteral( + function_ref emitError, const Token &tok, + bool isNegative, const llvm::fltSemantics &semantics) { + StringRef spelling = tok.getSpelling(); + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (!isHex) { + auto error = emitError(); + error << "unexpected decimal integer literal for a " + "floating point value"; + error.attachNote() << "add a trailing dot to make the literal a float"; + return failure(); + } + if (isNegative) { + emitError() << "hexadecimal float literal should not have a " + "leading minus"; + return failure(); + } + + APInt intValue; + tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); + auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics); + if (intValue.getActiveBits() > typeSizeInBits) { + return emitError() << "hexadecimal float constant out of range for type"; + return failure(); + } + + APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), + intValue.getRawData()); + return APFloat(semantics, truncatedValue); +} + //===----------------------------------------------------------------------===// // CodeComplete //===----------------------------------------------------------------------===// @@ -347,37 +379,6 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { return success(); } -/// Parse a floating point value from an integer literal token. -ParseResult Parser::parseFloatFromIntegerLiteral( - std::optional &result, const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, size_t typeSizeInBits) { - SMLoc loc = tok.getLoc(); - StringRef spelling = tok.getSpelling(); - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - if (!isHex) { - return emitError(loc, "unexpected decimal integer literal for a " - "floating point value") - .attachNote() - << "add a trailing dot to make the literal a float"; - } - if (isNegative) { - return emitError(loc, "hexadecimal float literal should not have a " - "leading minus"); - } - - APInt intValue; - tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue); - if (intValue.getActiveBits() > typeSizeInBits) - return emitError(loc, "hexadecimal float constant out of range for type"); - - APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), - intValue.getRawData()); - - result.emplace(semantics, truncatedValue); - - return success(); -} - ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { // Check that the current token is a keyword. if (!isCurrentTokenAKeyword()) diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index bf91831798056..fa29264ffe506 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -16,6 +16,12 @@ namespace mlir { namespace detail { +/// Parse a floating point value from an integer literal token. +FailureOr +parseFloatFromIntegerLiteral(function_ref emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// @@ -151,12 +157,6 @@ class Parser { /// Parse an optional integer value only in decimal format from the stream. OptionalParseResult parseOptionalDecimalInteger(APInt &result); - /// Parse a floating point value from an integer literal token. - ParseResult parseFloatFromIntegerLiteral(std::optional &result, - const Token &tok, bool isNegative, - const llvm::fltSemantics &semantics, - size_t typeSizeInBits); - /// Returns true if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { return getToken().isAny(Token::bare_identifier, Token::inttype) ||