Skip to content

Commit 278b8cd

Browse files
align expected types with databricks sdk
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 04a1936 commit 278b8cd

File tree

3 files changed

+56
-82
lines changed

3 files changed

+56
-82
lines changed

src/databricks/sql/conversion.py renamed to src/databricks/sql/backend/sea/conversion.py

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,89 +15,75 @@
1515

1616

1717
class SqlType:
18-
"""SQL type constants for improved maintainability."""
18+
"""
19+
SQL type constants
20+
"""
1921

2022
# Numeric types
21-
TINYINT = "tinyint"
22-
SMALLINT = "smallint"
23+
BYTE = "byte"
24+
SHORT = "short"
2325
INT = "int"
24-
INTEGER = "integer"
25-
BIGINT = "bigint"
26+
LONG = "long"
2627
FLOAT = "float"
27-
REAL = "real"
2828
DOUBLE = "double"
2929
DECIMAL = "decimal"
30-
NUMERIC = "numeric"
3130

32-
# Boolean types
31+
# Boolean type
3332
BOOLEAN = "boolean"
34-
BIT = "bit"
3533

3634
# Date/Time types
3735
DATE = "date"
38-
TIME = "time"
3936
TIMESTAMP = "timestamp"
40-
TIMESTAMP_NTZ = "timestamp_ntz"
41-
TIMESTAMP_LTZ = "timestamp_ltz"
42-
TIMESTAMP_TZ = "timestamp_tz"
37+
INTERVAL = "interval"
4338

4439
# String types
4540
CHAR = "char"
46-
VARCHAR = "varchar"
4741
STRING = "string"
48-
TEXT = "text"
4942

50-
# Binary types
43+
# Binary type
5144
BINARY = "binary"
52-
VARBINARY = "varbinary"
5345

5446
# Complex types
5547
ARRAY = "array"
5648
MAP = "map"
5749
STRUCT = "struct"
5850

51+
# Other types
52+
NULL = "null"
53+
USER_DEFINED_TYPE = "user_defined_type"
54+
5955
@classmethod
6056
def is_numeric(cls, sql_type: str) -> bool:
6157
"""Check if the SQL type is a numeric type."""
6258
return sql_type.lower() in (
63-
cls.TINYINT,
64-
cls.SMALLINT,
59+
cls.BYTE,
60+
cls.SHORT,
6561
cls.INT,
66-
cls.INTEGER,
67-
cls.BIGINT,
62+
cls.LONG,
6863
cls.FLOAT,
69-
cls.REAL,
7064
cls.DOUBLE,
7165
cls.DECIMAL,
72-
cls.NUMERIC,
7366
)
7467

7568
@classmethod
7669
def is_boolean(cls, sql_type: str) -> bool:
7770
"""Check if the SQL type is a boolean type."""
78-
return sql_type.lower() in (cls.BOOLEAN, cls.BIT)
71+
return sql_type.lower() == cls.BOOLEAN
7972

8073
@classmethod
8174
def is_datetime(cls, sql_type: str) -> bool:
8275
"""Check if the SQL type is a date/time type."""
83-
return sql_type.lower() in (
84-
cls.DATE,
85-
cls.TIME,
86-
cls.TIMESTAMP,
87-
cls.TIMESTAMP_NTZ,
88-
cls.TIMESTAMP_LTZ,
89-
cls.TIMESTAMP_TZ,
90-
)
76+
return sql_type.lower() in (cls.DATE, cls.TIMESTAMP, cls.INTERVAL)
9177

9278
@classmethod
9379
def is_string(cls, sql_type: str) -> bool:
9480
"""Check if the SQL type is a string type."""
95-
return sql_type.lower() in (cls.CHAR, cls.VARCHAR, cls.STRING, cls.TEXT)
81+
return sql_type.lower() in (cls.CHAR, cls.STRING)
9682

9783
@classmethod
9884
def is_binary(cls, sql_type: str) -> bool:
9985
"""Check if the SQL type is a binary type."""
100-
return sql_type.lower() in (cls.BINARY, cls.VARBINARY)
86+
return sql_type.lower() == cls.BINARY
10187

10288
@classmethod
10389
def is_complex(cls, sql_type: str) -> bool:
@@ -107,25 +93,25 @@ def is_complex(cls, sql_type: str) -> bool:
10793
sql_type.startswith(cls.ARRAY)
10894
or sql_type.startswith(cls.MAP)
10995
or sql_type.startswith(cls.STRUCT)
96+
or sql_type == cls.USER_DEFINED_TYPE
11097
)
11198

11299

113100
class SqlTypeConverter:
114101
"""
115102
Utility class for converting SQL types to Python types.
116-
Based on the JDBC ConverterHelper implementation.
103+
Based on the types supported by the Databricks SDK.
117104
"""
118105

119106
# SQL type to conversion function mapping
107+
# TODO: complex types
120108
TYPE_MAPPING: Dict[str, Callable] = {
121109
# Numeric types
122-
SqlType.TINYINT: lambda v: int(v),
123-
SqlType.SMALLINT: lambda v: int(v),
110+
SqlType.BYTE: lambda v: int(v),
111+
SqlType.SHORT: lambda v: int(v),
124112
SqlType.INT: lambda v: int(v),
125-
SqlType.INTEGER: lambda v: int(v),
126-
SqlType.BIGINT: lambda v: int(v),
113+
SqlType.LONG: lambda v: int(v),
127114
SqlType.FLOAT: lambda v: float(v),
128-
SqlType.REAL: lambda v: float(v),
129115
SqlType.DOUBLE: lambda v: float(v),
130116
SqlType.DECIMAL: lambda v, p=None, s=None: (
131117
decimal.Decimal(v).quantize(
@@ -134,31 +120,21 @@ class SqlTypeConverter:
134120
if p is not None and s is not None
135121
else decimal.Decimal(v)
136122
),
137-
SqlType.NUMERIC: lambda v, p=None, s=None: (
138-
decimal.Decimal(v).quantize(
139-
decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p)
140-
)
141-
if p is not None and s is not None
142-
else decimal.Decimal(v)
143-
),
144-
# Boolean types
123+
# Boolean type
145124
SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"),
146-
SqlType.BIT: lambda v: v.lower() in ("true", "t", "1", "yes", "y"),
147125
# Date/Time types
148126
SqlType.DATE: lambda v: datetime.date.fromisoformat(v),
149-
SqlType.TIME: lambda v: datetime.time.fromisoformat(v),
150127
SqlType.TIMESTAMP: lambda v: parser.parse(v),
151-
SqlType.TIMESTAMP_NTZ: lambda v: parser.parse(v).replace(tzinfo=None),
152-
SqlType.TIMESTAMP_LTZ: lambda v: parser.parse(v).astimezone(tz=None),
153-
SqlType.TIMESTAMP_TZ: lambda v: parser.parse(v),
128+
SqlType.INTERVAL: lambda v: v, # Keep as string for now
154129
# String types - no conversion needed
155130
SqlType.CHAR: lambda v: v,
156-
SqlType.VARCHAR: lambda v: v,
157131
SqlType.STRING: lambda v: v,
158-
SqlType.TEXT: lambda v: v,
159-
# Binary types
132+
# Binary type
160133
SqlType.BINARY: lambda v: bytes.fromhex(v),
161-
SqlType.VARBINARY: lambda v: bytes.fromhex(v),
134+
# Other types
135+
SqlType.NULL: lambda v: None,
136+
# Complex types and user-defined types return as-is
137+
SqlType.USER_DEFINED_TYPE: lambda v: v,
162138
}
163139

164140
@staticmethod
@@ -180,6 +156,7 @@ def convert_value(
180156
Returns:
181157
The converted value in the appropriate Python type
182158
"""
159+
183160
if value is None:
184161
return None
185162

@@ -190,7 +167,7 @@ def convert_value(
190167

191168
converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type]
192169
try:
193-
if sql_type in (SqlType.DECIMAL, SqlType.NUMERIC):
170+
if sql_type == SqlType.DECIMAL:
194171
return converter_func(value, precision, scale)
195172
else:
196173
return converter_func(value)

src/databricks/sql/result_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from databricks.sql.backend.sea.backend import SeaDatabricksClient
88
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
9-
from databricks.sql.conversion import SqlTypeConverter
9+
from databricks.sql.backend.sea.conversion import SqlTypeConverter
1010

1111
try:
1212
import pyarrow

tests/unit/test_type_conversion.py renamed to tests/unit/test_conversion.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import date, datetime, time
77
from decimal import Decimal
88

9-
from databricks.sql.conversion import SqlType, SqlTypeConverter
9+
from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter
1010

1111

1212
class TestSqlType(unittest.TestCase):
@@ -15,48 +15,40 @@ class TestSqlType(unittest.TestCase):
1515
def test_is_numeric(self):
1616
"""Test the is_numeric method."""
1717
self.assertTrue(SqlType.is_numeric(SqlType.INT))
18-
self.assertTrue(SqlType.is_numeric(SqlType.TINYINT))
19-
self.assertTrue(SqlType.is_numeric(SqlType.SMALLINT))
20-
self.assertTrue(SqlType.is_numeric(SqlType.BIGINT))
18+
self.assertTrue(SqlType.is_numeric(SqlType.BYTE))
19+
self.assertTrue(SqlType.is_numeric(SqlType.SHORT))
20+
self.assertTrue(SqlType.is_numeric(SqlType.LONG))
2121
self.assertTrue(SqlType.is_numeric(SqlType.FLOAT))
2222
self.assertTrue(SqlType.is_numeric(SqlType.DOUBLE))
2323
self.assertTrue(SqlType.is_numeric(SqlType.DECIMAL))
24-
self.assertTrue(SqlType.is_numeric(SqlType.NUMERIC))
2524
self.assertFalse(SqlType.is_numeric(SqlType.BOOLEAN))
2625
self.assertFalse(SqlType.is_numeric(SqlType.STRING))
2726
self.assertFalse(SqlType.is_numeric(SqlType.DATE))
2827

2928
def test_is_boolean(self):
3029
"""Test the is_boolean method."""
3130
self.assertTrue(SqlType.is_boolean(SqlType.BOOLEAN))
32-
self.assertTrue(SqlType.is_boolean(SqlType.BIT))
3331
self.assertFalse(SqlType.is_boolean(SqlType.INT))
3432
self.assertFalse(SqlType.is_boolean(SqlType.STRING))
3533

3634
def test_is_datetime(self):
3735
"""Test the is_datetime method."""
3836
self.assertTrue(SqlType.is_datetime(SqlType.DATE))
39-
self.assertTrue(SqlType.is_datetime(SqlType.TIME))
4037
self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP))
41-
self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_NTZ))
42-
self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_LTZ))
43-
self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_TZ))
38+
self.assertTrue(SqlType.is_datetime(SqlType.INTERVAL))
4439
self.assertFalse(SqlType.is_datetime(SqlType.INT))
4540
self.assertFalse(SqlType.is_datetime(SqlType.STRING))
4641

4742
def test_is_string(self):
4843
"""Test the is_string method."""
4944
self.assertTrue(SqlType.is_string(SqlType.CHAR))
50-
self.assertTrue(SqlType.is_string(SqlType.VARCHAR))
5145
self.assertTrue(SqlType.is_string(SqlType.STRING))
52-
self.assertTrue(SqlType.is_string(SqlType.TEXT))
5346
self.assertFalse(SqlType.is_string(SqlType.INT))
5447
self.assertFalse(SqlType.is_string(SqlType.DATE))
5548

5649
def test_is_binary(self):
5750
"""Test the is_binary method."""
5851
self.assertTrue(SqlType.is_binary(SqlType.BINARY))
59-
self.assertTrue(SqlType.is_binary(SqlType.VARBINARY))
6052
self.assertFalse(SqlType.is_binary(SqlType.INT))
6153
self.assertFalse(SqlType.is_binary(SqlType.STRING))
6254

@@ -75,9 +67,9 @@ class TestSqlTypeConverter(unittest.TestCase):
7567
def test_numeric_conversions(self):
7668
"""Test numeric type conversions."""
7769
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.INT), 123)
78-
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.TINYINT), 123)
79-
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SMALLINT), 123)
80-
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BIGINT), 123)
70+
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BYTE), 123)
71+
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SHORT), 123)
72+
self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.LONG), 123)
8173
self.assertEqual(
8274
SqlTypeConverter.convert_value("123.45", SqlType.FLOAT), 123.45
8375
)
@@ -113,9 +105,6 @@ def test_datetime_conversions(self):
113105
SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE),
114106
date(2023, 1, 15),
115107
)
116-
self.assertEqual(
117-
SqlTypeConverter.convert_value("14:30:45", SqlType.TIME), time(14, 30, 45)
118-
)
119108
self.assertEqual(
120109
SqlTypeConverter.convert_value("2023-01-15 14:30:45", SqlType.TIMESTAMP),
121110
datetime(2023, 1, 15, 14, 30, 45),
@@ -124,15 +113,19 @@ def test_datetime_conversions(self):
124113
def test_string_conversions(self):
125114
"""Test string type conversions."""
126115
self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.STRING), "test")
116+
self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test")
117+
118+
def test_binary_conversions(self):
119+
"""Test binary type conversions."""
120+
hex_str = "68656c6c6f" # "hello" in hex
121+
expected_bytes = b"hello"
122+
127123
self.assertEqual(
128-
SqlTypeConverter.convert_value("test", SqlType.VARCHAR), "test"
124+
SqlTypeConverter.convert_value(hex_str, SqlType.BINARY), expected_bytes
129125
)
130-
self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test")
131-
self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.TEXT), "test")
132126

133127
def test_error_handling(self):
134128
"""Test error handling in conversions."""
135-
# Test invalid conversions - should return original value
136129
self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.INT), "abc")
137130
self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.FLOAT), "abc")
138131
self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.DECIMAL), "abc")
@@ -155,6 +148,10 @@ def test_complex_type_handling(self):
155148
self.assertEqual(
156149
SqlTypeConverter.convert_value('{"a": 1}', "struct<a:int>"), '{"a": 1}'
157150
)
151+
self.assertEqual(
152+
SqlTypeConverter.convert_value('{"a": 1}', SqlType.USER_DEFINED_TYPE),
153+
'{"a": 1}',
154+
)
158155

159156

160157
if __name__ == "__main__":

0 commit comments

Comments
 (0)