Skip to content

Commit 93ac38d

Browse files
authored
Case sensitive/insensitive table validation (#3580)
closes #3568 Added case sensitive flag for metadata comparison. To consider/ignore column name case.
1 parent 34b3d86 commit 93ac38d

File tree

4 files changed

+63
-12
lines changed

4 files changed

+63
-12
lines changed

src/databricks/labs/ucx/recon/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
from abc import ABC, abstractmethod
3+
from collections.abc import Callable
34
from dataclasses import dataclass
45

56

@@ -82,7 +83,9 @@ def as_dict(self):
8283

8384
class TableMetadataRetriever(ABC):
8485
@abstractmethod
85-
def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
86+
def get_metadata(
87+
self, entity: TableIdentifier, *, column_name_transformer: Callable[[str], str] = str
88+
) -> TableMetadata:
8689
"""
8790
Get metadata for a given table
8891
"""

src/databricks/labs/ucx/recon/metadata_retriever.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterator
1+
from collections.abc import Iterator, Callable
22

33
from databricks.labs.lsql.backends import SqlBackend
44
from databricks.labs.lsql.core import Row
@@ -10,7 +10,9 @@ class DatabricksTableMetadataRetriever(TableMetadataRetriever):
1010
def __init__(self, sql_backend: SqlBackend):
1111
self._sql_backend = sql_backend
1212

13-
def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
13+
def get_metadata(
14+
self, entity: TableIdentifier, *, column_name_transformer: Callable[[str], str] = str
15+
) -> TableMetadata:
1416
"""
1517
This method retrieves the metadata for a given table. It takes a TableIdentifier object as input,
1618
which represents the table for which the metadata is to be retrieved.
@@ -24,11 +26,11 @@ def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
2426
# Partition information are typically prefixed with a # symbol,
2527
# so any column name starting with # is excluded from the final set of column metadata.
2628
# The column metadata objects are sorted by column name to ensure a consistent order.
27-
columns = {
28-
ColumnMetadata(str(row["col_name"]), str(row["data_type"]))
29-
for row in query_result
30-
if not str(row["col_name"]).startswith("#")
31-
}
29+
columns = set()
30+
for row in query_result:
31+
if str(row["col_name"]).startswith("#"):
32+
continue
33+
columns.add(ColumnMetadata(column_name_transformer(str(row["col_name"])), str(row["data_type"])))
3234
return TableMetadata(entity, sorted(columns, key=lambda x: x.name))
3335

3436
@classmethod
@@ -38,7 +40,7 @@ def _build_metadata_query(cls, entity: TableIdentifier) -> str:
3840

3941
query = f"""
4042
SELECT
41-
LOWER(column_name) AS col_name,
43+
column_name AS col_name,
4244
full_data_type AS data_type
4345
FROM
4446
{entity.catalog_escaped}.information_schema.columns

src/databricks/labs/ucx/recon/schema_comparator.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Callable
2+
13
from .base import (
24
SchemaComparator,
35
SchemaComparisonEntry,
@@ -9,8 +11,14 @@
911

1012

1113
class StandardSchemaComparator(SchemaComparator):
12-
def __init__(self, metadata_retriever: TableMetadataRetriever):
14+
def __init__(self, metadata_retriever: TableMetadataRetriever, *, case_sensitive: bool = False):
1315
self._metadata_retriever = metadata_retriever
16+
self._case_sensitive = case_sensitive
17+
18+
def _column_name_transformer(self) -> Callable[[str], str]:
19+
if self._case_sensitive:
20+
return lambda _: _
21+
return str.lower
1422

1523
def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult:
1624
"""
@@ -26,8 +34,12 @@ def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> Sc
2634
return SchemaComparisonResult(is_matching, comparison_result)
2735

2836
def _eval_schema_diffs(self, source: TableIdentifier, target: TableIdentifier) -> list[SchemaComparisonEntry]:
29-
source_metadata = self._metadata_retriever.get_metadata(source)
30-
target_metadata = self._metadata_retriever.get_metadata(target)
37+
source_metadata = self._metadata_retriever.get_metadata(
38+
source, column_name_transformer=self._column_name_transformer()
39+
)
40+
target_metadata = self._metadata_retriever.get_metadata(
41+
target, column_name_transformer=self._column_name_transformer()
42+
)
3143
# Combine the sets of column names for both the source and target tables
3244
# to create a set of all unique column names from both tables.
3345
source_column_names = {column.name for column in source_metadata.columns}

tests/unit/recon/test_schema_comparator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from databricks.labs.lsql.backends import MockBackend
23

34
from databricks.labs.ucx.recon.base import TableIdentifier, SchemaComparisonResult, SchemaComparisonEntry
@@ -100,3 +101,36 @@ def test_schema_comparison_failure(metadata_row_factory):
100101
schema_comparator = StandardSchemaComparator(metadata_retriever)
101102
actual_comparison_result = schema_comparator.compare_schema(source, target)
102103
assert actual_comparison_result == expected_comparison_result
104+
105+
106+
@pytest.mark.parametrize(
107+
"source_column, target_column, case_sensitive, expected_pass",
108+
[
109+
("column1", "columnx", True, False),
110+
("column1", "column1", True, True),
111+
("column1", "Column1", True, False),
112+
("column1", "Column1", False, True),
113+
("CoLuMn1", "cOlUmN1", True, False),
114+
("CoLuMn1", "cOlUmN1", False, True),
115+
],
116+
)
117+
def test_schema_comparison_case(metadata_row_factory, source_column, target_column, case_sensitive, expected_pass):
118+
source = TableIdentifier("hive_metastore", "db1", "table1")
119+
target = TableIdentifier("catalog1", "schema1", "table1")
120+
sql_backend = MockBackend(
121+
rows={
122+
"DESCRIBE TABLE": metadata_row_factory[
123+
(source_column, "int"),
124+
("column2", "string"),
125+
],
126+
f"{target.catalog_escaped}\\.information_schema\\.columns": metadata_row_factory[
127+
(target_column, "int"),
128+
("column2", "string"),
129+
],
130+
}
131+
)
132+
133+
metadata_retriever = DatabricksTableMetadataRetriever(sql_backend)
134+
schema_comparator = StandardSchemaComparator(metadata_retriever, case_sensitive=case_sensitive)
135+
actual_comparison_result = schema_comparator.compare_schema(source, target)
136+
assert actual_comparison_result.is_matching == expected_pass

0 commit comments

Comments
 (0)