Skip to content

Commit 2364f58

Browse files
move from bad thread-local to task-local
1 parent e5cff84 commit 2364f58

File tree

2 files changed

+69
-49
lines changed

2 files changed

+69
-49
lines changed

src/impl/conv_im2col.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ function conv_im2col!(
2424
y::AbstractArray{T,5}, x::AbstractArray{T,5},
2525
w::AbstractArray{T,5}, cdims::DenseConvDims;
2626
col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)),
27-
alpha::T=T(1), beta::T=T(0)) where {T}
27+
alpha::T=T(1), beta::T=T(0),
28+
ntasks::Int=nthreads()) where {T}
2829
check_dims(size(x), size(w), size(y), cdims)
2930

3031
# COL * W -> Y
@@ -44,16 +45,20 @@ function conv_im2col!(
4445
N = channels_out(cdims)
4546
K = prod(kernel_size(cdims))*channels_in(cdims)
4647

47-
@threads for batch_idx in 1:size(x,5)
48-
# col_slice is a thread-local workspace
49-
col_slice = view(col, :, :, threadid())
50-
51-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
52-
GC.@preserve col_slice w y begin
53-
col_ptr = pointer(col_slice)
54-
w_ptr = pointer(w)
55-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
56-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
48+
parts = Iterators.partition(axes(x, 5), ntasks)
49+
50+
@sync for task_n in 1:ntasks
51+
Threads.@spawn begin
52+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
53+
for batch_idx in parts[task_n]
54+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
55+
GC.@preserve col_slice w y begin
56+
col_ptr = pointer(col_slice)
57+
w_ptr = pointer(w)
58+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
59+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
60+
end
61+
end
5762
end
5863
end
5964
return y
@@ -122,7 +127,8 @@ function ∇conv_data_im2col!(
122127
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
123128
w::AbstractArray{T,5}, cdims::DenseConvDims;
124129
col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),
125-
alpha::T=T(1), beta::T=T(0)) where {T}
130+
alpha::T=T(1), beta::T=T(0),
131+
ntasks::Int=nthreads()) where {T}
126132
check_dims(size(dx), size(w), size(dy), cdims)
127133

128134
# dY W' -> dX
@@ -144,17 +150,21 @@ function ∇conv_data_im2col!(
144150
N = prod(kernel_size(cdims))*channels_in(cdims)
145151
K = channels_out(cdims)
146152

147-
@threads for batch_idx in 1:size(dx, 5)
148-
# col_slice is a thread-local workspace
149-
col_slice = view(col, :, :, threadid())
150-
151-
GC.@preserve col_slice w dy begin
152-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
153-
w_ptr = pointer(w)
154-
col_ptr = pointer(col_slice)
155-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
153+
parts = Iterators.partition(axes(dx, 5), ntasks)
154+
155+
@sync for task_n in 1:ntasks
156+
Threads.@spawn begin
157+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
158+
for batch_idx in parts[task_n]
159+
GC.@preserve col_slice w dy begin
160+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
161+
w_ptr = pointer(w)
162+
col_ptr = pointer(col_slice)
163+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
164+
end
165+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
166+
end
156167
end
157-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
158168
end
159169
return dx
160170
end

src/impl/depthwiseconv_im2col.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ function depthwiseconv_im2col!(
1313
y::AbstractArray{T,5}, x::AbstractArray{T,5},
1414
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
1515
col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)),
16-
alpha::T=T(1), beta::T=T(0)) where T
16+
alpha::T=T(1), beta::T=T(0),
17+
ntasks::Int=nthreads()) where T
1718
check_dims(size(x), size(w), size(y), cdims)
1819

1920
# This functions exactly the same as conv_im2col!(), except that we shard the
@@ -25,21 +26,26 @@ function depthwiseconv_im2col!(
2526
N = channel_multiplier(cdims)
2627
K = prod(kernel_size(cdims))
2728

28-
dcdims = DenseConvDims(cdims)
29-
@threads for batch_idx in 1:size(x)[end]
30-
# col_slice is a thread-local workspace
31-
col_slice = view(col, :, :, threadid())
29+
parts = Iterators.partition(axes(y)[end], ntasks)
3230

33-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
31+
dcdims = DenseConvDims(cdims)
3432

35-
# We do a separate convolution for each channel in x, as we must
36-
for c_in in 1:channels_in(cdims)
37-
# Walk each pointer forward as we process each input channel
38-
GC.@preserve col_slice w y begin
39-
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
40-
w_ptr = pointer(w, (c_in-1)*K*N+1)
41-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
42-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
33+
@sync for task_n in 1:ntasks
34+
Threads.@spawn begin
35+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
36+
for batch_idx in parts[task_n]
37+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
38+
39+
# We do a separate convolution for each channel in x, as we must
40+
for c_in in 1:channels_in(cdims)
41+
# Walk each pointer forward as we process each input channel
42+
GC.@preserve col_slice w y begin
43+
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
44+
w_ptr = pointer(w, (c_in-1)*K*N+1)
45+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
46+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
47+
end
48+
end
4349
end
4450
end
4551
end
@@ -108,21 +114,25 @@ function ∇depthwiseconv_data_im2col!(
108114
N = prod(kernel_size(cdims))
109115
K = channel_multiplier(cdims)
110116

111-
@threads for batch_idx in 1:size(dx)[end]
112-
# col_slice is a thread-local workspace
113-
col_slice = view(col, :, :, threadid())
114-
115-
# We do a separate convolution for each channel in x, as we must
116-
for cidx in 1:channels_in(cdims)
117-
GC.@preserve col_slice w dy begin
118-
# Walk each pointer forward as we process each input channel
119-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
120-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
121-
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
122-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
117+
parts = Iterators.partition(axes(dx)[end], ntasks)
118+
119+
@sync for task_n in 1:ntasks
120+
Threads.@spawn begin
121+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
122+
for batch_idx in parts[task_n]
123+
# We do a separate convolution for each channel in x, as we must
124+
for cidx in 1:channels_in(cdims)
125+
GC.@preserve col_slice w dy begin
126+
# Walk each pointer forward as we process each input channel
127+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
128+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
129+
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
130+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
131+
end
132+
end
133+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
123134
end
124135
end
125-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
126136
end
127137
return dx
128138
end

0 commit comments

Comments
 (0)