Skip to content

Commit d829da4

Browse files
oxinaboxwilltebbutt
authored andcommitted
Remove Casted (#55)
* Remove Casted * bump version * Update Project.toml
1 parent a133468 commit d829da4

File tree

6 files changed

+6
-47
lines changed

6 files changed

+6
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.3.0"
3+
version = "0.4.0-DEV"
44

55
[compat]
66
julia = "^1.0"

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export frule, rrule
55
export wirtinger_conjugate, wirtinger_primal, refine_differential
66
export @scalar_rule, @thunk
77
export extern, cast, store!
8-
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
8+
export Wirtinger, Zero, One, DNE, Thunk, InplaceableThunk
99
export NO_FIELDS
1010

1111
include("differentials.jl")

src/differential_arithmetic.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
10+
The precidence goes: (:Wirtinger, :Zero, :DNE, :One, :AbstractThunk, :Any)
1111
Thus each of the @eval loops creating definitions of + and *
1212
defines the combination this type with all types of lower precidence.
1313
This means each eval loops is 1 item smaller than the previous.
@@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
3636
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
3737
end
3838

39-
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
39+
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
4040
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
4141
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
4242

@@ -45,17 +45,6 @@ for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
4545
end
4646

4747

48-
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
49-
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
50-
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
51-
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
52-
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))
53-
54-
@eval Base.:*(a::Casted, b::$T) = Casted(broadcasted(*, a.value, b))
55-
@eval Base.:*(a::$T, b::Casted) = Casted(broadcasted(*, a, b.value))
56-
end
57-
58-
5948
Base.:+(::Zero, b::Zero) = Zero()
6049
Base.:*(::Zero, ::Zero) = Zero()
6150
for T in (:DNE, :One, :AbstractThunk, :Any)

src/differentials.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,33 +86,6 @@ Base.iterate(::Wirtinger, ::Any) = nothing
8686
# TODO: define `conj` for` `Wirtinger`
8787
Base.conj(x::Wirtinger) = throw(MethodError(conj, x))
8888

89-
90-
#####
91-
##### `Casted`
92-
#####
93-
94-
"""
95-
Casted(v)
96-
97-
This differential wraps another differential (including a number-like type)
98-
to indicate that it should be lazily broadcast.
99-
"""
100-
struct Casted{V} <: AbstractDifferential
101-
value::V
102-
end
103-
104-
cast(x) = Casted(x)
105-
cast(f, args...) = Casted(broadcasted(f, args...))
106-
107-
extern(x::Casted) = materialize(broadcasted(extern, x.value))
108-
109-
Base.Broadcast.broadcastable(x::Casted) = x.value
110-
111-
Base.iterate(x::Casted) = iterate(x.value)
112-
Base.iterate(x::Casted, state) = iterate(x.value, state)
113-
114-
Base.conj(x::Casted) = cast(conj, x.value)
115-
11689
#####
11790
##### `Zero`
11891
#####

src/operations.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,11 @@ so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
2222
This function is overloadable by using a [`InplaceThunk`](@ref).
2323
See also: [`accumulate`](@ref), [`store!`](@ref).
2424
"""
25-
function accumulate!(Δ, ∂)
26-
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
27-
end
25+
accumulate!(Δ, ∂) = store!(Δ, accumulate(Δ, ∂))
2826

2927
accumulate!::Number, ∂) = accumulate(Δ, ∂)
3028

3129

32-
3330
"""
3431
store!(Δ, ∂)
3532

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ChainRulesCore
44
using LinearAlgebra: Diagonal
55
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
66
Wirtinger, wirtinger_primal, wirtinger_conjugate,
7-
Zero, One, Casted, cast, DNE, Thunk
7+
Zero, One, DNE, Thunk
88
using Base.Broadcast: broadcastable
99

1010
@testset "ChainRulesCore" begin

0 commit comments

Comments
 (0)