Skip to content

Commit 08191e2

Browse files
committed
Fuse frule
1 parent 313f65f commit 08191e2

File tree

9 files changed

+59
-74
lines changed

9 files changed

+59
-74
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
ChainRulesCore = "0.4"
13+
Reexport = "0.2"
14+
Requires = "0.5.2"
15+
ChainRulesCore = "0.5"
1416
FiniteDifferences = "^0.7"
1517
Reexport = "0.2"
1618
Requires = "0.5.2, 1"

src/rulesets/Base/base.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,8 @@
100100

101101
# product rule requires special care for arguments where `mul` is non-commutative
102102

103-
function frule(::typeof(*), x::Number, y::Number)
104-
function times_pushforward(_, Δx, Δy)
105-
return Δx * y + x * Δy
106-
end
107-
return x * y, times_pushforward
103+
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
104+
return x * y, Δx * y + x * Δy
108105
end
109106

110107
function rrule(::typeof(*), x::Number, y::Number)
@@ -114,11 +111,8 @@ function rrule(::typeof(*), x::Number, y::Number)
114111
return x * y, times_pullback
115112
end
116113

117-
function frule(::typeof(identity), x)
118-
function identity_pushforward(_, ẏ)
119-
return
120-
end
121-
return x, identity_pushforward
114+
function frule(::typeof(identity), x, _, ẏ)
115+
return x, ẏ
122116
end
123117

124118
function rrule(::typeof(identity), x)

src/rulesets/Base/broadcast.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
66
=#
77
function _cast_diff(f, x)
88
function element_rule(u)
9-
fu, du = frule(f, u)
10-
fu, extern(du(NamedTuple(), One()))
9+
dself = Zero()
10+
fu, du = frule(f, u, dself, One())
11+
fu, extern(du)
1112
end
1213
results = broadcast(element_rule, x)
1314
return first.(results), last.(results)
1415
end
1516

16-
function frule(::typeof(broadcast), f, x)
17+
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
1718
Ω, ∂x = _cast_diff(f, x)
18-
function broadcast_pushforward(_, Δf, Δx)
19-
return Δx .* ∂x
20-
end
21-
return Ω, broadcast_pushforward
19+
return Ω, Δx .* ∂x
2220
end
2321

2422
function rrule(::typeof(broadcast), f, x)

src/rulesets/Base/mapreduce.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@ end
5050
##### `sum`
5151
#####
5252

53-
function frule(::typeof(sum), x)
54-
function sum_pushforward(_, ẋ)
55-
return sum(ẋ)
56-
end
57-
return sum(x), sum_pushforward
53+
function frule(::typeof(sum), x, _, ẋ)
54+
return sum(x), sum(ẋ)
5855
end
5956

6057
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
1111
##### `BLAS.dot`
1212
#####
1313

14-
frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
14+
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)
1515

1616
rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)
1717

@@ -35,12 +35,9 @@ end
3535
##### `BLAS.nrm2`
3636
#####
3737

38-
function frule(::typeof(BLAS.nrm2), x)
38+
function frule(::typeof(BLAS.nrm2), x, _, Δ)
3939
Ω = BLAS.nrm2(x)
40-
function nrm2_pushforward(_, Δx)
41-
return sum(Δx * cast(@thunk(x * inv(Ω))))
42-
end
43-
return Ω, nrm2_pushforward
40+
return Ω, sum(Δx * cast(@thunk(x * inv(Ω))))
4441
end
4542

4643
function rrule(::typeof(BLAS.nrm2), x)
@@ -70,11 +67,8 @@ end
7067
##### `BLAS.asum`
7168
#####
7269

73-
function frule(::typeof(BLAS.asum), x)
74-
function asum_pushforward(_, Δx)
75-
return sum(cast(sign, x) * Δx)
76-
end
77-
return BLAS.asum(x), asum_pushforward
70+
function frule(::typeof(BLAS.asum), x, _, Δx)
71+
return BLAS.asum(x), sum(cast(sign, x) * Δx)
7872
end
7973

8074
function rrule(::typeof(BLAS.asum), x)

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
88
##### `dot`
99
#####
1010

11-
function frule(::typeof(dot), x, y)
12-
function dot_pushforward(Δself, Δx, Δy)
13-
return sum(Δx .* y) + sum(x .* Δy)
14-
end
15-
return dot(x, y), dot_pushforward
11+
function frule(::typeof(dot), x, y, _, Δx, Δy)
12+
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
1613
end
1714

1815
function rrule(::typeof(dot), x, y)
@@ -26,13 +23,10 @@ end
2623
##### `inv`
2724
#####
2825

29-
function frule(::typeof(inv), x::AbstractArray)
26+
function frule(::typeof(inv), x::AbstractArray, _, Δx)
3027
Ω = inv(x)
3128
m = @thunk(-Ω)
32-
function inv_pushforward(_, Δx)
33-
return m * Δx * Ω
34-
end
35-
return Ω, inv_pushforward
29+
return Ω, m * Δx * Ω
3630
end
3731

3832
function rrule(::typeof(inv), x::AbstractArray)
@@ -48,14 +42,11 @@ end
4842
##### `det`
4943
#####
5044

51-
function frule(::typeof(det), x)
45+
function frule(::typeof(det), x, _, ẋ)
5246
Ω = det(x)
53-
function det_pushforward(_, ẋ)
54-
# TODO Performance optimization: probably there is an efficent
55-
# way to compute this trace without during the full compution within
56-
return Ω * tr(inv(x) * ẋ)
57-
end
58-
return Ω, det_pushforward
47+
# TODO Performance optimization: probably there is an efficent
48+
# way to compute this trace without during the full compution within
49+
return Ω, Ω * tr(inv(x) * ẋ)
5950
end
6051

6152
function rrule(::typeof(det), x)
@@ -70,12 +61,9 @@ end
7061
##### `logdet`
7162
#####
7263

73-
function frule(::typeof(logdet), x)
64+
function frule(::typeof(logdet), x, _, Δx)
7465
Ω = logdet(x)
75-
function logdet_pushforward(_, Δx)
76-
return tr(inv(x) * Δx)
77-
end
78-
return Ω, logdet_pushforward
66+
return Ω, tr(inv(x) * Δx)
7967
end
8068

8169
function rrule(::typeof(logdet), x)
@@ -90,11 +78,8 @@ end
9078
##### `trace`
9179
#####
9280

93-
function frule(::typeof(tr), x)
94-
function tr_pushforward(_, Δx)
95-
return tr(Δx)
96-
end
97-
return tr(x), tr_pushforward
81+
function frule(::typeof(tr), x, _, Δx)
82+
return tr(x), tr(Δx)
9883
end
9984

10085
function rrule(::typeof(tr), x)

test/rulesets/Base/base.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,32 @@
5151
test_scalar(acscd, 1/x)
5252
test_scalar(acotd, 1/x)
5353
end
54-
55-
@testset "sincos" begin
56-
x, Δx, x̄ = randn(3)
57-
Δz = (randn(), randn())
58-
59-
frule_test(sincos, (x, Δx))
60-
rrule_test(sincos, Δz, (x, x̄))
54+
@testset "Multivariate" begin
55+
@testset "atan2" begin
56+
# https://en.wikipedia.org/wiki/Atan2
57+
x, y = rand(2)
58+
ratan = atan(x, y)
59+
u = x^2 + y^2
60+
datan = y/u - 2x/u
61+
62+
r, ṙ = frule(atan, x, y, Zero(), 1, 2)
63+
@test r === ratan
64+
@test=== datan
65+
66+
r, pullback = rrule(atan, x, y)
67+
@test r === ratan
68+
dself, df1, df2 = pullback(1)
69+
@test dself == NO_FIELDS
70+
@test df1 + 2df2 === datan
71+
end
72+
73+
@testset "sincos" begin
74+
x, Δx, x̄ = randn(3)
75+
Δz = (randn(), randn())
76+
77+
frule_test(sincos, (x, Δx))
78+
rrule_test(sincos, Δz, (x, x̄))
79+
end
6180
end
6281
end # Trig
6382

@@ -128,8 +147,7 @@
128147
_, x̄ = pb(10.5)
129148
@test extern(x̄) == 0
130149

131-
_, pf = frule(sign, 0.0)
132-
= pf(NamedTuple(), 10.5)
150+
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
133151
@test extern(ẏ) == 0
134152
end
135153
end

test/rulesets/Base/broadcast.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
end
2323
@testset "frule" begin
2424
x = rand(3, 3)
25-
y, pushforward = frule(broadcast, sin, x)
25+
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
2626
@test y == sin.(x)
27-
28-
= pushforward(NamedTuple(), NamedTuple(), One())
2927
@test extern(ẏ) == cos.(x)
3028
end
3129
end

test/test_util.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
6969
Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...)
7070
dΩ_ad = dΩ_ad .+ 0
7171
@test f(xs...) == Ω
72-
dΩ_ad = pushforward(NamedTuple(), ẋs...)
7372

7473
# Correctness testing via finite differencing.
7574
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))

0 commit comments

Comments
 (0)