Skip to content

Commit 3d2618e

Browse files
authored
Merge pull request #146 from JuliaDiff/myb/fuse_frule
Update to ChainRulesCore 0.5.1
2 parents bc600ce + c1a4451 commit 3d2618e

File tree

11 files changed

+71
-142
lines changed

11 files changed

+71
-142
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.2.5"
3+
version = "0.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
ChainRulesCore = "0.4"
13+
ChainRulesCore = "0.5.1"
1414
FiniteDifferences = "^0.7"
1515
Reexport = "0.2"
1616
Requires = "0.5.2, 1"

docs/src/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
233233
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.
234234

235235
#### Other `AbstractDifferential`s: don't worry about them right now
236-
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
237236
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
238237
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
239238

src/rulesets/Base/base.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
@scalar_rule(zero(x), Zero())
33
@scalar_rule(sign(x), Zero())
44

5-
@scalar_rule(abs2(x), Wirtinger(x', x))
5+
@scalar_rule(abs2(x), 2x)
66
@scalar_rule(log(x), inv(x))
77
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
88
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
@@ -50,14 +50,12 @@
5050
@scalar_rule(deg2rad(x), π / oftype(x, 180))
5151
@scalar_rule(rad2deg(x), oftype(x, 180) / π)
5252

53-
@scalar_rule(conj(x), Wirtinger(Zero(), One()))
54-
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
53+
@scalar_rule(conj(x::Real), One())
54+
@scalar_rule(adjoint(x::Real), One())
5555
@scalar_rule(transpose(x), One())
5656

5757
@scalar_rule(abs(x::Real), sign(x))
58-
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
5958
@scalar_rule(hypot(x::Real), sign(x))
60-
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
6159
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))
6260

6361
@scalar_rule(+(x), One())
@@ -98,20 +96,14 @@
9896
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
9997
@scalar_rule(fma(x, y, z), (y, x, One()))
10098
@scalar_rule(muladd(x, y, z), (y, x, One()))
101-
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
10299
@scalar_rule(angle(x::Real), Zero())
103-
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
104100
@scalar_rule(real(x::Real), One())
105-
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
106101
@scalar_rule(imag(x::Real), Zero())
107102

108103
# product rule requires special care for arguments where `mul` is non-commutative
109104

110-
function frule(::typeof(*), x::Number, y::Number)
111-
function times_pushforward(_, Δx, Δy)
112-
return Δx * y + x * Δy
113-
end
114-
return x * y, times_pushforward
105+
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
106+
return x * y, Δx * y + x * Δy
115107
end
116108

117109
function rrule(::typeof(*), x::Number, y::Number)
@@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number)
121113
return x * y, times_pullback
122114
end
123115

124-
function frule(::typeof(identity), x)
125-
function identity_pushforward(_, ẏ)
126-
return
127-
end
128-
return x, identity_pushforward
116+
function frule(::typeof(identity), x, _, ẏ)
117+
return x, ẏ
129118
end
130119

131120
function rrule(::typeof(identity), x)

src/rulesets/Base/broadcast.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
66
=#
77
function _cast_diff(f, x)
88
function element_rule(u)
9-
fu, du = frule(f, u)
10-
fu, extern(du(NamedTuple(), One()))
9+
dself = Zero()
10+
fu, du = frule(f, u, dself, One())
11+
fu, extern(du)
1112
end
1213
results = broadcast(element_rule, x)
1314
return first.(results), last.(results)
1415
end
1516

16-
function frule(::typeof(broadcast), f, x)
17+
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
1718
Ω, ∂x = _cast_diff(f, x)
18-
function broadcast_pushforward(_, Δf, Δx)
19-
return Δx .* ∂x
20-
end
21-
return Ω, broadcast_pushforward
19+
return Ω, Δx .* ∂x
2220
end
2321

2422
function rrule(::typeof(broadcast), f, x)

src/rulesets/Base/mapreduce.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@ end
5050
##### `sum`
5151
#####
5252

53-
function frule(::typeof(sum), x)
54-
function sum_pushforward(_, ẋ)
55-
return sum(ẋ)
56-
end
57-
return sum(x), sum_pushforward
53+
function frule(::typeof(sum), x, _, ẋ)
54+
return sum(x), sum(ẋ)
5855
end
5956

6057
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
1111
##### `BLAS.dot`
1212
#####
1313

14-
frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
14+
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)
1515

1616
rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)
1717

@@ -35,12 +35,9 @@ end
3535
##### `BLAS.nrm2`
3636
#####
3737

38-
function frule(::typeof(BLAS.nrm2), x)
38+
function frule(::typeof(BLAS.nrm2), x, _, Δx)
3939
Ω = BLAS.nrm2(x)
40-
function nrm2_pushforward(_, Δx)
41-
return sum(Δx * cast(@thunk(x * inv(Ω))))
42-
end
43-
return Ω, nrm2_pushforward
40+
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
4441
end
4542

4643
function rrule(::typeof(BLAS.nrm2), x)
@@ -70,16 +67,13 @@ end
7067
##### `BLAS.asum`
7168
#####
7269

73-
function frule(::typeof(BLAS.asum), x)
74-
function asum_pushforward(_, Δx)
75-
return sum(cast(sign, x) * Δx)
76-
end
77-
return BLAS.asum(x), asum_pushforward
70+
function frule(::typeof(BLAS.asum), x, _, Δx)
71+
return BLAS.asum(x), sum(sign.(x) .* Δx)
7872
end
7973

8074
function rrule(::typeof(BLAS.asum), x)
8175
function asum_pullback(ΔΩ)
82-
return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x)))
76+
return (NO_FIELDS, @thunk(ΔΩ * sign.(x)))
8377
end
8478
return BLAS.asum(x), asum_pullback
8579
end

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
88
##### `dot`
99
#####
1010

11-
function frule(::typeof(dot), x, y)
12-
function dot_pushforward(Δself, Δx, Δy)
13-
return sum(Δx .* y) + sum(x .* Δy)
14-
end
15-
return dot(x, y), dot_pushforward
11+
function frule(::typeof(dot), x, y, _, Δx, Δy)
12+
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
1613
end
1714

1815
function rrule(::typeof(dot), x, y)
@@ -26,20 +23,15 @@ end
2623
##### `inv`
2724
#####
2825

29-
function frule(::typeof(inv), x::AbstractArray)
26+
function frule(::typeof(inv), x::AbstractArray, _, Δx)
3027
Ω = inv(x)
31-
m = @thunk(-Ω)
32-
function inv_pushforward(_, Δx)
33-
return m * Δx * Ω
34-
end
35-
return Ω, inv_pushforward
28+
return Ω, -Ω * Δx * Ω
3629
end
3730

3831
function rrule(::typeof(inv), x::AbstractArray)
3932
Ω = inv(x)
40-
m = @thunk(-Ω')
4133
function inv_pullback(ΔΩ)
42-
return NO_FIELDS, m * ΔΩ * Ω'
34+
return NO_FIELDS, -Ω' * ΔΩ * Ω'
4335
end
4436
return Ω, inv_pullback
4537
end
@@ -48,14 +40,11 @@ end
4840
##### `det`
4941
#####
5042

51-
function frule(::typeof(det), x)
43+
function frule(::typeof(det), x, _, ẋ)
5244
Ω = det(x)
53-
function det_pushforward(_, ẋ)
54-
# TODO Performance optimization: probably there is an efficent
55-
# way to compute this trace without during the full compution within
56-
return Ω * tr(inv(x) * ẋ)
57-
end
58-
return Ω, det_pushforward
45+
# TODO Performance optimization: probably there is an efficent
46+
# way to compute this trace without during the full compution within
47+
return Ω, Ω * tr(inv(x) * ẋ)
5948
end
6049

6150
function rrule(::typeof(det), x)
@@ -70,12 +59,9 @@ end
7059
##### `logdet`
7160
#####
7261

73-
function frule(::typeof(logdet), x)
62+
function frule(::typeof(logdet), x, _, Δx)
7463
Ω = logdet(x)
75-
function logdet_pushforward(_, Δx)
76-
return tr(inv(x) * Δx)
77-
end
78-
return Ω, logdet_pushforward
64+
return Ω, tr(inv(x) * Δx)
7965
end
8066

8167
function rrule(::typeof(logdet), x)
@@ -90,11 +76,8 @@ end
9076
##### `trace`
9177
#####
9278

93-
function frule(::typeof(tr), x)
94-
function tr_pushforward(_, Δx)
95-
return tr(Δx)
96-
end
97-
return tr(x), tr_pushforward
79+
function frule(::typeof(tr), x, _, Δx)
80+
return tr(x), tr(Δx)
9881
end
9982

10083
function rrule(::typeof(tr), x)

test/rulesets/Base/base.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@
5151
test_scalar(acscd, 1/x)
5252
test_scalar(acotd, 1/x)
5353
end
54-
55-
@testset "sincos" begin
56-
x, Δx, x̄ = randn(3)
57-
Δz = (randn(), randn())
54+
@testset "Multivariate" begin
55+
@testset "sincos" begin
56+
x, Δx, x̄ = randn(3)
57+
Δz = (randn(), randn())
5858

59-
frule_test(sincos, (x, Δx))
60-
rrule_test(sincos, Δz, (x, x̄))
59+
frule_test(sincos, (x, Δx))
60+
rrule_test(sincos, Δz, (x, x̄))
61+
end
6162
end
6263
end # Trig
6364

6465
@testset "math" begin
65-
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
66+
for x in (-0.1, 6.4)
6667
test_scalar(deg2rad, x)
6768
test_scalar(rad2deg, x)
6869

@@ -72,22 +73,20 @@
7273
test_scalar(exp2, x)
7374
test_scalar(exp10, x)
7475

75-
x isa Real && test_scalar(cbrt, x)
76-
if (x isa Real && x >= 0) || x isa Complex
77-
# this check is needed because these have discontinuities between
78-
# `-10 + im*eps()` and `-10 - im*eps()`
79-
should_test_wirtinger = imag(x) != 0 && real(x) < 0
80-
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
81-
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
82-
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
83-
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
84-
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
76+
test_scalar(cbrt, x)
77+
78+
if x >= 0
79+
test_scalar(sqrt, x)
80+
test_scalar(log, x)
81+
test_scalar(log2, x)
82+
test_scalar(log10, x)
83+
test_scalar(log1p, x)
8584
end
8685
end
8786
end
8887

8988
@testset "Unary complex functions" begin
90-
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
89+
for x in (-4.1, 6.4)
9190
test_scalar(real, x)
9291
test_scalar(imag, x)
9392

@@ -97,6 +96,7 @@
9796
test_scalar(angle, x)
9897
test_scalar(abs2, x)
9998
test_scalar(conj, x)
99+
test_scalar(adjoint, x)
100100
end
101101
end
102102

@@ -152,8 +152,7 @@
152152
_, x̄ = pb(10.5)
153153
@test extern(x̄) == 0
154154

155-
_, pf = frule(sign, 0.0)
156-
= pf(NamedTuple(), 10.5)
155+
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
157156
@test extern(ẏ) == 0
158157
end
159158
end

test/rulesets/Base/broadcast.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
end
2323
@testset "frule" begin
2424
x = rand(3, 3)
25-
y, pushforward = frule(broadcast, sin, x)
25+
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
2626
@test y == sin.(x)
27-
28-
= pushforward(NamedTuple(), NamedTuple(), One())
2927
@test extern(ẏ) == cos.(x)
3028
end
3129
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using Test
1111

1212
# For testing purposes we use a lot of
1313
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
14-
Wirtinger, wirtinger_primal, wirtinger_conjugate,
1514
Zero, One, DoesNotExist, Thunk, AbstractDifferential
1615

1716
Random.seed!(1) # Set seed that all testsets should reset to.

0 commit comments

Comments
 (0)