Skip to content

Commit a5e551b

Browse files
authored
Merge pull request #2466 from huggingface/naflex
Initial NaFlex ViT model and training support
2 parents a22366e + a0b5bcc commit a5e551b

20 files changed

+4876
-265
lines changed

tests/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@
5656
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
5757
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
5858
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59-
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet',
59+
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit'
6060
]
6161

6262
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
6363
NON_STD_FILTERS = [
64-
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
64+
'vit_*', 'naflexvit*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6565
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*',
6666
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6767
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
@@ -81,7 +81,7 @@
8181
EXCLUDE_FILTERS = ['*enormous*']
8282
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']
8383

84-
EXCLUDE_JIT_FILTERS = ['hiera_*']
84+
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']
8585

8686
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
8787
TARGET_BWD_SIZE = 128

timm/data/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11+
from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size
12+
from .naflex_loader import create_naflex_loader
13+
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
14+
from .naflex_transforms import (
15+
ResizeToSequence,
16+
CenterCropToSequence,
17+
RandomCropToSequence,
18+
RandomResizedCropToSequence,
19+
ResizeKeepRatioToSequence,
20+
Patchify,
21+
patchify_image,
22+
)
1123
from .readers import create_reader
1224
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
1325
from .real_labels import RealLabelsImagenet

timm/data/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def fast_collate(batch):
3333
if isinstance(batch[0][0], tuple):
3434
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
3535
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36+
is_np = isinstance(batch[0][0], np.ndarray)
3637
inner_tuple_size = len(batch[0][0])
3738
flattened_batch_size = batch_size * inner_tuple_size
3839
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
@@ -41,7 +42,10 @@ def fast_collate(batch):
4142
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
4243
for j in range(inner_tuple_size):
4344
targets[i + j * batch_size] = batch[i][1]
44-
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45+
if is_np:
46+
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
47+
else:
48+
tensor[i + j * batch_size] += batch[i][0][j]
4549
return tensor, targets
4650
elif isinstance(batch[0][0], np.ndarray):
4751
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)

timm/data/mixup.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
229229
num_elem = batch_size // 2 if half else batch_size
230230
assert len(output) == num_elem
231231
lam_batch, use_cutmix = self._params_per_elem(num_elem)
232+
is_np = isinstance(batch[0][0], np.ndarray)
233+
232234
for i in range(num_elem):
233235
j = batch_size - i - 1
234236
lam = lam_batch[i]
235237
mixed = batch[i][0]
236238
if lam != 1.:
237239
if use_cutmix[i]:
238240
if not half:
239-
mixed = mixed.copy()
241+
mixed = mixed.copy() if is_np else mixed.clone()
240242
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
243+
output.shape,
244+
lam,
245+
ratio_minmax=self.cutmix_minmax,
246+
correct_lam=self.correct_lam,
247+
)
242248
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243249
lam_batch[i] = lam
244250
else:
245-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246-
np.rint(mixed, out=mixed)
247-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
251+
if is_np:
252+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
253+
np.rint(mixed, out=mixed)
254+
else:
255+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
256+
torch.round(mixed, out=mixed)
257+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
248258
if half:
249259
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250260
return torch.tensor(lam_batch).unsqueeze(1)
251261

252262
def _mix_pair_collate(self, output, batch):
253263
batch_size = len(batch)
254264
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
265+
is_np = isinstance(batch[0][0], np.ndarray)
266+
255267
for i in range(batch_size // 2):
256268
j = batch_size - i - 1
257269
lam = lam_batch[i]
@@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
261273
if lam < 1.:
262274
if use_cutmix[i]:
263275
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265-
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
276+
output.shape,
277+
lam,
278+
ratio_minmax=self.cutmix_minmax,
279+
correct_lam=self.correct_lam,
280+
)
281+
patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone()
266282
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267283
mixed_j[:, yl:yh, xl:xh] = patch_i
268284
lam_batch[i] = lam
269285
else:
270-
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271-
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272-
mixed_i = mixed_temp
273-
np.rint(mixed_j, out=mixed_j)
274-
np.rint(mixed_i, out=mixed_i)
275-
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276-
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
286+
if is_np:
287+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
288+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
289+
mixed_i = mixed_temp
290+
np.rint(mixed_j, out=mixed_j)
291+
np.rint(mixed_i, out=mixed_i)
292+
else:
293+
mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam)
294+
mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam)
295+
mixed_i = mixed_temp
296+
torch.round(mixed_j, out=mixed_j)
297+
torch.round(mixed_i, out=mixed_i)
298+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte()
299+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte()
277300
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278301
return torch.tensor(lam_batch).unsqueeze(1)
279302

280303
def _mix_batch_collate(self, output, batch):
281304
batch_size = len(batch)
282305
lam, use_cutmix = self._params_per_batch()
306+
is_np = isinstance(batch[0][0], np.ndarray)
307+
283308
if use_cutmix:
284309
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
310+
output.shape,
311+
lam,
312+
ratio_minmax=self.cutmix_minmax,
313+
correct_lam=self.correct_lam,
314+
)
286315
for i in range(batch_size):
287316
j = batch_size - i - 1
288317
mixed = batch[i][0]
289318
if lam != 1.:
290319
if use_cutmix:
291-
mixed = mixed.copy() # don't want to modify the original while iterating
320+
mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating
292321
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
293322
else:
294-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295-
np.rint(mixed, out=mixed)
296-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
323+
if is_np:
324+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
325+
np.rint(mixed, out=mixed)
326+
else:
327+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
328+
torch.round(mixed, out=mixed)
329+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
297330
return lam
298331

299332
def __call__(self, batch, _=None):

0 commit comments

Comments
 (0)