@@ -8,28 +8,31 @@ const RCR = RuleConfig{>:HasReverseMode}
8
8
@inline only_derivative (y,f:: F ,x) where F = only (only (ChainRulesCore. derivatives_given_output (y, f, x)))
9
9
10
10
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
11
- # is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
11
+ # is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
12
12
struct NotaNumber <: Real end
13
13
14
14
"""
15
15
bias_act!(σ, x, b)
16
16
17
- This is equivalent to `σ.(x .+ b)`, but faster because
18
- it will overwrite `x` to save memory (when possible) and
19
- replace `sigmoid` & `tanh` with `sigmoid_fast` & `tanh_fast `.
17
+ This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`
18
+ with `sigmoid_fast` & `tanh_fast`.
19
+ It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat} `.
20
20
21
- The best case requires `x isa StridedArray{<:AbstractFloat}`,
22
- and that the activation has a method of `derivatives_given_output`
23
- which does not need the input at all (such as `relu`, `tanh`).
21
+ When used within a gradient, it will overwrite only when `σ` has
22
+ a method of `derivatives_given_output` which does not need the input at all.
23
+ Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative
24
+ contains only `Ω` (the output) not `x`.
24
25
25
26
!!! warning
26
27
This is not safe to use if `x` is still needed for the gradient
27
28
of some other function. Incorrect use will give silently wrong answers.
29
+ It is intended mainly for Flux layers, in which the previous operation is
30
+ known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
28
31
"""
29
32
bias_act! (σ:: Function , x:: AbstractArray , b) = fast_act (σ, x).(x .+ b) # fallback
30
33
31
34
bias_act! (σ:: Function , x:: StridedArray{<:AbstractFloat} , b:: AbstractArray{<:Union{Bool, AbstractFloat}} ) =
32
- fast_broadcast_plus ! (fast_act (σ, x), x, b) # hand-written version below.
35
+ _fast_broadcast ! (fast_act (σ, x)∘ ( + ) , x, b) # works around a SIMD bug
33
36
34
37
bias_act! (:: typeof (identity), x:: StridedArray{<:AbstractFloat} , b:: Bool ) =
35
38
(@assert ! b " bias=true is not accepted; layer constructors shoud guarantee this" ; x)
@@ -89,197 +92,10 @@ function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArr
89
92
end
90
93
return bias_act! (identity, x, b), bias_act!_idback
91
94
end
95
+
92
96
function rrule (cfg:: RCR , :: typeof (bias_act!), :: typeof (identity), x:: AbstractArray{T,N} , b:: Bool ) where {T,N}
93
97
bias_act!_trivial (Δ) = (NoTangent (), NoTangent (), Δ, NoTangent ())
94
98
return x, bias_act!_trivial
95
99
end
96
100
97
101
98
- """
99
- NNlib.fast_broadcast_plus!(f, x, b)
100
-
101
- This is equivalent to `x .= f.(x .+ b)`, but works around
102
- an issue with broadcasting that prevents SIMD in such cases.
103
-
104
- That can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
105
-
106
- Also has an `rrule` to prevent mutation inside 2nd-order AD.
107
-
108
- !!! warning
109
- Does not allow for derivatives with respect to `f`.
110
- """
111
- function fast_broadcast_plus! (f:: F , x:: Array{<:AbstractFloat} , b) where {F<: Function }
112
- if b === false
113
- @simd ivdep for I in eachindex (x)
114
- @inbounds x[I] = f (x[I])
115
- end
116
- else
117
- xplus = Broadcast. instantiate (Broadcast. broadcasted (+ , x, b))
118
- @simd ivdep for I in eachindex (xplus)
119
- @inbounds x[I] = f (xplus[I])
120
- end
121
- end
122
- return x
123
- end
124
- function fast_broadcast_plus! (f:: F , x:: StridedArray{<:AbstractFloat} , b) where {F<: Function }
125
- # CuArray has its own broadcasting.
126
- x .= f .(x .+ b)
127
- return x
128
- end
129
- function fast_broadcast_plus! (f:: F , x:: AbstractArray , b) where {F<: Function }
130
- # Don't try to write into weird arrays
131
- return f .(x .+ b)
132
- end
133
-
134
- function rrule (cfg:: RCR , :: typeof (fast_broadcast_plus!), f:: F , x:: AbstractArray{T,N} , b:: B ) where {F,T,N,B}
135
- rrule_via_ad (cfg, broadcast, (x,b) -> f .(x .+ b), x, b)
136
- end
137
-
138
-
139
- # """
140
- # add_act(σ, x, y...)
141
- # add_act!(σ, x, y, z...)
142
-
143
- # Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
144
- # """
145
- # add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
146
-
147
-
148
- # function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
149
- # if isconcretetype(Core.Compiler._return_type(
150
- # derivatives_given_output, Tuple{T, F, NotaNumber}))
151
-
152
- # end
153
-
154
-
155
- # bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
156
- # # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
157
- # (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
158
-
159
-
160
- # using NNlib, BenchmarkTools
161
-
162
- #=
163
-
164
- ## M1 mac, 1.10
165
-
166
- julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
167
-
168
- julia> @btime bias_act!(relu, $w, $b);
169
- min 19.500 μs, mean 21.375 μs (0 allocations)
170
-
171
- julia> @btime relu.($w .+ $b);
172
- min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
173
-
174
- julia> @btime bias_act!(tanh, $w, $b);
175
- min 63.792 μs, mean 65.052 μs (0 allocations)
176
-
177
- julia> @btime tanh_fast.($w .+ $b);
178
- min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
179
-
180
- julia> using Zygote
181
-
182
- julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
183
- min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
184
-
185
- julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
186
- min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
187
-
188
- julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
189
- min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
190
-
191
- julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
192
- min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
193
-
194
-
195
-
196
- ## Cyclops
197
-
198
- julia> using CUDA # 10x bigger
199
-
200
- julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
201
-
202
- julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
203
- 22.546 μs (27 allocations: 1.45 KiB)
204
-
205
- julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
206
- 31.282 μs (38 allocations: 1.81 KiB)
207
-
208
- julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
209
- 27.030 μs (27 allocations: 1.45 KiB)
210
-
211
- julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
212
- 36.421 μs (38 allocations: 1.81 KiB)
213
-
214
- julia> using Zygote
215
-
216
- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
217
- 204.507 μs (382 allocations: 18.15 KiB)
218
-
219
- julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
220
- 204.458 μs (409 allocations: 19.19 KiB)
221
-
222
- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
223
- 224.545 μs (382 allocations: 18.15 KiB)
224
-
225
- julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
226
- 204.793 μs (411 allocations: 19.30 KiB)
227
-
228
-
229
- =#
230
-
231
- #=
232
-
233
- (jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
234
-
235
- julia> using NNlib, Zygote, BenchmarkTools
236
-
237
- julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
238
-
239
- julia> @btime bias_act!(relu, $w * $x, $b);
240
- min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
241
-
242
- julia> @btime relu.($w * $x .+ $b);
243
- min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
244
-
245
- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
246
- min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
247
-
248
- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
249
- min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
250
-
251
- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
252
- min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
253
-
254
- julia> @btime gradient(x -> sum(abs2, x), $x);
255
- min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
256
-
257
-
258
- # Cyclops
259
-
260
- julia> @btime bias_act!(relu, $w * $x, $b);
261
- 24.786 μs (2 allocations: 19.61 KiB)
262
-
263
- julia> @btime relu.($w * $x .+ $b);
264
- 25.501 μs (4 allocations: 39.22 KiB)
265
-
266
- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
267
- 91.847 μs (43 allocations: 89.83 KiB)
268
-
269
- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
270
- 98.054 μs (41 allocations: 128.91 KiB)
271
-
272
- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
273
- 80.464 μs (28 allocations: 69.41 KiB)
274
-
275
- julia> @btime gradient(x -> sum(abs2, x), $x);
276
- 4.604 μs (2 allocations: 19.61 KiB)
277
-
278
- julia> @time using CUDA; @time cu(ones(3)) .+ 1;
279
-
280
- julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
281
-
282
-
283
-
284
- =#
285
-
0 commit comments