File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ def fast_collate(batch):
33
33
if isinstance (batch [0 ][0 ], tuple ):
34
34
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
35
35
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36
+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
36
37
inner_tuple_size = len (batch [0 ][0 ])
37
38
flattened_batch_size = batch_size * inner_tuple_size
38
39
targets = torch .zeros (flattened_batch_size , dtype = torch .int64 )
@@ -41,7 +42,10 @@ def fast_collate(batch):
41
42
assert len (batch [i ][0 ]) == inner_tuple_size # all input tensor tuples must be same length
42
43
for j in range (inner_tuple_size ):
43
44
targets [i + j * batch_size ] = batch [i ][1 ]
44
- tensor [i + j * batch_size ] += torch .from_numpy (batch [i ][0 ][j ])
45
+ if is_np :
46
+ tensor [i + j * batch_size ] += torch .from_numpy (batch [i ][0 ][j ])
47
+ else :
48
+ tensor [i + j * batch_size ] += batch [i ][0 ][j ]
45
49
return tensor , targets
46
50
elif isinstance (batch [0 ][0 ], np .ndarray ):
47
51
targets = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
You can’t perform that action at this time.
0 commit comments