Skip to content

Commit bd014f4

Browse files
committed
Initial removal of stuff
Fix test as per Lyndon's suggestion Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> fix type fix left over one
1 parent db5323f commit bd014f4

File tree

6 files changed

+141
-182
lines changed

6 files changed

+141
-182
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export wirtinger_conjugate, wirtinger_primal, refine_differential
66
export @scalar_rule, @thunk
77
export extern, store!
88
export unthunk
9-
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
9+
export Wirtinger, Zero, DoesNotExist, Thunk, InplaceableThunk
1010
export NO_FIELDS
1111

1212
include("differentials.jl")

src/differential_arithmetic.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
3636
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
3737
end
3838

39-
for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
39+
for T in (:Zero, :DoesNotExist, :AbstractThunk, :Any)
4040
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
4141
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
4242

@@ -47,7 +47,7 @@ end
4747

4848
Base.:+(::Zero, b::Zero) = Zero()
4949
Base.:*(::Zero, ::Zero) = Zero()
50-
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
50+
for T in (:DoesNotExist, :AbstractThunk, :Any)
5151
@eval Base.:+(::Zero, b::$T) = b
5252
@eval Base.:+(a::$T, ::Zero) = a
5353

@@ -58,7 +58,7 @@ end
5858

5959
Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
6060
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
61-
for T in (:One, :AbstractThunk, :Any)
61+
for T in (:AbstractThunk, :Any)
6262
@eval Base.:+(::DoesNotExist, b::$T) = b
6363
@eval Base.:+(a::$T, ::DoesNotExist) = a
6464

@@ -67,17 +67,6 @@ for T in (:One, :AbstractThunk, :Any)
6767
end
6868

6969

70-
Base.:+(a::One, b::One) = extern(a) + extern(b)
71-
Base.:*(::One, ::One) = One()
72-
for T in (:AbstractThunk, :Any)
73-
@eval Base.:+(a::One, b::$T) = extern(a) + b
74-
@eval Base.:+(a::$T, b::One) = a + extern(b)
75-
76-
@eval Base.:*(::One, b::$T) = b
77-
@eval Base.:*(a::$T, ::One) = a
78-
end
79-
80-
8170
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
8271
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
8372
for T in (:Any,)

src/differentials.jl

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ The two fields of the returned instance can be accessed generically via the
6363
struct Wirtinger{P,C} <: AbstractDifferential
6464
primal::P
6565
conjugate::C
66-
function Wirtinger(primal::Union{Number,AbstractDifferential},
67-
conjugate::Union{Number,AbstractDifferential})
66+
function Wirtinger(
67+
primal::Union{Number,AbstractDifferential},
68+
conjugate::Union{Number,AbstractDifferential},
69+
)
6870
return new{typeof(primal),typeof(conjugate)}(primal, conjugate)
6971
end
7072
end
@@ -75,10 +77,13 @@ wirtinger_primal(x) = x
7577
wirtinger_conjugate(x::Wirtinger) = x.conjugate
7678
wirtinger_conjugate(::Any) = Zero()
7779

78-
extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))
80+
function extern(x::Wirtinger)
81+
return throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))
82+
end
7983

80-
Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
81-
broadcastable(w.conjugate))
84+
function Base.Broadcast.broadcastable(w::Wirtinger)
85+
return Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate))
86+
end
8287

8388
Base.iterate(x::Wirtinger) = (x, nothing)
8489
Base.iterate(::Wirtinger, ::Any) = nothing
@@ -104,7 +109,6 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
104109
Base.iterate(x::Zero) = (x, nothing)
105110
Base.iterate(::Zero, ::Any) = nothing
106111

107-
108112
#####
109113
##### `DoesNotExist`
110114
#####
@@ -127,25 +131,6 @@ Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
127131
Base.iterate(x::DoesNotExist) = (x, nothing)
128132
Base.iterate(::DoesNotExist, ::Any) = nothing
129133

130-
#####
131-
##### `One`
132-
#####
133-
134-
"""
135-
One()
136-
The Differential which is the multiplicative identity.
137-
Basically, this represents `1`.
138-
"""
139-
struct One <: AbstractDifferential end
140-
141-
extern(x::One) = true # true is a strong 1.
142-
143-
Base.Broadcast.broadcastable(::One) = Ref(One())
144-
145-
Base.iterate(x::One) = (x, nothing)
146-
Base.iterate(::One, ::Any) = nothing
147-
148-
149134
#####
150135
##### `AbstractThunk
151136
#####

test/differentials.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
@testset "Differentials" begin
1+
@testset "differentials" begin
22
@testset "Wirtinger" begin
33
w = Wirtinger(1+1im, 2+2im)
44
@test wirtinger_primal(w) == 1+1im
55
@test wirtinger_conjugate(w) == 2+2im
66
@test w + w == Wirtinger(2+2im, 4+4im)
77

8-
@test w + One() == w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im)
9-
@test w * One() == One() * w == w
8+
@test w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im)
109
@test w * 2 == 2 * w == Wirtinger(2 + 2im, 4 + 4im)
1110

1211
# TODO: other + methods stack overflow
@@ -33,22 +32,6 @@
3332
@test broadcastable(z) isa Ref{Zero}
3433
@test conj(z) == z
3534
end
36-
@testset "One" begin
37-
o = One()
38-
@test extern(o) === true
39-
@test o + o == 2
40-
@test o + 1 == 2
41-
@test 1 + o == 2
42-
@test o * o == o
43-
@test o * 1 == 1
44-
@test 1 * o == 1
45-
for x in o
46-
@test x === o
47-
end
48-
@test broadcastable(o) isa Ref{One}
49-
@test conj(o) == o
50-
end
51-
5235
@testset "Thunk" begin
5336
@test @thunk(3) isa Thunk
5437

@@ -71,7 +54,6 @@
7154
@test (@thunk(3))() == 3
7255
@test (@thunk(@thunk(3)))() isa Thunk
7356
end
74-
7557
@testset "erroring thunks should include the source in the backtrack" begin
7658
expected_line = (@__LINE__) + 2 # for testing it is at right palce
7759
try
@@ -109,7 +91,7 @@
10991
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4
11092

11193
# For most differentials, in most domains, this does nothing
112-
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
94+
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], Zero(), 0.0)
11395
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
11496
@test refine_differential(𝒟, der) === der
11597
end

0 commit comments

Comments
 (0)