@@ -39,9 +39,12 @@ bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
39
39
40
40
41
41
function ChainRulesCore. rrule (cfg:: RCR , :: typeof (bias_act!), σ:: F , x:: AbstractArray{T,N} , b:: B ) where {F,T,N,B}
42
- if eltype (B) != = Bool
43
- b_dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
44
- size_b = size (b)
42
+ biasgrad = if eltype (B) != = Bool
43
+ # Summing over ndims(x)+1 is a trick to make b_dims type-stable
44
+ dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
45
+ _biasgrad (dx) = reshape (sum (dx; dims), size (b))
46
+ else
47
+ Returns (NoTangent ())
45
48
end
46
49
47
50
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
@@ -52,50 +55,195 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
52
55
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
53
56
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
54
57
dx = only_derivative .(Ω, σ, NotaNumber ()) .* unthunk (Δ)
55
- db = eltype (B) === Bool ? NoTangent () : reshape (sum (dx; dims = b_dims), size_b)
56
- return (NoTangent (), NoTangent (), dx, db)
58
+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
57
59
end
58
60
return Ω, bias_act!_fastback
59
61
60
- # # Slower path: can't overwrite x, but can use derivatives_given_output
61
- # # This case is WRONG and tests fail, but not sure why
62
- # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
63
- # Ω2 = fast_act(σ, x).(x) .+ b
64
- # @show σ b
65
- # function bias_act!_back2(Δ)
66
- # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
67
- # db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
68
- # return (NoTangent(), NoTangent(), dx, db)
69
- # end
70
- # return Ω2, bias_act!_back2
62
+ # Slower path: can't overwrite x, but can use derivatives_given_output
63
+ # This case is WRONG and tests fail, but not sure why
64
+ elseif isconcretetype (Core. Compiler. _return_type (only_derivative, Tuple{T, F, T}))
65
+ Ω2 = fast_act (σ, x).(x) .+ b
66
+ @show σ b
67
+ function bias_act!_back2 (Δ)
68
+ dx = only_derivative .(Ω2, σ, x .+ b) .* unthunk (Δ)
69
+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
70
+ end
71
+ return Ω2, bias_act!_back2
71
72
72
73
# Fallback path: let AD handle the broadcast
73
74
else
74
75
Ω3, back = rrule_via_ad (cfg, broadcast, fast_act (σ, x), bias_act! (identity, x, b))
75
76
@inline function bias_act!_slowback (Δ)
76
77
_, _, dx = back (Δ)
77
- db = eltype (B) === Bool ? NoTangent () : reshape (sum (dx; dims = b_dims), size_b)
78
- return (NoTangent (), NoTangent (), dx, db)
78
+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
79
79
end
80
80
return Ω3, bias_act!_slowback
81
81
end
82
82
end
83
83
84
- # Two easy cases
84
+ # Two easy cases with identity
85
85
function rrule (cfg:: RCR , :: typeof (bias_act!), :: typeof (identity), x:: AbstractArray{T,N} , b:: B ) where {T,N,B}
86
- b_dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
87
- size_b = size (b)
86
+ dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
87
+ biasgrad (dx) = reshape ( sum (dx; dims), size (b) )
88
88
function bias_act!_idback (Δ)
89
89
dx = unthunk (Δ)
90
- db = reshape (sum (dx; dims = b_dims), size_b)
91
- return (NoTangent (), NoTangent (), dx, db)
90
+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
92
91
end
93
92
return bias_act! (identity, x, b), bias_act!_idback
94
93
end
95
-
96
94
function rrule (cfg:: RCR , :: typeof (bias_act!), :: typeof (identity), x:: AbstractArray{T,N} , b:: Bool ) where {T,N}
97
95
bias_act!_trivial (Δ) = (NoTangent (), NoTangent (), Δ, NoTangent ())
98
96
return x, bias_act!_trivial
99
97
end
100
98
101
99
100
+
101
+ # """
102
+ # add_act(σ, x, y...)
103
+ # add_act!(σ, x, y, z...)
104
+
105
+ # Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106
+ # """
107
+ # add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108
+
109
+
110
+ # function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111
+ # if isconcretetype(Core.Compiler._return_type(
112
+ # derivatives_given_output, Tuple{T, F, NotaNumber}))
113
+
114
+ # end
115
+
116
+
117
+ # bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118
+ # # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119
+ # (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120
+
121
+
122
+ # using NNlib, BenchmarkTools
123
+
124
+ #=
125
+
126
+ ## M1 mac, 1.10
127
+
128
+ julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129
+
130
+ julia> @btime bias_act!(relu, $w, $b);
131
+ min 19.500 μs, mean 21.375 μs (0 allocations)
132
+
133
+ julia> @btime relu.($w .+ $b);
134
+ min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135
+
136
+ julia> @btime bias_act!(tanh, $w, $b);
137
+ min 63.792 μs, mean 65.052 μs (0 allocations)
138
+
139
+ julia> @btime tanh_fast.($w .+ $b);
140
+ min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141
+
142
+ julia> using Zygote
143
+
144
+ julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145
+ min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146
+
147
+ julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148
+ min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149
+
150
+ julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151
+ min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152
+
153
+ julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154
+ min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155
+
156
+
157
+
158
+ ## Cyclops
159
+
160
+ julia> using CUDA # 10x bigger
161
+
162
+ julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163
+
164
+ julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165
+ 22.546 μs (27 allocations: 1.45 KiB)
166
+
167
+ julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168
+ 31.282 μs (38 allocations: 1.81 KiB)
169
+
170
+ julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171
+ 27.030 μs (27 allocations: 1.45 KiB)
172
+
173
+ julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174
+ 36.421 μs (38 allocations: 1.81 KiB)
175
+
176
+ julia> using Zygote
177
+
178
+ julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179
+ 204.507 μs (382 allocations: 18.15 KiB)
180
+
181
+ julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182
+ 204.458 μs (409 allocations: 19.19 KiB)
183
+
184
+ julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185
+ 224.545 μs (382 allocations: 18.15 KiB)
186
+
187
+ julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188
+ 204.793 μs (411 allocations: 19.30 KiB)
189
+
190
+
191
+ =#
192
+
193
+ #=
194
+
195
+ (jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196
+
197
+ julia> using NNlib, Zygote, BenchmarkTools
198
+
199
+ julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200
+
201
+ julia> @btime bias_act!(relu, $w * $x, $b);
202
+ min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203
+
204
+ julia> @btime relu.($w * $x .+ $b);
205
+ min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206
+
207
+ julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208
+ min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209
+
210
+ julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211
+ min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212
+
213
+ julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214
+ min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215
+
216
+ julia> @btime gradient(x -> sum(abs2, x), $x);
217
+ min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218
+
219
+
220
+ # Cyclops
221
+
222
+ julia> @btime bias_act!(relu, $w * $x, $b);
223
+ 24.786 μs (2 allocations: 19.61 KiB)
224
+
225
+ julia> @btime relu.($w * $x .+ $b);
226
+ 25.501 μs (4 allocations: 39.22 KiB)
227
+
228
+ julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229
+ 91.847 μs (43 allocations: 89.83 KiB)
230
+
231
+ julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232
+ 98.054 μs (41 allocations: 128.91 KiB)
233
+
234
+ julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235
+ 80.464 μs (28 allocations: 69.41 KiB)
236
+
237
+ julia> @btime gradient(x -> sum(abs2, x), $x);
238
+ 4.604 μs (2 allocations: 19.61 KiB)
239
+
240
+ julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241
+
242
+ julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243
+
244
+
245
+
246
+ =#
247
+
248
+
249
+
0 commit comments