@@ -26,14 +26,14 @@ function depthwiseconv_im2col!(
26
26
N = channel_multiplier (cdims)
27
27
K = prod (kernel_size (cdims))
28
28
29
- parts = collect ( Iterators. partition (axes (y)[end ], ceil (Int, size (y, 5 ) / ntasks) ))
29
+ parts = Iterators. partition (axes (y)[end ], ceil (Int, size (y, 5 ) / ntasks))
30
30
31
31
dcdims = DenseConvDims (cdims)
32
32
33
- @sync for task_n in eachindex (parts)
33
+ @sync for ( task_n, part) in enumerate (parts)
34
34
Threads. @spawn begin
35
35
col_slice = col_slice = view (col, :, :, task_n) # col_slice is a task-local workspace
36
- for batch_idx in parts[task_n]
36
+ for batch_idx in part
37
37
im2col! (col_slice, view (x, :, :, :, :, batch_idx), dcdims)
38
38
39
39
# We do a separate convolution for each channel in x, as we must
@@ -115,12 +115,12 @@ function ∇depthwiseconv_data_im2col!(
115
115
N = prod (kernel_size (cdims))
116
116
K = channel_multiplier (cdims)
117
117
118
- parts = collect ( Iterators. partition (axes (dx)[end ], ceil (Int, size (dx, 5 ) / ntasks) ))
118
+ parts = Iterators. partition (axes (dx)[end ], ceil (Int, size (dx, 5 ) / ntasks))
119
119
120
- @sync for task_n in eachindex (parts)
120
+ @sync for ( task_n, part) in enumerate (parts)
121
121
Threads. @spawn begin
122
122
col_slice = col_slice = view (col, :, :, task_n) # col_slice is a task-local workspace
123
- for batch_idx in parts[task_n]
123
+ for batch_idx in part
124
124
# We do a separate convolution for each channel in x, as we must
125
125
for cidx in 1 : channels_in (cdims)
126
126
GC. @preserve col_slice w dy begin
0 commit comments