Skip to content

Commit e0384be

Browse files
committed
Address code review comments
1 parent 34a1bbe commit e0384be

File tree

4 files changed

+35
-31
lines changed

4 files changed

+35
-31
lines changed

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737

3838
function frule(::typeof(BLAS.nrm2), x, _, Δ)
3939
Ω = BLAS.nrm2(x)
40-
return Ω, sum(Δx * cast(@thunk(x * inv))))
40+
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
4141
end
4242

4343
function rrule(::typeof(BLAS.nrm2), x)
@@ -68,12 +68,15 @@ end
6868
#####
6969

7070
function frule(::typeof(BLAS.asum), x, _, Δx)
71-
return BLAS.asum(x), sum(cast(sign, x) * Δx)
71+
return BLAS.asum(x), sum(zip(x, Δx)) do xs
72+
x, Δx = xs
73+
return sign(x) * Δx
74+
end
7275
end
7376

7477
function rrule(::typeof(BLAS.asum), x)
7578
function asum_pullback(ΔΩ)
76-
return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x)))
79+
return (NO_FIELDS, @thunk(ΔΩ * sign.(x)))
7780
end
7881
return BLAS.asum(x), asum_pullback
7982
end

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@ end
2525

2626
function frule(::typeof(inv), x::AbstractArray, _, Δx)
2727
Ω = inv(x)
28-
m = @thunk(-Ω)
29-
return Ω, m * Δx * Ω
28+
return Ω, -Ω * Δx * Ω
3029
end
3130

3231
function rrule(::typeof(inv), x::AbstractArray)
3332
Ω = inv(x)
34-
m = @thunk(-Ω')
3533
function inv_pullback(ΔΩ)
36-
return NO_FIELDS, m * ΔΩ * Ω'
34+
return NO_FIELDS, -Ω' * ΔΩ * Ω'
3735
end
3836
return Ω, inv_pullback
3937
end

test/rulesets/Base/base.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,6 @@
5252
test_scalar(acotd, 1/x)
5353
end
5454
@testset "Multivariate" begin
55-
@testset "atan2" begin
56-
# https://en.wikipedia.org/wiki/Atan2
57-
x, y = rand(2)
58-
ratan = atan(x, y)
59-
u = x^2 + y^2
60-
datan = y/u - 2x/u
61-
62-
r, ṙ = frule(atan, x, y, Zero(), 1, 2)
63-
@test r === ratan
64-
@test=== datan
65-
66-
r, pullback = rrule(atan, x, y)
67-
@test r === ratan
68-
dself, df1, df2 = pullback(1)
69-
@test dself == NO_FIELDS
70-
@test df1 + 2df2 === datan
71-
end
72-
7355
@testset "sincos" begin
7456
x, Δx, x̄ = randn(3)
7557
Δz = (randn(), randn())
@@ -91,11 +73,30 @@
9173
test_scalar(exp2, x)
9274
test_scalar(exp10, x)
9375

76+
test_scalar(cbrt, x)
77+
78+
if x >= 0
79+
test_scalar(sqrt, x)
80+
test_scalar(log, x)
81+
test_scalar(log2, x)
82+
test_scalar(log10, x)
83+
test_scalar(log1p, x)
84+
end
85+
end
86+
end
87+
88+
@testset "Unary complex functions" begin
89+
for x in (-4.1, 6.4)
90+
test_scalar(real, x)
91+
test_scalar(imag, x)
92+
93+
test_scalar(abs, x)
94+
test_scalar(hypot, x)
95+
96+
test_scalar(angle, x)
97+
test_scalar(abs2, x)
9498
test_scalar(conj, x)
9599
test_scalar(adjoint, x)
96-
test_scalar(abs2, x)
97-
98-
x isa Real && test_scalar(cbrt, x)
99100
end
100101
end
101102

test/test_util.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
2626
@test r_res !== f_res !== nothing # Check the rule was defined
2727
r_fx, prop_rule = r_res
2828
f_fx, f_∂x = f_res
29-
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ((rrule, r_fx, prop_rule(1)), (frule, f_fx, f_∂x))
29+
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
30+
(rrule, r_fx, prop_rule(1)),
31+
(frule, f_fx, f_∂x)
32+
)
3033
@test fx == f(x) # Check we still get the normal value, right
3134

3235
if rule == rrule
3336
∂self, ∂x = ∂x
3437
@test ∂self === NO_FIELDS
3538
end
36-
@test isapprox(∂x, fdm(f, x);
37-
rtol=rtol, atol=atol, kwargs...)
39+
@test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...)
3840
end
3941
end
4042

0 commit comments

Comments
 (0)