Skip to content

Commit ea23107

Browse files
authored
Merge pull request #2361 from huggingface/grodino-dataset_trust_remote
Dataset trust remote tweaks
2 parents 9eee47d + 7573096 commit ea23107

File tree

6 files changed

+40
-21
lines changed

6 files changed

+40
-21
lines changed

timm/data/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
transform=None,
104104
target_transform=None,
105105
max_steps=None,
106+
**kwargs,
106107
):
107108
assert reader is not None
108109
if isinstance(reader, str):
@@ -121,6 +122,7 @@ def __init__(
121122
input_key=input_key,
122123
target_key=target_key,
123124
max_steps=max_steps,
125+
**kwargs,
124126
)
125127
else:
126128
self.reader = reader

timm/data/dataset_factory.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,34 +74,37 @@ def create_dataset(
7474
seed: int = 42,
7575
repeats: int = 0,
7676
input_img_mode: str = 'RGB',
77+
trust_remote_code: bool = False,
7778
**kwargs,
7879
):
7980
""" Dataset factory method
8081
8182
In parentheses after each arg are the type of dataset supported for each arg, one of:
82-
* folder - default, timm folder (or tar) based ImageDataset
83-
* torch - torchvision based datasets
83+
* Folder - default, timm folder (or tar) based ImageDataset
84+
* Torch - torchvision based datasets
8485
* HFDS - Hugging Face Datasets
86+
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
8587
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
8688
* WDS - Webdataset
87-
* all - any of the above
89+
* All - any of the above
8890
8991
Args:
90-
name: dataset name, empty is okay for folder based datasets
91-
root: root folder of dataset (all)
92-
split: dataset split (all)
93-
search_split: search for split specific child fold from root so one can specify
94-
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
95-
class_map: specify class -> index mapping via text file or dict (folder)
96-
load_bytes: load data, return images as undecoded bytes (folder)
97-
download: download dataset if not present and supported (HFDS, TFDS, torch)
98-
is_training: create dataset in train mode, this is different from the split.
99-
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
100-
batch_size: batch size hint for (TFDS, WDS)
101-
seed: seed for iterable datasets (TFDS, WDS)
102-
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
103-
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
104-
**kwargs: other args to pass to dataset
92+
name: Dataset name, empty is okay for folder based datasets
93+
root: Root folder of dataset (All)
94+
split: Dataset split (All)
95+
search_split: Search for split specific child fold from root so one can specify
96+
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
97+
class_map: Specify class -> index mapping via text file or dict (Folder)
98+
load_bytes: Load data, return images as undecoded bytes (Folder)
99+
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
100+
is_training: Create dataset in train mode, this is different from the split.
101+
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
102+
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
103+
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
104+
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
105+
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
106+
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
107+
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
105108
106109
Returns:
107110
Dataset object
@@ -162,6 +165,7 @@ def create_dataset(
162165
split=split,
163166
class_map=class_map,
164167
input_img_mode=input_img_mode,
168+
trust_remote_code=trust_remote_code,
165169
**kwargs,
166170
)
167171
elif name.startswith('hfids/'):
@@ -177,7 +181,8 @@ def create_dataset(
177181
repeats=repeats,
178182
seed=seed,
179183
input_img_mode=input_img_mode,
180-
**kwargs
184+
trust_remote_code=trust_remote_code,
185+
**kwargs,
181186
)
182187
elif name.startswith('tfds/'):
183188
ds = IterableImageDataset(

timm/data/readers/reader_hfds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self.dataset = datasets.load_dataset(
4949
name, # 'name' maps to path arg in hf datasets
5050
split=split,
51-
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
51+
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
5252
trust_remote_code=trust_remote_code
5353
)
5454
# leave decode for caller, plus we want easy access to original path names...

timm/data/readers/reader_hfids.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
target_img_mode: str = '',
4545
shuffle_size: Optional[int] = None,
4646
num_samples: Optional[int] = None,
47+
trust_remote_code: bool = False
4748
):
4849
super().__init__()
4950
self.root = root
@@ -60,7 +61,11 @@ def __init__(
6061
self.target_key = target_key
6162
self.target_img_mode = target_img_mode
6263

63-
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
64+
self.builder = datasets.load_dataset_builder(
65+
name,
66+
cache_dir=root,
67+
trust_remote_code=trust_remote_code,
68+
)
6469
if download:
6570
self.builder.download_and_prepare()
6671

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@
103103
help='Dataset key for input images.')
104104
group.add_argument('--target-key', default=None, type=str,
105105
help='Dataset key for target labels.')
106+
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
107+
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
106108

107109
# Model parameters
108110
group = parser.add_argument_group('Model parameters')
@@ -653,6 +655,7 @@ def main():
653655
input_key=args.input_key,
654656
target_key=args.target_key,
655657
num_samples=args.train_num_samples,
658+
trust_remote_code=args.dataset_trust_remote_code,
656659
)
657660

658661
if args.val_split:
@@ -668,6 +671,7 @@ def main():
668671
input_key=args.input_key,
669672
target_key=args.target_key,
670673
num_samples=args.val_num_samples,
674+
trust_remote_code=args.dataset_trust_remote_code,
671675
)
672676

673677
# setup mixup / cutmix

validate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
help='Dataset image conversion mode for input images.')
6767
parser.add_argument('--target-key', default=None, type=str,
6868
help='Dataset key for target labels.')
69+
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
70+
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
6971

7072
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
7173
help='model architecture (default: dpn92)')
@@ -268,6 +270,7 @@ def validate(args):
268270
input_key=args.input_key,
269271
input_img_mode=input_img_mode,
270272
target_key=args.target_key,
273+
trust_remote_code=args.dataset_trust_remote_code,
271274
)
272275

273276
if args.valid_labels:

0 commit comments

Comments
 (0)