Skip to content

Commit 33a55a3

Browse files
committed
restore mapreduce
1 parent 2809636 commit 33a55a3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/layers/basic.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,11 @@ end
258258
@functor Maxout
259259

260260
function (mo::Maxout)(input::AbstractArray)
261-
outs = map(lay -> lay(input), mo.over)
262-
return max.(outs...)
261+
# outs = map(lay -> lay(input), mo.over)
262+
# return max.(outs...)
263+
# Perhaps surprisingly, pairwise max broadcast is often faster,
264+
# even with Zygote. See #698 and #1794
265+
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
263266
end
264267

265268
trainable(mo::Maxout) = mo.over

0 commit comments

Comments
 (0)