@@ -103,10 +103,23 @@ end
103
103
end
104
104
end
105
105
106
- @inline _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S} =
107
- _mapreduce (f, op, Val (D), nt, sz, a)
106
+ @inline function _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
107
+ # Body of this function is split because constant propagation (at least
108
+ # as of Julia 1.2) can't always correctly propagate here and
109
+ # as a result the function is not type stable and very slow.
110
+ # This makes it at least fast for three dimensions but people should use
111
+ # for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
112
+ if D == 1
113
+ return _mapreduce (f, op, Val (1 ), nt, sz, a)
114
+ elseif D == 2
115
+ return _mapreduce (f, op, Val (2 ), nt, sz, a)
116
+ elseif D == 3
117
+ return _mapreduce (f, op, Val (3 ), nt, sz, a)
118
+ else
119
+ return _mapreduce (f, op, Val (D), nt, sz, a)
120
+ end
121
+ end
108
122
109
-
110
123
@generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{()} ,
111
124
:: Size{S} , a:: StaticArray ) where {S,D}
112
125
N = length (S)
0 commit comments