Skip to content

Commit b2969a5

Browse files
committed
comment out 2nd path again
1 parent c1d834f commit b2969a5

File tree

2 files changed

+16
-164
lines changed

2 files changed

+16
-164
lines changed

src/bias_act.jl

Lines changed: 10 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
5959
end
6060
return Ω, bias_act!_fastback
6161

62-
# Slower path: can't overwrite x, but can use derivatives_given_output
63-
# This case is WRONG and tests fail, but not sure why
64-
elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
65-
Ω2 = fast_act(σ, x).(x) .+ b
66-
@show σ b
67-
function bias_act!_back2(Δ)
68-
dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
69-
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
70-
end
71-
return Ω2, bias_act!_back2
62+
# # Slower path: can't overwrite x, but can use derivatives_given_output
63+
# # This case is WRONG and tests fail, but not sure why
64+
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
65+
# Ω2 = fast_act(σ, x).(x) .+ b
66+
# @show σ b
67+
# function bias_act!_back2(Δ)
68+
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
69+
# return (NoTangent(), NoTangent(), dx, biasgrad(dx))
70+
# end
71+
# return Ω2, bias_act!_back2
7272

7373
# Fallback path: let AD handle the broadcast
7474
else
@@ -96,154 +96,3 @@ function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArr
9696
return x, bias_act!_trivial
9797
end
9898

99-
100-
101-
# """
102-
# add_act(σ, x, y...)
103-
# add_act!(σ, x, y, z...)
104-
105-
# Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106-
# """
107-
# add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108-
109-
110-
# function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111-
# if isconcretetype(Core.Compiler._return_type(
112-
# derivatives_given_output, Tuple{T, F, NotaNumber}))
113-
114-
# end
115-
116-
117-
# bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118-
# # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119-
# (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120-
121-
122-
# using NNlib, BenchmarkTools
123-
124-
#=
125-
126-
## M1 mac, 1.10
127-
128-
julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129-
130-
julia> @btime bias_act!(relu, $w, $b);
131-
min 19.500 μs, mean 21.375 μs (0 allocations)
132-
133-
julia> @btime relu.($w .+ $b);
134-
min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135-
136-
julia> @btime bias_act!(tanh, $w, $b);
137-
min 63.792 μs, mean 65.052 μs (0 allocations)
138-
139-
julia> @btime tanh_fast.($w .+ $b);
140-
min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141-
142-
julia> using Zygote
143-
144-
julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145-
min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146-
147-
julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148-
min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149-
150-
julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151-
min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152-
153-
julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154-
min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155-
156-
157-
158-
## Cyclops
159-
160-
julia> using CUDA # 10x bigger
161-
162-
julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163-
164-
julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165-
22.546 μs (27 allocations: 1.45 KiB)
166-
167-
julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168-
31.282 μs (38 allocations: 1.81 KiB)
169-
170-
julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171-
27.030 μs (27 allocations: 1.45 KiB)
172-
173-
julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174-
36.421 μs (38 allocations: 1.81 KiB)
175-
176-
julia> using Zygote
177-
178-
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179-
204.507 μs (382 allocations: 18.15 KiB)
180-
181-
julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182-
204.458 μs (409 allocations: 19.19 KiB)
183-
184-
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185-
224.545 μs (382 allocations: 18.15 KiB)
186-
187-
julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188-
204.793 μs (411 allocations: 19.30 KiB)
189-
190-
191-
=#
192-
193-
#=
194-
195-
(jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196-
197-
julia> using NNlib, Zygote, BenchmarkTools
198-
199-
julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200-
201-
julia> @btime bias_act!(relu, $w * $x, $b);
202-
min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203-
204-
julia> @btime relu.($w * $x .+ $b);
205-
min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206-
207-
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208-
min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209-
210-
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211-
min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212-
213-
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214-
min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215-
216-
julia> @btime gradient(x -> sum(abs2, x), $x);
217-
min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218-
219-
220-
# Cyclops
221-
222-
julia> @btime bias_act!(relu, $w * $x, $b);
223-
24.786 μs (2 allocations: 19.61 KiB)
224-
225-
julia> @btime relu.($w * $x .+ $b);
226-
25.501 μs (4 allocations: 39.22 KiB)
227-
228-
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229-
91.847 μs (43 allocations: 89.83 KiB)
230-
231-
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232-
98.054 μs (41 allocations: 128.91 KiB)
233-
234-
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235-
80.464 μs (28 allocations: 69.41 KiB)
236-
237-
julia> @btime gradient(x -> sum(abs2, x), $x);
238-
4.604 μs (2 allocations: 19.61 KiB)
239-
240-
julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241-
242-
julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243-
244-
245-
246-
=#
247-
248-
249-

src/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,14 @@ unsqueeze(x) = reshape(x, 1, size(x)...)
119119
120120
This does `x .= f.(x, y, z...)`, but works around
121121
an issue with broadcasting that prevents SIMD in such cases.
122-
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
122+
Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
123123
124-
Not intended for general use. Uses `@inbounds` but does not check sizes!
124+
Has an `rrule` to avoid mutation within derivatives.
125125
126-
Has an `rrule` to avoid mutation within derivatives. This assumes that `f` has no derivative!
126+
!!! warning
127+
Not intended for general use.
128+
Uses `@inbounds` but does not check sizes!
129+
Assumes that `f` has no derivative!
127130
"""
128131
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
129132
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))

0 commit comments

Comments
 (0)