Skip to content

Commit 37bcaca

Browse files
committed
add sqlglot dialect test case
1 parent 81c2d04 commit 37bcaca

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
2+
from sqlglot import exp, parse_one
3+
from timeplus_connect.tp_superset.sql_parse import TimeplusSqlglotDialect
4+
5+
6+
dialect = TimeplusSqlglotDialect()
7+
8+
def validate_identity(sql: str) -> exp.Expression:
9+
ast: exp.Expression = parse_one(sql, read=dialect)
10+
11+
assert ast is not None
12+
assert ast.sql(dialect=dialect) == sql
13+
14+
return ast
15+
16+
def test_simple_sql():
17+
validate_identity("CAST(1 AS bool)")
18+
validate_identity("SELECT to_string(CHAR(104.1, 101, 108.9, 108.9, 111, 32))")
19+
validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var)
20+
validate_identity("SELECT to_float(like)")
21+
validate_identity("SELECT like")
22+
validate_identity("SELECT EXTRACT(YEAR FROM to_datetime('2023-02-01'))")
23+
validate_identity("extract(haystack, pattern)")
24+
validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y")
25+
validate_identity("SELECT CAST(x AS tuple(string, array(nullable(float64))))")
26+
validate_identity("count_if(x)")
27+
validate_identity("x = y")
28+
validate_identity("x <> y")
29+
validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)")
30+
validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)")
31+
validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000")
32+
validate_identity("CAST(x AS nested(ID uint32, Serial uint32, EventTime DateTime))")
33+
validate_identity("CAST(x AS enum('hello' = 1, 'world' = 2))")
34+
validate_identity("CAST(x AS enum('hello', 'world'))")
35+
validate_identity("CAST(x AS enum('hello' = 1, 'world'))")
36+
validate_identity("CAST(x AS enum8('hello' = -123, 'world'))")
37+
validate_identity("CAST(x AS fixed_string(1))")
38+
validate_identity("CAST(x AS low_cardinality(fixed_string))")
39+
validate_identity("SELECT is_nan(1.0)")
40+
validate_identity("SELECT start_with('Spider-Man', 'Spi')")
41+
validate_identity("SELECT xor(TRUE, FALSE)")
42+
validate_identity("CAST(['hello'], 'array(enum8(''hello'' = 1))')")
43+
validate_identity("SELECT x, COUNT() FROM y GROUP BY x WITH TOTALS")
44+
validate_identity("SELECT INTERVAL t.days DAY")
45+
validate_identity("SELECT match('abc', '([a-z]+)')")
46+
validate_identity("SELECT window_start, avg(price) AS avg_price FROM tumble(coinbase, 10s) GROUP BY window_start")

tests/test_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ pandas
1717
zstandard
1818
lz4
1919
pyjwt[crypto]==2.10.1
20+
sqlglot
21+
marshmallow
22+
apche-superset

timeplus_connect/tp_superset/sql_parse.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class TimeplusSqlglotDialect(Dialect):
208208
exp.Union: None,
209209
}
210210

211+
212+
211213
class Tokenizer(tokens.Tokenizer):
212214
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
213215
IDENTIFIERS = ['"', "`"]
@@ -226,11 +228,11 @@ class Tokenizer(tokens.Tokenizer):
226228
"ENUM8": TokenType.ENUM8,
227229
"ENUM16": TokenType.ENUM16,
228230
"FINAL": TokenType.FINAL,
229-
"FIXEDSTRING": TokenType.FIXEDSTRING,
231+
"FIXED_STRING": TokenType.FIXEDSTRING,
230232
"FLOAT32": TokenType.FLOAT,
231233
"FLOAT64": TokenType.DOUBLE,
232234
"GLOBAL": TokenType.GLOBAL,
233-
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
235+
"LOW_CARDINALITY": TokenType.LOWCARDINALITY,
234236
"MAP": TokenType.MAP,
235237
"NESTED": TokenType.NESTED,
236238
"SAMPLE": TokenType.TABLE_SAMPLE,
@@ -393,7 +395,7 @@ class Parser(parser.Parser):
393395

394396
AGG_FUNCTIONS_SUFFIXES = [
395397
"if",
396-
"aary",
398+
"array",
397399
"array_if",
398400
"map",
399401
"simple_state",
@@ -524,9 +526,9 @@ def _parse_types(
524526
if isinstance(dtype, exp.DataType) and dtype.args.get("nullable") is not True:
525527
# Mark every type as non-nullable which is Timeplus's default, unless it's
526528
# already marked as nullable. This marker helps us transpile types from other
527-
# dialects to Timeplus, so that we can e.g. produce `CAST(x AS nullabe(String))`
529+
# dialects to Timeplus, so that we can e.g. produce `CAST(x AS nullable(String))`
528530
# from `CAST(x AS TEXT)`. If there is a `NULL` value in `x`, the former would
529-
# fail in Timeplus without the `nullabe` type constructor.
531+
# fail in Timeplus without the `nullable` type constructor.
530532
dtype.set("nullable", False)
531533

532534
return dtype
@@ -563,7 +565,7 @@ def _parse_assignment(self) -> t.Optional[exp.Expression]:
563565

564566
def _parse_query_parameter(self) -> t.Optional[exp.Expression]:
565567
"""
566-
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
568+
Parse a placeholder expression like SELECT {abc: uint32} or FROM {table: Identifier}
567569
"""
568570
index = self._index
569571

@@ -958,7 +960,7 @@ class Generator(generator.Generator):
958960
exp.ArgMax: arg_max_or_min_no_count("arg_max"),
959961
exp.ArgMin: arg_max_or_min_no_count("arg_min"),
960962
exp.Array: inline_array_sql,
961-
exp.CastToStrType: rename_func("cast"),
963+
exp.CastToStrType: rename_func("CAST"),
962964
exp.CountIf: rename_func("count_if"),
963965
exp.CompressColumnConstraint: lambda self,
964966
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
@@ -1062,7 +1064,7 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) ->
10621064
def trycast_sql(self, expression: exp.TryCast) -> str:
10631065
dtype = expression.to
10641066
if not dtype.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True):
1065-
# Casting x into nullabe(T) appears to behave similarly to TRY_CAST(x AS T)
1067+
# Casting x into nullable(T) appears to behave similarly to TRY_CAST(x AS T)
10661068
dtype.set("nullable", True)
10671069

10681070
return super().cast_sql(expression)
@@ -1110,13 +1112,13 @@ def datatype_sql(self, expression: exp.DataType) -> str:
11101112
else:
11111113
dtype = super().datatype_sql(expression)
11121114

1113-
# This section changes the type to `nullabe(...)` if the following conditions hold:
1114-
# - It's marked as nullable - this ensures we won't wrap Timeplus types with `nullabe`
1115+
# This section changes the type to `nullable(...)` if the following conditions hold:
1116+
# - It's marked as nullable - this ensures we won't wrap Timeplus types with `nullable`
11151117
# and change their semantics
11161118
# - It's not the key type of a `Map`. This is because Timeplus enforces the following
11171119
# constraint: "Type of Map key must be a type, that can be represented by integer or
11181120
# String or FixedString (possibly LowCardinality) or UUID or IPv6"
1119-
# - It's not a composite type, e.g. `nullabe(array(...))` is not a valid type
1121+
# - It's not a composite type, e.g. `nullable(array(...))` is not a valid type
11201122
parent = expression.parent
11211123
nullable = expression.args.get("nullable")
11221124
if nullable is True or (
@@ -1128,7 +1130,7 @@ def datatype_sql(self, expression: exp.DataType) -> str:
11281130
)
11291131
and not expression.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True)
11301132
):
1131-
dtype = f"nullabe({dtype})"
1133+
dtype = f"nullable({dtype})"
11321134

11331135
return dtype
11341136

@@ -1247,3 +1249,7 @@ def is_sql(self, expression: exp.Is) -> str:
12471249
is_sql = self.wrap(is_sql)
12481250

12491251
return is_sql
1252+
1253+
tokenizer_class = Tokenizer
1254+
parser_class = Parser
1255+
generator_class = Generator

0 commit comments

Comments
 (0)