diff --git a/Project.toml b/Project.toml index 5af68933..280e0b8d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,13 +8,14 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] -ChainRules = "1.5" -ChainRulesCore = "1.2" +ChainRules = "1.17" +ChainRulesCore = "1.11" Combinatorics = "1" StaticArrays = "1" StatsBase = "0.33" diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 4d22c172..164ded8d 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -137,6 +137,7 @@ struct NonDiffOdd{N, O, P}; end # This should not happen (::NonDiffEven{N, O, O})(Δ...) where {N, O} = error() +# WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140. @Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...) Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() end @@ -145,17 +146,8 @@ function ChainRulesCore.rrule(::typeof(Core.tuple), args...) Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...) end -# TODO: What to do about these integer rules -@ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) - ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent() - -# Skip AD'ing through the axis computation -function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) - return Base.Broadcast.instantiate(bc), Δ->begin - Core.tuple(NoTangent(), Δ) - end -end +ChainRulesCore.canonicalize(::NoTangent) = NoTangent() using StaticArrays @@ -199,20 +191,6 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ) end -function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N} - # We're leaving these in the eltype that the cotangent vector already has. - # There isn't really a good reason to believe we should convert to the - # original array type, so don't unless explicitly requested. - AT(x), Δ->(NoTangent(), Δ) -end - -function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...) - # We're leaving these in the eltype that the cotangent vector already has. - # There isn't really a good reason to believe we should convert to the - # original array type, so don't unless explicitly requested. - AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...) -end - function unzip_tuple(t::Tuple) map(x->x[1], t), map(x->x[2], t) end @@ -252,10 +230,8 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I Vector{T}(undef, dims...), zeros(T, dims...) end -@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) @ChainRules.non_differentiable Base.throw(err) @ChainRules.non_differentiable Core.Compiler.return_type(args...) -ChainRulesCore.canonicalize(::NoTangent) = NoTangent() # Disable thunking at higher order (TODO: These should go into ChainRulesCore) function ChainRulesCore.rrule(::Type{Thunk}, thnk) @@ -266,3 +242,17 @@ end function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val) val, Δ->(NoTangent(), NoTangent(), Δ) end + +# ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}. +ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that! + +# Rather than have a rule for broadcasted 3-arg *, just send it to the efficient path: +ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = ((y*z, x*z, x*y),) +function ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number, w::Number) + xy = x*y + zw = z*w + ((y*zw, x*zw, xy*w, xy*z),) +end + +# Fixes @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] +(project::ProjectTo{<:AbstractArray})(th::InplaceableThunk) = project(unthunk(th)) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 88fc0b39..78964184 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -29,45 +29,188 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, return r end +# Reverse mode broadcast rules + +using ChainRulesCore: derivatives_given_output + # Broadcast over one element is just map -function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} - ∂⃖ₙ(map, f, a) +# function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} +# ∂⃖ₙ(map, f, a) +# end + +(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) + +(::∂⃖{1})(::typeof(broadcasted), f::F, args...) where {F} = split_bc_rule(f, args...) +# (::∂⃖{1})(::typeof(broadcasted), f::F, arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity +function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N} + T = Broadcast.combine_eltypes(f, args) + TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) + if T === Bool + # Trivial case: non-differentiable output, e.g. `x .> 0` + back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) + return f.(args...), back_1 + elseif T <: Number && isconcretetype(TΔ) + # Fast path: just broadcast, and use arguments & result to find derivatives. + ys = f.(args...) + function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all + delta = broadcast(unthunk(dys), ys, args...) do dy, y, a + das = only(derivatives_given_output(y, f, a)) + dy * conj(only(das)) # possibly this * should be made nan-safe. + end + (NoTangent(), NoTangent(), unbroadcast(only(args), delta)) + end + function back_2_many(dys) + deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... + das = only(derivatives_given_output(y, f, as...)) + map(da -> dy * conj(da), das) + end + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast? + (NoTangent(), NoTangent(), dargs...) + end + return ys, N==1 ? back_2_one : back_2_many + else + # Slow path: collect all the pullbacks & apply them later. + # (Since broadcast makes no guarantee about order of calls, and un-fusing + # can change the number of calls, this does not bother to try to reverse.) + ys3, backs = tuplecast(∂⃖{1}(), f, args...) + function back_3(dys) + deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match + map(unthunk, back(dy)) + end + dargs = map(unbroadcast, args, Base.tail(deltas)) + (NoTangent(), sum(first(deltas)), dargs...) + end + back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...) + return ys3, back_3 + end end -# The below is from Zygote: TODO: DO we want to do something better here? +# Don't run broadcasting on scalars +function split_bc_rule(f::F, args::Number...) where {F} + z, back = ∂⃖{1}()(f, args...) + z, dz -> (NoTangent(), back(dz)...) +end -accum_sum(xs::Nothing; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::Number; dims = :) = xs +split_bc_rule(::typeof(identity), x) = x, Δ -> (NoTangent(), NoTangent(), Δ) +split_bc_rule(::typeof(identity), x::Number) = x, Δ -> (NoTangent(), NoTangent(), Δ) -# https://github.com/FluxML/Zygote.jl/issues/594 -function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region) - Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)}) +# Skip AD'ing through the axis computation +function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) + uninstantiate(Δ) = Core.tuple(NoTangent(), Δ) + return Base.Broadcast.instantiate(bc), uninstantiate end -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) +using StructArrays + +function tuplecast(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple.")) + end + bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) + StructArrays.components(StructArray(bc)) +end -unbroadcast(x::AbstractArray, x̄) = - size(x) == size(x̄) ? x̄ : - length(x) == length(x̄) ? trim(x, x̄) : - trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) +# For certain cheap operations we can easily allow fused broadcast: +const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} -unbroadcast(x::Number, x̄) = accum_sum(x̄) -unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) -unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::NumericOrBroadcast...) = lazy_bc_plus(args...) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::Number) = split_bc_rule(+, args...) +function lazy_bc_plus(xs...) where {F} + broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw) + (NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...) + end +end -unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent() +(::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = split_bc_rule(-, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) + broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw) + (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)) + end +end -const Numeric = Union{Number, AbstractArray{<:Number, N} where N} +using LinearAlgebra: dot -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...) - broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = split_bc_rule(*, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) + broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw) + (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) + end end +_back_star(x, y, Δ) = unbroadcast(x, Δ .* conj.(y)) +_back_star(x::Number, y, Δ) = dot(y, Δ) +_back_star(x::Bool, y, Δ) = NoTangent() -ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) + broadcasted(*, x, x), Δ -> begin + dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x)) + (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) + end +end +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) + x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent()) +end + +(::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = split_bc_rule(/, x, y) +function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) + z = broadcast(/, x, y) + z, Δth -> let Δ = unthunk(Δth) + dx = unbroadcast(x, Δ ./ conj.(y)) + dy = -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here + (NoTangent(), NoTangent(), dx, dy) + end +end + +(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = split_bc_rule(identity, x) +# (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity + +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = split_bc_rule(identity, x) +# (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) = + broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) +(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) = + broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ))) + +# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape: +function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) + N = ndims(dx) + if length(x) == length(dx) + ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors + else + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims` + ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked? + end +end +unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx + +unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) +function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} + val = if length(x) == length(dx) + dx + else + sum(dx; dims=2:ndims(dx)) + end + ProjectTo(x)(NTuple{length(x)}(val)) # Tangent +end -ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, - z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end \ No newline at end of file +unbroadcast(f::Function, df) = sum(df) +unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) +unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx))) + +unbroadcast(::Bool, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity +unbroadcast(::Val, dx) = NoTangent() + +function unbroadcast(x, dx) + p = ProjectTo(x) + if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero} + return NoTangent() + end + b = Broadcast.broadcastable(x) + if b isa Ref # then x is scalar under broadcast + return p(sum(dx)) + else + error("don't know how to handle broadcast gradient for x::$(typeof(x))") + end +end diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 2f8f48fd..036c4710 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -49,6 +49,8 @@ end Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x : i == 2 ? o.clos : throw(BoundsError(o, i)) +Base.lastindex(o::OpticBundle) = 2 + Base.iterate(o::OpticBundle) = (o.x, nothing) Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing) Base.iterate(o::OpticBundle, ::Missing) = nothing diff --git a/test/runtests.jl b/test/runtests.jl index 51335369..d752fc11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig +using Diffractor: ∂⃖, DiffractorRuleConfig using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -8,6 +8,11 @@ using LinearAlgebra using Test +const fwd = Diffractor.PrimeDerivativeFwd +const bwd = Diffractor.PrimeDerivativeBack + +@testset verbose=true "Unit tests" begin # from before broadcasting PR + # Unit tests function tup2(f) a, b = ∂⃖{2}()(f, 1) @@ -43,7 +48,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent() # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) + Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == cos(1.0) function simple_control_flow(b, x) if b @@ -86,7 +91,7 @@ isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x) let var"'" = Diffractor.PrimeDerivativeBack # Integration tests @test @inferred(sin'(1.0)) == cos(1.0) - @test @inferred(sin''(1.0)) == -sin(1.0) + @test @inferred(sin''(1.0)) == -sin(1.0) broken=true @test sin'''(1.0) == -cos(1.0) @test sin''''(1.0) == sin(1.0) @test sin'''''(1.0) == cos(1.0) @@ -100,10 +105,15 @@ let var"'" = Diffractor.PrimeDerivativeBack # Higher order mixed mode tests complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) - @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) - @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) - @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) - @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) + @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) broken=true # inference broken on Julia nightly without PR68 + @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true + @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true + @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true + + @test complicated_2sin'(1.0) == 2sin'(1.0) + @test complicated_2sin''(1.0) == 2sin''(1.0) + @test complicated_2sin'''(1.0) == 2sin'''(1.0) + @test complicated_2sin''''(1.0) == 2sin''''(1.0) # Control flow cases @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) @@ -149,9 +159,6 @@ end # Regression tests @test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - function f_broadcast(a) l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] return sum(l) @@ -161,7 +168,7 @@ end # Make sure that there's no infinite recursion in kwarg calls g_kw(;x=1.0) = sin(x) f_kw(x) = g_kw(;x) -@test bwd(f_kw)(1.0) == bwd(sin)(1.0) +@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true # involves [3] +(a::Tangent{Core.Box, ...}, b::Tangent{Core.Box, ...} function f_crit_edge(a, b, c, x) # A function with two critical edges. This used to trigger an issue where @@ -214,5 +221,59 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test z45 ≈ 2.0 @test delta45 ≈ 1.0 + +end # @testset + +@testset verbose=true "broadcast" begin + @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output + @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 + @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) + + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # stores pullback + exp_log(x) = exp(log(x)) + @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) + @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure + + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays + @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path + + @test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],) + @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule + @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule + @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule + + @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) + @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) + @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) + + @test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output + @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),) + @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent()) + @test gradient(x -> sum(x .+ [1,2,3]), true) == (NoTangent(),) # Bool input + @test gradient(x -> sum(x ./ [1,2,3]), [true false]) == (NoTangent(),) + @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) == (NoTangent(),) + + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) + @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) + @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] + @test tup_adj[2] isa Transpose + @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # path 0, order 2, MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3.0])[1] ≈ exp.(1:3) # path 0, order 2, MethodError: no method matching length(::ChainRulesCore.InplaceableThunk + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # Control flow support not fully implemented yet for higher-order reverse mode + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + + @test (@inferred Diffractor.split_bc_rule(<, [1,2,3], [4 5]); true) # path 1, bool + @test (@inferred Diffractor.split_bc_rule(+, [1,2,3]); true) # path 2, derivatives_given_output + @test (@inferred Diffractor.split_bc_rule(+, [1,2,3], [4 5]); true) # path 2 vararg, tuplecast + @test (@inferred Diffractor.split_bc_rule(exp_log, [1,2,3]); true) # path 3 generic + +end + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl")