Skip to content

Commit 7a92caa

Browse files
committed
Add basic image folder style dataset to read directly out of tar files, example in validate.py
1 parent d6ac5bb commit 7a92caa

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .constants import *
22
from .config import resolve_data_config
3-
from .dataset import Dataset
3+
from .dataset import Dataset, DatasetTar
44
from .transforms import *
55
from .loader import create_loader
66
from .mixup import mixup_target, FastCollateMixup

timm/data/dataset.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import re
99
import torch
10+
import tarfile
1011
from PIL import Image
1112

1213

@@ -89,3 +90,53 @@ def filenames(self, indices=[], basename=False):
8990
return [os.path.basename(x[0]) for x in self.imgs]
9091
else:
9192
return [x[0] for x in self.imgs]
93+
94+
95+
def _extract_tar_info(tarfile):
96+
class_to_idx = {}
97+
files = []
98+
labels = []
99+
for ti in tarfile.getmembers():
100+
if not ti.isfile():
101+
continue
102+
dirname, basename = os.path.split(ti.path)
103+
label = os.path.basename(dirname)
104+
class_to_idx[label] = None
105+
ext = os.path.splitext(basename)[1]
106+
if ext.lower() in IMG_EXTENSIONS:
107+
files.append(ti)
108+
labels.append(label)
109+
for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)):
110+
class_to_idx[c] = idx
111+
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
112+
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
113+
return tarinfo_and_targets
114+
115+
116+
class DatasetTar(data.Dataset):
117+
118+
def __init__(self, root, load_bytes=False, transform=None):
119+
120+
assert os.path.isfile(root)
121+
self.root = root
122+
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
123+
self.imgs = _extract_tar_info(tf)
124+
self.tarfile = None # lazy init in __getitem__
125+
self.load_bytes = load_bytes
126+
self.transform = transform
127+
128+
def __getitem__(self, index):
129+
if self.tarfile is None:
130+
self.tarfile = tarfile.open(self.root)
131+
tarinfo, target = self.imgs[index]
132+
iob = self.tarfile.extractfile(tarinfo)
133+
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
134+
if self.transform is not None:
135+
img = self.transform(img)
136+
if target is None:
137+
target = torch.zeros(1).long()
138+
return img, target
139+
140+
def __len__(self):
141+
return len(self.imgs)
142+

validate.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import OrderedDict
1515

1616
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
17-
from timm.data import Dataset, create_loader, resolve_data_config
17+
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
1818
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
1919

2020
torch.backends.cudnn.benchmark = True
@@ -24,7 +24,7 @@
2424
help='path to dataset')
2525
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
2626
help='model architecture (default: dpn92)')
27-
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
27+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
2828
help='number of data loading workers (default: 2)')
2929
parser.add_argument('-b', '--batch-size', default=256, type=int,
3030
metavar='N', help='mini-batch size (default: 256)')
@@ -91,9 +91,14 @@ def validate(args):
9191

9292
criterion = nn.CrossEntropyLoss().cuda()
9393

94+
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
95+
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
96+
else:
97+
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
98+
9499
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
95100
loader = create_loader(
96-
Dataset(args.data, load_bytes=args.tf_preprocessing),
101+
dataset,
97102
input_size=data_config['input_size'],
98103
batch_size=args.batch_size,
99104
use_prefetcher=args.prefetcher,

0 commit comments

Comments
 (0)