Skip to content

Commit 563da71

Browse files
introduce more integration tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent ea7ff73 commit 563da71

File tree

5 files changed

+656
-67
lines changed

5 files changed

+656
-67
lines changed

src/databricks/sql/result_set.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,9 @@ def __init__(
471471
manifest: Manifest from SEA response (optional)
472472
"""
473473

474-
self.results = None
474+
results_queue = None
475475
if result_data:
476-
self.results = SeaResultSetQueueFactory.build_queue(
476+
results_queue = SeaResultSetQueueFactory.build_queue(
477477
result_data,
478478
manifest,
479479
str(execute_response.command_id.to_sea_statement_id()),
@@ -492,6 +492,7 @@ def __init__(
492492
command_id=execute_response.command_id,
493493
status=execute_response.status,
494494
has_been_closed_server_side=execute_response.has_been_closed_server_side,
495+
results_queue=results_queue,
495496
description=execute_response.description,
496497
is_staging_operation=execute_response.is_staging_operation,
497498
lz4_compressed=execute_response.lz4_compressed,

tests/e2e/test_driver.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,17 @@ def test_execute_async__long_running(self):
196196

197197
assert result[0].asDict() == {"count(1)": 0}
198198

199+
@pytest.mark.parametrize(
200+
"extra_params",
201+
[
202+
{},
203+
{
204+
"use_sea": True,
205+
"use_cloud_fetch": False,
206+
"enable_query_result_lz4_compression": False,
207+
},
208+
],
209+
)
199210
def test_execute_async__small_result(self, extra_params):
200211
small_result_query = "SELECT 1"
201212

@@ -352,8 +363,8 @@ def test_create_table_will_return_empty_result_set(self, extra_params):
352363
finally:
353364
cursor.execute("DROP TABLE IF EXISTS {}".format(table_name))
354365

355-
def test_get_tables(self, extra_params):
356-
with self.cursor(extra_params) as cursor:
366+
def test_get_tables(self):
367+
with self.cursor() as cursor:
357368
table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_"))
358369
table_names = [table_name + "_1", table_name + "_2"]
359370

@@ -398,8 +409,8 @@ def test_get_tables(self, extra_params):
398409
for table in table_names:
399410
cursor.execute("DROP TABLE IF EXISTS {}".format(table))
400411

401-
def test_get_columns(self, extra_params):
402-
with self.cursor(extra_params) as cursor:
412+
def test_get_columns(self):
413+
with self.cursor() as cursor:
403414
table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_"))
404415
table_names = [table_name + "_1", table_name + "_2"]
405416

@@ -521,8 +532,8 @@ def test_escape_single_quotes(self, extra_params):
521532
rows = cursor.fetchall()
522533
assert rows[0]["col_1"] == "you're"
523534

524-
def test_get_schemas(self, extra_params):
525-
with self.cursor(extra_params) as cursor:
535+
def test_get_schemas(self):
536+
with self.cursor() as cursor:
526537
database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_"))
527538
try:
528539
cursor.execute("CREATE DATABASE IF NOT EXISTS {}".format(database_name))
@@ -539,8 +550,8 @@ def test_get_schemas(self, extra_params):
539550
finally:
540551
cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name))
541552

542-
def test_get_catalogs(self, extra_params):
543-
with self.cursor(extra_params) as cursor:
553+
def test_get_catalogs(self):
554+
with self.cursor() as cursor:
544555
cursor.catalogs()
545556
cursor.fetchall()
546557
catalogs_desc = cursor.description
@@ -813,8 +824,21 @@ def test_ssp_passthrough(self):
813824
assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)]
814825

815826
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
816-
def test_timestamps_arrow(self):
817-
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
827+
@pytest.mark.parametrize(
828+
"extra_params",
829+
[
830+
{},
831+
{
832+
"use_sea": True,
833+
"use_cloud_fetch": False,
834+
"enable_query_result_lz4_compression": False,
835+
},
836+
],
837+
)
838+
def test_timestamps_arrow(self, extra_params):
839+
with self.cursor(
840+
{"session_configuration": {"ansi_mode": False}, **extra_params}
841+
) as cursor:
818842
for timestamp, expected in self.timestamp_and_expected_results:
819843
cursor.execute(
820844
"SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)
@@ -837,8 +861,21 @@ def test_timestamps_arrow(self):
837861
), "timestamp {} did not match {}".format(timestamp, expected)
838862

839863
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
840-
def test_multi_timestamps_arrow(self):
841-
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
864+
@pytest.mark.parametrize(
865+
"extra_params",
866+
[
867+
{},
868+
{
869+
"use_sea": True,
870+
"use_cloud_fetch": False,
871+
"enable_query_result_lz4_compression": False,
872+
},
873+
],
874+
)
875+
def test_multi_timestamps_arrow(self, extra_params):
876+
with self.cursor(
877+
{"session_configuration": {"ansi_mode": False}, **extra_params}
878+
) as cursor:
842879
query, expected = self.multi_query()
843880
expected = [
844881
[self.maybe_add_timezone_to_timestamp(ts) for ts in row]
@@ -855,9 +892,20 @@ def test_multi_timestamps_arrow(self):
855892
assert result == expected
856893

857894
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
858-
def test_timezone_with_timestamp(self):
895+
@pytest.mark.parametrize(
896+
"extra_params",
897+
[
898+
{},
899+
{
900+
"use_sea": True,
901+
"use_cloud_fetch": False,
902+
"enable_query_result_lz4_compression": False,
903+
},
904+
],
905+
)
906+
def test_timezone_with_timestamp(self, extra_params):
859907
if self.should_add_timezone():
860-
with self.cursor() as cursor:
908+
with self.cursor(extra_params) as cursor:
861909
cursor.execute("SET TIME ZONE 'Europe/Amsterdam'")
862910
cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)")
863911
amsterdam = pytz.timezone("Europe/Amsterdam")

tests/unit/test_sea_conversion.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
Tests for the conversion module in the SEA backend.
3+
4+
This module contains tests for the SqlType and SqlTypeConverter classes.
5+
"""
6+
7+
import pytest
8+
import datetime
9+
import decimal
10+
from unittest.mock import Mock, patch
11+
12+
from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter
13+
14+
15+
class TestSqlType:
16+
"""Test suite for the SqlType class."""
17+
18+
def test_is_numeric(self):
19+
"""Test the is_numeric method."""
20+
assert SqlType.is_numeric(SqlType.BYTE)
21+
assert SqlType.is_numeric(SqlType.SHORT)
22+
assert SqlType.is_numeric(SqlType.INT)
23+
assert SqlType.is_numeric(SqlType.LONG)
24+
assert SqlType.is_numeric(SqlType.FLOAT)
25+
assert SqlType.is_numeric(SqlType.DOUBLE)
26+
assert SqlType.is_numeric(SqlType.DECIMAL)
27+
28+
# Test with uppercase
29+
assert SqlType.is_numeric("INT")
30+
assert SqlType.is_numeric("DECIMAL")
31+
32+
# Test non-numeric types
33+
assert not SqlType.is_numeric(SqlType.STRING)
34+
assert not SqlType.is_numeric(SqlType.BOOLEAN)
35+
assert not SqlType.is_numeric(SqlType.DATE)
36+
37+
def test_is_boolean(self):
38+
"""Test the is_boolean method."""
39+
assert SqlType.is_boolean(SqlType.BOOLEAN)
40+
assert SqlType.is_boolean("BOOLEAN")
41+
42+
# Test non-boolean types
43+
assert not SqlType.is_boolean(SqlType.STRING)
44+
assert not SqlType.is_boolean(SqlType.INT)
45+
46+
def test_is_datetime(self):
47+
"""Test the is_datetime method."""
48+
assert SqlType.is_datetime(SqlType.DATE)
49+
assert SqlType.is_datetime(SqlType.TIMESTAMP)
50+
assert SqlType.is_datetime(SqlType.INTERVAL)
51+
assert SqlType.is_datetime("DATE")
52+
assert SqlType.is_datetime("TIMESTAMP")
53+
54+
# Test non-datetime types
55+
assert not SqlType.is_datetime(SqlType.STRING)
56+
assert not SqlType.is_datetime(SqlType.INT)
57+
58+
def test_is_string(self):
59+
"""Test the is_string method."""
60+
assert SqlType.is_string(SqlType.STRING)
61+
assert SqlType.is_string(SqlType.CHAR)
62+
assert SqlType.is_string("STRING")
63+
assert SqlType.is_string("CHAR")
64+
65+
# Test non-string types
66+
assert not SqlType.is_string(SqlType.INT)
67+
assert not SqlType.is_string(SqlType.BOOLEAN)
68+
69+
def test_is_binary(self):
70+
"""Test the is_binary method."""
71+
assert SqlType.is_binary(SqlType.BINARY)
72+
assert SqlType.is_binary("BINARY")
73+
74+
# Test non-binary types
75+
assert not SqlType.is_binary(SqlType.STRING)
76+
assert not SqlType.is_binary(SqlType.INT)
77+
78+
def test_is_complex(self):
79+
"""Test the is_complex method."""
80+
assert SqlType.is_complex(SqlType.ARRAY)
81+
assert SqlType.is_complex(SqlType.MAP)
82+
assert SqlType.is_complex(SqlType.STRUCT)
83+
assert SqlType.is_complex(SqlType.USER_DEFINED_TYPE)
84+
assert SqlType.is_complex("ARRAY<int>")
85+
assert SqlType.is_complex("MAP<string,int>")
86+
assert SqlType.is_complex("STRUCT<name:string,age:int>")
87+
88+
# Test non-complex types
89+
assert not SqlType.is_complex(SqlType.STRING)
90+
assert not SqlType.is_complex(SqlType.INT)
91+
92+
93+
class TestSqlTypeConverter:
94+
"""Test suite for the SqlTypeConverter class."""
95+
96+
def test_convert_value_null(self):
97+
"""Test converting null values."""
98+
assert SqlTypeConverter.convert_value(None, SqlType.INT) is None
99+
assert SqlTypeConverter.convert_value(None, SqlType.STRING) is None
100+
assert SqlTypeConverter.convert_value(None, SqlType.BOOLEAN) is None
101+
102+
def test_convert_numeric_types(self):
103+
"""Test converting numeric types."""
104+
# Test integer types
105+
assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123
106+
assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456
107+
assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789
108+
assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890
109+
110+
# Test floating point types
111+
assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45
112+
assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90
113+
114+
# Test decimal type
115+
decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL)
116+
assert isinstance(decimal_value, decimal.Decimal)
117+
assert decimal_value == decimal.Decimal("123.45")
118+
119+
# Test decimal with precision and scale
120+
decimal_value = SqlTypeConverter.convert_value(
121+
"123.45", SqlType.DECIMAL, precision=5, scale=2
122+
)
123+
assert isinstance(decimal_value, decimal.Decimal)
124+
assert decimal_value == decimal.Decimal("123.45")
125+
126+
# Test invalid numeric input
127+
result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT)
128+
assert result == "not_a_number" # Returns original value on error
129+
130+
def test_convert_boolean_type(self):
131+
"""Test converting boolean types."""
132+
# True values
133+
assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True
134+
assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True
135+
assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True
136+
assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True
137+
assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True
138+
assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True
139+
140+
# False values
141+
assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False
142+
assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False
143+
assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False
144+
assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False
145+
assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False
146+
assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False
147+
148+
def test_convert_datetime_types(self):
149+
"""Test converting datetime types."""
150+
# Test date type
151+
date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE)
152+
assert isinstance(date_value, datetime.date)
153+
assert date_value == datetime.date(2023, 1, 15)
154+
155+
# Test timestamp type
156+
timestamp_value = SqlTypeConverter.convert_value(
157+
"2023-01-15T12:30:45", SqlType.TIMESTAMP
158+
)
159+
assert isinstance(timestamp_value, datetime.datetime)
160+
assert timestamp_value.year == 2023
161+
assert timestamp_value.month == 1
162+
assert timestamp_value.day == 15
163+
assert timestamp_value.hour == 12
164+
assert timestamp_value.minute == 30
165+
assert timestamp_value.second == 45
166+
167+
# Test interval type (currently returns as string)
168+
interval_value = SqlTypeConverter.convert_value(
169+
"1 day 2 hours", SqlType.INTERVAL
170+
)
171+
assert interval_value == "1 day 2 hours"
172+
173+
# Test invalid date input
174+
result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE)
175+
assert result == "not_a_date" # Returns original value on error
176+
177+
def test_convert_string_types(self):
178+
"""Test converting string types."""
179+
# String types don't need conversion, they should be returned as-is
180+
assert (
181+
SqlTypeConverter.convert_value("test string", SqlType.STRING)
182+
== "test string"
183+
)
184+
assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char"
185+
186+
def test_convert_binary_type(self):
187+
"""Test converting binary type."""
188+
# Test valid hex string
189+
binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY)
190+
assert isinstance(binary_value, bytes)
191+
assert binary_value == b"Hello"
192+
193+
# Test invalid binary input
194+
result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY)
195+
assert result == "not_hex" # Returns original value on error
196+
197+
def test_convert_unsupported_type(self):
198+
"""Test converting an unsupported type."""
199+
# Should return the original value
200+
assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test"
201+
202+
# Complex types should return as-is
203+
assert (
204+
SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY)
205+
== "complex_value"
206+
)
207+
assert (
208+
SqlTypeConverter.convert_value("complex_value", SqlType.MAP)
209+
== "complex_value"
210+
)
211+
assert (
212+
SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT)
213+
== "complex_value"
214+
)

0 commit comments

Comments
 (0)