Skip to content

Commit 2fd8a91

Browse files
YingboMa"Shashi Gowda"
andcommitted
Improve code coverage by removing source code and tests
Co-authored-by: "Shashi Gowda" <gowda@mit.edu> Co-authored-by: "Yingbo Ma" <mayingbo5@gmail.com>
1 parent d94863c commit 2fd8a91

File tree

7 files changed

+40
-169
lines changed

7 files changed

+40
-169
lines changed

src/dual_context.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,8 @@ end
3030
@inline _partials(::Any, x) = Zero()
3131
@inline _partials(::Tag{T}, d::Dual{Tag{T}}) where T = d.partials
3232

33-
#=
34-
function Wirtinger(primal::Partials, conjugate::Union{Number,ChainRulesCore.AbstractDifferential})
35-
return Partials(map(p->Wirtinger(p, conjugate), primal.values))
36-
end
37-
function Wirtinger(primal::Partials, conjugate::Partials)
38-
return Partials(map((p, c)->Wirtinger(p, c), primal.values, conjugate.values))
39-
end
40-
=#
33+
Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)
34+
4135
@inline _values(S, xs) = map(x->_value(S, x), xs)
4236
@inline _partialss(S, xs) = map(x->_partials(S, x), xs)
4337

@@ -77,7 +71,7 @@ end
7771

7872
# call frule to see if there is a rule for this call:
7973
if ctx.metadata isa Tag
80-
ctx1 = similarcontext(ctx, metadata=innertag(ctx.metadata))
74+
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))
8175

8276
# we call frule with an older context because the Dual numbers may
8377
# themselves contain Dual numbers that were created in an older context
@@ -92,7 +86,7 @@ end
9286
# a closure which closes over a Dual number, hence we call
9387
# recurse. Recurse overdubs the calls inside `f` and not `f` itself
9488

95-
return Cassette.recurse(ctx, f, args...)
89+
return Cassette.overdub(ctx, f, args...)
9690
else
9791
# this means there exists an frule for this specific call.
9892
# frule_result is then a tuple (val, pushforward) where val
@@ -108,7 +102,7 @@ end
108102
# we call it with the older context because the partials
109103
# might themselves be Duals from older contexts
110104
if ctx.metadata isa Tag
111-
ctx1 = similarcontext(ctx, metadata=innertag(ctx.metadata))
105+
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))
112106
∂s = overdub(ctx1, pushforward, Zero(), ps...)
113107
else
114108
∂s = pushforward(Zero(), ps...)
@@ -119,10 +113,10 @@ end
119113
# a tuple, we handle both cases:
120114
return if ∂s isa Tuple
121115
map(val, ∂s) do v, ∂
122-
Dual{typeof(tag)}(v, ∂)
116+
Dual{Tag{T}}(v, ∂)
123117
end
124118
else
125-
Dual{typeof(tag)}(val, ∂s)
119+
Dual{Tag{T}}(val, ∂s)
126120
end
127121
end
128122
end
@@ -145,7 +139,7 @@ end
145139
end
146140
# none of the arguments have the same tag as the context
147141
# try with the parent context
148-
ctx1 = similarcontext(ctx, metadata=innertag(ctx.metadata))
142+
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))
149143
return overdub(ctx1, f, args...)
150144
else
151145
# call ChainRules.frule to execute `f` and
@@ -156,14 +150,16 @@ end
156150

157151
function dualrun(f, args...)
158152
ctx = dualcontext()
159-
overdub(ctx, f, args...)
153+
return overdub(ctx, f, args...)
160154
end
161155

162156
const BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
163157

164158
for pred in BINARY_PREDICATES
165-
@eval begin
166-
isinteresting(ctx::TaggedCtx, ::typeof($(pred)), x, y) = anydual(x, y)
167-
alternative(ctx::TaggedCtx, ::typeof($(pred)), x, y) = overdub(ctx, () -> $pred(value(x), value(y)))
159+
@eval function alternative(ctx::TaggedCtx, ::typeof($(pred)), x, y)
160+
vx, vy = value(x), value(y)
161+
return isinteresting(ctx, $pred, vx, vy) ?
162+
alternative(ctx, $pred, vx, vy) :
163+
$pred(vx, vy)
168164
end
169165
end

src/dualarray.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function Base.print_array(io::IO, da::DualArray)
2222
Base.println(io)
2323
for i=1:npartials(da)
2424
Base.printstyled(io,"Partials($i):\n", bold=false, color=3)
25-
Base.print_array(ioc, partials.(da, i))
25+
Base.print_array(ioc, getindex.(partials.(da), i))
2626
i !== npartials(da) && Base.println(io)
2727
end
2828
return nothing
@@ -100,8 +100,9 @@ Base.@propagate_inbounds function Base.setindex!(d::DualArray, dual, i::Int...)
100100
dd[ii] = value(dual)
101101

102102
slice_len = length(d)
103+
ps = partials(dual)
103104
for j = 1:npartials(d)
104-
dd[j * slice_len + ii] = partials(dual, j)
105+
dd[j * slice_len + ii] = ps[j]
105106
end
106107
return dual
107108
end

src/dualnumber.jl

Lines changed: 10 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ Dual(::Number1, [1]) --> Dual(1, Partials([1,])) # fail Any
2121

2222
const NANSAFE_MODE_ENABLED = false
2323

24-
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
25-
2624
const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
2725

2826
const DEFAULT_CHUNK_THRESHOLD = 12
@@ -59,37 +57,11 @@ chunksize(::Chunk{N}) where {N} = N
5957
# Dual #
6058
########
6159

62-
"""
63-
ForwardDiff2.can_dual(V::Type)
64-
65-
Determines whether the type V is allowed as the scalar type in a
66-
Dual. By default, only `<:Real` types are allowed.
67-
"""
68-
can_dual(::Type{<:Real}) = true
69-
can_dual(::Type) = false
70-
7160
struct Dual{T,V,P} <: Real
7261
value::V
7362
partials::P
7463
end
7564

76-
##############
77-
# Exceptions #
78-
##############
79-
80-
struct DualMismatchError{A,B} <: Exception
81-
a::A
82-
b::B
83-
end
84-
85-
Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} =
86-
print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)")
87-
88-
@noinline function throw_cannot_dual(V::Type)
89-
throw(ArgumentError("Cannot create a dual over scalar type $V." *
90-
" If the type behaves as a scalar, define FowardDiff.can_dual."))
91-
end
92-
9365
################
9466
# Constructors #
9567
################
@@ -98,48 +70,26 @@ end
9870
# intercept calls to dualtag
9971
dualtag() = nothing
10072

101-
@inline function Dual{T}(value::V, partials::P) where {T,V,P}
102-
Q = promote_type(bottomvaluetype(V), eltype(P))
103-
partials′ = convert.(Q, partials)
104-
Dual{T,V,typeof(partials′)}(value, partials′)
105-
end
73+
@inline Dual{T}(value::V, partials::P) where {T,V,P} = Dual{T,V,P}(value, partials)
74+
10675
#@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,P,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
10776
@inline Dual(value, partials) = Dual{typeof(dualtag())}(value, partials)
10877

109-
# we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
110-
@inline Dual{T,V,P}(x::Dual{T,V,P}) where {T,V,P} = x
111-
@inline Dual{T,V,P}(x) where {T,V,P} = convert(Dual{T,V,P}, x)
112-
@inline Dual{T,V,P}(x::Number) where {T,V,P} = convert(Dual{T,V,P}, x)
113-
@inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x)
114-
11578
##############################
11679
# Utility/Accessor Functions #
11780
##############################
11881

119-
@inline bottomvaluetype(::Type{Dual{T,V,P}}) where {T,V,P} = bottomvaluetype(V)
120-
@inline bottomvaluetype(::Type{T}) where {T<:Any} = T
12182
@inline value(x) = x
12283
@inline value(d::Dual) = d.value
12384

124-
@inline partials(d::Dual) = d.partials
125-
@inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i]
126-
@inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j]
127-
@inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...)
128-
129-
@inline npartials(d::Dual) = length(d.partials)
130-
131-
@inline order(::Type{V}) where {V} = 0
132-
@inline order(::Type{Dual{T,V,P}}) where {T,V,P} = 1 + order(V)
133-
13485
@inline valtype(::V) where {V} = V
13586
@inline valtype(::Type{V}) where {V} = V
13687
@inline valtype(::Dual{T,V}) where {T,V} = V
13788
@inline valtype(::Type{Dual{T,V,P}}) where {T,V,P} = V
13889

139-
@inline tagtype(::V) where {V} = Nothing
140-
@inline tagtype(::Type{V}) where {V} = Nothing
141-
@inline tagtype(::Dual{T}) where {T} = T
142-
@inline tagtype(::Type{Dual{T,V,P}}) where {T,V,P} = T
90+
@inline partials(d::Dual) = d.partials
91+
92+
@inline npartials(d::Dual) = (ps = d.partials) isa Wirtinger ? 1 : length(ps)
14393

14494
#####################
14595
# Generic Functions #
@@ -152,16 +102,16 @@ Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))
152102

153103
Base.rtoldefault(::Type{D}) where {D<:Dual} = Base.rtoldefault(valtype(D))
154104

155-
Base.floor(::Type{R}, d::Dual) where {R<:Real} = floor(R, value(d))
105+
Base.floor(::Type{R}, d::Dual) where {R<:Number} = floor(R, value(d))
156106
Base.floor(d::Dual) = floor(value(d))
157107

158-
Base.ceil(::Type{R}, d::Dual) where {R<:Real} = ceil(R, value(d))
108+
Base.ceil(::Type{R}, d::Dual) where {R<:Number} = ceil(R, value(d))
159109
Base.ceil(d::Dual) = ceil(value(d))
160110

161-
Base.trunc(::Type{R}, d::Dual) where {R<:Real} = trunc(R, value(d))
111+
Base.trunc(::Type{R}, d::Dual) where {R<:Number} = trunc(R, value(d))
162112
Base.trunc(d::Dual) = trunc(value(d))
163113

164-
Base.round(::Type{R}, d::Dual) where {R<:Real} = round(R, value(d))
114+
Base.round(::Type{R}, d::Dual) where {R<:Number} = round(R, value(d))
165115
Base.round(d::Dual) = round(value(d))
166116

167117
Base.hash(d::Dual) = hash(value(d))
@@ -205,49 +155,11 @@ end
205155
# Promotion/Conversion #
206156
########################
207157

208-
function Base.promote_rule(::Type{Dual{T1,V1,P1}},
209-
::Type{Dual{T2,V2,P2}}) where {T1,V1,P1,T2,V2,P2}
210-
# V1 and V2 might themselves be Dual types
211-
if T2 T1
212-
Dual{T1,promote_type(V1,Dual{T2,V2,P2}),P1}
213-
else
214-
Dual{T2,promote_type(V2,Dual{T1,V1,P1}),P2}
215-
end
216-
end
217-
218-
function Base.promote_rule(::Type{Dual{T,A,P}},
219-
::Type{Dual{T,B,P}}) where {T,A,B,P}
220-
return Dual{T,promote_type(A, B),P}
221-
end
222-
223-
for R in (Irrational, Real, BigFloat, Bool)
224-
if isconcretetype(R) # issue #322
225-
@eval begin
226-
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T,promote_type($R, V),P}
227-
Base.promote_rule(::Type{Dual{T,V,P}}, ::Type{$R}) where {T,V,P} = Dual{T,promote_type(V, $R),P}
228-
end
229-
else
230-
@eval begin
231-
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,P}}) where {R<:$R,T,V,P} = Dual{T,promote_type(R, V),P}
232-
Base.promote_rule(::Type{Dual{T,V,P}}, ::Type{R}) where {T,V,P,R<:$R} = Dual{T,promote_type(V, R),P}
233-
end
234-
end
235-
end
236-
237158
Base.convert(::Type{Dual{T,V,P}}, d::Dual{T}) where {T,V,P} = Dual{T}(convert(V, value(d)), convert(P, partials(d)))
238159
Base.convert(::Type{Dual{T,V,P}}, x) where {T,V,P} = Dual{T}(convert(V, x), zero(P))
239160
Base.convert(::Type{Dual{T,V,P}}, x::Number) where {T,V,P} = Dual{T}(convert(V, x), zero(P))
240161
Base.convert(::Type{D}, d::D) where {D<:Dual} = d
241162

242-
Base.float(d::Dual{T}) where {T} = Dual{T}(value(d), map(float, partials(d)))
243-
Base.AbstractFloat(d::Dual{T}) where {T} = Dual{T}(convert(AbstractFloat, value(d)), map(x->convert(AbstractFloat, x), partials(d)))
244-
245-
###################################
246-
# General Mathematical Operations #
247-
###################################
248-
249-
@inline Base.conj(d::Dual) = d
250-
251163
###################
252164
# Pretty Printing #
253165
###################
@@ -258,7 +170,7 @@ function tag_show(t, n=0)
258170
if t isa Nothing
259171
return subscript_num(n)
260172
elseif t isa Tag
261-
tag_show(innertag(t), n+1)
173+
tag_show(oldertag(t), n+1)
262174
else
263175
return "{" * repr(t) * "}" * subscript_num(n)
264176
end

src/tag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ struct Tag{Parent} end
44
@inline _find_dual(tag::T, l, i, x::Dual{T}, xs...) where {T} = i
55
@inline _find_dual(tag::T, l, i, x, xs...) where {T} = _find_dual(tag, l, i+1, xs...)
66

7-
@inline innertag(::Tag{Tag{T}}) where T = Tag{T}()
8-
@inline innertag(::Tag{T}) where T = nothing
7+
@inline oldertag(::Tag{Tag{T}}) where T = Tag{T}()
8+
@inline oldertag(::Tag{T}) where T = nothing
99

1010
@inline function find_dual(T::TT, xs...) where {TT<:Tag}
1111
_find_dual(T, length(xs), 1, xs...)

test/dualarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using StaticArrays: SVector
1717
dx .= sin.(x)
1818
@test_broken sin.(x) == dx
1919
@test sin.(_x) == value.(dx)
20-
@test cos.(_x) == partials.(dx, 1)
20+
@test cos.(_x) == first.(partials.(dx))
2121
end
2222
end
2323

0 commit comments

Comments
 (0)