diff --git a/alluxio.py b/alluxio.py index 8579d44..a76ae35 100644 --- a/alluxio.py +++ b/alluxio.py @@ -7,6 +7,7 @@ from PIL import Image from requests.adapters import HTTPAdapter from torch.utils.data import Dataset +from concurrent.futures import ThreadPoolExecutor class AlluxioDataset(Dataset): @@ -75,6 +76,8 @@ def __init__( self.page_size = humanfriendly.parse_size(alluxio_page_size) self._logger = _logger self.session = self.create_session(concurrency) + self.executor = ThreadPoolExecutor(max_workers=concurrency) + def create_session(self, concurrency): session = requests.Session() @@ -84,25 +87,27 @@ def create_session(self, concurrency): session.mount("http://", adapter) return session - def read_whole_file(self, file_path): + def read_whole_file(self, file_path, page_number): file_id = self.get_file_id(file_path) worker_address = self.get_worker_address(file_id) page_index = 0 - def page_generator(): - nonlocal page_index - while True: - page_content = self.read_file( - worker_address, file_id, page_index - ) - if not page_content: - return - yield page_content - if len(page_content) < self.page_size: # last page - return - page_index += 1 - - content = b"".join(page_generator()) + def page_generator(page_index): + page_content = self.read_file(worker_address, file_id, page_index) + if not page_content: + return None + if len(page_content) < self.page_size: # last page + return page_content + return page_content + + # Use the executor to map the page_generator function to the data + pages = list(self.executor.map(page_generator, range(page_number))) + + # Remove None values from the list + pages = [page for page in pages if page is not None] + + content = b"".join(pages) + return content def read_file(self, worker_address, file_id, page_index): diff --git a/benchmark-large-datasets.py b/benchmark-large-datasets.py index 8c2de41..7c42e0c 100644 --- a/benchmark-large-datasets.py +++ b/benchmark-large-datasets.py @@ -55,6 +55,7 @@ def get_args(): default="1MB", ) + return parser.parse_args() @@ -89,7 +90,7 @@ def benchmark_data_loading(self): 1, # Only using one thread self._logger, ) - alluxio_rest.read_whole_file(self.alluxio_ufs_path) + alluxio_rest.read_whole_file(self.alluxio_ufs_path, 5094) else: self._logger.debug("Using alluxio FUSE/local dataset") self._logger.info(f"Loading dataset from {self.path}")