Skip to content

Commit 9e4cb76

Browse files
oxinaboxwilltebbutt
authored andcommitted
Removes rules depending on casted (#120)
* Remove used of Casted * bump version * Update Project.toml
1 parent 296723c commit 9e4cb76

File tree

7 files changed

+19
-29
lines changed

7 files changed

+19
-29
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.2.1"
3+
version = "0.2.2-DEV"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
ChainRulesCore = "^0.3"
13+
ChainRulesCore = "0.3, 0.4"
1414
FiniteDifferences = "^0.7"
1515
julia = "^1.0"
1616

src/rulesets/Base/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ end
1616
function frule(::typeof(broadcast), f, x)
1717
Ω, ∂x = _cast_diff(f, x)
1818
function broadcast_pushforward(_, Δf, Δx)
19-
return Δx * cast(∂x)
19+
return Δx .* ∂x
2020
end
2121
return Ω, broadcast_pushforward
2222
end
2323

2424
function rrule(::typeof(broadcast), f, x)
2525
values, derivs = _cast_diff(f, x)
2626
function broadcast_pullback(ΔΩ)
27-
return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs)))
27+
return (NO_FIELDS, DNE(), @thunk(ΔΩ .* derivs))
2828
end
2929
return values, broadcast_pullback
3030
end

src/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959

6060
function rrule(::typeof(sum), x)
6161
function sum_pullback(ȳ)
62-
return (NO_FIELDS, cast(ȳ))
62+
return (NO_FIELDS, @thunk(fill(ȳ, size(x))))
6363
end
6464
return sum(x), sum_pullback
6565
end

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
1010

1111
function frule(::typeof(dot), x, y)
1212
function dot_pushforward(Δself, Δx, Δy)
13-
return sum(Δx * cast(y)) + sum(cast(x) * Δy)
13+
return sum(Δx .* y) + sum(x .* Δy)
1414
end
1515
return dot(x, y), dot_pushforward
1616
end
1717

1818
function rrule(::typeof(dot), x, y)
1919
function dot_pullback(ΔΩ)
20-
return (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,)
20+
return (NO_FIELDS, @thunk(ΔΩ .* y), @thunk(x .* ΔΩ))
2121
end
2222
return dot(x, y), dot_pullback
2323
end

test/rulesets/Base/base.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136

137137
= rand(3, 5)
138138
(ds, dx, dy) = pullback(z̄)
139-
139+
140140
@test ds === NO_FIELDS
141141

142142
@test extern(dx) == extern(accumulate(zeros(3, 2), dx))
@@ -147,22 +147,13 @@
147147
end
148148

149149
@testset "hypot(x, y)" begin
150-
x, y = rand(2)
151-
h, pushforward = frule(hypot, x, y)
152-
dxy(x, y) = pushforward(NamedTuple(), x, y)
153-
154-
@test extern(dxy(One(), Zero())) === x / h
155-
@test extern(dxy(Zero(), One())) === y / h
156-
157-
cx, cy = cast((One(), Zero())), cast((Zero(), One()))
158-
dx, dy = extern(dxy(cx, cy))
159-
@test dx === x / h
160-
@test dy === y / h
161-
162-
cx, cy = cast((rand(), Zero())), cast((Zero(), rand()))
163-
dx, dy = extern(dxy(cx, cy))
164-
@test dx === x / h * cx.value[1]
165-
@test dy === y / h * cy.value[2]
150+
rng = MersenneTwister(123456)
151+
x, Δx, x̄ = randn(rng, 3)
152+
y, Δy, ȳ = randn(rng, 3)
153+
Δz = randn(rng)
154+
155+
frule_test(hypot, (x, Δx), (y, Δy))
156+
rrule_test(hypot, Δz, (x, x̄), (y, ȳ))
166157
end
167158

168159
@testset "identity" begin

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ using Statistics
1010
using Test
1111

1212
# For testing purposes we use a lot of
13-
using ChainRulesCore: cast, extern, accumulate, accumulate!, store!, @scalar_rule,
13+
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
1414
Wirtinger, wirtinger_primal, wirtinger_conjugate,
15-
Zero, One, Casted, DNE, Thunk, AbstractDifferential
15+
Zero, One, DNE, Thunk, AbstractDifferential
1616

1717
include("test_util.jl")
1818

test/test_util.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,11 @@ end
193193
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
194194
error("Finite differencing with Wirtinger rules not implemented")
195195
end
196-
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
197-
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
198-
end
196+
199197
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
200198
error("Tried to differentiate w.r.t. a DNE")
201199
end
200+
202201
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
203202
return isapprox(extern(d_ad), d_fd; kwargs...)
204203
end

0 commit comments

Comments
 (0)