Skip to content

Commit ff394d6

Browse files
committed
add sqlglot dialect test case
1 parent 81c2d04 commit ff394d6

File tree

6 files changed

+66
-15
lines changed

6 files changed

+66
-15
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
2+
from sqlglot import exp, parse_one
3+
from timeplus_connect.tp_superset.sql_parse import TimeplusSqlglotDialect
4+
5+
6+
dialect = TimeplusSqlglotDialect()
7+
8+
def validate_identity(sql: str) -> exp.Expression:
9+
ast: exp.Expression = parse_one(sql, read=dialect)
10+
11+
assert ast is not None
12+
assert ast.sql(dialect=dialect) == sql
13+
14+
return ast
15+
16+
def test_simple_sql():
17+
validate_identity("CAST(1 AS bool)")
18+
validate_identity("SELECT to_string(CHAR(104.1, 101, 108.9, 108.9, 111, 32))")
19+
validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var)
20+
validate_identity("SELECT to_float(like)")
21+
validate_identity("SELECT like")
22+
validate_identity("SELECT EXTRACT(YEAR FROM to_datetime('2023-02-01'))")
23+
validate_identity("extract(haystack, pattern)")
24+
validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y")
25+
validate_identity("SELECT CAST(x AS tuple(string, array(nullable(float64))))")
26+
validate_identity("count_if(x)")
27+
validate_identity("x = y")
28+
validate_identity("x <> y")
29+
validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)")
30+
validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)")
31+
validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000")
32+
validate_identity("CAST(x AS nested(ID uint32, Serial uint32, EventTime DateTime))")
33+
validate_identity("CAST(x AS enum('hello' = 1, 'world' = 2))")
34+
validate_identity("CAST(x AS enum('hello', 'world'))")
35+
validate_identity("CAST(x AS enum('hello' = 1, 'world'))")
36+
validate_identity("CAST(x AS enum8('hello' = -123, 'world'))")
37+
validate_identity("CAST(x AS fixed_string(1))")
38+
validate_identity("CAST(x AS low_cardinality(fixed_string))")
39+
validate_identity("SELECT is_nan(1.0)")
40+
validate_identity("SELECT start_with('Spider-Man', 'Spi')")
41+
validate_identity("SELECT xor(TRUE, FALSE)")
42+
validate_identity("CAST(['hello'], 'array(enum8(''hello'' = 1))')")
43+
validate_identity("SELECT x, COUNT() FROM y GROUP BY x WITH TOTALS")
44+
validate_identity("SELECT INTERVAL t.days DAY")
45+
validate_identity("SELECT match('abc', '([a-z]+)')")
46+
validate_identity("SELECT window_start, avg(price) AS avg_price FROM tumble(coinbase, 10s) GROUP BY window_start")

tests/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ pandas
1717
zstandard
1818
lz4
1919
pyjwt[crypto]==2.10.1
20+
sqlglot
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import timeplus_connect.tp_sqlglot.dialect

timeplus_connect/tp_superset/sql_parse.py renamed to timeplus_connect/tp_sqlglot/dialect.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ class Tokenizer(tokens.Tokenizer):
226226
"ENUM8": TokenType.ENUM8,
227227
"ENUM16": TokenType.ENUM16,
228228
"FINAL": TokenType.FINAL,
229-
"FIXEDSTRING": TokenType.FIXEDSTRING,
229+
"FIXED_STRING": TokenType.FIXEDSTRING,
230230
"FLOAT32": TokenType.FLOAT,
231231
"FLOAT64": TokenType.DOUBLE,
232232
"GLOBAL": TokenType.GLOBAL,
233-
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
233+
"LOW_CARDINALITY": TokenType.LOWCARDINALITY,
234234
"MAP": TokenType.MAP,
235235
"NESTED": TokenType.NESTED,
236236
"SAMPLE": TokenType.TABLE_SAMPLE,
@@ -393,7 +393,7 @@ class Parser(parser.Parser):
393393

394394
AGG_FUNCTIONS_SUFFIXES = [
395395
"if",
396-
"aary",
396+
"array",
397397
"array_if",
398398
"map",
399399
"simple_state",
@@ -524,9 +524,9 @@ def _parse_types(
524524
if isinstance(dtype, exp.DataType) and dtype.args.get("nullable") is not True:
525525
# Mark every type as non-nullable which is Timeplus's default, unless it's
526526
# already marked as nullable. This marker helps us transpile types from other
527-
# dialects to Timeplus, so that we can e.g. produce `CAST(x AS nullabe(String))`
527+
# dialects to Timeplus, so that we can e.g. produce `CAST(x AS nullable(String))`
528528
# from `CAST(x AS TEXT)`. If there is a `NULL` value in `x`, the former would
529-
# fail in Timeplus without the `nullabe` type constructor.
529+
# fail in Timeplus without the `nullable` type constructor.
530530
dtype.set("nullable", False)
531531

532532
return dtype
@@ -563,7 +563,7 @@ def _parse_assignment(self) -> t.Optional[exp.Expression]:
563563

564564
def _parse_query_parameter(self) -> t.Optional[exp.Expression]:
565565
"""
566-
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
566+
Parse a placeholder expression like SELECT {abc: uint32} or FROM {table: Identifier}
567567
"""
568568
index = self._index
569569

@@ -958,7 +958,7 @@ class Generator(generator.Generator):
958958
exp.ArgMax: arg_max_or_min_no_count("arg_max"),
959959
exp.ArgMin: arg_max_or_min_no_count("arg_min"),
960960
exp.Array: inline_array_sql,
961-
exp.CastToStrType: rename_func("cast"),
961+
exp.CastToStrType: rename_func("CAST"),
962962
exp.CountIf: rename_func("count_if"),
963963
exp.CompressColumnConstraint: lambda self,
964964
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
@@ -1062,7 +1062,7 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) ->
10621062
def trycast_sql(self, expression: exp.TryCast) -> str:
10631063
dtype = expression.to
10641064
if not dtype.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True):
1065-
# Casting x into nullabe(T) appears to behave similarly to TRY_CAST(x AS T)
1065+
# Casting x into nullable(T) appears to behave similarly to TRY_CAST(x AS T)
10661066
dtype.set("nullable", True)
10671067

10681068
return super().cast_sql(expression)
@@ -1110,13 +1110,13 @@ def datatype_sql(self, expression: exp.DataType) -> str:
11101110
else:
11111111
dtype = super().datatype_sql(expression)
11121112

1113-
# This section changes the type to `nullabe(...)` if the following conditions hold:
1114-
# - It's marked as nullable - this ensures we won't wrap Timeplus types with `nullabe`
1113+
# This section changes the type to `nullable(...)` if the following conditions hold:
1114+
# - It's marked as nullable - this ensures we won't wrap Timeplus types with `nullable`
11151115
# and change their semantics
11161116
# - It's not the key type of a `Map`. This is because Timeplus enforces the following
11171117
# constraint: "Type of Map key must be a type, that can be represented by integer or
11181118
# String or FixedString (possibly LowCardinality) or UUID or IPv6"
1119-
# - It's not a composite type, e.g. `nullabe(array(...))` is not a valid type
1119+
# - It's not a composite type, e.g. `nullable(array(...))` is not a valid type
11201120
parent = expression.parent
11211121
nullable = expression.args.get("nullable")
11221122
if nullable is True or (
@@ -1128,7 +1128,7 @@ def datatype_sql(self, expression: exp.DataType) -> str:
11281128
)
11291129
and not expression.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True)
11301130
):
1131-
dtype = f"nullabe({dtype})"
1131+
dtype = f"nullable({dtype})"
11321132

11331133
return dtype
11341134

@@ -1247,3 +1247,7 @@ def is_sql(self, expression: exp.Is) -> str:
12471247
is_sql = self.wrap(is_sql)
12481248

12491249
return is_sql
1250+
1251+
tokenizer_class = Tokenizer
1252+
parser_class = Parser
1253+
generator_class = Generator
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
import timeplus_connect.tp_superset.db_engine_spec
2-
import timeplus_connect.tp_superset.sql_parse

timeplus_connect/tp_superset/db_engine_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import datetime
33
from typing import Any
44

5-
# pylint:disable=E0401
5+
# pylint:disable=E0401, E0611
66
from marshmallow import fields, Schema
77
from marshmallow.validate import Range
88
from superset.db_engine_specs.base import BaseEngineSpec
@@ -13,7 +13,7 @@
1313
from sqlalchemy.engine.url import URL
1414

1515

16-
from timeplus_connect.tp_superset.sql_parse import TimeplusSqlglotDialect
16+
from timeplus_connect.tp_sqlglot.dialect import TimeplusSqlglotDialect
1717

1818

1919
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)