From d5f5d8ef78b512c972a1812bd7aa98e5cc3c19b8 Mon Sep 17 00:00:00 2001 From: "dan.zafar" Date: Sat, 9 Nov 2024 15:26:04 -0700 Subject: [PATCH 1/2] proposed dbfs file api with test suite --- pyproject.toml | 3 +- .../labs/ucx/hive_metastore/dbfs_files.py | 85 ++++++++ tests/unit/hive_metastore/test_dbfs_files.py | 186 ++++++++++++++++++ 3 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 src/databricks/labs/ucx/hive_metastore/dbfs_files.py create mode 100644 tests/unit/hive_metastore/test_dbfs_files.py diff --git a/pyproject.toml b/pyproject.toml index f7c2bf2214..356bde131d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ dependencies = ["databricks-sdk~=0.30", "databricks-labs-blueprint>=0.9.1,<0.10", "PyYAML>=6.0.0,<7.0.0", "sqlglot>=25.5.0,<25.30", - "astroid>=3.3.1"] + "astroid>=3.3.1", + "py4j==0.10.9.7"] [project.optional-dependencies] pylsp = [ diff --git a/src/databricks/labs/ucx/hive_metastore/dbfs_files.py b/src/databricks/labs/ucx/hive_metastore/dbfs_files.py new file mode 100644 index 0000000000..64f6ad0f45 --- /dev/null +++ b/src/databricks/labs/ucx/hive_metastore/dbfs_files.py @@ -0,0 +1,85 @@ +import logging +from dataclasses import dataclass +from functools import cached_property + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DbfsFileInfo: + path: str + name: str + is_dir: bool + modification_time: int + + +class DbfsFiles: + + def __init__(self, jvm_interface=None): + # pylint: disable=import-error,import-outside-toplevel + if jvm_interface: + self._spark = jvm_interface + else: + try: + from pyspark.sql.session import SparkSession # type: ignore[import-not-found] + + self._spark = SparkSession.builder.getOrCreate() + except Exception as err: + logger.error(f"Unable to get SparkSession: {err}") + raise err + + # if a test-related jvm_interface is passed in, we don't use py4j's java_import + self._java_import = self._noop if jvm_interface else self._default_java_import + + @staticmethod + def _noop(*args, **kwargs): + pass + + @staticmethod + def _default_java_import(jvm, import_path: str): + # pylint: disable=import-outside-toplevel + from py4j.java_gateway import java_import # type: ignore[import] + + java_import(jvm, import_path) + + @cached_property + def _jvm(self): + try: + _jvm = self._spark._jvm + self._java_import(_jvm, "org.apache.hadoop.fs.FileSystem") + self._java_import(_jvm, "org.apache.hadoop.fs.Path") + return _jvm + except Exception as err: + logger.error(f"Cannot create Py4j proxy: {err}") + raise err + + @cached_property + def _fs(self): + try: + _jsc = self._spark._jsc # pylint: disable=protected-access + return self._jvm.FileSystem.get(_jsc.hadoopConfiguration()) + except Exception as err: + logger.error(f"Cannot create Py4j file system proxy: {err}") + raise err + + class InvalidPathFormatError(ValueError): + pass + + def validate_path(self, path: str) -> None: + if not path.startswith("dbfs:/"): + raise self.InvalidPathFormatError(f"Input path should begin with 'dbfs:/' prefix. Input path: '{path}'") + + def list_dir(self, path: str) -> list[DbfsFileInfo]: + self.validate_path(path) + return self._list_dir(path) + + def _list_dir(self, path_str: str) -> list[DbfsFileInfo]: + path = self._jvm.Path(path_str) + statuses = self._fs.listStatus(path) + return [self._file_status_to_dbfs_file_info(status) for status in statuses] + + @staticmethod + def _file_status_to_dbfs_file_info(status): + return DbfsFileInfo( + status.getPath().toString(), status.getPath().getName(), status.isDir(), status.getModificationTime() + ) diff --git a/tests/unit/hive_metastore/test_dbfs_files.py b/tests/unit/hive_metastore/test_dbfs_files.py new file mode 100644 index 0000000000..c80630b755 --- /dev/null +++ b/tests/unit/hive_metastore/test_dbfs_files.py @@ -0,0 +1,186 @@ +from unittest.mock import Mock +import pytest +from databricks.labs.ucx.hive_metastore.dbfs_files import DbfsFiles, DbfsFileInfo + + +@pytest.fixture +def dbfs_files(mocker): + # dbfs_files is dependent on the jvm, so we need to mock that + + mock_spark = mocker.Mock() + mock_jsc = mocker.Mock() + mock_jvm = mocker.Mock() + mock_filesystem = mocker.Mock() + mock_jvm.FileSystem.get.return_value = mock_filesystem + mock_jvm.Path.side_effect = lambda x: x + mock_spark._jsc = mock_jsc # pylint: disable=protected-access + mock_spark._jvm = mock_jvm # pylint: disable=protected-access + + _dbfs_files = DbfsFiles(jvm_interface=mock_spark) + + return _dbfs_files + + +# use trie data structure to mock a backend file system, matching the general behavior of hadoop java library +# provide an elegant API for adding files into the mock filesystem, reducing human error +class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_path = False + + +class MockFs: + def __init__(self): + self.root = TrieNode() + + def put(self, path: str) -> None: + node = self.root + parts = path.removeprefix("dbfs:/").rstrip("/").split("/") + for part in parts: + if part not in node.children: + node.children[part] = TrieNode() + node = node.children[part] + node.is_end_of_path = True + + @staticmethod + def has_children(node: TrieNode) -> bool: + return bool(node.children) + + def mock_path_component(self, path: str, name: str, node: TrieNode) -> Mock: + if name: + _path = path.rstrip("/") + '/' + name + else: + _path = path.rstrip("/") + name = _path.rsplit('/')[-1] + + mock_status: Mock = Mock() + mock_status.getPath.return_value.toString.return_value = _path + mock_status.getPath.return_value.getName.return_value = name + mock_status.isDir.return_value = self.has_children(node) + mock_status.getModificationTime.return_value = 0 + + return mock_status + + class IllegalArgumentException(ValueError): + pass + + def list_dir(self, path: str) -> list[Mock]: + node: TrieNode = self.root + if path: + parts: list[str] = path.removeprefix("dbfs:/").rstrip("/").split("/") + for part in parts: + if part == '': + continue + if part in node.children: + node = node.children[part] + else: + raise FileNotFoundError(f"'{path}' not found") + + # list_files will return identity if there are no child path components + # note: in the actual api, listing an empty directory just results in an empty list, [], but + # that functionality is not supported in the mock api + if not self.has_children(node): + return [self.mock_path_component(path, '', node)] + + # in typical case, return children + return [self.mock_path_component(path, name, node) for name, node in node.children.items()] + + raise self.IllegalArgumentException("Can not create a Path from an empty string") + + +@pytest.fixture +def mock_hadoop_fs(): + return MockFs() + + +def test_mock_hadoop_fs_put(mock_hadoop_fs): + mock_hadoop_fs.put("dbfs:/dir1/dir2/file") + node = mock_hadoop_fs.root + assert "dir1" in node.children + assert "dir2" in node.children["dir1"].children + assert "file" in node.children["dir1"].children["dir2"].children + assert node.children["dir1"].children["dir2"].children["file"].is_end_of_path is True + + +def test_mock_hadoop_fs_listdir(dbfs_files, mock_hadoop_fs): + mock_hadoop_fs.put("dbfs:/test/path_a") + mock_hadoop_fs.put("dbfs:/test/path_a/file_a") + mock_hadoop_fs.put("dbfs:/test/path_a/file_b") + mock_hadoop_fs.put("dbfs:/test/path_b") + + dbfs_files._fs.listStatus.side_effect = mock_hadoop_fs.list_dir # pylint: disable=protected-access + + result = dbfs_files.list_dir("dbfs:/") + assert result == [DbfsFileInfo("dbfs:/test", "test", True, 0)] + + result = dbfs_files.list_dir("dbfs:/test") + assert result == [ + DbfsFileInfo("dbfs:/test/path_a", "path_a", True, 0), + DbfsFileInfo("dbfs:/test/path_b", "path_b", False, 0), + ] + + result = dbfs_files.list_dir("dbfs:/test/path_a") + assert result == [ + DbfsFileInfo("dbfs:/test/path_a/file_a", "file_a", False, 0), + DbfsFileInfo("dbfs:/test/path_a/file_b", "file_b", False, 0), + ] + + # ensure identity is passed back if there are no children + result = dbfs_files.list_dir("dbfs:/test/path_b") + assert result == [DbfsFileInfo("dbfs:/test/path_b", "path_b", False, 0)] + + +def test_mock_hadoop_fs_nonexistent_path(mock_hadoop_fs): + with pytest.raises(FileNotFoundError, match="'dbfs:/nonexistent' not found"): + mock_hadoop_fs.list_dir("dbfs:/nonexistent") + + +def test_mock_hadoop_fs_invalid_path(mock_hadoop_fs): + with pytest.raises(mock_hadoop_fs.IllegalArgumentException, match="Can not create a Path from an empty string"): + mock_hadoop_fs.list_dir("") + + +def test_list_dir(dbfs_files, mock_hadoop_fs): + mock_hadoop_fs.put("dbfs:/test/path_a") + mock_hadoop_fs.put("dbfs:/test/path_a/file_a") + mock_hadoop_fs.put("dbfs:/test/path_b") + + dbfs_files._fs.listStatus.side_effect = mock_hadoop_fs.list_dir # pylint: disable=protected-access + + result = dbfs_files.list_dir("dbfs:/test") + assert result == [ + DbfsFileInfo("dbfs:/test/path_a", "path_a", True, 0), + DbfsFileInfo("dbfs:/test/path_b", "path_b", False, 0), + ] + result = dbfs_files.list_dir("dbfs:/test/path_a") + assert result == [DbfsFileInfo("dbfs:/test/path_a/file_a", "file_a", False, 0)] + bad_path = "dbfs:/test/path_c" + with pytest.raises(FileNotFoundError, match=f"'{bad_path}' not found"): + dbfs_files.list_dir(bad_path) + bad_path = "dbfs:/test/path_c" + with pytest.raises(FileNotFoundError, match=f"'{bad_path}' not found"): + dbfs_files.list_dir(bad_path) + invalid_path = "/dbfs/test/path_a" + with pytest.raises( + DbfsFiles.InvalidPathFormatError, + match=f"Input path should begin with 'dbfs:/' prefix. Input path: '{invalid_path}'", + ): + dbfs_files.list_dir(invalid_path) + + +def test_file_status_to_dbfs_file_info(mocker): + # Create a mock status to simulate a file's metadata + mock_status: mocker.Mock = mocker.Mock() + mock_status.getPath.return_value.toString.return_value = "/test/path" + mock_status.getPath.return_value.getName.return_value = "test" + mock_status.isDir.return_value = False + mock_status.getModificationTime.return_value = 1234567890 + + # Convert this mock status to DbfsFileInfo using the method + result = DbfsFiles._file_status_to_dbfs_file_info(mock_status) # pylint: disable=protected-access + + # Assert that the DbfsFileInfo object has the expected values + assert result.path == "/test/path" + assert result.name == "test" + assert not result.is_dir + assert result.modification_time == 1234567890 From 8c80c8e2ad597a20287ab00f52c9630a1cd01588 Mon Sep 17 00:00:00 2001 From: "dan.zafar" Date: Mon, 11 Nov 2024 12:34:21 -0700 Subject: [PATCH 2/2] remove py4j dep --- pyproject.toml | 3 +-- .../labs/ucx/hive_metastore/dbfs_files.py | 26 +++++++------------ tests/unit/hive_metastore/test_dbfs_files.py | 6 ++--- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 356bde131d..f7c2bf2214 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,7 @@ dependencies = ["databricks-sdk~=0.30", "databricks-labs-blueprint>=0.9.1,<0.10", "PyYAML>=6.0.0,<7.0.0", "sqlglot>=25.5.0,<25.30", - "astroid>=3.3.1", - "py4j==0.10.9.7"] + "astroid>=3.3.1"] [project.optional-dependencies] pylsp = [ diff --git a/src/databricks/labs/ucx/hive_metastore/dbfs_files.py b/src/databricks/labs/ucx/hive_metastore/dbfs_files.py index 64f6ad0f45..c4a0382292 100644 --- a/src/databricks/labs/ucx/hive_metastore/dbfs_files.py +++ b/src/databricks/labs/ucx/hive_metastore/dbfs_files.py @@ -28,26 +28,18 @@ def __init__(self, jvm_interface=None): logger.error(f"Unable to get SparkSession: {err}") raise err - # if a test-related jvm_interface is passed in, we don't use py4j's java_import - self._java_import = self._noop if jvm_interface else self._default_java_import - - @staticmethod - def _noop(*args, **kwargs): - pass - - @staticmethod - def _default_java_import(jvm, import_path: str): - # pylint: disable=import-outside-toplevel - from py4j.java_gateway import java_import # type: ignore[import] - - java_import(jvm, import_path) + # if a test-related jvm_interface is passed in, we don't use py4j's modules + if jvm_interface: + self.jvm_filesystem = jvm_interface.jvm.jvm_filesystem + self.jvm_path = jvm_interface.jvm.jvm_path + else: + self.jvm_filesystem = self._jvm.org.apache.hadoop.fs.FileSystem + self.jvm_path = self._jvm.org.apache.hadoop.fs.Path @cached_property def _jvm(self): try: _jvm = self._spark._jvm - self._java_import(_jvm, "org.apache.hadoop.fs.FileSystem") - self._java_import(_jvm, "org.apache.hadoop.fs.Path") return _jvm except Exception as err: logger.error(f"Cannot create Py4j proxy: {err}") @@ -57,7 +49,7 @@ def _jvm(self): def _fs(self): try: _jsc = self._spark._jsc # pylint: disable=protected-access - return self._jvm.FileSystem.get(_jsc.hadoopConfiguration()) + return self.jvm_filesystem.get(_jsc.hadoopConfiguration()) except Exception as err: logger.error(f"Cannot create Py4j file system proxy: {err}") raise err @@ -74,7 +66,7 @@ def list_dir(self, path: str) -> list[DbfsFileInfo]: return self._list_dir(path) def _list_dir(self, path_str: str) -> list[DbfsFileInfo]: - path = self._jvm.Path(path_str) + path = self.jvm_path(path_str) statuses = self._fs.listStatus(path) return [self._file_status_to_dbfs_file_info(status) for status in statuses] diff --git a/tests/unit/hive_metastore/test_dbfs_files.py b/tests/unit/hive_metastore/test_dbfs_files.py index c80630b755..5300e5f900 100644 --- a/tests/unit/hive_metastore/test_dbfs_files.py +++ b/tests/unit/hive_metastore/test_dbfs_files.py @@ -11,10 +11,10 @@ def dbfs_files(mocker): mock_jsc = mocker.Mock() mock_jvm = mocker.Mock() mock_filesystem = mocker.Mock() - mock_jvm.FileSystem.get.return_value = mock_filesystem - mock_jvm.Path.side_effect = lambda x: x + mock_jvm.jvm_filesystem.get.return_value = mock_filesystem + mock_jvm.jvm_path.side_effect = lambda x: x mock_spark._jsc = mock_jsc # pylint: disable=protected-access - mock_spark._jvm = mock_jvm # pylint: disable=protected-access + mock_spark.jvm = mock_jvm _dbfs_files = DbfsFiles(jvm_interface=mock_spark)