Skip to content

Commit 2ac91f0

Browse files
committed
more
1 parent 27328c2 commit 2ac91f0

File tree

3 files changed

+76
-77
lines changed

3 files changed

+76
-77
lines changed

src/extra_rules.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ end
150150
# @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18
151151

152152
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
153+
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
153154

154155
# # Skip AD'ing through the axis computation
155156
# function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
@@ -200,12 +201,13 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst
200201
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
201202
end
202203

203-
function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
204-
# We're leaving these in the eltype that the cotangent vector already has.
205-
# There isn't really a good reason to believe we should convert to the
206-
# original array type, so don't unless explicitly requested.
207-
AT(x), Δ->(NoTangent(), Δ)
208-
end
204+
# https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L7
205+
# function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
206+
# # We're leaving these in the eltype that the cotangent vector already has.
207+
# # There isn't really a good reason to believe we should convert to the
208+
# # original array type, so don't unless explicitly requested.
209+
# AT(x), Δ->(NoTangent(), Δ)
210+
# end
209211

210212
# WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209.
211213
# function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
@@ -254,10 +256,9 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I
254256
Vector{T}(undef, dims...), zeros(T, dims...)
255257
end
256258

257-
@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
259+
# @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) CR#558
258260
@ChainRules.non_differentiable Base.throw(err)
259261
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
260-
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
261262

262263
# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
263264
function ChainRulesCore.rrule(::Type{Thunk}, thnk)

src/stage1/broadcast.jl

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ end
3333

3434
using ChainRulesCore: derivatives_given_output
3535

36-
# _print(s) = nothing
37-
_print(s) = printstyled(s, "\n"; color=:magenta)
36+
_print(s) = nothing
37+
# _print(s) = printstyled(s, "\n"; color=:magenta)
3838

3939
# Broadcast over one element is just map
4040
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
41-
_print("path 0")
41+
_print("path 0, order $N")
4242
∂⃖ₙ(map, f, a)
4343
end
4444

@@ -47,8 +47,8 @@ end
4747
function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4848
T = Broadcast.combine_eltypes(f, args)
4949
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
50-
if eltype(T) == Bool
51-
# Trivial case: non-differentiable output
50+
if T === Bool
51+
# Trivial case: non-differentiable output, e.g. `x .> 0`
5252
_print("path 1")
5353
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
5454
return f.(args...), back_1
@@ -160,16 +160,14 @@ end
160160

161161
# For certain cheap operations we can easily allow fused broadcast:
162162

163-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...)
164-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity
165-
function split_bc_plus(xs...) where {F}
163+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...)
164+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity
165+
function lazy_bc_plus(xs...) where {F}
166166
broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw)
167167
_print("broadcast +")
168168
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...)
169169
end
170170
end
171-
Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
172-
mapreduce(eltype, promote_type, bc.args) # needed to hit fast path
173171

174172
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
175173

@@ -182,24 +180,22 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
182180
end
183181

184182
using LinearAlgebra: dot
183+
const Numeric{T<:Number} = Union{T, AbstractArray{T}}
185184

186-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it?
185+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric)
187186
broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw)
188187
_print("broadcast *")
189188
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y))
190189
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x))
191190
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
192-
# Will things like this work? Ref([1,2]) .* [1,2,3]
193191
(NoTangent(), NoTangent(), dx, dy)
194192
end
195193
end
196-
# Alternative to `x isa Number` etc above... but not quite right!
197-
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y)
198194

199195
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2})
200196
_print("broadcast ^2")
201197
broadcasted(*, x, x), Δ -> begin
202-
dx = unbroadcast(x, 2 .* Δ .* conj.(x))
198+
dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x))
203199
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
204200
end
205201
end
@@ -208,30 +204,25 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type
208204
x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent())
209205
end
210206

211-
# function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic
212-
# broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ)
213-
# _print("broadcast /")
214-
# dx = unbroadcast(x, Δ ./ conj.(y))
215-
# dy = unbroadcast(y, .-Δ .* conj.(res ./ y))
216-
# (NoTangent(), NoTangent(), dx, dy)
217-
# end
218-
# end
219-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y::Number)
207+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number)
220208
_print("simple /")
221209
z, back = ∂⃖{1}()(/, x, y)
222-
z, Δ -> begin
223-
_, dx, dy = back(Δ)
224-
(NoTangent(), NoTangent(), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too
210+
z, dz -> begin
211+
_, dx, dy = back(dz)
212+
(NoTangent(), NoTangent(), dx, dy)
225213
end
226214
end
227215

216+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = x, identity_pullback
217+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = x, identity_pullback # ambiguity
218+
identity_pullback(Δ) = (NoTangent(), NoTangent(), Δ)
219+
220+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = x, identity_pullback
221+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = x, identity_pullback
228222
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) =
229223
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
230-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) =
231-
x, Δ -> (NoTangent(), Δ)
232-
233-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) =
234-
x, Δ -> (NoTangent(), Δ)
224+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) =
225+
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
235226

236227
# All broadcasts use `unbroadcast` to reduce to correct shape:
237228

@@ -244,7 +235,7 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
244235
ProjectTo(x)(sum(dx; dims))
245236
end
246237
end
247-
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent()
238+
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
248239

249240
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
250241
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}

test/runtests.jl

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -215,42 +215,49 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
215215
@test delta45 1.0
216216

217217
# Broadcasting
218-
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
219-
@test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
220-
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
221-
222-
@test_broken gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
223-
exp_log(x) = exp(log(x))
224-
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
225-
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
226-
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
227-
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
228-
229-
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
230-
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
231-
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
232-
233-
@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],)
234-
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
235-
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
236-
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
237-
238-
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output
239-
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),)
240-
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent())
241-
@test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input
242-
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),)
243-
@test gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),)
244-
245-
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]')
246-
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
247-
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
248-
@test tup_adj[2] isa Adjoint
249-
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
250-
251-
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64})
252-
@test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode
253-
@test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
218+
@testset "broadcast" begin
219+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
220+
@test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
221+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
222+
223+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
224+
exp_log(x) = exp(log(x))
225+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
226+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
227+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
228+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
229+
230+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
231+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
232+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
233+
@test gradient(x -> sum(sum, (x,) .* x'), [1,2,3])[1] [12, 12, 12] # must not take the * fast path
234+
235+
@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],)
236+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
237+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
238+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
239+
240+
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
241+
@test gradient(x -> sum([1,2,3]' .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
242+
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)
243+
244+
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output
245+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),)
246+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent())
247+
@test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input
248+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),)
249+
@test_broken gradient(x -> sum(x .* [1,2,3]'), (true, false)) == (NoTangent(),) # Cannot `convert` an object of type NoTangent to an object of type ZeroTangent
250+
251+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), [3,4,5]')
252+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
253+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
254+
@test tup_adj[2] isa Adjoint
255+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
256+
257+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3]) # path 0, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64})
258+
@test_broken gradient(x -> sum(gradient(x -> sum(x' .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode
259+
@test_broken gradient(x -> sum(gradient(x -> sum(x' ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
260+
end
254261

255262
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
256263
#include("pinn.jl")

0 commit comments

Comments
 (0)