Skip to content

map_chunked implementation #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions discoverx/explorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import concurrent.futures
import copy
import re
from typing import Optional, List
import more_itertools
from typing import Optional, List, Callable
from discoverx import logging
from discoverx.common import helper
from discoverx.discovery import Discovery
Expand Down Expand Up @@ -165,7 +166,7 @@ def scan(
discover.scan(rules=rules, sample_size=sample_size, what_if=what_if)
return discover

def map(self, f) -> list[any]:
def map(self, f: Callable) -> list[any]:
"""Runs a function for each table in the data explorer

Args:
Expand Down Expand Up @@ -197,6 +198,39 @@ def map(self, f) -> list[any]:

return res

def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[any]:
def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[Any]:

any is a function, not a type

"""Runs a function for each table in the data explorer

Args:
f (function): The function to run. The function should accept either a list of TableInfo objects as input and return a list of any object as output.

Returns:
list[any]: A list of the results of running the function for each table
"""
res = []
table_list = self._info_fetcher.get_tables_info(
self._catalogs,
self._schemas,
self._tables,
self._having_columns,
self._with_tags,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor:
# Submit tasks to the thread pool
futures = [
executor.submit(f, table_chunk, **kwargs) for table_chunk in more_itertools.chunked(table_list, tables_per_chunk)
]

# Process completed tasks
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result is not None:
res.extend(result)

logger.debug("Finished lakehouse map_chunked task")

return res


class DataExplorerActions:
def __init__(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"delta-spark>=2.2.0",
"pandas<2.0.0", # From 2.0.0 onwards, pandas does not support iteritems() anymore, spark.createDataFrame will fail
"numpy<1.24", # From 1.24 onwards, module 'numpy' has no attribute 'bool'.
"more_itertools",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create LPP ticket for this, otherwise re-implement a single function. Don't add whole library for the sake of a function

]

TEST_REQUIREMENTS = [
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def test_map(spark, info_fetcher):
assert result[0].tags == None


def test_map_chunked_1(spark, info_fetcher):
data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher)
result = data_explorer.map_chunked(lambda table_info: table_info, 10)
assert len(result) == 1
assert result[0].table == "tb_1"
assert result[0].schema == "default"
assert result[0].catalog == None
assert result[0].tags == None


def test_map_chunked_2(spark, info_fetcher):
data_explorer = DataExplorer("*.default.*", spark, info_fetcher)
result = data_explorer.map_chunked(lambda table_info: table_info, 10)
assert len(result) == 3
for res in result:
assert res.table in ["tb_1", "tb_2", "tb_all_types"]
if res.table == "tb_1":
assert res.schema == "default"
assert res.catalog == None
assert res.tags == None
elif res.table == "tb_2":
assert res.schema == "default"
assert res.catalog == None
assert res.tags == None
else:
assert res.schema == "default"
assert res.catalog == "hive_metastore"
assert res.tags == None
result2 = data_explorer.map_chunked(lambda table_info: table_info, 2)
assert result2 == result


def test_map_with_tags(spark, info_fetcher):
data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher).with_tags()
result = data_explorer.map(lambda table_info: table_info)
Expand Down