Skip to content

Commit 8c63307

Browse files
authored
feat: atlas_search_extractor | 🎉 Initial commit. (#415)
Signed-off-by: mgorsk1 <gorskimariusz13@gmail.com>
1 parent 26a0d0a commit 8c63307

File tree

4 files changed

+408
-1
lines changed

4 files changed

+408
-1
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright Contributors to the Amundsen project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import importlib
5+
import logging
6+
import multiprocessing.pool
7+
from copy import deepcopy
8+
from functools import reduce
9+
from typing import (
10+
Any, Dict, Generator, Iterator, List, Optional, Tuple,
11+
)
12+
13+
from atlasclient.client import Atlas
14+
from pyhocon import ConfigFactory, ConfigTree
15+
16+
from databuilder.extractor.base_extractor import Extractor
17+
18+
LOGGER = logging.getLogger(__name__)
19+
20+
# custom types
21+
type_fields_mapping_spec = Dict[str, List[Tuple[str, str, Any, Any]]]
22+
type_fields_mapping = List[Tuple[str, str, Any, Any]]
23+
24+
# @todo document classes/methods
25+
# @todo write tests
26+
27+
__all__ = ['AtlasSearchDataExtractor']
28+
29+
30+
class AtlasSearchDataExtractorHelpers:
31+
@staticmethod
32+
def _filter_none(input_list: List) -> List:
33+
return list(filter(None, input_list))
34+
35+
@staticmethod
36+
def get_column_names(column_list: List) -> List:
37+
return AtlasSearchDataExtractorHelpers._filter_none(
38+
[c.get('attributes').get('name') for c in column_list if c.get('status').lower() == 'active'])
39+
40+
@staticmethod
41+
def get_column_descriptions(column_list: List) -> List:
42+
return AtlasSearchDataExtractorHelpers._filter_none(
43+
[c.get('attributes').get('description') for c in column_list if c.get('status').lower() == 'active'])
44+
45+
@staticmethod
46+
def get_badges_from_classifications(classifications: List) -> List:
47+
return AtlasSearchDataExtractorHelpers._filter_none(
48+
[c.get('typeName') for c in classifications if c.get('entityStatus', '').lower() == 'active'])
49+
50+
51+
class AtlasSearchDataExtractor(Extractor):
52+
ATLAS_URL_CONFIG_KEY = 'atlas_url'
53+
ATLAS_PORT_CONFIG_KEY = 'atlas_port'
54+
ATLAS_PROTOCOL_CONFIG_KEY = 'atlas_protocol'
55+
ATLAS_VALIDATE_SSL_CONFIG_KEY = 'atlas_validate_ssl'
56+
ATLAS_USERNAME_CONFIG_KEY = 'atlas_auth_user'
57+
ATLAS_PASSWORD_CONFIG_KEY = 'atlas_auth_pw'
58+
ATLAS_SEARCH_CHUNK_SIZE_KEY = 'atlas_search_chunk_size'
59+
ATLAS_DETAILS_CHUNK_SIZE_KEY = 'atlas_details_chunk_size'
60+
ATLAS_TIMEOUT_SECONDS_KEY = 'atlas_timeout_seconds'
61+
ATLAS_MAX_RETRIES_KEY = 'atlas_max_retries'
62+
63+
PROCESS_POOL_SIZE_KEY = 'process_pool_size'
64+
65+
ENTITY_TYPE_KEY = 'entity_type'
66+
67+
DEFAULT_CONFIG = ConfigFactory.from_dict({ATLAS_URL_CONFIG_KEY: "localhost",
68+
ATLAS_PORT_CONFIG_KEY: 21000,
69+
ATLAS_PROTOCOL_CONFIG_KEY: 'http',
70+
ATLAS_VALIDATE_SSL_CONFIG_KEY: False,
71+
ATLAS_SEARCH_CHUNK_SIZE_KEY: 250,
72+
ATLAS_DETAILS_CHUNK_SIZE_KEY: 25,
73+
ATLAS_TIMEOUT_SECONDS_KEY: 120,
74+
ATLAS_MAX_RETRIES_KEY: 2,
75+
PROCESS_POOL_SIZE_KEY: 10})
76+
77+
# @todo fill out below fields for TableESDocument
78+
# tags: List[str],
79+
80+
# es_document field, atlas field path, modification function, default_value
81+
FIELDS_MAPPING_SPEC: type_fields_mapping_spec = {
82+
'Table': [
83+
('database', 'typeName', None, None),
84+
('cluster', 'attributes.qualifiedName', lambda x: x.split('@')[-1], None),
85+
('schema', 'relationshipAttributes.db.displayText', None, None),
86+
('name', 'attributes.name', None, None),
87+
('key', 'attributes.qualifiedName', None, None),
88+
('description', 'attributes.description', None, None),
89+
('last_updated_timestamp', 'updateTime', lambda x: int(x) / 1000, 0),
90+
('total_usage', 'attributes.popularityScore', lambda x: int(x), 0),
91+
('unique_usage', 'attributes.uniqueUsage', lambda x: int(x), 1),
92+
('column_names', 'relationshipAttributes.columns',
93+
lambda x: AtlasSearchDataExtractorHelpers.get_column_names(x), []),
94+
('column_descriptions', 'relationshipAttributes.columns',
95+
lambda x: AtlasSearchDataExtractorHelpers.get_column_descriptions(x), []),
96+
('tags', 'tags', None, []),
97+
('badges', 'classifications',
98+
lambda x: AtlasSearchDataExtractorHelpers.get_badges_from_classifications(x), []),
99+
('display_name', 'attributes.qualifiedName', lambda x: x.split('@')[0], None),
100+
('schema_description', 'attributes.parameters.sourceDescription', None, None),
101+
('programmatic_descriptions', 'attributes.parameters', lambda x: [str(s) for s in list(x.values())], {})
102+
]
103+
}
104+
105+
ENTITY_MODEL_BY_TYPE = {
106+
'Table': 'databuilder.models.table_elasticsearch_document.TableESDocument'
107+
}
108+
109+
REQUIRED_RELATIONSHIPS_BY_TYPE = {
110+
'Table': ['columns']
111+
}
112+
113+
def init(self, conf: ConfigTree) -> None:
114+
self.conf = conf.with_fallback(AtlasSearchDataExtractor.DEFAULT_CONFIG)
115+
self.driver = self._get_driver()
116+
117+
self._extract_iter: Optional[Iterator[Any]] = None
118+
119+
@property
120+
def entity_type(self) -> str:
121+
return self.conf.get(AtlasSearchDataExtractor.ENTITY_TYPE_KEY)
122+
123+
@property
124+
def basic_search_query(self) -> Dict:
125+
query = {
126+
'typeName': self.entity_type,
127+
'excludeDeletedEntities': True,
128+
'query': '*'
129+
}
130+
131+
LOGGER.debug(f'Basic Search Query: {query}')
132+
133+
return query
134+
135+
@property
136+
def dsl_search_query(self) -> Dict:
137+
query = {
138+
'query': f'{self.entity_type} where __state = "ACTIVE"'
139+
}
140+
141+
LOGGER.debug(f'DSL Search Query: {query}')
142+
143+
return query
144+
145+
@property
146+
def model_class(self) -> Any:
147+
model_class = AtlasSearchDataExtractor.ENTITY_MODEL_BY_TYPE.get(self.entity_type)
148+
149+
if model_class:
150+
module_name, class_name = model_class.rsplit(".", 1)
151+
mod = importlib.import_module(module_name)
152+
153+
return getattr(mod, class_name)
154+
155+
@property
156+
def field_mappings(self) -> type_fields_mapping:
157+
return AtlasSearchDataExtractor.FIELDS_MAPPING_SPEC.get(self.entity_type) or []
158+
159+
@property
160+
def search_chunk_size(self) -> int:
161+
return self.conf.get_int(AtlasSearchDataExtractor.ATLAS_SEARCH_CHUNK_SIZE_KEY)
162+
163+
@property
164+
def relationships(self) -> Optional[List[str]]:
165+
return AtlasSearchDataExtractor.REQUIRED_RELATIONSHIPS_BY_TYPE.get(self.entity_type)
166+
167+
def extract(self) -> Any:
168+
if not self._extract_iter:
169+
self._extract_iter = self._get_extract_iter()
170+
171+
try:
172+
return next(self._extract_iter)
173+
except StopIteration:
174+
return None
175+
176+
def get_scope(self) -> str:
177+
return 'extractor.atlas_search_data'
178+
179+
def _get_driver(self) -> Any:
180+
return Atlas(host=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_URL_CONFIG_KEY),
181+
port=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PORT_CONFIG_KEY),
182+
username=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_USERNAME_CONFIG_KEY),
183+
password=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PASSWORD_CONFIG_KEY),
184+
protocol=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PROTOCOL_CONFIG_KEY),
185+
validate_ssl=self.conf.get_bool(AtlasSearchDataExtractor.ATLAS_VALIDATE_SSL_CONFIG_KEY),
186+
timeout=self.conf.get_int(AtlasSearchDataExtractor.ATLAS_TIMEOUT_SECONDS_KEY),
187+
max_retries=self.conf.get_int(AtlasSearchDataExtractor.ATLAS_MAX_RETRIES_KEY))
188+
189+
def _get_approximate_count_of_entities(self) -> int:
190+
try:
191+
# Fetch the table entities based on query terms
192+
count_query = deepcopy(self.basic_search_query)
193+
194+
minimal_parameters = {
195+
'includeClassificationAttributes': False,
196+
'includeSubClassifications': False
197+
}
198+
199+
count_query.update(minimal_parameters)
200+
201+
search_results = self.driver.search_basic.create(data=count_query)
202+
203+
count = search_results._data.get("approximateCount")
204+
except Exception as e:
205+
count = 0
206+
207+
return count
208+
209+
def _get_entity_guids(self, start_offset: int) -> List[str]:
210+
result = []
211+
212+
batch_start = start_offset
213+
batch_end = start_offset + self.search_chunk_size
214+
215+
LOGGER.info(f'Collecting guids for batch: {batch_start}-{batch_end}')
216+
217+
_params = {'offset': str(batch_start), 'limit': str(self.search_chunk_size)}
218+
219+
full_params = deepcopy(self.dsl_search_query)
220+
full_params.update(**_params)
221+
222+
try:
223+
results = self.driver.search_dsl(**full_params)
224+
225+
for hit in results:
226+
for entity in hit.entities:
227+
result.append(entity.guid)
228+
229+
return result
230+
except Exception:
231+
LOGGER.warning(f'Error processing batch: {batch_start}-{batch_end}', exc_info=True)
232+
233+
return []
234+
235+
def _get_entity_details(self, guid_list: List[str]) -> List:
236+
result = []
237+
238+
LOGGER.info(f'Processing guids chunk of size: {len(guid_list)}')
239+
240+
try:
241+
bulk_collection = self.driver.entity_bulk(guid=guid_list)
242+
243+
for collection in bulk_collection:
244+
search_chunk = list(collection.entities_with_relationships(attributes=self.relationships))
245+
246+
result += search_chunk
247+
248+
return result
249+
except Exception:
250+
return []
251+
252+
@staticmethod
253+
def split_list_to_chunks(input_list: List[Any], n: int) -> Generator:
254+
"""Yield successive n-sized chunks from lst."""
255+
for i in range(0, len(input_list), n):
256+
yield input_list[i:i + n]
257+
258+
def _execute_query(self) -> Any:
259+
details_chunk_size = self.conf.get_int(AtlasSearchDataExtractor.ATLAS_DETAILS_CHUNK_SIZE_KEY)
260+
process_pool_size = self.conf.get_int(AtlasSearchDataExtractor.PROCESS_POOL_SIZE_KEY)
261+
262+
guids = []
263+
264+
approximate_count = self._get_approximate_count_of_entities()
265+
266+
LOGGER.info(f'Received count: {approximate_count}')
267+
268+
if approximate_count > 0:
269+
offsets = [i * self.search_chunk_size for i in range(int(approximate_count / self.search_chunk_size) + 1)]
270+
else:
271+
offsets = []
272+
273+
with multiprocessing.pool.ThreadPool(processes=process_pool_size) as pool:
274+
guid_list = pool.map(self._get_entity_guids, offsets, chunksize=1)
275+
276+
for sub_list in guid_list:
277+
guids += sub_list
278+
279+
LOGGER.info(f'Received guids: {len(guids)}')
280+
281+
if guids:
282+
guids_chunks = AtlasSearchDataExtractor.split_list_to_chunks(guids, details_chunk_size)
283+
284+
with multiprocessing.pool.ThreadPool(processes=process_pool_size) as pool:
285+
return_list = pool.map(self._get_entity_details, guids_chunks)
286+
287+
for sub_list in return_list:
288+
for entry in sub_list:
289+
yield entry
290+
291+
def _get_extract_iter(self) -> Iterator[Any]:
292+
for atlas_entity in self._execute_query():
293+
model_dict = dict()
294+
295+
try:
296+
data = atlas_entity.__dict__['_data']
297+
298+
for spec in self.field_mappings:
299+
model_field, atlas_field_path, _transform_spec, default_value = spec
300+
301+
atlas_value = reduce(lambda x, y: x.get(y, dict()), atlas_field_path.split('.'),
302+
data) or default_value
303+
304+
transform_spec = _transform_spec or (lambda x: x)
305+
306+
es_entity_value = transform_spec(atlas_value)
307+
model_dict[model_field] = es_entity_value
308+
309+
yield self.model_class(**model_dict)
310+
except Exception:
311+
LOGGER.warning(f'Error building model object.', exc_info=True)

0 commit comments

Comments
 (0)