diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index fc3d86ca858f0..698afa4860027 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1340,7 +1340,20 @@ collateClause : COLLATE collationName=multipartIdentifier ; -type +nonTrivialPrimitiveType + : STRING collateClause? + | (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? + | VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? + | (DECIMAL | DEC | NUMERIC) + (LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)? + | INTERVAL + (fromYearMonth=(YEAR | MONTH) (TO to=MONTH)? | + fromDayTime=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | SECOND))?)? + | TIMESTAMP (WITHOUT TIME ZONE)? + | TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? (WITHOUT TIME ZONE)? + ; + +trivialPrimitiveType : BOOLEAN | TINYINT | BYTE | SMALLINT | SHORT @@ -1349,32 +1362,23 @@ type | FLOAT | REAL | DOUBLE | DATE - | TIME - | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ - | STRING collateClause? - | CHARACTER | CHAR - | VARCHAR + | TIMESTAMP_LTZ | TIMESTAMP_NTZ | BINARY - | DECIMAL | DEC | NUMERIC | VOID - | INTERVAL | VARIANT - | ARRAY | STRUCT | MAP - | unsupportedType=identifier + ; + +primitiveType + : nonTrivialPrimitiveType + | trivialPrimitiveType + | unsupportedType=identifier (LEFT_PAREN INTEGER_VALUE(COMMA INTEGER_VALUE)* RIGHT_PAREN)? ; dataType - : complex=ARRAY LT dataType GT #complexDataType - | complex=MAP LT dataType COMMA dataType GT #complexDataType - | complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType - | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType - | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) - (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType - | TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? - (WITHOUT TIME ZONE)? #timeDataType - | (TIMESTAMP_NTZ | TIMESTAMP WITHOUT TIME ZONE) #timestampNtzDataType - | type (LEFT_PAREN INTEGER_VALUE - (COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType + : complex=ARRAY (LT dataType GT)? #complexDataType + | complex=MAP (LT dataType COMMA dataType GT)? #complexDataType + | complex=STRUCT ((LT complexColTypeList? GT) | NEQ)? #complexDataType + | primitiveType #primitiveDataType ; qualifiedColTypeWithPositionList diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index e83a987263db4..beb7061a841a8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -65,74 +65,89 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { ctx.parts.asScala.map(_.getText).toSeq } - /** - * Resolve/create the TIME primitive type. - */ - override def visitTimeDataType(ctx: TimeDataTypeContext): DataType = withOrigin(ctx) { - val precision = if (ctx.precision == null) { - TimeType.DEFAULT_PRECISION - } else { - ctx.precision.getText.toInt - } - TimeType(precision) - } - - /** - * Create the TIMESTAMP_NTZ primitive type. - */ - override def visitTimestampNtzDataType(ctx: TimestampNtzDataTypeContext): DataType = { - withOrigin(ctx)(TimestampNTZType) - } - /** * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - val typeCtx = ctx.`type` - (typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match { - case (BOOLEAN, Nil) => BooleanType - case (TINYINT | BYTE, Nil) => ByteType - case (SMALLINT | SHORT, Nil) => ShortType - case (INT | INTEGER, Nil) => IntegerType - case (BIGINT | LONG, Nil) => LongType - case (FLOAT | REAL, Nil) => FloatType - case (DOUBLE, Nil) => DoubleType - case (DATE, Nil) => DateType - case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType - case (TIMESTAMP_LTZ, Nil) => TimestampType - case (STRING, Nil) => - typeCtx.children.asScala.toSeq match { - case Seq(_) => StringType - case Seq(_, ctx: CollateClauseContext) => - val collationNameParts = visitCollateClause(ctx).toArray - val collationId = CollationFactory.collationNameToId( - CollationFactory.resolveFullyQualifiedName(collationNameParts)) - StringType(collationId) - } - case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt) - case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt) - case (BINARY, Nil) => BinaryType - case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT - case (DECIMAL | DEC | NUMERIC, precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case (VOID, Nil) => NullType - case (INTERVAL, Nil) => CalendarIntervalType - case (VARIANT, Nil) => VariantType - case (CHARACTER | CHAR | VARCHAR, Nil) => - throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx) - case (ARRAY | STRUCT | MAP, Nil) => - throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx) - case (_, params) => - val badType = ctx.`type`.getText - val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})" else badType - throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx) + val typeCtx = ctx.primitiveType + if (typeCtx.nonTrivialPrimitiveType != null) { + // This is a primitive type with parameters, e.g. VARCHAR(10), DECIMAL(10, 2), etc. + val currentCtx = typeCtx.nonTrivialPrimitiveType + currentCtx.start.getType match { + case STRING => + currentCtx.children.asScala.toSeq match { + case Seq(_) => StringType + case Seq(_, ctx: CollateClauseContext) => + val collationNameParts = visitCollateClause(ctx).toArray + val collationId = CollationFactory.collationNameToId( + CollationFactory.resolveFullyQualifiedName(collationNameParts)) + StringType(collationId) + } + case CHARACTER | CHAR => + if (currentCtx.length == null) { + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + } else CharType(currentCtx.length.getText.toInt) + case VARCHAR => + if (currentCtx.length == null) { + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + } else VarcharType(currentCtx.length.getText.toInt) + case DECIMAL | DEC | NUMERIC => + if (currentCtx.precision == null) { + DecimalType.USER_DEFAULT + } else if (currentCtx.scale == null) { + DecimalType(currentCtx.precision.getText.toInt, 0) + } else { + DecimalType(currentCtx.precision.getText.toInt, currentCtx.scale.getText.toInt) + } + case INTERVAL => + if (currentCtx.fromDayTime != null) { + visitDayTimeIntervalDataType(currentCtx) + } else if (currentCtx.fromYearMonth != null) { + visitYearMonthIntervalDataType(currentCtx) + } else { + CalendarIntervalType + } + case TIMESTAMP => + if (currentCtx.WITHOUT() == null) { + SqlApiConf.get.timestampType + } else TimestampNTZType + case TIME => + val precision = if (currentCtx.precision == null) { + TimeType.DEFAULT_PRECISION + } else { + currentCtx.precision.getText.toInt + } + TimeType(precision) + } + } else if (typeCtx.trivialPrimitiveType != null) { + // This is a primitive type without parameters, e.g. BOOLEAN, TINYINT, etc. + typeCtx.trivialPrimitiveType.start.getType match { + case BOOLEAN => BooleanType + case TINYINT | BYTE => ByteType + case SMALLINT | SHORT => ShortType + case INT | INTEGER => IntegerType + case BIGINT | LONG => LongType + case FLOAT | REAL => FloatType + case DOUBLE => DoubleType + case DATE => DateType + case TIMESTAMP_LTZ => TimestampType + case TIMESTAMP_NTZ => TimestampNTZType + case BINARY => BinaryType + case VOID => NullType + case VARIANT => VariantType + } + } else { + val badType = typeCtx.unsupportedType.getText + val params = typeCtx.INTEGER_VALUE().asScala.toList + val dtStr = + if (params.nonEmpty) s"$badType(${params.mkString(",")})" + else badType + throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx) } } - override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = { - val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + private def visitYearMonthIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = { + val startStr = ctx.fromYearMonth.getText.toLowerCase(Locale.ROOT) val start = YearMonthIntervalType.stringToField(startStr) if (ctx.to != null) { val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) @@ -146,8 +161,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { } } - override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { - val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + private def visitDayTimeIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = { + val startStr = ctx.fromDayTime.getText.toLowerCase(Locale.ROOT) val start = DayTimeIntervalType.stringToField(startStr) if (ctx.to != null) { val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) @@ -165,6 +180,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { * Create a complex DataType. Arrays, Maps and Structures are supported. */ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + if (ctx.LT() == null && ctx.NEQ() == null) { + throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, ctx) + } ctx.complex.getType match { case SqlBaseParser.ARRAY => ArrayType(typedVisit(ctx.dataType(0))) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 12f986b89fd2b..60ccf7a9282cf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -324,7 +324,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } - def charTypeMissingLengthError(dataType: String, ctx: PrimitiveDataTypeContext): Throwable = { + def charVarcharTypeMissingLengthError( + dataType: String, + ctx: PrimitiveDataTypeContext): Throwable = { new ParseException( errorClass = "DATATYPE_MISSING_SIZE", messageParameters = Map("type" -> toSQLType(dataType)), @@ -333,7 +335,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def nestedTypeMissingElementTypeError( dataType: String, - ctx: PrimitiveDataTypeContext): Throwable = { + ctx: ComplexDataTypeContext): Throwable = { dataType.toUpperCase(Locale.ROOT) match { case "ARRAY" => new ParseException(