Skip to content

Commit 8529a00

Browse files
authored
Merge pull request #8 from dpguthrie/fix/column-name-upper
Fix: Column name for multiple dialects
2 parents f817cc9 + b227a27 commit 8529a00

File tree

7 files changed

+81
-12
lines changed

7 files changed

+81
-12
lines changed

src/config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
VALID_DIALECTS = (
12+
"athena",
13+
"bigquery",
14+
"databricks",
15+
"postgres",
16+
"redshift",
17+
"snowflake",
18+
"spark",
19+
"trino",
20+
)
21+
1122

1223
@dataclass
1324
class Config:
@@ -57,7 +68,11 @@ def is_valid_field(cls, field_name: str) -> bool:
5768

5869
dialect = os.getenv("INPUT_DIALECT", None)
5970
if dialect is not None:
60-
env_vars["dialect"] = dialect
71+
if dialect.lower() not in VALID_DIALECTS:
72+
raise ValueError(
73+
f"Invalid dialect: {dialect}. Valid dialects are: {VALID_DIALECTS}"
74+
)
75+
env_vars["dialect"] = dialect.lower()
6176

6277
dry_run = os.getenv("INPUT_DRY_RUN", "false").lower() == "true"
6378
env_vars["dry_run"] = dry_run

src/interfaces/lineage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
if TYPE_CHECKING: # pragma: no cover
44
from src.models.node import Node
5+
from src.config import Config
56

67

78
class LineageServiceProtocol(Protocol):
9+
config: "Config"
10+
811
def get_node_lineage(self, nodes: List["Node"]) -> Set[str]: ...
912

1013
def get_column_lineage(self, node_id: str, column_name: str) -> Set[str]: ...

src/models/column_tracker.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@ class ColumnTracker:
3030
_tracked_columns: Set[str] = field(default_factory=set)
3131
_impacted_ids: Set[str] = field(default_factory=set)
3232

33+
def _column_name_for_dialect(self, column_name: str) -> str:
34+
"""
35+
Get the column name for the current dialect.
36+
37+
Args:
38+
column_name: The original column name
39+
40+
Returns:
41+
str: The column name for the current dialect
42+
"""
43+
if self._lineage_service.config.dialect == "snowflake":
44+
return column_name.upper()
45+
46+
# TODO: Any other modifications?
47+
return column_name
48+
3349
def track_node_columns(self, node: "Node") -> Set[str]:
3450
"""
3551
Track columns for a node and identify impacted downstream nodes.
@@ -56,7 +72,7 @@ def track_node_columns(self, node: "Node") -> Set[str]:
5672
)
5773
impacted_ids.update(
5874
self._lineage_service.get_column_lineage(
59-
node.unique_id, column_name
75+
node.unique_id, self._column_name_for_dialect(column_name)
6076
)
6177
)
6278
self._tracked_columns.add(node_column)

src/services/discovery_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def get_column_lineage(
5656
variables = {
5757
"environmentId": environment_id,
5858
"nodeUniqueId": node_id,
59-
# TODO: This is a hack because Snowflake uppercases everything
60-
"filters": {"columnName": column_name.upper()},
59+
"filters": {"columnName": column_name},
6160
}
6261

6362
lineage = self.config.dbtc_client.metadata.query(

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ def mock_dbt_runner() -> DbtRunnerProtocol:
7777

7878

7979
@pytest.fixture
80-
def mock_lineage_service() -> LineageServiceProtocol:
80+
def mock_lineage_service(mock_config: Config) -> LineageServiceProtocol:
8181
"""Create a mock lineage service."""
8282
service = MagicMock(spec=LineageServiceProtocol)
83+
service.config = mock_config
8384

8485
# Setup default return values
8586
service.get_column_lineage.return_value = set()

tests/models/test_column_tracker.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def test_track_node_columns_new_columns(mock_lineage_service, mock_node):
3535
impacted_ids = tracker.track_node_columns(mock_node)
3636

3737
# Verify the results
38-
expected_tracked_columns = {
39-
"model.my_project.test_model.column1",
40-
"model.my_project.test_model.column2",
41-
}
4238
expected_impacted_ids = {
4339
"model.my_project.downstream_model1",
4440
"model.my_project.downstream_model2",
4541
}
42+
expected_tracked_columns = {
43+
"model.my_project.test_model.column1",
44+
"model.my_project.test_model.column2",
45+
}
4646

4747
assert tracker._tracked_columns == expected_tracked_columns
4848
assert tracker._impacted_ids == expected_impacted_ids
@@ -51,10 +51,10 @@ def test_track_node_columns_new_columns(mock_lineage_service, mock_node):
5151
# Verify lineage service was called correctly
5252
assert mock_lineage_service.get_column_lineage.call_count == 2
5353
mock_lineage_service.get_column_lineage.assert_any_call(
54-
"model.my_project.test_model", "column1"
54+
"model.my_project.test_model", "COLUMN1"
5555
)
5656
mock_lineage_service.get_column_lineage.assert_any_call(
57-
"model.my_project.test_model", "column2"
57+
"model.my_project.test_model", "COLUMN2"
5858
)
5959

6060

@@ -85,7 +85,7 @@ def test_track_node_columns_already_tracked(mock_lineage_service, mock_node):
8585

8686
# Verify lineage service was called only once (for column2)
8787
mock_lineage_service.get_column_lineage.assert_called_once_with(
88-
"model.my_project.test_model", "column2"
88+
"model.my_project.test_model", "COLUMN2"
8989
)
9090

9191

@@ -100,3 +100,18 @@ def test_impacted_ids_property(mock_lineage_service):
100100
assert tracker.impacted_ids == expected_ids
101101
# Ensure we get a copy of the set, not the original
102102
assert tracker.impacted_ids is not tracker._impacted_ids
103+
104+
105+
def test_column_name_for_dialect(mock_lineage_service):
106+
"""Test column name handling for different dialects."""
107+
tracker = ColumnTracker(mock_lineage_service)
108+
109+
# Test Snowflake dialect (should uppercase)
110+
mock_lineage_service.config.dialect = "snowflake"
111+
assert tracker._column_name_for_dialect("test_column") == "TEST_COLUMN"
112+
assert tracker._column_name_for_dialect("MixedCase") == "MIXEDCASE"
113+
114+
# Test other dialect (should return unchanged)
115+
mock_lineage_service.config.dialect = "bigquery"
116+
assert tracker._column_name_for_dialect("test_column") == "test_column"
117+
assert tracker._column_name_for_dialect("MixedCase") == "MixedCase"

tests/test_config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,23 @@ def test_set_fields_from_dbtc_client_invalid_response(mock_config):
131131
mock_config._set_fields_from_dbtc_client()
132132

133133
assert "An error occurred retrieving your job's data" in str(exc_info.value)
134+
135+
136+
def test_config_invalid_dialect():
137+
"""Test Config creation with an invalid dialect."""
138+
env_vars = {
139+
"INPUT_DBT_CLOUD_HOST": "cloud.getdbt.com",
140+
"INPUT_DBT_CLOUD_SERVICE_TOKEN": "test_token",
141+
"INPUT_DBT_CLOUD_TOKEN_NAME": "cloud-cli-6d65",
142+
"INPUT_DBT_CLOUD_TOKEN_VALUE": "test_token_value",
143+
"INPUT_DBT_CLOUD_ACCOUNT_ID": "43786",
144+
"INPUT_DBT_CLOUD_JOB_ID": "567183",
145+
"INPUT_DIALECT": "invalid_dialect",
146+
}
147+
148+
with patch.dict("os.environ", env_vars, clear=True):
149+
with pytest.raises(ValueError) as exc_info:
150+
Config.from_env()
151+
152+
assert "Invalid dialect: invalid_dialect" in str(exc_info.value)
153+
assert "Valid dialects are:" in str(exc_info.value)

0 commit comments

Comments
 (0)