Skip to content

Commit 81f85b7

Browse files
committed
make tests pass
1 parent e400bd5 commit 81f85b7

File tree

2 files changed

+33
-35
lines changed

2 files changed

+33
-35
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
66
y = norm(x)
77
return y, _norm2_forward(x, Δx, norm(x))
88
end
9+
910
function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
1011
y = norm(x, p)
1112
∂y = if iszero(Δx) || iszero(p)
@@ -20,9 +21,9 @@ end
2021
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
2122
y = LinearAlgebra.norm(x, p)
2223
function norm_pullback_p(Δy)
23-
# ∂x = InplaceableThunk(
24+
∂x = InplaceableThunk(
2425
# out-of-place versions
25-
∂x = @thunk(if isempty(x) || p == 0
26+
@thunk(if isempty(x) || p == 0
2627
zero.(x) .* (zero(y) * zero(real(Δy)))
2728
elseif p == 2
2829
_norm2_back(x, y, Δy)
@@ -35,48 +36,50 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
3536
else
3637
_normp_back_x(x, p, y, Δy)
3738
end)
38-
# , # in-place versions -- can be fixed when actually useful?
39-
# dx -> if isempty(x) || p == 0
40-
# dx
41-
# elseif p == 2
42-
# _norm2_back!(dx, x, y, Δy)
43-
# elseif p == 1
44-
# _norm1_back!(dx, x, y, Δy)
45-
# elseif p == Inf
46-
# dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
47-
# elseif p == -Inf
48-
# dx .+= _normInf_back(x, y, Δy)
49-
# else
50-
# dx .+= _normp_back_x(x, p, y, Δy)
51-
# end
52-
# )
39+
, # in-place versions -- can be fixed when actually useful?
40+
dx -> if isempty(x) || p == 0
41+
dx
42+
elseif p == 2
43+
_norm2_back!(dx, x, y, Δy)
44+
elseif p == 1
45+
_norm1_back!(dx, x, y, Δy)
46+
elseif p == Inf
47+
dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
48+
elseif p == -Inf
49+
dx .+= _normInf_back(x, y, Δy)
50+
else
51+
dx .+= _normp_back_x(x, p, y, Δy)
52+
end
53+
)
5354
∂p = @thunk _normp_back_p(x, p, y, Δy)
5455
return (NO_FIELDS, ∂x, ∂p)
5556
end
5657
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
5758
return y, norm_pullback_p
5859
end
60+
5961
function rrule(::typeof(norm), x::AbstractArray{<:Number})
6062
y = LinearAlgebra.norm(x)
6163
function norm_pullback_2(Δy)
62-
# ∂x = InplaceableThunk(
63-
∂x = @thunk(if isempty(x)
64+
∂x = InplaceableThunk(
65+
@thunk(if isempty(x)
6466
zero.(x) .* (zero(y) * zero(real(Δy)))
6567
else
6668
_norm2_back(x, y, Δy)
6769
end)
68-
# ,
69-
# dx -> if isempty(x)
70-
# dx
71-
# else
72-
# _norm2_back!(dx, x, y, Δy)
73-
# end
74-
# )
70+
,
71+
dx -> if isempty(x)
72+
dx
73+
else
74+
_norm2_back!(dx, x, y, Δy)
75+
end
76+
)
7577
return (NO_FIELDS, ∂x)
7678
end
7779
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
7880
return y, norm_pullback_2
7981
end
82+
8083
function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real)
8184
y, inner_pullback = rrule(norm, parent(x), p)
8285
function norm_pullback(Δy)
@@ -87,6 +90,7 @@ function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::R
8790
end
8891
return y, norm_pullback
8992
end
93+
9094
function rrule(::typeof(norm), x::Number, p::Real)
9195
y = norm(x, p)
9296
function norm_pullback(Δy)
@@ -126,6 +130,7 @@ function _normp_back_x(x, p, y, Δy)
126130
end
127131
return ∂x
128132
end
133+
129134
function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
130135
c = real(Δy) / y
131136
∂x_data = map(parent(x)) do xi
@@ -261,6 +266,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
261266
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
262267
return y, normalize_pullback
263268
end
269+
264270
function rrule(::typeof(normalize), x::AbstractVector{<:Number})
265271
nrm = LinearAlgebra.norm2(x)
266272
Ty = typeof(first(x) / nrm)

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
@eval ChainRulesTestUtils check_thunking_is_appropriate(_) = nothing
2-
31
@testset "norm functions" begin
42
@testset "$fnorm(x::Array{$T,$(length(sz))})" for
53
fnorm in (
@@ -80,8 +78,6 @@ println("starting exported norm T=$T, sz=$sz")
8078
@testset "rrule" begin
8179
test_rrule(norm, x)
8280
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
83-
# we don't check inference on older julia versions. Improvements to
84-
# inference mean on 1.5+ it works, and that is good enough
8581
test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5")
8682
end
8783

@@ -130,13 +126,9 @@ println("starting p-norm p=$p, T=$T, sz=$sz")
130126
kwargs = NamedTuple()
131127
end
132128

133-
134129
test_rrule(fnorm, x, p; kwargs...)
135130
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
136-
test_rrule(fnorm, MT(x), p;
137-
#Don't check inference on old julia, what matters is that works on new
138-
check_inferred=VERSION>=v"1.5", kwargs...
139-
)
131+
test_rrule(fnorm, MT(x), p; kwargs..., check_inferred=VERSION>=v"1.5")
140132
end
141133

142134
= rand_tangent(fnorm(x, p))

0 commit comments

Comments
 (0)