@@ -65,74 +65,89 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
65
65
ctx.parts.asScala.map(_.getText).toSeq
66
66
}
67
67
68
- /**
69
- * Resolve/create the TIME primitive type.
70
- */
71
- override def visitTimeDataType (ctx : TimeDataTypeContext ): DataType = withOrigin(ctx) {
72
- val precision = if (ctx.precision == null ) {
73
- TimeType .DEFAULT_PRECISION
74
- } else {
75
- ctx.precision.getText.toInt
76
- }
77
- TimeType (precision)
78
- }
79
-
80
- /**
81
- * Create the TIMESTAMP_NTZ primitive type.
82
- */
83
- override def visitTimestampNtzDataType (ctx : TimestampNtzDataTypeContext ): DataType = {
84
- withOrigin(ctx)(TimestampNTZType )
85
- }
86
-
87
68
/**
88
69
* Resolve/create a primitive type.
89
70
*/
90
71
override def visitPrimitiveDataType (ctx : PrimitiveDataTypeContext ): DataType = withOrigin(ctx) {
91
- val typeCtx = ctx.`type`
92
- (typeCtx.start.getType, ctx.INTEGER_VALUE ().asScala.toList) match {
93
- case (BOOLEAN , Nil ) => BooleanType
94
- case (TINYINT | BYTE , Nil ) => ByteType
95
- case (SMALLINT | SHORT , Nil ) => ShortType
96
- case (INT | INTEGER , Nil ) => IntegerType
97
- case (BIGINT | LONG , Nil ) => LongType
98
- case (FLOAT | REAL , Nil ) => FloatType
99
- case (DOUBLE , Nil ) => DoubleType
100
- case (DATE , Nil ) => DateType
101
- case (TIMESTAMP , Nil ) => SqlApiConf .get.timestampType
102
- case (TIMESTAMP_LTZ , Nil ) => TimestampType
103
- case (STRING , Nil ) =>
104
- typeCtx.children.asScala.toSeq match {
105
- case Seq (_) => StringType
106
- case Seq (_, ctx : CollateClauseContext ) =>
107
- val collationNameParts = visitCollateClause(ctx).toArray
108
- val collationId = CollationFactory .collationNameToId(
109
- CollationFactory .resolveFullyQualifiedName(collationNameParts))
110
- StringType (collationId)
111
- }
112
- case (CHARACTER | CHAR , length :: Nil ) => CharType (length.getText.toInt)
113
- case (VARCHAR , length :: Nil ) => VarcharType (length.getText.toInt)
114
- case (BINARY , Nil ) => BinaryType
115
- case (DECIMAL | DEC | NUMERIC , Nil ) => DecimalType .USER_DEFAULT
116
- case (DECIMAL | DEC | NUMERIC , precision :: Nil ) =>
117
- DecimalType (precision.getText.toInt, 0 )
118
- case (DECIMAL | DEC | NUMERIC , precision :: scale :: Nil ) =>
119
- DecimalType (precision.getText.toInt, scale.getText.toInt)
120
- case (VOID , Nil ) => NullType
121
- case (INTERVAL , Nil ) => CalendarIntervalType
122
- case (VARIANT , Nil ) => VariantType
123
- case (CHARACTER | CHAR | VARCHAR , Nil ) =>
124
- throw QueryParsingErrors .charTypeMissingLengthError(ctx.`type`.getText, ctx)
125
- case (ARRAY | STRUCT | MAP , Nil ) =>
126
- throw QueryParsingErrors .nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
127
- case (_, params) =>
128
- val badType = ctx.`type`.getText
129
- val dtStr = if (params.nonEmpty) s " $badType( ${params.mkString(" ," )}) " else badType
130
- throw QueryParsingErrors .dataTypeUnsupportedError(dtStr, ctx)
72
+ val typeCtx = ctx.primitiveType
73
+ if (typeCtx.nonTrivialPrimitiveType != null ) {
74
+ // This is a primitive type with parameters, e.g. VARCHAR(10), DECIMAL(10, 2), etc.
75
+ val currentCtx = typeCtx.nonTrivialPrimitiveType
76
+ currentCtx.start.getType match {
77
+ case STRING =>
78
+ currentCtx.children.asScala.toSeq match {
79
+ case Seq (_) => StringType
80
+ case Seq (_, ctx : CollateClauseContext ) =>
81
+ val collationNameParts = visitCollateClause(ctx).toArray
82
+ val collationId = CollationFactory .collationNameToId(
83
+ CollationFactory .resolveFullyQualifiedName(collationNameParts))
84
+ StringType (collationId)
85
+ }
86
+ case CHARACTER | CHAR =>
87
+ if (currentCtx.length == null ) {
88
+ throw QueryParsingErrors .charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
89
+ } else CharType (currentCtx.length.getText.toInt)
90
+ case VARCHAR =>
91
+ if (currentCtx.length == null ) {
92
+ throw QueryParsingErrors .charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
93
+ } else VarcharType (currentCtx.length.getText.toInt)
94
+ case DECIMAL | DEC | NUMERIC =>
95
+ if (currentCtx.precision == null ) {
96
+ DecimalType .USER_DEFAULT
97
+ } else if (currentCtx.scale == null ) {
98
+ DecimalType (currentCtx.precision.getText.toInt, 0 )
99
+ } else {
100
+ DecimalType (currentCtx.precision.getText.toInt, currentCtx.scale.getText.toInt)
101
+ }
102
+ case INTERVAL =>
103
+ if (currentCtx.fromDayTime != null ) {
104
+ visitDayTimeIntervalDataType(currentCtx)
105
+ } else if (currentCtx.fromYearMonth != null ) {
106
+ visitYearMonthIntervalDataType(currentCtx)
107
+ } else {
108
+ CalendarIntervalType
109
+ }
110
+ case TIMESTAMP =>
111
+ if (currentCtx.WITHOUT () == null ) {
112
+ SqlApiConf .get.timestampType
113
+ } else TimestampNTZType
114
+ case TIME =>
115
+ val precision = if (currentCtx.precision == null ) {
116
+ TimeType .DEFAULT_PRECISION
117
+ } else {
118
+ currentCtx.precision.getText.toInt
119
+ }
120
+ TimeType (precision)
121
+ }
122
+ } else if (typeCtx.trivialPrimitiveType != null ) {
123
+ // This is a primitive type without parameters, e.g. BOOLEAN, TINYINT, etc.
124
+ typeCtx.trivialPrimitiveType.start.getType match {
125
+ case BOOLEAN => BooleanType
126
+ case TINYINT | BYTE => ByteType
127
+ case SMALLINT | SHORT => ShortType
128
+ case INT | INTEGER => IntegerType
129
+ case BIGINT | LONG => LongType
130
+ case FLOAT | REAL => FloatType
131
+ case DOUBLE => DoubleType
132
+ case DATE => DateType
133
+ case TIMESTAMP_LTZ => TimestampType
134
+ case TIMESTAMP_NTZ => TimestampNTZType
135
+ case BINARY => BinaryType
136
+ case VOID => NullType
137
+ case VARIANT => VariantType
138
+ }
139
+ } else {
140
+ val badType = typeCtx.unsupportedType.getText
141
+ val params = typeCtx.INTEGER_VALUE ().asScala.toList
142
+ val dtStr =
143
+ if (params.nonEmpty) s " $badType( ${params.mkString(" ," )}) "
144
+ else badType
145
+ throw QueryParsingErrors .dataTypeUnsupportedError(dtStr, ctx)
131
146
}
132
147
}
133
148
134
- override def visitYearMonthIntervalDataType (ctx : YearMonthIntervalDataTypeContext ): DataType = {
135
- val startStr = ctx.from .getText.toLowerCase(Locale .ROOT )
149
+ private def visitYearMonthIntervalDataType (ctx : NonTrivialPrimitiveTypeContext ): DataType = {
150
+ val startStr = ctx.fromYearMonth .getText.toLowerCase(Locale .ROOT )
136
151
val start = YearMonthIntervalType .stringToField(startStr)
137
152
if (ctx.to != null ) {
138
153
val endStr = ctx.to.getText.toLowerCase(Locale .ROOT )
@@ -146,8 +161,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
146
161
}
147
162
}
148
163
149
- override def visitDayTimeIntervalDataType (ctx : DayTimeIntervalDataTypeContext ): DataType = {
150
- val startStr = ctx.from .getText.toLowerCase(Locale .ROOT )
164
+ private def visitDayTimeIntervalDataType (ctx : NonTrivialPrimitiveTypeContext ): DataType = {
165
+ val startStr = ctx.fromDayTime .getText.toLowerCase(Locale .ROOT )
151
166
val start = DayTimeIntervalType .stringToField(startStr)
152
167
if (ctx.to != null ) {
153
168
val endStr = ctx.to.getText.toLowerCase(Locale .ROOT )
@@ -165,6 +180,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
165
180
* Create a complex DataType. Arrays, Maps and Structures are supported.
166
181
*/
167
182
override def visitComplexDataType (ctx : ComplexDataTypeContext ): DataType = withOrigin(ctx) {
183
+ if (ctx.LT () == null && ctx.NEQ () == null ) {
184
+ throw QueryParsingErrors .nestedTypeMissingElementTypeError(ctx.getText, ctx)
185
+ }
168
186
ctx.complex.getType match {
169
187
case SqlBaseParser .ARRAY =>
170
188
ArrayType (typedVisit(ctx.dataType(0 )))
0 commit comments