Skip to content

Commit f30b3dd

Browse files
reduce allocations in partitioning
1 parent f8cb89b commit f30b3dd

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/impl/conv_im2col.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ function conv_im2col!(
4545
N = channels_out(cdims)
4646
K = prod(kernel_size(cdims))*channels_in(cdims)
4747

48-
parts = collect(Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks)))
48+
parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))
4949

50-
@sync for task_n in eachindex(parts)
50+
@sync for (task_n, part) in enumerate(parts)
5151
Threads.@spawn begin
5252
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
53-
for batch_idx in parts[task_n]
53+
for batch_idx in part
5454
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
5555
GC.@preserve col_slice w y begin
5656
col_ptr = pointer(col_slice)
@@ -150,12 +150,12 @@ function ∇conv_data_im2col!(
150150
N = prod(kernel_size(cdims))*channels_in(cdims)
151151
K = channels_out(cdims)
152152

153-
parts = collect(Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks)))
153+
parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))
154154

155-
@sync for task_n in eachindex(parts)
155+
@sync for (task_n, part) in enumerate(parts)
156156
Threads.@spawn begin
157157
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
158-
for batch_idx in parts[task_n]
158+
for batch_idx in part
159159
GC.@preserve col_slice w dy begin
160160
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
161161
w_ptr = pointer(w)

src/impl/depthwiseconv_im2col.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ function depthwiseconv_im2col!(
2626
N = channel_multiplier(cdims)
2727
K = prod(kernel_size(cdims))
2828

29-
parts = collect(Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks)))
29+
parts = Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks))
3030

3131
dcdims = DenseConvDims(cdims)
3232

33-
@sync for task_n in eachindex(parts)
33+
@sync for (task_n, part) in enumerate(parts)
3434
Threads.@spawn begin
3535
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
36-
for batch_idx in parts[task_n]
36+
for batch_idx in part
3737
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
3838

3939
# We do a separate convolution for each channel in x, as we must
@@ -115,12 +115,12 @@ function ∇depthwiseconv_data_im2col!(
115115
N = prod(kernel_size(cdims))
116116
K = channel_multiplier(cdims)
117117

118-
parts = collect(Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks)))
118+
parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))
119119

120-
@sync for task_n in eachindex(parts)
120+
@sync for (task_n, part) in enumerate(parts)
121121
Threads.@spawn begin
122122
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
123-
for batch_idx in parts[task_n]
123+
for batch_idx in part
124124
# We do a separate convolution for each channel in x, as we must
125125
for cidx in 1:channels_in(cdims)
126126
GC.@preserve col_slice w dy begin

0 commit comments

Comments
 (0)