@@ -59,16 +59,16 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
59
59
end
60
60
return Ω, bias_act!_fastback
61
61
62
- # Slower path: can't overwrite x, but can use derivatives_given_output
63
- # This case is WRONG and tests fail, but not sure why
64
- elseif isconcretetype (Core. Compiler. _return_type (only_derivative, Tuple{T, F, T}))
65
- Ω2 = fast_act (σ, x).(x) .+ b
66
- @show σ b
67
- function bias_act!_back2 (Δ)
68
- dx = only_derivative .(Ω2, σ, x .+ b) .* unthunk (Δ)
69
- return (NoTangent (), NoTangent (), dx, biasgrad (dx))
70
- end
71
- return Ω2, bias_act!_back2
62
+ # # Slower path: can't overwrite x, but can use derivatives_given_output
63
+ # # This case is WRONG and tests fail, but not sure why
64
+ # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
65
+ # Ω2 = fast_act(σ, x).(x) .+ b
66
+ # @show σ b
67
+ # function bias_act!_back2(Δ)
68
+ # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
69
+ # return (NoTangent(), NoTangent(), dx, biasgrad(dx))
70
+ # end
71
+ # return Ω2, bias_act!_back2
72
72
73
73
# Fallback path: let AD handle the broadcast
74
74
else
@@ -96,154 +96,3 @@ function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArr
96
96
return x, bias_act!_trivial
97
97
end
98
98
99
-
100
-
101
- # """
102
- # add_act(σ, x, y...)
103
- # add_act!(σ, x, y, z...)
104
-
105
- # Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106
- # """
107
- # add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108
-
109
-
110
- # function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111
- # if isconcretetype(Core.Compiler._return_type(
112
- # derivatives_given_output, Tuple{T, F, NotaNumber}))
113
-
114
- # end
115
-
116
-
117
- # bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118
- # # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119
- # (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120
-
121
-
122
- # using NNlib, BenchmarkTools
123
-
124
- #=
125
-
126
- ## M1 mac, 1.10
127
-
128
- julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129
-
130
- julia> @btime bias_act!(relu, $w, $b);
131
- min 19.500 μs, mean 21.375 μs (0 allocations)
132
-
133
- julia> @btime relu.($w .+ $b);
134
- min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135
-
136
- julia> @btime bias_act!(tanh, $w, $b);
137
- min 63.792 μs, mean 65.052 μs (0 allocations)
138
-
139
- julia> @btime tanh_fast.($w .+ $b);
140
- min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141
-
142
- julia> using Zygote
143
-
144
- julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145
- min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146
-
147
- julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148
- min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149
-
150
- julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151
- min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152
-
153
- julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154
- min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155
-
156
-
157
-
158
- ## Cyclops
159
-
160
- julia> using CUDA # 10x bigger
161
-
162
- julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163
-
164
- julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165
- 22.546 μs (27 allocations: 1.45 KiB)
166
-
167
- julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168
- 31.282 μs (38 allocations: 1.81 KiB)
169
-
170
- julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171
- 27.030 μs (27 allocations: 1.45 KiB)
172
-
173
- julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174
- 36.421 μs (38 allocations: 1.81 KiB)
175
-
176
- julia> using Zygote
177
-
178
- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179
- 204.507 μs (382 allocations: 18.15 KiB)
180
-
181
- julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182
- 204.458 μs (409 allocations: 19.19 KiB)
183
-
184
- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185
- 224.545 μs (382 allocations: 18.15 KiB)
186
-
187
- julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188
- 204.793 μs (411 allocations: 19.30 KiB)
189
-
190
-
191
- =#
192
-
193
- #=
194
-
195
- (jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196
-
197
- julia> using NNlib, Zygote, BenchmarkTools
198
-
199
- julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200
-
201
- julia> @btime bias_act!(relu, $w * $x, $b);
202
- min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203
-
204
- julia> @btime relu.($w * $x .+ $b);
205
- min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206
-
207
- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208
- min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209
-
210
- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211
- min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212
-
213
- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214
- min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215
-
216
- julia> @btime gradient(x -> sum(abs2, x), $x);
217
- min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218
-
219
-
220
- # Cyclops
221
-
222
- julia> @btime bias_act!(relu, $w * $x, $b);
223
- 24.786 μs (2 allocations: 19.61 KiB)
224
-
225
- julia> @btime relu.($w * $x .+ $b);
226
- 25.501 μs (4 allocations: 39.22 KiB)
227
-
228
- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229
- 91.847 μs (43 allocations: 89.83 KiB)
230
-
231
- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232
- 98.054 μs (41 allocations: 128.91 KiB)
233
-
234
- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235
- 80.464 μs (28 allocations: 69.41 KiB)
236
-
237
- julia> @btime gradient(x -> sum(abs2, x), $x);
238
- 4.604 μs (2 allocations: 19.61 KiB)
239
-
240
- julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241
-
242
- julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243
-
244
-
245
-
246
- =#
247
-
248
-
249
-
0 commit comments