|
| 1 | + |
| 2 | +import pytest |
| 3 | +from numpy import ndarray |
| 4 | + |
| 5 | +from tests.e2e.test_driver import PySQLPytestTestCase |
| 6 | + |
| 7 | + |
| 8 | +class TestComplexTypes(PySQLPytestTestCase): |
| 9 | + @pytest.fixture(scope="class") |
| 10 | + def table_fixture(self): |
| 11 | + """A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table""" |
| 12 | + |
| 13 | + with self.cursor() as cursor: |
| 14 | + # Create the table |
| 15 | + cursor.execute( |
| 16 | + """ |
| 17 | + CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table ( |
| 18 | + array_col ARRAY<STRING>, |
| 19 | + map_col MAP<STRING, INTEGER>, |
| 20 | + struct_col STRUCT<field1: STRING, field2: INTEGER> |
| 21 | + ) |
| 22 | + """ |
| 23 | + ) |
| 24 | + # Insert a record |
| 25 | + cursor.execute( |
| 26 | + """ |
| 27 | + INSERT INTO pysql_test_complex_types_table |
| 28 | + VALUES ( |
| 29 | + ARRAY('a', 'b', 'c'), |
| 30 | + MAP('a', 1, 'b', 2, 'c', 3), |
| 31 | + NAMED_STRUCT('field1', 'a', 'field2', 1) |
| 32 | + ) |
| 33 | + """ |
| 34 | + ) |
| 35 | + yield |
| 36 | + # Clean up the table after the test |
| 37 | + cursor.execute("DROP TABLE IF EXISTS pysql_test_complex_types_table") |
| 38 | + |
| 39 | + @pytest.mark.parametrize( |
| 40 | + "field,expected_type", |
| 41 | + [("array_col", ndarray), ("map_col", list), ("struct_col", dict)], |
| 42 | + ) |
| 43 | + def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): |
| 44 | + """Confirms the return types of a complex type field when reading as arrow""" |
| 45 | + |
| 46 | + with self.cursor() as cursor: |
| 47 | + result = cursor.execute( |
| 48 | + "SELECT * FROM pysql_test_complex_types_table LIMIT 1" |
| 49 | + ).fetchone() |
| 50 | + |
| 51 | + assert isinstance(result[field], expected_type) |
| 52 | + |
| 53 | + @pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")]) |
| 54 | + def test_read_complex_types_as_string(self, field, table_fixture): |
| 55 | + """Confirms the return type of a complex type that is returned as a string""" |
| 56 | + with self.cursor( |
| 57 | + extra_params={"_use_arrow_native_complex_types": False} |
| 58 | + ) as cursor: |
| 59 | + result = cursor.execute( |
| 60 | + "SELECT * FROM pysql_test_complex_types_table LIMIT 1" |
| 61 | + ).fetchone() |
| 62 | + |
| 63 | + assert isinstance(result[field], str) |
0 commit comments