@@ -13,7 +13,8 @@ function depthwiseconv_im2col!(
13
13
y:: AbstractArray{T,5} , x:: AbstractArray{T,5} ,
14
14
w:: AbstractArray{T,5} , cdims:: DepthwiseConvDims ;
15
15
col:: AbstractArray{T,3} = similar (x, im2col_dims (cdims)),
16
- alpha:: T = T (1 ), beta:: T = T (0 )) where T
16
+ alpha:: T = T (1 ), beta:: T = T (0 ),
17
+ ntasks:: Int = nthreads ()) where T
17
18
check_dims (size (x), size (w), size (y), cdims)
18
19
19
20
# This functions exactly the same as conv_im2col!(), except that we shard the
@@ -25,21 +26,26 @@ function depthwiseconv_im2col!(
25
26
N = channel_multiplier (cdims)
26
27
K = prod (kernel_size (cdims))
27
28
28
- dcdims = DenseConvDims (cdims)
29
- @threads for batch_idx in 1 : size (x)[end ]
30
- # col_slice is a thread-local workspace
31
- col_slice = view (col, :, :, threadid ())
29
+ parts = Iterators. partition (axes (y)[end ], ceil (Int, size (y, 5 ) / ntasks))
32
30
33
- im2col! (col_slice, view (x, :, :, :, :, batch_idx), dcdims )
31
+ dcdims = DenseConvDims (cdims )
34
32
35
- # We do a separate convolution for each channel in x, as we must
36
- for c_in in 1 : channels_in (cdims)
37
- # Walk each pointer forward as we process each input channel
38
- GC. @preserve col_slice w y begin
39
- col_ptr = pointer (col_slice, (c_in- 1 )* M* K+ 1 )
40
- w_ptr = pointer (w, (c_in- 1 )* K* N+ 1 )
41
- y_ptr = pointer (y, ((batch_idx - 1 )* channels_in (cdims) + c_in - 1 )* M* N + 1 )
42
- gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
33
+ @sync for (task_n, part) in enumerate (parts)
34
+ Threads. @spawn begin
35
+ col_slice = col_slice = view (col, :, :, task_n) # col_slice is a task-local workspace
36
+ for batch_idx in part
37
+ im2col! (col_slice, view (x, :, :, :, :, batch_idx), dcdims)
38
+
39
+ # We do a separate convolution for each channel in x, as we must
40
+ for c_in in 1 : channels_in (cdims)
41
+ # Walk each pointer forward as we process each input channel
42
+ GC. @preserve col_slice w y begin
43
+ col_ptr = pointer (col_slice, (c_in- 1 )* M* K+ 1 )
44
+ w_ptr = pointer (w, (c_in- 1 )* K* N+ 1 )
45
+ y_ptr = pointer (y, ((batch_idx - 1 )* channels_in (cdims) + c_in - 1 )* M* N + 1 )
46
+ gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
47
+ end
48
+ end
43
49
end
44
50
end
45
51
end
@@ -101,28 +107,33 @@ function ∇depthwiseconv_data_im2col!(
101
107
dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} ,
102
108
w:: AbstractArray{T,5} , cdims:: DepthwiseConvDims ;
103
109
col:: AbstractArray{T,3} = similar (dx, im2col_dims (cdims)),
104
- alpha:: T = T (1 ), beta:: T = T (0 )) where T
110
+ alpha:: T = T (1 ), beta:: T = T (0 ),
111
+ ntasks:: Int = nthreads ()) where T
105
112
check_dims (size (dx), size (w), size (dy), cdims)
106
113
107
114
M = prod (output_size (cdims))
108
115
N = prod (kernel_size (cdims))
109
116
K = channel_multiplier (cdims)
110
117
111
- @threads for batch_idx in 1 : size (dx)[end ]
112
- # col_slice is a thread-local workspace
113
- col_slice = view (col, :, :, threadid ())
114
-
115
- # We do a separate convolution for each channel in x, as we must
116
- for cidx in 1 : channels_in (cdims)
117
- GC. @preserve col_slice w dy begin
118
- # Walk each pointer forward as we process each input channel
119
- dy_ptr = pointer (dy, (batch_idx - 1 )* M* K* channels_in (cdims)+ (cidx - 1 )* K* M + 1 )
120
- w_ptr = pointer (w, (cidx - 1 )* K* N + 1 )
121
- col_ptr = pointer (col_slice, (cidx - 1 )* M* N + 1 )
122
- gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
118
+ parts = Iterators. partition (axes (dx)[end ], ceil (Int, size (dx, 5 ) / ntasks))
119
+
120
+ @sync for (task_n, part) in enumerate (parts)
121
+ Threads. @spawn begin
122
+ col_slice = col_slice = view (col, :, :, task_n) # col_slice is a task-local workspace
123
+ for batch_idx in part
124
+ # We do a separate convolution for each channel in x, as we must
125
+ for cidx in 1 : channels_in (cdims)
126
+ GC. @preserve col_slice w dy begin
127
+ # Walk each pointer forward as we process each input channel
128
+ dy_ptr = pointer (dy, (batch_idx - 1 )* M* K* channels_in (cdims)+ (cidx - 1 )* K* M + 1 )
129
+ w_ptr = pointer (w, (cidx - 1 )* K* N + 1 )
130
+ col_ptr = pointer (col_slice, (cidx - 1 )* M* N + 1 )
131
+ gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
132
+ end
133
+ end
134
+ col2im! (view (dx, :, :, :, :, batch_idx), col_slice, cdims)
123
135
end
124
136
end
125
- col2im! (view (dx, :, :, :, :, batch_idx), col_slice, cdims)
126
137
end
127
138
return dx
128
139
end
0 commit comments