Skip to content

Commit 6afbfaf

Browse files
mihailom-dbasl3
authored andcommitted
[SPARK-52706][SQL] Fix inconsistencies and refactor primitive types in parser
### What changes were proposed in this pull request? This PR proposes a change in how our parser treats datatypes. We introduce types with/without parameters and group accordingly. ### Why are the changes needed? Changes are needed for many reasons: 1. Context of primitiveDataType is constantly getting bigger. This is not a good practice, as we have many null fields which only take up memory. 2. We have inconsistencies in where we use each type. We get TIMESTAMP_NTZ in a separate rule, but we also mention it in primitive types. 3. Primitive types should stay related to primitive types, adding ARRAY, STRUCT, MAP in the rule just because it is convenient is not good practice. 4. Current structure does not give option of extending types with different features. For example, we introduced STRING collations, but what if we were to introduce CHAR/VARCHAR with collations. Current structure gives us 0 possibility of making a type CHAR(5) COLLATE UTF8_BINARY (We can only do CHAR COLLATE UTF8_BINARY (5)). ### Does this PR introduce _any_ user-facing change? No. This is internal refactoring. ### How was this patch tested? All existing tests should pass, this is just code refactoring. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51335 from mihailom-db/restructure-primitive. Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent fdee7cb commit 6afbfaf

File tree

3 files changed

+110
-86
lines changed

3 files changed

+110
-86
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,20 @@ collateClause
13401340
: COLLATE collationName=multipartIdentifier
13411341
;
13421342

1343-
type
1343+
nonTrivialPrimitiveType
1344+
: STRING collateClause?
1345+
| (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
1346+
| VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
1347+
| (DECIMAL | DEC | NUMERIC)
1348+
(LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)?
1349+
| INTERVAL
1350+
(fromYearMonth=(YEAR | MONTH) (TO to=MONTH)? |
1351+
fromDayTime=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | SECOND))?)?
1352+
| TIMESTAMP (WITHOUT TIME ZONE)?
1353+
| TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? (WITHOUT TIME ZONE)?
1354+
;
1355+
1356+
trivialPrimitiveType
13441357
: BOOLEAN
13451358
| TINYINT | BYTE
13461359
| SMALLINT | SHORT
@@ -1349,32 +1362,23 @@ type
13491362
| FLOAT | REAL
13501363
| DOUBLE
13511364
| DATE
1352-
| TIME
1353-
| TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
1354-
| STRING collateClause?
1355-
| CHARACTER | CHAR
1356-
| VARCHAR
1365+
| TIMESTAMP_LTZ | TIMESTAMP_NTZ
13571366
| BINARY
1358-
| DECIMAL | DEC | NUMERIC
13591367
| VOID
1360-
| INTERVAL
13611368
| VARIANT
1362-
| ARRAY | STRUCT | MAP
1363-
| unsupportedType=identifier
1369+
;
1370+
1371+
primitiveType
1372+
: nonTrivialPrimitiveType
1373+
| trivialPrimitiveType
1374+
| unsupportedType=identifier (LEFT_PAREN INTEGER_VALUE(COMMA INTEGER_VALUE)* RIGHT_PAREN)?
13641375
;
13651376

13661377
dataType
1367-
: complex=ARRAY LT dataType GT #complexDataType
1368-
| complex=MAP LT dataType COMMA dataType GT #complexDataType
1369-
| complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType
1370-
| INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType
1371-
| INTERVAL from=(DAY | HOUR | MINUTE | SECOND)
1372-
(TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType
1373-
| TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)?
1374-
(WITHOUT TIME ZONE)? #timeDataType
1375-
| (TIMESTAMP_NTZ | TIMESTAMP WITHOUT TIME ZONE) #timestampNtzDataType
1376-
| type (LEFT_PAREN INTEGER_VALUE
1377-
(COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType
1378+
: complex=ARRAY (LT dataType GT)? #complexDataType
1379+
| complex=MAP (LT dataType COMMA dataType GT)? #complexDataType
1380+
| complex=STRUCT ((LT complexColTypeList? GT) | NEQ)? #complexDataType
1381+
| primitiveType #primitiveDataType
13781382
;
13791383

13801384
qualifiedColTypeWithPositionList

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala

Lines changed: 81 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -65,74 +65,89 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
6565
ctx.parts.asScala.map(_.getText).toSeq
6666
}
6767

68-
/**
69-
* Resolve/create the TIME primitive type.
70-
*/
71-
override def visitTimeDataType(ctx: TimeDataTypeContext): DataType = withOrigin(ctx) {
72-
val precision = if (ctx.precision == null) {
73-
TimeType.DEFAULT_PRECISION
74-
} else {
75-
ctx.precision.getText.toInt
76-
}
77-
TimeType(precision)
78-
}
79-
80-
/**
81-
* Create the TIMESTAMP_NTZ primitive type.
82-
*/
83-
override def visitTimestampNtzDataType(ctx: TimestampNtzDataTypeContext): DataType = {
84-
withOrigin(ctx)(TimestampNTZType)
85-
}
86-
8768
/**
8869
* Resolve/create a primitive type.
8970
*/
9071
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
91-
val typeCtx = ctx.`type`
92-
(typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match {
93-
case (BOOLEAN, Nil) => BooleanType
94-
case (TINYINT | BYTE, Nil) => ByteType
95-
case (SMALLINT | SHORT, Nil) => ShortType
96-
case (INT | INTEGER, Nil) => IntegerType
97-
case (BIGINT | LONG, Nil) => LongType
98-
case (FLOAT | REAL, Nil) => FloatType
99-
case (DOUBLE, Nil) => DoubleType
100-
case (DATE, Nil) => DateType
101-
case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
102-
case (TIMESTAMP_LTZ, Nil) => TimestampType
103-
case (STRING, Nil) =>
104-
typeCtx.children.asScala.toSeq match {
105-
case Seq(_) => StringType
106-
case Seq(_, ctx: CollateClauseContext) =>
107-
val collationNameParts = visitCollateClause(ctx).toArray
108-
val collationId = CollationFactory.collationNameToId(
109-
CollationFactory.resolveFullyQualifiedName(collationNameParts))
110-
StringType(collationId)
111-
}
112-
case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
113-
case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
114-
case (BINARY, Nil) => BinaryType
115-
case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT
116-
case (DECIMAL | DEC | NUMERIC, precision :: Nil) =>
117-
DecimalType(precision.getText.toInt, 0)
118-
case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) =>
119-
DecimalType(precision.getText.toInt, scale.getText.toInt)
120-
case (VOID, Nil) => NullType
121-
case (INTERVAL, Nil) => CalendarIntervalType
122-
case (VARIANT, Nil) => VariantType
123-
case (CHARACTER | CHAR | VARCHAR, Nil) =>
124-
throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
125-
case (ARRAY | STRUCT | MAP, Nil) =>
126-
throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
127-
case (_, params) =>
128-
val badType = ctx.`type`.getText
129-
val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})" else badType
130-
throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
72+
val typeCtx = ctx.primitiveType
73+
if (typeCtx.nonTrivialPrimitiveType != null) {
74+
// This is a primitive type with parameters, e.g. VARCHAR(10), DECIMAL(10, 2), etc.
75+
val currentCtx = typeCtx.nonTrivialPrimitiveType
76+
currentCtx.start.getType match {
77+
case STRING =>
78+
currentCtx.children.asScala.toSeq match {
79+
case Seq(_) => StringType
80+
case Seq(_, ctx: CollateClauseContext) =>
81+
val collationNameParts = visitCollateClause(ctx).toArray
82+
val collationId = CollationFactory.collationNameToId(
83+
CollationFactory.resolveFullyQualifiedName(collationNameParts))
84+
StringType(collationId)
85+
}
86+
case CHARACTER | CHAR =>
87+
if (currentCtx.length == null) {
88+
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
89+
} else CharType(currentCtx.length.getText.toInt)
90+
case VARCHAR =>
91+
if (currentCtx.length == null) {
92+
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
93+
} else VarcharType(currentCtx.length.getText.toInt)
94+
case DECIMAL | DEC | NUMERIC =>
95+
if (currentCtx.precision == null) {
96+
DecimalType.USER_DEFAULT
97+
} else if (currentCtx.scale == null) {
98+
DecimalType(currentCtx.precision.getText.toInt, 0)
99+
} else {
100+
DecimalType(currentCtx.precision.getText.toInt, currentCtx.scale.getText.toInt)
101+
}
102+
case INTERVAL =>
103+
if (currentCtx.fromDayTime != null) {
104+
visitDayTimeIntervalDataType(currentCtx)
105+
} else if (currentCtx.fromYearMonth != null) {
106+
visitYearMonthIntervalDataType(currentCtx)
107+
} else {
108+
CalendarIntervalType
109+
}
110+
case TIMESTAMP =>
111+
if (currentCtx.WITHOUT() == null) {
112+
SqlApiConf.get.timestampType
113+
} else TimestampNTZType
114+
case TIME =>
115+
val precision = if (currentCtx.precision == null) {
116+
TimeType.DEFAULT_PRECISION
117+
} else {
118+
currentCtx.precision.getText.toInt
119+
}
120+
TimeType(precision)
121+
}
122+
} else if (typeCtx.trivialPrimitiveType != null) {
123+
// This is a primitive type without parameters, e.g. BOOLEAN, TINYINT, etc.
124+
typeCtx.trivialPrimitiveType.start.getType match {
125+
case BOOLEAN => BooleanType
126+
case TINYINT | BYTE => ByteType
127+
case SMALLINT | SHORT => ShortType
128+
case INT | INTEGER => IntegerType
129+
case BIGINT | LONG => LongType
130+
case FLOAT | REAL => FloatType
131+
case DOUBLE => DoubleType
132+
case DATE => DateType
133+
case TIMESTAMP_LTZ => TimestampType
134+
case TIMESTAMP_NTZ => TimestampNTZType
135+
case BINARY => BinaryType
136+
case VOID => NullType
137+
case VARIANT => VariantType
138+
}
139+
} else {
140+
val badType = typeCtx.unsupportedType.getText
141+
val params = typeCtx.INTEGER_VALUE().asScala.toList
142+
val dtStr =
143+
if (params.nonEmpty) s"$badType(${params.mkString(",")})"
144+
else badType
145+
throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
131146
}
132147
}
133148

134-
override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = {
135-
val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
149+
private def visitYearMonthIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = {
150+
val startStr = ctx.fromYearMonth.getText.toLowerCase(Locale.ROOT)
136151
val start = YearMonthIntervalType.stringToField(startStr)
137152
if (ctx.to != null) {
138153
val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
@@ -146,8 +161,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
146161
}
147162
}
148163

149-
override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = {
150-
val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
164+
private def visitDayTimeIntervalDataType(ctx: NonTrivialPrimitiveTypeContext): DataType = {
165+
val startStr = ctx.fromDayTime.getText.toLowerCase(Locale.ROOT)
151166
val start = DayTimeIntervalType.stringToField(startStr)
152167
if (ctx.to != null) {
153168
val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
@@ -165,6 +180,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
165180
* Create a complex DataType. Arrays, Maps and Structures are supported.
166181
*/
167182
override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
183+
if (ctx.LT() == null && ctx.NEQ() == null) {
184+
throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, ctx)
185+
}
168186
ctx.complex.getType match {
169187
case SqlBaseParser.ARRAY =>
170188
ArrayType(typedVisit(ctx.dataType(0)))

sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
324324
ctx)
325325
}
326326

327-
def charTypeMissingLengthError(dataType: String, ctx: PrimitiveDataTypeContext): Throwable = {
327+
def charVarcharTypeMissingLengthError(
328+
dataType: String,
329+
ctx: PrimitiveDataTypeContext): Throwable = {
328330
new ParseException(
329331
errorClass = "DATATYPE_MISSING_SIZE",
330332
messageParameters = Map("type" -> toSQLType(dataType)),
@@ -333,7 +335,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
333335

334336
def nestedTypeMissingElementTypeError(
335337
dataType: String,
336-
ctx: PrimitiveDataTypeContext): Throwable = {
338+
ctx: ComplexDataTypeContext): Throwable = {
337339
dataType.toUpperCase(Locale.ROOT) match {
338340
case "ARRAY" =>
339341
new ParseException(

0 commit comments

Comments
 (0)