@@ -29,44 +29,123 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
29
29
Tuple{typeof (rrule),typeof (cool),String}])
30
30
@test cool_methods == only_methods
31
31
32
- frx, fr = frule (cool, 1 )
32
+ frx, cool_pushforward = frule (cool, 1 )
33
33
@test frx == 2
34
- @test fr (NamedTuple (), 1 ) == (1 ,)
35
- rrx, (rr) = rrule (cool, 1 )
36
- self, rr1 = rr (1 )
34
+ @test cool_pushforward (NamedTuple (), 1 ) == (1 ,)
35
+ rrx, cool_pullback = rrule (cool, 1 )
36
+ self, rr1 = cool_pullback (1 )
37
37
@test self == NO_FIELDS
38
38
@test rrx == 2
39
39
@test rr1 == 1
40
40
end
41
41
42
42
43
- @testset " Wirtinger scalar_rule" begin
43
+ @testset " Basic Wirtinger scalar_rule" begin
44
44
myabs2 (x) = abs2 (x)
45
45
@scalar_rule (myabs2 (x), Wirtinger (x' , x))
46
46
47
- # real input
48
- x = rand (Float64)
49
- f, pushforward = frule (myabs2, x)
50
- @test f === x^ 2
47
+ @testset " real input" begin
48
+ # even though our rule was define in terms of Wirtinger,
49
+ # pushforward result will be real as real (even if seed is Compex)
51
50
52
- df = @inferred pushforward (NamedTuple (), One ())
53
- @test df === (x + x,)
51
+ x = rand (Float64)
52
+ f, myabs2_pushforward = frule (myabs2, x)
53
+ @test f === x^ 2
54
54
55
+ Δ = One ()
56
+ df = @inferred myabs2_pushforward (NamedTuple (), Δ)
57
+ @test df === (x + x,)
55
58
56
- Δ = rand (Complex{Int64})
57
- df = @inferred pushforward (NamedTuple (), Δ)
58
- @test df === (Δ * (x + x),)
59
+ Δ = rand (Complex{Int64})
60
+ df = @inferred myabs2_pushforward (NamedTuple (), Δ)
61
+ @test df === (Δ * (x + x),)
62
+ end
59
63
64
+ @testset " complex input" begin
65
+ z = rand (Complex{Float64})
66
+ f, myabs2_pushforward = frule (myabs2, z)
67
+ @test f === abs2 (z)
60
68
61
- # complex input
62
- z = rand (Complex{Float64})
63
- f, pushforward = frule (myabs2, z)
64
- @test f === abs2 (z)
69
+ df = @inferred myabs2_pushforward (NamedTuple (), One ())
70
+ @test df === (Wirtinger (z' , z),)
71
+
72
+ Δ = rand (Complex{Int64})
73
+ df = @inferred myabs2_pushforward (NamedTuple (), Δ)
74
+ @test df === (Wirtinger (Δ * z' , Δ * z),)
75
+ end
76
+ end
65
77
66
- df = @inferred pushforward (NamedTuple (), One ())
67
- @test df === (Wirtinger (z' , z),)
68
78
69
- Δ = rand (Complex{Int64})
70
- df = @inferred pushforward (NamedTuple (), Δ)
71
- @test df === (Wirtinger (Δ * z' , Δ * z),)
79
+ @testset " Advanced Wirtinger @scalar_rule: abs_to_pow" begin
80
+ # This is based on SimeonSchaub excellent example:
81
+ # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97
82
+
83
+ # This is much more complex than the previous case
84
+ # as it has many different types
85
+ # depending on input, and the output types do not always agree
86
+
87
+ abs_to_pow (x, p) = abs (x)^ p
88
+ @scalar_rule (
89
+ abs_to_pow (x:: Real , p),
90
+ (
91
+ p == 0 ? Zero () : p * abs_to_pow (x, p- 1 ) * sign (x),
92
+ Ω * log (abs (x))
93
+ )
94
+ )
95
+
96
+ @scalar_rule (
97
+ abs_to_pow (x:: Complex , p),
98
+ @setup (u = abs (x)),
99
+ (
100
+ p == 0 ? Zero () : p * u^ (p- 1 ) * Wirtinger (x' / 2 u, x / 2 u),
101
+ Ω * log (abs (x))
102
+ )
103
+ )
104
+
105
+
106
+ f = abs_to_pow
107
+ @testset " f($x , $p )" for (x, p) in Iterators. product (
108
+ (2 , 3.4 , - 2.1 , - 10 + 0im , 2.3 - 2im ),
109
+ (0 , 1 , 2 , 4.3 , - 2.1 , 1 + .2im )
110
+ )
111
+ expected_type_df_dx =
112
+ if iszero (p)
113
+ Zero
114
+ elseif typeof (x) <: Complex
115
+ Wirtinger
116
+ elseif typeof (p) <: Complex
117
+ Complex
118
+ else
119
+ Real
120
+ end
121
+
122
+ expected_type_df_dp =
123
+ if typeof (p) <: Real
124
+ Real
125
+ else
126
+ Complex
127
+ end
128
+
129
+
130
+ res = frule (f, x, p)
131
+ @test res != = nothing # Check the rule was defined
132
+ fx, f_pushforward = res
133
+ df (Δx, Δp) = f_pushforward (NamedTuple (), Δx, Δp)
134
+
135
+ df_dx, = df (One (), Zero ())
136
+ df_dp,= df (Zero (), One ())
137
+ @test fx == f (x, p) # Check we still get the normal value, right
138
+ @test extern (df_dx) isa expected_type_df_dx
139
+ @test extern (df_dp) isa expected_type_df_dp
140
+
141
+
142
+ res = rrule (f, x, p)
143
+ @test res != = nothing # Check the rule was defined
144
+ fx, f_pullback = res
145
+ dself, df_dx, df_dp = f_pullback (One ())
146
+ @test fx == f (x, p) # Check we still get the normal value, right
147
+ @test dself == NO_FIELDS
148
+ @test extern (df_dx) isa expected_type_df_dx
149
+ @test extern (df_dp) isa expected_type_df_dp
150
+ end
72
151
end
0 commit comments