Skip to content

Update to ChainRulesCore 0.5.1 #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.5"
version = "0.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.4"
ChainRulesCore = "0.5.1"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s: don't worry about them right now
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.

Expand Down
25 changes: 7 additions & 18 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs2(x), Wirtinger(x', x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
Expand Down Expand Up @@ -50,14 +50,12 @@
@scalar_rule(deg2rad(x), π / oftype(x, 180))
@scalar_rule(rad2deg(x), oftype(x, 180) / π)

@scalar_rule(conj(x), Wirtinger(Zero(), One()))
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
@scalar_rule(conj(x::Real), One())
@scalar_rule(adjoint(x::Real), One())
@scalar_rule(transpose(x), One())

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))

@scalar_rule(+(x), One())
Expand Down Expand Up @@ -98,20 +96,14 @@
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
@scalar_rule(real(x::Real), One())
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
@scalar_rule(imag(x::Real), Zero())

# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number)
function times_pushforward(_, Δx, Δy)
return Δx * y + x * Δy
end
return x * y, times_pushforward
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
return x * y, Δx * y + x * Δy
end

function rrule(::typeof(*), x::Number, y::Number)
Expand All @@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number)
return x * y, times_pullback
end

function frule(::typeof(identity), x)
function identity_pushforward(_, ẏ)
return ẏ
end
return x, identity_pushforward
function frule(::typeof(identity), x, _, ẏ)
return x, ẏ
end

function rrule(::typeof(identity), x)
Expand Down
12 changes: 5 additions & 7 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
=#
function _cast_diff(f, x)
function element_rule(u)
fu, du = frule(f, u)
fu, extern(du(NamedTuple(), One()))
dself = Zero()
fu, du = frule(f, u, dself, One())
fu, extern(du)
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x)
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
Ω, ∂x = _cast_diff(f, x)
function broadcast_pushforward(_, Δf, Δx)
return Δx .* ∂x
end
return Ω, broadcast_pushforward
return Ω, Δx .* ∂x
end

function rrule(::typeof(broadcast), f, x)
Expand Down
7 changes: 2 additions & 5 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ end
##### `sum`
#####

function frule(::typeof(sum), x)
function sum_pushforward(_, ẋ)
return sum(ẋ)
end
return sum(x), sum_pushforward
function frule(::typeof(sum), x, _, ẋ)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
Expand Down
18 changes: 6 additions & 12 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
##### `BLAS.dot`
#####

frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)

rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

Expand All @@ -35,12 +35,9 @@ end
##### `BLAS.nrm2`
#####

function frule(::typeof(BLAS.nrm2), x)
function frule(::typeof(BLAS.nrm2), x, _, Δx)
Ω = BLAS.nrm2(x)
function nrm2_pushforward(_, Δx)
return sum(Δx * cast(@thunk(x * inv(Ω))))
end
return Ω, nrm2_pushforward
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
end

function rrule(::typeof(BLAS.nrm2), x)
Expand Down Expand Up @@ -70,16 +67,13 @@ end
##### `BLAS.asum`
#####

function frule(::typeof(BLAS.asum), x)
function asum_pushforward(_, Δx)
return sum(cast(sign, x) * Δx)
end
return BLAS.asum(x), asum_pushforward
function frule(::typeof(BLAS.asum), x, _, Δx)
return BLAS.asum(x), sum(sign.(x) .* Δx)
end

function rrule(::typeof(BLAS.asum), x)
function asum_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x)))
return (NO_FIELDS, @thunk(ΔΩ * sign.(x)))
end
return BLAS.asum(x), asum_pullback
end
Expand Down
43 changes: 13 additions & 30 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
##### `dot`
#####

function frule(::typeof(dot), x, y)
function dot_pushforward(Δself, Δx, Δy)
return sum(Δx .* y) + sum(x .* Δy)
end
return dot(x, y), dot_pushforward
function frule(::typeof(dot), x, y, _, Δx, Δy)
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
end

function rrule(::typeof(dot), x, y)
Expand All @@ -26,20 +23,15 @@ end
##### `inv`
#####

function frule(::typeof(inv), x::AbstractArray)
function frule(::typeof(inv), x::AbstractArray, _, Δx)
Ω = inv(x)
m = @thunk(-Ω)
function inv_pushforward(_, Δx)
return m * Δx * Ω
end
return Ω, inv_pushforward
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
m = @thunk(-Ω')
function inv_pullback(ΔΩ)
return NO_FIELDS, m * ΔΩ * Ω'
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end
Expand All @@ -48,14 +40,11 @@ end
##### `det`
#####

function frule(::typeof(det), x)
function frule(::typeof(det), x, _, ẋ)
Ω = det(x)
function det_pushforward(_, ẋ)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω * tr(inv(x) * ẋ)
end
return Ω, det_pushforward
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
Expand All @@ -70,12 +59,9 @@ end
##### `logdet`
#####

function frule(::typeof(logdet), x)
function frule(::typeof(logdet), x, _, Δx)
Ω = logdet(x)
function logdet_pushforward(_, Δx)
return tr(inv(x) * Δx)
end
return Ω, logdet_pushforward
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
Expand All @@ -90,11 +76,8 @@ end
##### `trace`
#####

function frule(::typeof(tr), x)
function tr_pushforward(_, Δx)
return tr(Δx)
end
return tr(x), tr_pushforward
function frule(::typeof(tr), x, _, Δx)
return tr(x), tr(Δx)
end

function rrule(::typeof(tr), x)
Expand Down
39 changes: 19 additions & 20 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,19 @@
test_scalar(acscd, 1/x)
test_scalar(acotd, 1/x)
end

@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "math" begin
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
for x in (-0.1, 6.4)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)

Expand All @@ -72,22 +73,20 @@
test_scalar(exp2, x)
test_scalar(exp10, x)

x isa Real && test_scalar(cbrt, x)
if (x isa Real && x >= 0) || x isa Complex
# this check is needed because these have discontinuities between
# `-10 + im*eps()` and `-10 - im*eps()`
should_test_wirtinger = imag(x) != 0 && real(x) < 0
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
test_scalar(cbrt, x)

if x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
test_scalar(log10, x)
test_scalar(log1p, x)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
for x in (-4.1, 6.4)
test_scalar(real, x)
test_scalar(imag, x)

Expand All @@ -97,6 +96,7 @@
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
test_scalar(adjoint, x)
end
end

Expand Down Expand Up @@ -152,8 +152,7 @@
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, pf = frule(sign, 0.0)
ẏ = pf(NamedTuple(), 10.5)
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
@test extern(ẏ) == 0
end
end
Expand Down
4 changes: 1 addition & 3 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
end
@testset "frule" begin
x = rand(3, 3)
y, pushforward = frule(broadcast, sin, x)
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
@test y == sin.(x)

ẏ = pushforward(NamedTuple(), NamedTuple(), One())
@test extern(ẏ) == cos.(x)
end
end
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.
Expand Down
Loading