diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 14d484ba9f..4bdd67bd40 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -47,19 +47,26 @@ def __init__( self.target_transform = target_transform self._consecutive_errors = 0 - def __getitem__(self, index): - img, target = self.reader[index] - + def __getitem__(self, index, retry_count=0): + max_retries = min(_ERROR_RETRY * 2, len(self.reader)) # Don't retry more than dataset size try: - img = img.read() if self.load_bytes else Image.open(img) + img, target = self.reader[index] + try: + img = img.read() if self.load_bytes else Image.open(img) + self._consecutive_errors = 0 + except Exception as e: + _logger.warning(f'Failed to load sample (index {index}, file {self.reader.filename(index)}). {str(e)}') + self._consecutive_errors += 1 + if retry_count < max_retries and self._consecutive_errors < _ERROR_RETRY: + next_index = (index + 1) % len(self.reader) + return self.__getitem__(next_index, retry_count + 1) + else: + raise RuntimeError( + f'Failed to load any valid samples after {retry_count} attempts. ' + f'Last error: {str(e)}') except Exception as e: - _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}') - self._consecutive_errors += 1 - if self._consecutive_errors < _ERROR_RETRY: - return self.__getitem__((index + 1) % len(self.reader)) - else: - raise e - self._consecutive_errors = 0 + _logger.error(f'Error accessing dataset at index {index}: {str(e)}') + raise if self.input_img_mode and not self.load_bytes: img = img.convert(self.input_img_mode) diff --git a/timm/data/readers/reader_image_tar.py b/timm/data/readers/reader_image_tar.py index 6051f26dd1..1e5a3035a3 100644 --- a/timm/data/readers/reader_image_tar.py +++ b/timm/data/readers/reader_image_tar.py @@ -49,20 +49,53 @@ def __init__(self, root, class_map=''): class_to_idx = None if class_map: class_to_idx = load_class_map(class_map, root) - assert os.path.isfile(root) + assert os.path.isfile(root), f'Root file {root} not found' self.root = root + + # Initialize worker info attributes + self._worker_info = None + self._worker_id = 0 + self._num_workers = 1 - with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + # Extract tar info without keeping the file open + with tarfile.open(root) as tf: self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) self.imgs = self.samples self.tarfile = None # lazy init in __getitem__ + + def __del__(self): + # Clean up the tarfile when the reader is garbage collected + if hasattr(self, 'tarfile') and self.tarfile is not None: + try: + self.tarfile.close() + except Exception as e: + import warnings + warnings.warn(f'Error closing tarfile {self.root}: {str(e)}') def __getitem__(self, index): if self.tarfile is None: + # Only keep one tarfile open per worker process to avoid file descriptor leaks + if not hasattr(self, '_worker_info'): + import torch.utils.data + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + self._worker_info = worker_info + self._worker_id = worker_info.id + self._num_workers = worker_info.num_workers + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.samples[index] - fileobj = self.tarfile.extractfile(tarinfo) - return fileobj, target + try: + fileobj = self.tarfile.extractfile(tarinfo) + if fileobj is None: + raise RuntimeError(f'Failed to extract file {tarinfo.name} from tar {self.root}') + # Read the file content immediately and close the file object + content = fileobj.read() + fileobj.close() + return io.BytesIO(content), target + except Exception as e: + raise RuntimeError(f'Error reading {tarinfo.name} from {self.root}: {str(e)}') def __len__(self): return len(self.samples)