Skip to content

Commit 61d4a62

Browse files
author
Michael Abbott
committed
widen type of norm pullback, add inplace + tests
1 parent ebc99f7 commit 61d4a62

File tree

2 files changed

+76
-42
lines changed

2 files changed

+76
-42
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
rrule(f::Function, args...; kwargs...) = (@error "no rrule defined!" f args ; nothing)
3+
14
#####
25
##### `norm`
36
#####
@@ -17,20 +20,25 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
1720
return y, ∂y
1821
end
1922

20-
function rrule(
21-
::typeof(norm),
22-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
23-
p::Real,
24-
)
23+
function rrule(::typeof(norm), x::AbstractArray, p::Real)
2524
y = LinearAlgebra.norm(x, p)
26-
function norm_pullback(Δy)
25+
function norm_pullback_p(Δy)
2726
∂x = Thunk() do
2827
return if isempty(x) || p == 0
29-
zero.(x) .* (zero(y) * zero(real(Δy)))
28+
InplaceableThunk(
29+
@thunk(zero.(x) .* (zero(y) * zero(real(Δy)))),
30+
dx -> dx .= zero(eltype(dx)),
31+
)
3032
elseif p == 2
31-
_norm2_back(x, y, Δy)
33+
InplaceableThunk(
34+
@thunk(_norm2_back(x, y, Δy)),
35+
dx -> _norm2_back!(dx, x, y, Δy),
36+
)
3237
elseif p == 1
33-
_norm1_back(x, y, Δy)
38+
InplaceableThunk(
39+
@thunk(_norm1_back(x, y, Δy)),
40+
dx -> _norm1_back!(dx, x, y, Δy),
41+
)
3442
elseif p == Inf
3543
_normInf_back(x, y, Δy)
3644
elseif p == -Inf
@@ -42,24 +50,24 @@ function rrule(
4250
∂p = @thunk _normp_back_p(x, p, y, Δy)
4351
return (NO_FIELDS, ∂x, ∂p)
4452
end
45-
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
46-
return y, norm_pullback
53+
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
54+
return y, norm_pullback_p
4755
end
48-
function rrule(
49-
::typeof(norm),
50-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
51-
)
56+
function rrule(::typeof(norm), x::AbstractArray)
5257
y = LinearAlgebra.norm(x)
53-
function norm_pullback(Δy)
58+
function norm_pullback_2(Δy)
5459
∂x = if isempty(x)
5560
zero.(x) .* (zero(y) * zero(real(Δy)))
5661
else
57-
_norm2_back(x, y, Δy)
62+
InplaceableThunk(
63+
@thunk(_norm2_back(x, y, Δy)),
64+
dx -> _norm2_back!(dx, x, y, Δy),
65+
)
5866
end
5967
return (NO_FIELDS, ∂x)
6068
end
61-
norm_pullback(::Zero) = (NO_FIELDS, Zero())
62-
return y, norm_pullback
69+
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
70+
return y, norm_pullback_2
6371
end
6472
function rrule(
6573
::typeof(norm),
@@ -94,11 +102,7 @@ end
94102
##### `normp`
95103
#####
96104

97-
function rrule(
98-
::typeof(LinearAlgebra.normp),
99-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
100-
p,
101-
)
105+
function rrule(::typeof(LinearAlgebra.normp),x::AbstractArray, p)
102106
y = LinearAlgebra.normp(x, p)
103107
function normp_pullback(Δy)
104108
∂x = @thunk _normp_back_x(x, p, y, Δy)
@@ -135,20 +139,14 @@ end
135139
##### `normMinusInf`/`normInf`
136140
#####
137141

138-
function rrule(
139-
::typeof(LinearAlgebra.normMinusInf),
140-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
141-
)
142+
function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray)
142143
y = LinearAlgebra.normMinusInf(x)
143144
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
144145
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
145146
return y, normMinusInf_pullback
146147
end
147148

148-
function rrule(
149-
::typeof(LinearAlgebra.normInf),
150-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
151-
)
149+
function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray)
152150
y = LinearAlgebra.normInf(x)
153151
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
154152
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
@@ -172,12 +170,12 @@ end
172170
##### `norm1`
173171
#####
174172

175-
function rrule(
176-
::typeof(LinearAlgebra.norm1),
177-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
178-
)
173+
function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
179174
y = LinearAlgebra.norm1(x)
180-
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
175+
norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
176+
@thunk(_norm1_back(x, y, Δy)),
177+
dx -> _norm1_back!(dx, x, y, Δy),
178+
))
181179
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
182180
return y, norm1_pullback
183181
end
@@ -187,6 +185,10 @@ function _norm1_back(x, y, Δy)
187185
∂x .= sign.(x) .* real(Δy)
188186
return ∂x
189187
end
188+
function _norm1_back!(∂x, x, y, Δy)
189+
∂x .+= sign.(x) .* real(Δy)
190+
return ∂x
191+
end
190192

191193
#####
192194
##### `norm2`
@@ -197,12 +199,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
197199
return y, _norm2_forward(x, Δx, y)
198200
end
199201

200-
function rrule(
201-
::typeof(LinearAlgebra.norm2),
202-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
203-
)
202+
function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray)
204203
y = LinearAlgebra.norm2(x)
205-
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
204+
norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
205+
@thunk(_norm2_back(x, y, Δy)),
206+
dx -> _norm2_back!(dx, x, y, Δy),
207+
))
206208
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
207209
return y, norm2_pullback
208210
end
@@ -216,6 +218,10 @@ function _norm2_back(x, y, Δy)
216218
∂x .= x .* (real(Δy) * pinv(y))
217219
return ∂x
218220
end
221+
function _norm2_back!(∂x, x, y, Δy)
222+
∂x .+= x .* (real(Δy) * pinv(y))
223+
return ∂x # must return after mutating
224+
end
219225

220226
#####
221227
##### `normalize`

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@
3939
@test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) zero(x)
4040
@test rrule(fnorm, x)[2](Zero())[2] isa Zero
4141
end
42+
ndims(x) > 1 && @testset "non-strided" begin
43+
xp = if x isa Matrix
44+
view(x, [1,2,3], 1:3)
45+
elseif x isa Array{T,3}
46+
PermutedDimsArray(x, (1,2,3))
47+
end
48+
@test !(xp isa StridedArray)
49+
y = fnorm(x)
50+
# ẋ = rand(T, size(xp)) # rand_tangent(xp)
51+
= rand(T, size(xp)) # rand_tangent(xp)
52+
= rand_tangent(y)
53+
# frule_test(fnorm, (xp, ẋ))
54+
rrule_test(fnorm, ȳ, (xp, x̄))
55+
end
4256
end
4357
@testset "norm(x::Array{$T,$(length(sz))})" for
4458
T in (Float64, ComplexF64),
@@ -63,6 +77,20 @@
6377
@test extern(rrule(norm, zero(x))[2](ȳ)[2]) zero(x)
6478
@test rrule(norm, x)[2](Zero())[2] isa Zero
6579
end
80+
ndims(x) > 1 && @testset "non-strided" begin
81+
xp = if x isa Matrix
82+
view(x, [1,2,3], 1:3)
83+
elseif x isa Array{T,3}
84+
PermutedDimsArray(x, (1,2,3))
85+
end
86+
@test !(xp isa StridedArray)
87+
y = norm(x)
88+
= rand(T, size(xp)) # rand_tangent(xp)
89+
= rand(T, size(xp)) # rand_tangent(xp)
90+
= rand_tangent(y)
91+
frule_test(norm, (xp, ẋ))
92+
rrule_test(norm, ȳ, (xp, x̄))
93+
end
6694
end
6795
@testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for
6896
fnorm in (norm, LinearAlgebra.normp),

0 commit comments

Comments
 (0)