Skip to content

Commit d3623b9

Browse files
committed
Faster reductions where dim is specified as a number instead of Val
1 parent c7b81d6 commit d3623b9

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/mapreduce.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,23 @@ end
103103
end
104104
end
105105

106-
@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
107-
_mapreduce(f, op, Val(D), nt, sz, a)
106+
@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
107+
# Body of this function is split because constant propagation (at least
108+
# as of Julia 1.2) can't always correctly propagate here and
109+
# as a result the function is not type stable and very slow.
110+
# This makes it at least fast for three dimensions but people should use
111+
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
112+
if D == 1
113+
return _mapreduce(f, op, Val(1), nt, sz, a)
114+
elseif D == 2
115+
return _mapreduce(f, op, Val(2), nt, sz, a)
116+
elseif D == 3
117+
return _mapreduce(f, op, Val(3), nt, sz, a)
118+
else
119+
return _mapreduce(f, op, Val(D), nt, sz, a)
120+
end
121+
end
108122

109-
110123
@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
111124
::Size{S}, a::StaticArray) where {S,D}
112125
N = length(S)

0 commit comments

Comments
 (0)