Skip to content

Commit 29b0839

Browse files
committed
make real tests
1 parent 8405b68 commit 29b0839

File tree

1 file changed

+64
-65
lines changed

1 file changed

+64
-65
lines changed

test/rulesets/Base/fastmath_able.jl

Lines changed: 64 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -177,77 +177,76 @@ const FASTABLE_AST = quote
177177
# test_rrule(^, randn(T) + 3, p ⊢ NoTangent())
178178
# end
179179

180-
# Tests for power functions, at values near to zero.
181-
182-
POWERGRADS = [ # (x,p) => (dx,dp)
183-
# Some regular points, sanity checks
184-
(1.0, 2) => (2.0, 0.0),
185-
(2.0, 2) => (4.0, 2.772588722239781),
186-
# At x=0, gradients for x seem clear,
187-
# for p I've just written here what it gives
188-
(0.0, 2) => (0.0, NaN),
189-
(-0.0, 2) => (-0.0, NaN),
190-
(0.0, 1) => (1.0, NaN), # or zero?
191-
(-0.0, 1) => (1.0, NaN),
192-
(0.0, 0) => (0.0, -Inf),
193-
(-0.0, 0) => (0.0, -Inf),
194-
(0.0, -1) => (-Inf, -Inf),
195-
(-0.0, -1) => (-Inf, Inf),
196-
(0.0, -2) => (-Inf, -Inf),
197-
(-0.0, -2) => (Inf, -Inf),
198-
# Non-integer powers:
199-
(0.0, 0.5) => (Inf, NaN),
200-
(0.0, 3.5) => (0.0, NaN),
201-
]
202-
203-
for ((x,p), (gx, gp)) in POWERGRADS # power ^
204-
y = x^p
205-
206-
# Forward
207-
y_f = frule((1,1,1), ^, x, p)[1]
208-
isequal(y, y_f) || println("^ forward value for $x^$p: got $y_f, expected $y")
209-
210-
gx_f = frule((0,1,0), ^, x, p)[1]
211-
gp_f = frule((0,0,1), ^, x, p)[2]
212-
# isequal(gx, gx_f) || println("^ forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe")
213-
# isequal(gp, gp_f) || println("^ forward `p` gradient for $x^$p: got $gp_f, expected $gp, maybe")
214-
215-
# Reverse
216-
y_r = rrule(^, x, p)[1]
217-
isequal(y, y_r) || println("^ reverse value for $x^$p: got $y_r, expected $y")
218-
219-
gx_r, gp_r = unthunk.(rrule(^, x, p)[2](1))[2:3]
220-
if x === -0.0 && p === 2
221-
@test 0.0 == gx_r # POWERGRADS says -0.0
222-
else
223-
isequal(gx, gx_r) || println("^ reverse `x` gradient for $x^$p: got $gx_r, expected $gx")
224-
end
225-
isequal(gp, gp_r) || println("^ reverse `p` gradient for $x^$p: got $gp_r, expected $gp")
226-
end
227-
228-
for ((x,p), (gx, gp)) in POWERGRADS # literal_pow
229-
p isa Int || continue
230-
x isa Real || continue
231-
232-
y = x^p
180+
# Tests for power functions, at values near to zero.
181+
182+
POWERGRADS = [ # (x,p) => (dx,dp)
183+
# Some regular points, as sanity checks:
184+
(1.0, 2) => (2.0, 0.0),
185+
(2.0, 2) => (4.0, 2.772588722239781),
186+
# At x=0, gradients for x seem clear,
187+
# for p less certain but I think 0 or NaN right?
188+
(0.0, 2) => (0.0, 0.0),
189+
(-0.0, 2) => (-0.0, 0.0),
190+
(0.0, 1) => (1.0, 0.0),
191+
(-0.0, 1) => (1.0, 0.0),
192+
(0.0, 0) => (0.0, NaN),
193+
(-0.0, 0) => (0.0, NaN),
194+
(0.0, -1) => (-Inf, NaN),
195+
(-0.0, -1) => (-Inf, NaN),
196+
(0.0, -2) => (-Inf, NaN),
197+
(-0.0, -2) => (Inf, NaN),
198+
# Non-integer powers:
199+
(0.0, 0.5) => (Inf, 0.0),
200+
(0.0, 3.5) => (0.0, 0.0),
201+
(0.0, -1.5) => (-Inf, NaN),
202+
]
203+
204+
@testset "$x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
205+
y = x^p
206+
207+
# Forward
208+
y_f = frule((1,1,1), ^, x, p)[1]
209+
@test isequal(y, y_f) # || println("^ forward value for $x^$p: got $y_f, expected $y")
210+
211+
∂x_fwd = frule((0,1,0), ^, x, p)[1]
212+
∂p_fwd = frule((0,0,1), ^, x, p)[2]
213+
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!")
214+
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe")
215+
216+
# Reverse
217+
y_r = rrule(^, x, p)[1]
218+
@test isequal(y, y_r) # || println("^ reverse value for $x^$p: got $y_r, expected $y")
219+
220+
∂x_rev, ∂p_rev = unthunk.(rrule(^, x, p)[2](1))[2:3]
221+
if ∂x === -0.0 # happens at at x === -0.0 && p === 2, ignore the sign
222+
@test 0.0 == ∂x_rev
223+
else
224+
@test isequal(∂x, ∂x_rev) # || println("^ reverse `x` gradient for $x^$p: got $∂x_rev, expected $∂x")
225+
end
226+
@test isequal(∂p, ∂p_rev) # || println("^ reverse `p` gradient for $x^$p: got $∂p_rev, expected $∂p")
227+
end
233228

234-
# Forward
235-
y_f = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
236-
isequal(y, y_f) || println("literal_pow forward value for $x^$p: got $y_f, expected $y")
229+
@testset "literal_pow $x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
230+
# p isa Int || continue
231+
# x isa Real || continue
237232

238-
gx_f = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
239-
# isequal(gx, gx_f) || println("literal_pow forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe, y=$y")
233+
y = x^p
240234

241-
# Reverse
242-
y_r = rrule(Base.literal_pow, ^, x, Val(p))[1]
243-
isequal(y, y_r) || println("literal_pow reverse value for $x^$p: got $y_r, expected $y")
235+
# Forward
236+
y_f = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
237+
@test isequal(y, y_f) # || println("literal_pow forward value for $x^$p: got $y_f, expected $y")
244238

245-
gx_r = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
246-
isequal(gx, gx_r) || println("literal_pow `x` gradient for $x^$p: got $gx_r, expected $gx")
239+
∂x_fwd = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
240+
# isequal(∂x, ∂x_fwd) || println("literal_pow forward `x` gradient for $x^$p: got $∂x_fwd, expected $∂x, maybe, y=$y")
247241

248-
# @info "all" x y p gx_f gx_r
249-
end
242+
# Reverse
243+
y_r = rrule(Base.literal_pow, ^, x, Val(p))[1]
244+
@test isequal(y, y_r) # || println("literal_pow reverse value for $x^$p: got $y_r, expected $y")
250245

246+
∂x_rev = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
247+
@test isequal(∂x, ∂x_rev) # || println("literal_pow `x` gradient for $x^$p: got $∂x_rev, expected $∂x")
248+
end
249+
end
251250

252251
@testset "sign" begin
253252
@testset "real" begin

0 commit comments

Comments
 (0)