Skip to content

Commit 8ab38b0

Browse files
authored
Merge pull request #355 from DataRecce/feature/drc-510-support-select-api
[Feature] Add the select nodes method
2 parents f81b21c + bfa18f2 commit 8ab38b0

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed

recce/adapter/dbt_adapter/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import contextmanager
66
from dataclasses import dataclass, fields
7-
from typing import Callable, Dict, List, Optional, Tuple, Iterator, Any
7+
from typing import Callable, Dict, List, Optional, Tuple, Iterator, Any, Set
88

99
import agate
1010
import dbt.adapters.factory
@@ -134,6 +134,7 @@ class DbtArgs:
134134
profile: Optional[str] = None,
135135
target_path: Optional[str] = None,
136136
project_only_flags: Optional[Dict[str, Any]] = None
137+
which: Optional[str] = 'run'
137138

138139

139140
@dataclass
@@ -172,6 +173,7 @@ def load(cls, artifacts: ArtifactsRoot = None, no_artifacts=False, **kwargs):
172173
profiles_dir=profiles_dir,
173174
profile=profile_name,
174175
project_only_flags={},
176+
which='recce'
175177
)
176178
set_from_args(args, args)
177179

@@ -584,6 +586,16 @@ def create_relation(self, model, base=False):
584586

585587
return self.adapter.Relation.create_from(self.runtime_config, node)
586588

589+
def select_nodes(self, select: str) -> Set[str]:
590+
from dbt.graph import SelectionCriteria, NodeSelector
591+
from dbt.compilation import Compiler
592+
593+
compiler = Compiler(self.runtime_config)
594+
graph = compiler.compile(self.manifest)
595+
selector = NodeSelector(graph, self.manifest)
596+
spec = SelectionCriteria.from_single_spec(select)
597+
return selector.get_selected(spec)
598+
587599
def export_artifacts(self) -> ArtifactsRoot:
588600
artifacts = ArtifactsRoot()
589601
target_path = self.runtime_config.target_path
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import textwrap
3+
import uuid
4+
5+
import pytest
6+
from dbt.contracts.graph.nodes import ModelNode
7+
8+
from recce.adapter.dbt_adapter import DbtAdapter, as_manifest, load_manifest
9+
from recce.core import RecceContext, set_default_context
10+
11+
12+
class DbtTestHelper:
13+
14+
def __init__(self):
15+
schema_prefix = "schema_" + uuid.uuid4().hex
16+
self.base_schema = f"{schema_prefix}_base"
17+
self.curr_schema = f"{schema_prefix}_curr"
18+
19+
current_dir = os.path.dirname(os.path.abspath(__file__))
20+
project_dir = os.path.join(current_dir, '.')
21+
profiles_dir = project_dir
22+
manifest_path = os.path.join(project_dir, 'manifest.json')
23+
24+
dbt_adapter = DbtAdapter.load(
25+
no_artifacts=True,
26+
project_dir=project_dir,
27+
profiles_dir=profiles_dir,
28+
)
29+
30+
context = RecceContext()
31+
context.adapter_type = 'dbt'
32+
context.adapter = dbt_adapter
33+
context.schema_prefix = schema_prefix
34+
self.adapter = dbt_adapter
35+
self.context = context
36+
self.curr_manifest = as_manifest(load_manifest(manifest_path))
37+
self.base_manifest = as_manifest(load_manifest(manifest_path))
38+
self.context = context
39+
40+
self.adapter.execute(f"CREATE schema IF NOT EXISTS {self.base_schema}")
41+
self.adapter.execute(f"CREATE schema IF NOT EXISTS {self.curr_schema}")
42+
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)
43+
44+
def create_model(self, model_name, base_csv, curr_csv):
45+
package_name = "recce_test"
46+
47+
def _add_model_to_manifest(base):
48+
if base:
49+
schema = self.base_schema
50+
manifest = self.base_manifest
51+
else:
52+
schema = self.curr_schema
53+
manifest = self.curr_manifest
54+
55+
node = ModelNode.from_dict({
56+
"resource_type": "model",
57+
"name": model_name,
58+
"package_name": package_name,
59+
"path": "",
60+
"original_file_path": "",
61+
"unique_id": f"model.{package_name}.{model_name}",
62+
"fqn": [
63+
package_name,
64+
model_name,
65+
],
66+
"schema": schema,
67+
"alias": model_name,
68+
"checksum": {
69+
"name": "sha256",
70+
"checksum": ""
71+
},
72+
"raw_code": 'dummy',
73+
"config": {
74+
"materialized": "table",
75+
"tags": ["test_tag"],
76+
},
77+
"tags": ["test_tag"],
78+
})
79+
manifest.add_node_nofile(node)
80+
81+
_add_model_to_manifest(base=True)
82+
_add_model_to_manifest(base=False)
83+
84+
import pandas as pd
85+
from io import StringIO
86+
df_base = pd.read_csv(StringIO(textwrap.dedent(base_csv)))
87+
df_curr = pd.read_csv(StringIO(textwrap.dedent(curr_csv)))
88+
dbt_adapter = self.adapter
89+
with dbt_adapter.connection_named('create model'):
90+
dbt_adapter.execute(f"CREATE TABLE {self.base_schema}.{model_name} AS SELECT * FROM df_base")
91+
dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr")
92+
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)
93+
94+
def remove_model(self, model_name):
95+
dbt_adapter = self.adapter
96+
with dbt_adapter.connection_named('cleanup'):
97+
dbt_adapter.execute(f"DROP TABLE IF EXISTS {self.base_schema}.{model_name}")
98+
dbt_adapter.execute(f"DROP TABLE IF EXISTS {self.curr_schema}.{model_name} ")
99+
100+
def cleanup(self):
101+
dbt_adapter = self.adapter
102+
with dbt_adapter.connection_named('cleanup'):
103+
dbt_adapter.execute(f"DROP SCHEMA IF EXISTS {self.base_schema} CASCADE")
104+
dbt_adapter.execute(f"DROP SCHEMA IF EXISTS {self.curr_schema} CASCADE")
105+
106+
107+
@pytest.fixture
108+
def helper():
109+
helper = DbtTestHelper()
110+
context = helper.context
111+
set_default_context(context)
112+
yield helper
113+
helper.cleanup()
114+
115+
116+
def test_select(helper):
117+
csv_data_curr = """
118+
customer_id,name,age
119+
1,Alice,30
120+
2,Bob,25
121+
3,Charlie,35
122+
"""
123+
124+
csv_data_base = """
125+
customer_id,name,age
126+
1,Alice,35
127+
2,Bob,25
128+
3,Charlie,35
129+
"""
130+
131+
helper.create_model("customers_1", csv_data_base, csv_data_curr)
132+
helper.create_model("customers_2", csv_data_base, csv_data_curr)
133+
adapter: DbtAdapter = helper.context.adapter
134+
node_ids = adapter.select_nodes('resource_type:model')
135+
assert len(node_ids) == 2
136+
node_ids = adapter.select_nodes('customers_1')
137+
assert len(node_ids) == 1
138+
node_ids = adapter.select_nodes('tag:test_tag')
139+
assert len(node_ids) == 2
140+
node_ids = adapter.select_nodes('tag:test_tag2')
141+
assert len(node_ids) == 0
142+
node_ids = adapter.select_nodes("config.materialized:incremental")
143+
assert len(node_ids) == 0
144+
node_ids = adapter.select_nodes("config.materialized:table")
145+
assert len(node_ids) == 2

0 commit comments

Comments
 (0)