Skip to content

Commit 41f1071

Browse files
committed
fix tests accordingly
1 parent 61a60ec commit 41f1071

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

test/rules.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,32 @@ end
4646

4747
@testset "real input" begin
4848
# even though our rule was define in terms of Wirtinger,
49-
# pushforward result will be real as real (even if seed is Compex)
49+
# pushforward result will be real as real (even if seed is Complex)
5050

51-
x = rand(Float64)
51+
x = 5.0
5252
f, myabs2_pushforward = frule(myabs2, x)
5353
@test f === x^2
5454

5555
Δ = One()
5656
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
5757
@test df === x + x
5858

59-
Δ = rand(Complex{Int64})
59+
Δ = 2.0 + 3.0im
6060
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
61-
@test df === Δ * (x + x)
61+
@test df === + conj(Δ)) * x
6262
end
6363

6464
@testset "complex input" begin
65-
z = rand(Complex{Float64})
65+
z = 5.0 + 7.0im
6666
f, myabs2_pushforward = frule(myabs2, z)
6767
@test f === abs2(z)
6868

6969
df = @inferred myabs2_pushforward(NamedTuple(), One())
7070
@test df === Wirtinger(z', z)
7171

72-
Δ = rand(Complex{Int64})
72+
Δ = 2.0 + 3.0im
7373
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
74-
@test df === Wirtinger* z', Δ * z)
74+
@test df === Wirtinger* conj(z), conj(Δ) * z)
7575
end
7676
end
7777

@@ -134,11 +134,11 @@ end
134134
fx, f_pushforward = res
135135
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
136136

137-
df_dx::Thunk = df(One(), Zero())
138-
df_dp::Thunk = df(Zero(), One())
137+
df_dx = df(One(), Zero())
138+
df_dp = df(Zero(), One())
139139
@test fx == f(x, p) # Check we still get the normal value, right
140-
@test df_dx() isa expected_type_df_dx
141-
@test df_dp() isa expected_type_df_dp
140+
@test df_dx isa expected_type_df_dx
141+
@test df_dp isa expected_type_df_dp
142142

143143

144144
res = rrule(f, x, p)
@@ -147,7 +147,7 @@ end
147147
dself, df_dx, df_dp = f_pullback(One())
148148
@test fx == f(x, p) # Check we still get the normal value, right
149149
@test dself == NO_FIELDS
150-
@test df_dx() isa expected_type_df_dx
151-
@test df_dp() isa expected_type_df_dp
150+
@test df_dx isa expected_type_df_dx
151+
@test df_dp isa expected_type_df_dp
152152
end
153153
end

0 commit comments

Comments
 (0)