Skip to content

Update to ChainRulesCore 0.5.1 #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 12, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.5"
version = "0.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.4"
ChainRulesCore = "0.5.1"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s: don't worry about them right now
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.

Expand Down
25 changes: 7 additions & 18 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs2(x), Wirtinger(x', x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
Expand Down Expand Up @@ -50,14 +50,12 @@
@scalar_rule(deg2rad(x), π / oftype(x, 180))
@scalar_rule(rad2deg(x), oftype(x, 180) / π)

@scalar_rule(conj(x), Wirtinger(Zero(), One()))
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
@scalar_rule(conj(x::Real), One())
@scalar_rule(adjoint(x::Real), One())
@scalar_rule(transpose(x), One())

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))

@scalar_rule(+(x), One())
Expand Down Expand Up @@ -98,20 +96,14 @@
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
@scalar_rule(real(x::Real), One())
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
@scalar_rule(imag(x::Real), Zero())

# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number)
function times_pushforward(_, Δx, Δy)
return Δx * y + x * Δy
end
return x * y, times_pushforward
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
return x * y, Δx * y + x * Δy
end

function rrule(::typeof(*), x::Number, y::Number)
Expand All @@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number)
return x * y, times_pullback
end

function frule(::typeof(identity), x)
function identity_pushforward(_, ẏ)
return ẏ
end
return x, identity_pushforward
function frule(::typeof(identity), x, _, ẏ)
return x, ẏ
end

function rrule(::typeof(identity), x)
Expand Down
12 changes: 5 additions & 7 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
=#
function _cast_diff(f, x)
function element_rule(u)
fu, du = frule(f, u)
fu, extern(du(NamedTuple(), One()))
dself = Zero()
fu, du = frule(f, u, dself, One())
fu, extern(du)
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x)
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
Ω, ∂x = _cast_diff(f, x)
function broadcast_pushforward(_, Δf, Δx)
return Δx .* ∂x
end
return Ω, broadcast_pushforward
return Ω, Δx .* ∂x
end

function rrule(::typeof(broadcast), f, x)
Expand Down
7 changes: 2 additions & 5 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ end
##### `sum`
#####

function frule(::typeof(sum), x)
function sum_pushforward(_, ẋ)
return sum(ẋ)
end
return sum(x), sum_pushforward
function frule(::typeof(sum), x, _, ẋ)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
Expand Down
16 changes: 5 additions & 11 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
##### `BLAS.dot`
#####

frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)

rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

Expand All @@ -35,12 +35,9 @@ end
##### `BLAS.nrm2`
#####

function frule(::typeof(BLAS.nrm2), x)
function frule(::typeof(BLAS.nrm2), x, _, Δ)
Ω = BLAS.nrm2(x)
function nrm2_pushforward(_, Δx)
return sum(Δx * cast(@thunk(x * inv(Ω))))
end
return Ω, nrm2_pushforward
return Ω, sum(Δx * cast(@thunk(x * inv(Ω))))
end

function rrule(::typeof(BLAS.nrm2), x)
Expand Down Expand Up @@ -70,11 +67,8 @@ end
##### `BLAS.asum`
#####

function frule(::typeof(BLAS.asum), x)
function asum_pushforward(_, Δx)
return sum(cast(sign, x) * Δx)
end
return BLAS.asum(x), asum_pushforward
function frule(::typeof(BLAS.asum), x, _, Δx)
return BLAS.asum(x), sum(cast(sign, x) * Δx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return BLAS.asum(x), sum(cast(sign, x) * Δx)
return BLAS.asum(x), sum(sign(x) .* Δx)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is actually no BLAS.asum(x), so I removed it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, wonder how it got in there,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I confused myself. BLAS.asum(x) is there, but it is not documented in Julia

help?> BLAS.asum
  asum(n, X, incx)

  Sum of the absolute values of the first n elements of array X with stride incx.

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  julia> BLAS.asum(5, fill(1.0im, 10), 2)
  5.0

  julia> BLAS.asum(2, fill(1.0im, 10), 5)
  2.0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added back.

end

function rrule(::typeof(BLAS.asum), x)
Expand Down
39 changes: 12 additions & 27 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
##### `dot`
#####

function frule(::typeof(dot), x, y)
function dot_pushforward(Δself, Δx, Δy)
return sum(Δx .* y) + sum(x .* Δy)
end
return dot(x, y), dot_pushforward
function frule(::typeof(dot), x, y, _, Δx, Δy)
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
end

function rrule(::typeof(dot), x, y)
Expand All @@ -26,13 +23,10 @@ end
##### `inv`
#####

function frule(::typeof(inv), x::AbstractArray)
function frule(::typeof(inv), x::AbstractArray, _, Δx)
Ω = inv(x)
m = @thunk(-Ω)
function inv_pushforward(_, Δx)
return m * Δx * Ω
end
return Ω, inv_pushforward
return Ω, m * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Expand All @@ -48,14 +42,11 @@ end
##### `det`
#####

function frule(::typeof(det), x)
function frule(::typeof(det), x, _, ẋ)
Ω = det(x)
function det_pushforward(_, ẋ)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω * tr(inv(x) * ẋ)
end
return Ω, det_pushforward
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
Expand All @@ -70,12 +61,9 @@ end
##### `logdet`
#####

function frule(::typeof(logdet), x)
function frule(::typeof(logdet), x, _, Δx)
Ω = logdet(x)
function logdet_pushforward(_, Δx)
return tr(inv(x) * Δx)
end
return Ω, logdet_pushforward
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
Expand All @@ -90,11 +78,8 @@ end
##### `trace`
#####

function frule(::typeof(tr), x)
function tr_pushforward(_, Δx)
return tr(Δx)
end
return tr(x), tr_pushforward
function frule(::typeof(tr), x, _, Δx)
return tr(x), tr(Δx)
end

function rrule(::typeof(tr), x)
Expand Down
64 changes: 31 additions & 33 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,37 @@
test_scalar(acscd, 1/x)
test_scalar(acotd, 1/x)
end

@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())
@testset "Multivariate" begin
@testset "atan2" begin
# https://en.wikipedia.org/wiki/Atan2
x, y = rand(2)
ratan = atan(x, y)
u = x^2 + y^2
datan = y/u - 2x/u

r, ṙ = frule(atan, x, y, Zero(), 1, 2)
@test r === ratan
@test ṙ === datan

r, pullback = rrule(atan, x, y)
@test r === ratan
dself, df1, df2 = pullback(1)
@test dself == NO_FIELDS
@test df1 + 2df2 === datan
end

@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "math" begin
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
for x in (-0.1, 6.4)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)

Expand All @@ -72,31 +91,11 @@
test_scalar(exp2, x)
test_scalar(exp10, x)

x isa Real && test_scalar(cbrt, x)
if (x isa Real && x >= 0) || x isa Complex
# this check is needed because these have discontinuities between
# `-10 + im*eps()` and `-10 - im*eps()`
should_test_wirtinger = imag(x) != 0 && real(x) < 0
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
test_scalar(real, x)
test_scalar(imag, x)

test_scalar(abs, x)
test_scalar(hypot, x)

test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
test_scalar(adjoint, x)
test_scalar(abs2, x)

x isa Real && test_scalar(cbrt, x)
end
end

Expand Down Expand Up @@ -152,8 +151,7 @@
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, pf = frule(sign, 0.0)
ẏ = pf(NamedTuple(), 10.5)
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
@test extern(ẏ) == 0
end
end
Expand Down
4 changes: 1 addition & 3 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
end
@testset "frule" begin
x = rand(3, 3)
y, pushforward = frule(broadcast, sin, x)
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
@test y == sin.(x)

ẏ = pushforward(NamedTuple(), NamedTuple(), One())
@test extern(ẏ) == cos.(x)
end
end
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.
Expand Down
Loading