|
46 | 46 |
|
47 | 47 | @testset "real input" begin
|
48 | 48 | # even though our rule was define in terms of Wirtinger,
|
49 |
| - # pushforward result will be real as real (even if seed is Compex) |
| 49 | + # pushforward result will be real as real (even if seed is Complex) |
50 | 50 |
|
51 |
| - x = rand(Float64) |
| 51 | + x = 5.0 |
52 | 52 | f, myabs2_pushforward = frule(myabs2, x)
|
53 | 53 | @test f === x^2
|
54 | 54 |
|
55 | 55 | Δ = One()
|
56 | 56 | df = @inferred myabs2_pushforward(NamedTuple(), Δ)
|
57 | 57 | @test df === x + x
|
58 | 58 |
|
59 |
| - Δ = rand(Complex{Int64}) |
| 59 | + Δ = 2.0 + 3.0im |
60 | 60 | df = @inferred myabs2_pushforward(NamedTuple(), Δ)
|
61 |
| - @test df === Δ * (x + x) |
| 61 | + @test df === (Δ + conj(Δ)) * x |
62 | 62 | end
|
63 | 63 |
|
64 | 64 | @testset "complex input" begin
|
65 |
| - z = rand(Complex{Float64}) |
| 65 | + z = 5.0 + 7.0im |
66 | 66 | f, myabs2_pushforward = frule(myabs2, z)
|
67 | 67 | @test f === abs2(z)
|
68 | 68 |
|
69 | 69 | df = @inferred myabs2_pushforward(NamedTuple(), One())
|
70 | 70 | @test df === Wirtinger(z', z)
|
71 | 71 |
|
72 |
| - Δ = rand(Complex{Int64}) |
| 72 | + Δ = 2.0 + 3.0im |
73 | 73 | df = @inferred myabs2_pushforward(NamedTuple(), Δ)
|
74 |
| - @test df === Wirtinger(Δ * z', Δ * z) |
| 74 | + @test df === Wirtinger(Δ * conj(z), conj(Δ) * z) |
75 | 75 | end
|
76 | 76 | end
|
77 | 77 |
|
@@ -134,11 +134,11 @@ end
|
134 | 134 | fx, f_pushforward = res
|
135 | 135 | df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
|
136 | 136 |
|
137 |
| - df_dx::Thunk = df(One(), Zero()) |
138 |
| - df_dp::Thunk = df(Zero(), One()) |
| 137 | + df_dx = df(One(), Zero()) |
| 138 | + df_dp = df(Zero(), One()) |
139 | 139 | @test fx == f(x, p) # Check we still get the normal value, right
|
140 |
| - @test df_dx() isa expected_type_df_dx |
141 |
| - @test df_dp() isa expected_type_df_dp |
| 140 | + @test df_dx isa expected_type_df_dx |
| 141 | + @test df_dp isa expected_type_df_dp |
142 | 142 |
|
143 | 143 |
|
144 | 144 | res = rrule(f, x, p)
|
|
147 | 147 | dself, df_dx, df_dp = f_pullback(One())
|
148 | 148 | @test fx == f(x, p) # Check we still get the normal value, right
|
149 | 149 | @test dself == NO_FIELDS
|
150 |
| - @test df_dx() isa expected_type_df_dx |
151 |
| - @test df_dp() isa expected_type_df_dp |
| 150 | + @test df_dx isa expected_type_df_dx |
| 151 | + @test df_dp isa expected_type_df_dp |
152 | 152 | end
|
153 | 153 | end
|
0 commit comments