Skip to content

Fix infinite recursion risk in ImageDataset error handling #2508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,26 @@ def __init__(
self.target_transform = target_transform
self._consecutive_errors = 0

def __getitem__(self, index):
img, target = self.reader[index]

def __getitem__(self, index, retry_count=0):
max_retries = min(_ERROR_RETRY * 2, len(self.reader)) # Don't retry more than dataset size
try:
img = img.read() if self.load_bytes else Image.open(img)
img, target = self.reader[index]
try:
img = img.read() if self.load_bytes else Image.open(img)
self._consecutive_errors = 0
except Exception as e:
_logger.warning(f'Failed to load sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if retry_count < max_retries and self._consecutive_errors < _ERROR_RETRY:
next_index = (index + 1) % len(self.reader)
return self.__getitem__(next_index, retry_count + 1)
else:
raise RuntimeError(
f'Failed to load any valid samples after {retry_count} attempts. '
f'Last error: {str(e)}')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.reader))
else:
raise e
self._consecutive_errors = 0
_logger.error(f'Error accessing dataset at index {index}: {str(e)}')
raise

if self.input_img_mode and not self.load_bytes:
img = img.convert(self.input_img_mode)
Expand Down
41 changes: 37 additions & 4 deletions timm/data/readers/reader_image_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,53 @@ def __init__(self, root, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
assert os.path.isfile(root), f'Root file {root} not found'
self.root = root

# Initialize worker info attributes
self._worker_info = None
self._worker_id = 0
self._num_workers = 1

with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
# Extract tar info without keeping the file open
with tarfile.open(root) as tf:
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__

def __del__(self):
# Clean up the tarfile when the reader is garbage collected
if hasattr(self, 'tarfile') and self.tarfile is not None:
try:
self.tarfile.close()
except Exception as e:
import warnings
warnings.warn(f'Error closing tarfile {self.root}: {str(e)}')

def __getitem__(self, index):
if self.tarfile is None:
# Only keep one tarfile open per worker process to avoid file descriptor leaks
if not hasattr(self, '_worker_info'):
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self._worker_info = worker_info
self._worker_id = worker_info.id
self._num_workers = worker_info.num_workers

self.tarfile = tarfile.open(self.root)

tarinfo, target = self.samples[index]
fileobj = self.tarfile.extractfile(tarinfo)
return fileobj, target
try:
fileobj = self.tarfile.extractfile(tarinfo)
if fileobj is None:
raise RuntimeError(f'Failed to extract file {tarinfo.name} from tar {self.root}')
# Read the file content immediately and close the file object
content = fileobj.read()
fileobj.close()
return io.BytesIO(content), target
except Exception as e:
raise RuntimeError(f'Error reading {tarinfo.name} from {self.root}: {str(e)}')

def __len__(self):
return len(self.samples)
Expand Down