diff --git a/src/databricks/sql/backend/column_mapping.py b/src/databricks/sql/backend/column_mapping.py new file mode 100644 index 00000000..6305d3b5 --- /dev/null +++ b/src/databricks/sql/backend/column_mapping.py @@ -0,0 +1,110 @@ +""" +Column name mappings between different backend protocols. + +This module provides mappings between column names returned by different backends + to ensure a consistent interface for metadata operations. +""" + +from enum import Enum + + +class MetadataOp(Enum): + """Enum for metadata operations.""" + + CATALOGS = "catalogs" + SCHEMAS = "schemas" + TABLES = "tables" + COLUMNS = "columns" + + +# Mappings from column names to standard column names +# ref: CATALOG_COLUMNS in JDBC: https://github.com/databricks/databricks-jdbc/blob/e3d0d8dad683146a3afc3d501ddf0864ba086309/src/main/java/com/databricks/jdbc/common/MetadataResultConstants.java#L219 +CATALOG_COLUMNS = { + "catalog": "TABLE_CAT", +} + +# ref: SCHEMA_COLUMNS in JDBC: https://github.com/databricks/databricks-jdbc/blob/e3d0d8dad683146a3afc3d501ddf0864ba086309/src/main/java/com/databricks/jdbc/common/MetadataResultConstants.java#L221 +SCHEMA_COLUMNS = { + "databaseName": "TABLE_SCHEM", + "catalogName": "TABLE_CATALOG", +} + +# ref: TABLE_COLUMNS in JDBC: https://github.com/databricks/databricks-jdbc/blob/e3d0d8dad683146a3afc3d501ddf0864ba086309/src/main/java/com/databricks/jdbc/common/MetadataResultConstants.java#L224 +TABLE_COLUMNS = { + "catalogName": "TABLE_CAT", + "namespace": "TABLE_SCHEM", + "tableName": "TABLE_NAME", + "tableType": "TABLE_TYPE", + "remarks": "REMARKS", + "TYPE_CATALOG_COLUMN": "TYPE_CAT", + "TYPE_SCHEMA_COLUMN": "TYPE_SCHEM", + "TYPE_NAME": "TYPE_NAME", + "SELF_REFERENCING_COLUMN_NAME": "SELF_REFERENCING_COL_NAME", + "REF_GENERATION_COLUMN": "REF_GENERATION", +} + +# ref: COLUMN_COLUMNS in JDBC: https://github.com/databricks/databricks-jdbc/blob/e3d0d8dad683146a3afc3d501ddf0864ba086309/src/main/java/com/databricks/jdbc/common/MetadataResultConstants.java#L192 +# TYPE_NAME is not included because it is a duplicate target for columnType, and COLUMN_DEF is known to be returned by Thrift. +# TODO: check if TYPE_NAME is to be returned / also used by Thrift. +COLUMN_COLUMNS = { + "catalogName": "TABLE_CAT", + "namespace": "TABLE_SCHEM", + "tableName": "TABLE_NAME", + "col_name": "COLUMN_NAME", + "dataType": "DATA_TYPE", + "columnSize": "COLUMN_SIZE", + "bufferLength": "BUFFER_LENGTH", + "decimalDigits": "DECIMAL_DIGITS", + "radix": "NUM_PREC_RADIX", + "Nullable": "NULLABLE", + "remarks": "REMARKS", + "columnType": "COLUMN_DEF", + "SQLDataType": "SQL_DATA_TYPE", + "SQLDatetimeSub": "SQL_DATETIME_SUB", + "CharOctetLength": "CHAR_OCTET_LENGTH", + "ordinalPosition": "ORDINAL_POSITION", + "isNullable": "IS_NULLABLE", + "ScopeCatalog": "SCOPE_CATALOG", + "ScopeSchema": "SCOPE_SCHEMA", + "ScopeTable": "SCOPE_TABLE", + "SourceDataType": "SOURCE_DATA_TYPE", + "isAutoIncrement": "IS_AUTOINCREMENT", + "isGenerated": "IS_GENERATEDCOLUMN", +} + + +def normalise_metadata_result(result_set, operation: MetadataOp): + """ + Normalise column names in a result set based on the operation type. + This function modifies the result set in place. + + Args: + result_set: The result set object to normalise + operation: The metadata operation (from MetadataOp enum) + """ + + # Select the appropriate mapping based on the operation + mapping = None + if operation == MetadataOp.CATALOGS: + mapping = CATALOG_COLUMNS + elif operation == MetadataOp.SCHEMAS: + mapping = SCHEMA_COLUMNS + elif operation == MetadataOp.TABLES: + mapping = TABLE_COLUMNS + elif operation == MetadataOp.COLUMNS: + mapping = COLUMN_COLUMNS + + if mapping is None: + return + + # Normalize column names in the description + new_description = [] + for col_desc in result_set.description: + col_name = col_desc[0] + if col_name in mapping: + # Create a new column description tuple with the normalized name + new_col_desc = (mapping[col_name],) + col_desc[1:] + new_description.append(new_col_desc) + else: + new_description.append(col_desc) + result_set.description = new_description diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 25f706a7..1c220e68 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -31,6 +31,9 @@ from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions +# Import the column mapping module +from databricks.sql.backend.column_mapping import normalise_metadata_result, MetadataOp + from databricks.sql.backend.sea.models import ( ExecuteStatementRequest, GetStatementRequest, @@ -682,6 +685,9 @@ def get_catalogs( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" + + normalise_metadata_result(result, MetadataOp.CATALOGS) + return result def get_schemas( @@ -715,6 +721,9 @@ def get_schemas( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" + + normalise_metadata_result(result, MetadataOp.SCHEMAS) + return result def get_tables( @@ -762,6 +771,8 @@ def get_tables( result = ResultSetFilter.filter_tables_by_type(result, table_types) + normalise_metadata_result(result, MetadataOp.TABLES) + return result def get_columns( @@ -803,4 +814,7 @@ def get_columns( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" + + normalise_metadata_result(result, MetadataOp.COLUMNS) + return result diff --git a/tests/unit/test_column_mapping.py b/tests/unit/test_column_mapping.py new file mode 100644 index 00000000..a5f4712b --- /dev/null +++ b/tests/unit/test_column_mapping.py @@ -0,0 +1,194 @@ +""" +Tests for the column mapping module. +""" + +import pytest +from unittest.mock import MagicMock +from enum import Enum + +from databricks.sql.backend.column_mapping import ( + normalise_metadata_result, + MetadataOp, +) + + +class TestColumnMapping: + """Tests for the column mapping module.""" + + def test_normalize_metadata_result_catalogs(self): + """Test normalizing catalog column names.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ("catalog", "string", None, None, None, None, True), + ("other_column", "string", None, None, None, None, True), + ] + + # Normalize the result set + normalise_metadata_result(mock_result, MetadataOp.CATALOGS) + + # Check that the column names were normalized + assert mock_result.description[0][0] == "TABLE_CAT" + assert mock_result.description[1][0] == "other_column" + + def test_normalize_metadata_result_schemas(self): + """Test normalizing schema column names.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ("databaseName", "string", None, None, None, None, True), + ("catalogName", "string", None, None, None, None, True), + ("other_column", "string", None, None, None, None, True), + ] + + # Normalize the result set + normalise_metadata_result(mock_result, MetadataOp.SCHEMAS) + + # Check that the column names were normalized + assert mock_result.description[0][0] == "TABLE_SCHEM" + assert mock_result.description[1][0] == "TABLE_CATALOG" + assert mock_result.description[2][0] == "other_column" + + def test_normalize_metadata_result_tables(self): + """Test normalizing table column names.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ("catalogName", "string", None, None, None, None, True), + ("namespace", "string", None, None, None, None, True), + ("tableName", "string", None, None, None, None, True), + ("tableType", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ("TYPE_CATALOG_COLUMN", "string", None, None, None, None, True), + ("TYPE_SCHEMA_COLUMN", "string", None, None, None, None, True), + ("TYPE_NAME", "string", None, None, None, None, True), + ("SELF_REFERENCING_COLUMN_NAME", "string", None, None, None, None, True), + ("REF_GENERATION_COLUMN", "string", None, None, None, None, True), + ("other_column", "string", None, None, None, None, True), + ] + + # Normalize the result set + normalise_metadata_result(mock_result, MetadataOp.TABLES) + + # Check that the column names were normalized + assert mock_result.description[0][0] == "TABLE_CAT" + assert mock_result.description[1][0] == "TABLE_SCHEM" + assert mock_result.description[2][0] == "TABLE_NAME" + assert mock_result.description[3][0] == "TABLE_TYPE" + assert mock_result.description[4][0] == "REMARKS" + assert mock_result.description[5][0] == "TYPE_CAT" + assert mock_result.description[6][0] == "TYPE_SCHEM" + assert mock_result.description[7][0] == "TYPE_NAME" + assert mock_result.description[8][0] == "SELF_REFERENCING_COL_NAME" + assert mock_result.description[9][0] == "REF_GENERATION" + assert mock_result.description[10][0] == "other_column" + + def test_normalize_metadata_result_columns(self): + """Test normalizing column column names.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ("catalogName", "string", None, None, None, None, True), + ("namespace", "string", None, None, None, None, True), + ("tableName", "string", None, None, None, None, True), + ("col_name", "string", None, None, None, None, True), + ("dataType", "string", None, None, None, None, True), + ("columnSize", "string", None, None, None, None, True), + ("bufferLength", "string", None, None, None, None, True), + ("decimalDigits", "string", None, None, None, None, True), + ("radix", "string", None, None, None, None, True), + ("Nullable", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ("columnType", "string", None, None, None, None, True), + ("SQLDataType", "string", None, None, None, None, True), + ("SQLDatetimeSub", "string", None, None, None, None, True), + ("CharOctetLength", "string", None, None, None, None, True), + ("ordinalPosition", "string", None, None, None, None, True), + ("isNullable", "string", None, None, None, None, True), + ("ScopeCatalog", "string", None, None, None, None, True), + ("ScopeSchema", "string", None, None, None, None, True), + ("ScopeTable", "string", None, None, None, None, True), + ("SourceDataType", "string", None, None, None, None, True), + ("isAutoIncrement", "string", None, None, None, None, True), + ("isGenerated", "string", None, None, None, None, True), + ("other_column", "string", None, None, None, None, True), + ] + + # Normalize the result set + normalise_metadata_result(mock_result, MetadataOp.COLUMNS) + + # Check that the column names were normalized + assert mock_result.description[0][0] == "TABLE_CAT" + assert mock_result.description[1][0] == "TABLE_SCHEM" + assert mock_result.description[2][0] == "TABLE_NAME" + assert mock_result.description[3][0] == "COLUMN_NAME" + assert mock_result.description[4][0] == "DATA_TYPE" + assert mock_result.description[5][0] == "COLUMN_SIZE" + assert mock_result.description[6][0] == "BUFFER_LENGTH" + assert mock_result.description[7][0] == "DECIMAL_DIGITS" + assert mock_result.description[8][0] == "NUM_PREC_RADIX" + assert mock_result.description[9][0] == "NULLABLE" + assert mock_result.description[10][0] == "REMARKS" + assert mock_result.description[11][0] == "COLUMN_DEF" + assert mock_result.description[12][0] == "SQL_DATA_TYPE" + assert mock_result.description[13][0] == "SQL_DATETIME_SUB" + assert mock_result.description[14][0] == "CHAR_OCTET_LENGTH" + assert mock_result.description[15][0] == "ORDINAL_POSITION" + assert mock_result.description[16][0] == "IS_NULLABLE" + assert mock_result.description[17][0] == "SCOPE_CATALOG" + assert mock_result.description[18][0] == "SCOPE_SCHEMA" + assert mock_result.description[19][0] == "SCOPE_TABLE" + assert mock_result.description[20][0] == "SOURCE_DATA_TYPE" + assert mock_result.description[21][0] == "IS_AUTOINCREMENT" + assert mock_result.description[22][0] == "IS_GENERATEDCOLUMN" + assert mock_result.description[23][0] == "other_column" + + def test_normalize_metadata_result_unknown_operation(self): + """Test normalizing with an unknown operation type.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ("column1", "string", None, None, None, None, True), + ("column2", "string", None, None, None, None, True), + ] + + # Save the original description + original_description = mock_result.description.copy() + + # Create a separate enum for testing + class TestOp(Enum): + UNKNOWN = "unknown" + + # Normalize the result set with an unknown operation + normalise_metadata_result(mock_result, TestOp.UNKNOWN) + + # Check that the description was not modified + assert mock_result.description == original_description + + def test_normalize_metadata_result_preserves_other_fields(self): + """Test that normalization preserves other fields in the description.""" + # Create a mock result set with a description + mock_result = MagicMock() + mock_result.description = [ + ( + "catalog", + "string", + "display_size", + "internal_size", + "precision", + "scale", + True, + ), + ] + + # Normalize the result set + normalise_metadata_result(mock_result, MetadataOp.CATALOGS) + + # Check that the column name was normalized but other fields preserved + assert mock_result.description[0][0] == "TABLE_CAT" + assert mock_result.description[0][1] == "string" + assert mock_result.description[0][2] == "display_size" + assert mock_result.description[0][3] == "internal_size" + assert mock_result.description[0][4] == "precision" + assert mock_result.description[0][5] == "scale" + assert mock_result.description[0][6] == True diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 67c202bc..93f53288 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -650,6 +650,11 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): """Test the get_catalogs method.""" # Mock the execute_command method mock_result_set = Mock() + # Add description attribute to the mock result set + mock_result_set.description = [ + ("catalog", "string", None, None, None, None, True), + ] + with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -678,10 +683,19 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): # Verify the result is correct assert result == mock_result_set + # Verify that column normalization was applied + assert result.description[0][0] == "TABLE_CAT" + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): """Test the get_schemas method with various parameter combinations.""" # Mock the execute_command method mock_result_set = Mock() + # Add description attribute to the mock result set + mock_result_set.description = [ + ("databaseName", "string", None, None, None, None, True), + ("catalogName", "string", None, None, None, None, True), + ] + with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -707,6 +721,10 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) + # Verify that column normalization was applied + assert result.description[0][0] == "TABLE_SCHEM" + assert result.description[1][0] == "TABLE_CATALOG" + # Case 2: With catalog and schema names result = sea_client.get_schemas( session_id=sea_session_id, @@ -746,6 +764,14 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): from databricks.sql.backend.sea.result_set import SeaResultSet mock_result_set = Mock(spec=SeaResultSet) + # Add description attribute to the mock result set + mock_result_set.description = [ + ("catalogName", "string", None, None, None, None, True), + ("namespace", "string", None, None, None, None, True), + ("tableName", "string", None, None, None, None, True), + ("tableType", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] with patch.object( sea_client, "execute_command", return_value=mock_result_set @@ -778,6 +804,13 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): ) mock_filter.assert_called_with(mock_result_set, None) + # Verify that column normalization was applied + assert result.description[0][0] == "TABLE_CAT" + assert result.description[1][0] == "TABLE_SCHEM" + assert result.description[2][0] == "TABLE_NAME" + assert result.description[3][0] == "TABLE_TYPE" + assert result.description[4][0] == "REMARKS" + # Case 2: With all parameters table_types = ["TABLE", "VIEW"] result = sea_client.get_tables( @@ -831,6 +864,19 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method mock_result_set = Mock() + # Add description attribute to the mock result set + mock_result_set.description = [ + ("catalogName", "string", None, None, None, None, True), + ("namespace", "string", None, None, None, None, True), + ("tableName", "string", None, None, None, None, True), + ("col_name", "string", None, None, None, None, True), + ("columnType", "string", None, None, None, None, True), + ("dataType", "string", None, None, None, None, True), + ("Nullable", "string", None, None, None, None, True), + ("isNullable", "string", None, None, None, None, True), + ("ordinalPosition", "string", None, None, None, None, True), + ] + with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: @@ -856,6 +902,17 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) + # Verify that column normalization was applied + assert result.description[0][0] == "TABLE_CAT" + assert result.description[1][0] == "TABLE_SCHEM" + assert result.description[2][0] == "TABLE_NAME" + assert result.description[3][0] == "COLUMN_NAME" + assert result.description[4][0] == "COLUMN_DEF" + assert result.description[5][0] == "DATA_TYPE" + assert result.description[6][0] == "NULLABLE" + assert result.description[7][0] == "IS_NULLABLE" + assert result.description[8][0] == "ORDINAL_POSITION" + # Case 2: With all parameters result = sea_client.get_columns( session_id=sea_session_id,