Skip to content

[SPARK-52706][SQL] Fix inconsistencies and refactor primitive types in parser #51335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,20 @@ collateClause
: COLLATE collationName=multipartIdentifier
;

type
nonTrivialPrimitiveType
: STRING collateClause?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be primitiveTypeWithoutParameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can go, but in this case I would say collation is a parameter as well. It can change it's value to some different value not known at parsing time. If we follow this case, then probably INTERVAL should go to primitiveTypeWithoutParameters, as it is actually without parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not known at parsing time, I mean identifier/arbitrary value.

| (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
| VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
| (DECIMAL | DEC | NUMERIC)
(LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)?
| INTERVAL
(fromYearMonth=(YEAR | MONTH) (TO to=MONTH)? |
fromDayTime=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | SECOND))?)?
| TIMESTAMP (WITHOUT TIME ZONE)?
| TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? (WITHOUT TIME ZONE)?
;

trivialPrimitiveType
: BOOLEAN
| TINYINT | BYTE
| SMALLINT | SHORT
Expand All @@ -1349,32 +1362,23 @@ type
| FLOAT | REAL
| DOUBLE
| DATE
| TIME
| TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
| STRING collateClause?
| CHARACTER | CHAR
| VARCHAR
| TIMESTAMP_LTZ | TIMESTAMP_NTZ
| BINARY
| DECIMAL | DEC | NUMERIC
| VOID
| INTERVAL
| VARIANT
| ARRAY | STRUCT | MAP
| unsupportedType=identifier
;

primitiveType
: nonTrivialPrimitiveType
| trivialPrimitiveType
| unsupportedType=identifier (LEFT_PAREN INTEGER_VALUE(COMMA INTEGER_VALUE)* RIGHT_PAREN)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we get an unsupportedType with a suffix like: TIME WITH TIME ZONE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this potentially can make a problem. I mean, we would probably return a bad error message. Let me think if we can scope issues like this as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually error message will stay the same, even previously we would return syntax error. The only thing here is if we want to improve the error messages a bit further?

;

dataType
: complex=ARRAY LT dataType GT #complexDataType
| complex=MAP LT dataType COMMA dataType GT #complexDataType
| complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType
| INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType
| INTERVAL from=(DAY | HOUR | MINUTE | SECOND)
(TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType
| TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)?
(WITHOUT TIME ZONE)? #timeDataType
| (TIMESTAMP_NTZ | TIMESTAMP WITHOUT TIME ZONE) #timestampNtzDataType
| type (LEFT_PAREN INTEGER_VALUE
(COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType
: complex=ARRAY (LT dataType GT)? #complexDataType
| complex=MAP (LT dataType COMMA dataType GT)? #complexDataType
| complex=STRUCT ((LT complexColTypeList? GT) | NEQ)? #complexDataType
| primitiveType #primitiveDataType
;

qualifiedColTypeWithPositionList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,74 +65,89 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
ctx.parts.asScala.map(_.getText).toSeq
}

/**
* Resolve/create the TIME primitive type.
*/
override def visitTimeDataType(ctx: TimeDataTypeContext): DataType = withOrigin(ctx) {
val precision = if (ctx.precision == null) {
TimeType.DEFAULT_PRECISION
} else {
ctx.precision.getText.toInt
}
TimeType(precision)
}

/**
* Create the TIMESTAMP_NTZ primitive type.
*/
override def visitTimestampNtzDataType(ctx: TimestampNtzDataTypeContext): DataType = {
withOrigin(ctx)(TimestampNTZType)
}

/**
* Resolve/create a primitive type.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
val typeCtx = ctx.`type`
(typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match {
case (BOOLEAN, Nil) => BooleanType
case (TINYINT | BYTE, Nil) => ByteType
case (SMALLINT | SHORT, Nil) => ShortType
case (INT | INTEGER, Nil) => IntegerType
case (BIGINT | LONG, Nil) => LongType
case (FLOAT | REAL, Nil) => FloatType
case (DOUBLE, Nil) => DoubleType
case (DATE, Nil) => DateType
case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
case (TIMESTAMP_LTZ, Nil) => TimestampType
case (STRING, Nil) =>
typeCtx.children.asScala.toSeq match {
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationNameParts = visitCollateClause(ctx).toArray
val collationId = CollationFactory.collationNameToId(
CollationFactory.resolveFullyQualifiedName(collationNameParts))
StringType(collationId)
}
case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
case (BINARY, Nil) => BinaryType
case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT
case (DECIMAL | DEC | NUMERIC, precision :: Nil) =>
DecimalType(precision.getText.toInt, 0)
case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
case (VOID, Nil) => NullType
case (INTERVAL, Nil) => CalendarIntervalType
case (VARIANT, Nil) => VariantType
case (CHARACTER | CHAR | VARCHAR, Nil) =>
throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
case (ARRAY | STRUCT | MAP, Nil) =>
throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
case (_, params) =>
val badType = ctx.`type`.getText
val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})" else badType
throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
val typeCtx = ctx.primitiveType
if (typeCtx.nonTrivialPrimitiveType != null) {
// This is a primitive type with parameters, e.g. VARCHAR(10), DECIMAL(10, 2), etc.
val currentCtx = typeCtx.nonTrivialPrimitiveType
currentCtx.start.getType match {
case STRING =>
currentCtx.children.asScala.toSeq match {
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationNameParts = visitCollateClause(ctx).toArray
val collationId = CollationFactory.collationNameToId(
CollationFactory.resolveFullyQualifiedName(collationNameParts))
StringType(collationId)
}
case CHARACTER | CHAR =>
if (currentCtx.length == null) {
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
} else CharType(currentCtx.length.getText.toInt)
case VARCHAR =>
if (currentCtx.length == null) {
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
} else VarcharType(currentCtx.length.getText.toInt)
case DECIMAL | DEC | NUMERIC =>
if (currentCtx.precision == null) {
DecimalType.USER_DEFAULT
} else if (currentCtx.scale == null) {
DecimalType(currentCtx.precision.getText.toInt, 0)
} else {
DecimalType(currentCtx.precision.getText.toInt, currentCtx.scale.getText.toInt)
}
case INTERVAL =>
if (currentCtx.fromDayTime != null) {
visitDayTimeIntervalDataType(currentCtx)
} else if (currentCtx.fromYearMonth != null) {
visitYearMonthIntervalDataType(currentCtx)
} else {
CalendarIntervalType
}
case TIMESTAMP =>
if (currentCtx.WITHOUT() == null) {
SqlApiConf.get.timestampType
} else TimestampNTZType
case TIME =>
val precision = if (currentCtx.precision == null) {
TimeType.DEFAULT_PRECISION
} else {
currentCtx.precision.getText.toInt
}
TimeType(precision)
}
} else if (typeCtx.trivialPrimitiveType != null) {
// This is a primitive type without parameters, e.g. BOOLEAN, TINYINT, etc.
typeCtx.trivialPrimitiveType.start.getType match {
case BOOLEAN => BooleanType
case TINYINT | BYTE => ByteType
case SMALLINT | SHORT => ShortType
case INT | INTEGER => IntegerType
case BIGINT | LONG => LongType
case FLOAT | REAL => FloatType
case DOUBLE => DoubleType
case DATE => DateType
case TIMESTAMP_LTZ => TimestampType
case TIMESTAMP_NTZ => TimestampNTZType
case BINARY => BinaryType
case VOID => NullType
case VARIANT => VariantType
}
} else {
val badType = typeCtx.unsupportedType.getText
val params = typeCtx.INTEGER_VALUE().asScala.toList
val dtStr =
if (params.nonEmpty) s"$badType(${params.mkString(",")})"
else badType
throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
}
}

override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = {
val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
private def visitYearMonthIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = {
val startStr = ctx.fromYearMonth.getText.toLowerCase(Locale.ROOT)
val start = YearMonthIntervalType.stringToField(startStr)
if (ctx.to != null) {
val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
Expand All @@ -146,8 +161,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
}
}

override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = {
val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
private def visitDayTimeIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = {
val startStr = ctx.fromDayTime.getText.toLowerCase(Locale.ROOT)
val start = DayTimeIntervalType.stringToField(startStr)
if (ctx.to != null) {
val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
Expand All @@ -165,6 +180,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
* Create a complex DataType. Arrays, Maps and Structures are supported.
*/
override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
if (ctx.LT() == null && ctx.NEQ() == null) {
throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a new error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is refactoring that I did. Previously if someone only writes STRUCT/ARRAY/MAP without parameters it would go to path of the primitive type. This is not a good practice, so I made a change that complex types are isolated in the separate context. The only change here is that we would only change the error message for when someone writes STRUCT(2) where it would return unsupported primitive type instead of the complex type missing element. We could argue that this is a change, but if you ask me, we need to distinguish between primitive and complex types first, as this is a general practice in type theory. We have primitive types which can be used as leaf arguments in complex types, we do not want to go into some recursive link between the two.

}
ctx.complex.getType match {
case SqlBaseParser.ARRAY =>
ArrayType(typedVisit(ctx.dataType(0)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}

def charTypeMissingLengthError(dataType: String, ctx: PrimitiveDataTypeContext): Throwable = {
def charVarcharTypeMissingLengthError(
dataType: String,
ctx: PrimitiveDataTypeContext): Throwable = {
new ParseException(
errorClass = "DATATYPE_MISSING_SIZE",
messageParameters = Map("type" -> toSQLType(dataType)),
Expand All @@ -333,7 +335,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {

def nestedTypeMissingElementTypeError(
dataType: String,
ctx: PrimitiveDataTypeContext): Throwable = {
ctx: ComplexDataTypeContext): Throwable = {
dataType.toUpperCase(Locale.ROOT) match {
case "ARRAY" =>
new ParseException(
Expand Down