Skip to content

Commit 0211fad

Browse files
Michael Abbottmcabbott
authored andcommitted
allow, and test, integer x
1 parent 744c6b9 commit 0211fad

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ end
108108

109109
function _normp_back_x(x, p, y, Δy)
110110
c = real(Δy) / y
111-
∂x = similar(x)
112-
broadcast!(∂x, x) do xi
111+
∂x = broadcast(x) do xi
113112
a = norm(xi)
114113
∂xi = xi * ((a / y)^(p - 2) * c)
115114
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
@@ -174,8 +173,7 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
174173
end
175174

176175
function _norm1_back(x, y, Δy)
177-
∂x = similar(x)
178-
∂x .= sign.(x) .* real(Δy)
176+
∂x = sign.(x) .* real(Δy)
179177
return ∂x
180178
end
181179
function _norm1_back!(∂x, x, y, Δy)
@@ -207,8 +205,7 @@ function _norm2_forward(x, Δx, y)
207205
return ∂y
208206
end
209207
function _norm2_back(x, y, Δy)
210-
∂x = similar(x)
211-
∂x .= x .* (real(Δy) * pinv(y))
208+
∂x = x .* (real(Δy) * pinv(y))
212209
return ∂x
213210
end
214211
function _norm2_back!(∂x, x, y, Δy)

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
# frule_test(fnorm, (xp, ẋ))
5151
rrule_test(fnorm, ȳ, (xp, x̄))
5252
end
53+
T == Float64 && ndims(x) == 1 && @testset "Integer input" begin
54+
x = [1,2,3]
55+
_, int_back = rrule(fnorm, x)
56+
_, float_back = rrule(fnorm, float(x))
57+
@test unthunk(int_back(1.0)[2]) unthunk(float_back(1.0)[2])
58+
end
5359
end
5460
@testset "norm(x::Array{$T,$(length(sz))})" for
5561
T in (Float64, ComplexF64),
@@ -127,6 +133,12 @@
127133
= rand_tangent(fnorm(x, p))
128134
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
129135
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
136+
T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin
137+
x = [1,2,3]
138+
_, int_back = rrule(fnorm, x, p)
139+
_, float_back = rrule(fnorm, float(x), p)
140+
@test unthunk(unthunk(int_back(1.0)[2])) unthunk(unthunk(float_back(1.0)[2]))
141+
end
130142
end
131143
@testset "norm($fdual(::Vector{$T}), p)" for
132144
T in (Float64, ComplexF64),

0 commit comments

Comments
 (0)