Skip to content

Commit eb3c292

Browse files
Update test/rules.jl
error ratehr than Assert cleanup Update src/rule_definition_tools.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> Add more complex Wirtinger Scalar Rule Test
1 parent de2bb62 commit eb3c292

File tree

3 files changed

+108
-28
lines changed

3 files changed

+108
-28
lines changed

src/differentials.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,10 @@ const NO_FIELDS = DNE()
241241
"""
242242
differential(𝒟::Type, der)
243243
244-
For some differential (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
245-
convert it to another differential that is more suited for the domain given by
246-
the type 𝒟.
244+
Converts, if required, a differential object `der`
245+
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
246+
to another differential that is more suited for the domain given by the type 𝒟.
247+
Often this will behave as the identity function on `der`.
247248
"""
248249
function differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
249250
return wirtinger_primal(w) + wirtinger_conjugate(w)

src/rule_definition_tools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ macro scalar_rule(call, maybe_setup, partials...)
183183
if Meta.isexpr(partial, :tuple)
184184
partial
185185
else
186-
@assert length(inputs) == 1
186+
length(inputs) == 1 || error("Invalid use of `@scalar_rule`")
187187
Expr(:tuple, partial)
188188
end
189189
end
@@ -192,7 +192,7 @@ macro scalar_rule(call, maybe_setup, partials...)
192192
# Main body: defining the results of the frule/rrule
193193

194194
# An expression that when evaluated will return the type of the input domain.
195-
# Multiple repetitions of this expression should optimize ot. But if it does not then
195+
# Multiple repetitions of this expression should optimize out. But if it does not then
196196
# may need to move its definition into the body of the `rrule`/`frule`
197197
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))
198198

test/rules.jl

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,44 +29,123 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
2929
Tuple{typeof(rrule),typeof(cool),String}])
3030
@test cool_methods == only_methods
3131

32-
frx, fr = frule(cool, 1)
32+
frx, cool_pushforward = frule(cool, 1)
3333
@test frx == 2
34-
@test fr(NamedTuple(), 1) == (1,)
35-
rrx, (rr) = rrule(cool, 1)
36-
self, rr1 = rr(1)
34+
@test cool_pushforward(NamedTuple(), 1) == (1,)
35+
rrx, cool_pullback = rrule(cool, 1)
36+
self, rr1 = cool_pullback(1)
3737
@test self == NO_FIELDS
3838
@test rrx == 2
3939
@test rr1 == 1
4040
end
4141

4242

43-
@testset "Wirtinger scalar_rule" begin
43+
@testset "Basic Wirtinger scalar_rule" begin
4444
myabs2(x) = abs2(x)
4545
@scalar_rule(myabs2(x), Wirtinger(x', x))
4646

47-
# real input
48-
x = rand(Float64)
49-
f, pushforward = frule(myabs2, x)
50-
@test f === x^2
47+
@testset "real input" begin
48+
# even though our rule was define in terms of Wirtinger,
49+
# pushforward result will be real as real (even if seed is Compex)
5150

52-
df = @inferred pushforward(NamedTuple(), One())
53-
@test df === (x + x,)
51+
x = rand(Float64)
52+
f, myabs2_pushforward = frule(myabs2, x)
53+
@test f === x^2
5454

55+
Δ = One()
56+
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
57+
@test df === (x + x,)
5558

56-
Δ = rand(Complex{Int64})
57-
df = @inferred pushforward(NamedTuple(), Δ)
58-
@test df ===* (x + x),)
59+
Δ = rand(Complex{Int64})
60+
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
61+
@test df ===* (x + x),)
62+
end
5963

64+
@testset "complex input" begin
65+
z = rand(Complex{Float64})
66+
f, myabs2_pushforward = frule(myabs2, z)
67+
@test f === abs2(z)
6068

61-
# complex input
62-
z = rand(Complex{Float64})
63-
f, pushforward = frule(myabs2, z)
64-
@test f === abs2(z)
69+
df = @inferred myabs2_pushforward(NamedTuple(), One())
70+
@test df === (Wirtinger(z', z),)
71+
72+
Δ = rand(Complex{Int64})
73+
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
74+
@test df === (Wirtinger* z', Δ * z),)
75+
end
76+
end
6577

66-
df = @inferred pushforward(NamedTuple(), One())
67-
@test df === (Wirtinger(z', z),)
6878

69-
Δ = rand(Complex{Int64})
70-
df = @inferred pushforward(NamedTuple(), Δ)
71-
@test df === (Wirtinger* z', Δ * z),)
79+
@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin
80+
# This is based on SimeonSchaub excellent example:
81+
# https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97
82+
83+
# This is much more complex than the previous case
84+
# as it has many different types
85+
# depending on input, and the output types do not always agree
86+
87+
abs_to_pow(x, p) = abs(x)^p
88+
@scalar_rule(
89+
abs_to_pow(x::Real, p),
90+
(
91+
p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x),
92+
Ω * log(abs(x))
93+
)
94+
)
95+
96+
@scalar_rule(
97+
abs_to_pow(x::Complex, p),
98+
@setup(u = abs(x)),
99+
(
100+
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
101+
Ω * log(abs(x))
102+
)
103+
)
104+
105+
106+
f = abs_to_pow
107+
@testset "f($x, $p)" for (x, p) in Iterators.product(
108+
(2, 3.4, -2.1, -10+0im, 2.3-2im),
109+
(0, 1, 2, 4.3, -2.1, 1+.2im)
110+
)
111+
expected_type_df_dx =
112+
if iszero(p)
113+
Zero
114+
elseif typeof(x) <: Complex
115+
Wirtinger
116+
elseif typeof(p) <: Complex
117+
Complex
118+
else
119+
Real
120+
end
121+
122+
expected_type_df_dp =
123+
if typeof(p) <: Real
124+
Real
125+
else
126+
Complex
127+
end
128+
129+
130+
res = frule(f, x, p)
131+
@test res !== nothing # Check the rule was defined
132+
fx, f_pushforward = res
133+
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
134+
135+
df_dx, = df(One(), Zero())
136+
df_dp,= df(Zero(), One())
137+
@test fx == f(x, p) # Check we still get the normal value, right
138+
@test extern(df_dx) isa expected_type_df_dx
139+
@test extern(df_dp) isa expected_type_df_dp
140+
141+
142+
res = rrule(f, x, p)
143+
@test res !== nothing # Check the rule was defined
144+
fx, f_pullback = res
145+
dself, df_dx, df_dp = f_pullback(One())
146+
@test fx == f(x, p) # Check we still get the normal value, right
147+
@test dself == NO_FIELDS
148+
@test extern(df_dx) isa expected_type_df_dx
149+
@test extern(df_dp) isa expected_type_df_dp
150+
end
72151
end

0 commit comments

Comments
 (0)