Skip to content

Commit a5fbf95

Browse files
authored
Merge pull request #497 from IanButterworth/ib/task_local
move from bad thread-local to task-local
2 parents a3cdee6 + f30b3dd commit a5fbf95

File tree

2 files changed

+71
-50
lines changed

2 files changed

+71
-50
lines changed

src/impl/conv_im2col.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ function conv_im2col!(
2424
y::AbstractArray{T,5}, x::AbstractArray{T,5},
2525
w::AbstractArray{T,5}, cdims::DenseConvDims;
2626
col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)),
27-
alpha::T=T(1), beta::T=T(0)) where {T}
27+
alpha::T=T(1), beta::T=T(0),
28+
ntasks::Int=nthreads()) where {T}
2829
check_dims(size(x), size(w), size(y), cdims)
2930

3031
# COL * W -> Y
@@ -44,16 +45,20 @@ function conv_im2col!(
4445
N = channels_out(cdims)
4546
K = prod(kernel_size(cdims))*channels_in(cdims)
4647

47-
@threads for batch_idx in 1:size(x,5)
48-
# col_slice is a thread-local workspace
49-
col_slice = view(col, :, :, threadid())
50-
51-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
52-
GC.@preserve col_slice w y begin
53-
col_ptr = pointer(col_slice)
54-
w_ptr = pointer(w)
55-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
56-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
48+
parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))
49+
50+
@sync for (task_n, part) in enumerate(parts)
51+
Threads.@spawn begin
52+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
53+
for batch_idx in part
54+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
55+
GC.@preserve col_slice w y begin
56+
col_ptr = pointer(col_slice)
57+
w_ptr = pointer(w)
58+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
59+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
60+
end
61+
end
5762
end
5863
end
5964
return y
@@ -122,7 +127,8 @@ function ∇conv_data_im2col!(
122127
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
123128
w::AbstractArray{T,5}, cdims::DenseConvDims;
124129
col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),
125-
alpha::T=T(1), beta::T=T(0)) where {T}
130+
alpha::T=T(1), beta::T=T(0),
131+
ntasks::Int=nthreads()) where {T}
126132
check_dims(size(dx), size(w), size(dy), cdims)
127133

128134
# dY W' -> dX
@@ -144,17 +150,21 @@ function ∇conv_data_im2col!(
144150
N = prod(kernel_size(cdims))*channels_in(cdims)
145151
K = channels_out(cdims)
146152

147-
@threads for batch_idx in 1:size(dx, 5)
148-
# col_slice is a thread-local workspace
149-
col_slice = view(col, :, :, threadid())
150-
151-
GC.@preserve col_slice w dy begin
152-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
153-
w_ptr = pointer(w)
154-
col_ptr = pointer(col_slice)
155-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
153+
parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))
154+
155+
@sync for (task_n, part) in enumerate(parts)
156+
Threads.@spawn begin
157+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
158+
for batch_idx in part
159+
GC.@preserve col_slice w dy begin
160+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
161+
w_ptr = pointer(w)
162+
col_ptr = pointer(col_slice)
163+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
164+
end
165+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
166+
end
156167
end
157-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
158168
end
159169
return dx
160170
end

src/impl/depthwiseconv_im2col.jl

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ function depthwiseconv_im2col!(
1313
y::AbstractArray{T,5}, x::AbstractArray{T,5},
1414
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
1515
col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)),
16-
alpha::T=T(1), beta::T=T(0)) where T
16+
alpha::T=T(1), beta::T=T(0),
17+
ntasks::Int=nthreads()) where T
1718
check_dims(size(x), size(w), size(y), cdims)
1819

1920
# This functions exactly the same as conv_im2col!(), except that we shard the
@@ -25,21 +26,26 @@ function depthwiseconv_im2col!(
2526
N = channel_multiplier(cdims)
2627
K = prod(kernel_size(cdims))
2728

28-
dcdims = DenseConvDims(cdims)
29-
@threads for batch_idx in 1:size(x)[end]
30-
# col_slice is a thread-local workspace
31-
col_slice = view(col, :, :, threadid())
29+
parts = Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks))
3230

33-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
31+
dcdims = DenseConvDims(cdims)
3432

35-
# We do a separate convolution for each channel in x, as we must
36-
for c_in in 1:channels_in(cdims)
37-
# Walk each pointer forward as we process each input channel
38-
GC.@preserve col_slice w y begin
39-
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
40-
w_ptr = pointer(w, (c_in-1)*K*N+1)
41-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
42-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
33+
@sync for (task_n, part) in enumerate(parts)
34+
Threads.@spawn begin
35+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
36+
for batch_idx in part
37+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
38+
39+
# We do a separate convolution for each channel in x, as we must
40+
for c_in in 1:channels_in(cdims)
41+
# Walk each pointer forward as we process each input channel
42+
GC.@preserve col_slice w y begin
43+
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
44+
w_ptr = pointer(w, (c_in-1)*K*N+1)
45+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
46+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
47+
end
48+
end
4349
end
4450
end
4551
end
@@ -101,28 +107,33 @@ function ∇depthwiseconv_data_im2col!(
101107
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
102108
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
103109
col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),
104-
alpha::T=T(1), beta::T=T(0)) where T
110+
alpha::T=T(1), beta::T=T(0),
111+
ntasks::Int=nthreads()) where T
105112
check_dims(size(dx), size(w), size(dy), cdims)
106113

107114
M = prod(output_size(cdims))
108115
N = prod(kernel_size(cdims))
109116
K = channel_multiplier(cdims)
110117

111-
@threads for batch_idx in 1:size(dx)[end]
112-
# col_slice is a thread-local workspace
113-
col_slice = view(col, :, :, threadid())
114-
115-
# We do a separate convolution for each channel in x, as we must
116-
for cidx in 1:channels_in(cdims)
117-
GC.@preserve col_slice w dy begin
118-
# Walk each pointer forward as we process each input channel
119-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
120-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
121-
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
122-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
118+
parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))
119+
120+
@sync for (task_n, part) in enumerate(parts)
121+
Threads.@spawn begin
122+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
123+
for batch_idx in part
124+
# We do a separate convolution for each channel in x, as we must
125+
for cidx in 1:channels_in(cdims)
126+
GC.@preserve col_slice w dy begin
127+
# Walk each pointer forward as we process each input channel
128+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
129+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
130+
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
131+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
132+
end
133+
end
134+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
123135
end
124136
end
125-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
126137
end
127138
return dx
128139
end

0 commit comments

Comments
 (0)