Skip to content

Commit a0b5bcc

Browse files
committed
Fix another low use path where only numpy arrays are supported
1 parent 99a09eb commit a0b5bcc

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

timm/data/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def fast_collate(batch):
3333
if isinstance(batch[0][0], tuple):
3434
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
3535
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36+
is_np = isinstance(batch[0][0], np.ndarray)
3637
inner_tuple_size = len(batch[0][0])
3738
flattened_batch_size = batch_size * inner_tuple_size
3839
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
@@ -41,7 +42,10 @@ def fast_collate(batch):
4142
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
4243
for j in range(inner_tuple_size):
4344
targets[i + j * batch_size] = batch[i][1]
44-
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45+
if is_np:
46+
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
47+
else:
48+
tensor[i + j * batch_size] += batch[i][0][j]
4549
return tensor, targets
4650
elif isinstance(batch[0][0], np.ndarray):
4751
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)

0 commit comments

Comments
 (0)