Skip to content

Commit 3ffd391

Browse files
author
Michael Abbott
committed
allow, and test, integer x
1 parent e3f31e1 commit 3ffd391

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ end
108108

109109
function _normp_back_x(x, p, y, Δy)
110110
c = real(Δy) / y
111-
∂x = similar(x)
112-
broadcast!(∂x, x) do xi
111+
∂x = broadcast(x) do xi
113112
a = norm(xi)
114113
∂xi = xi * ((a / y)^(p - 2) * c)
115114
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
@@ -174,8 +173,7 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
174173
end
175174

176175
function _norm1_back(x, y, Δy)
177-
∂x = similar(x)
178-
∂x .= sign.(x) .* real(Δy)
176+
∂x = sign.(x) .* real(Δy)
179177
return ∂x
180178
end
181179
function _norm1_back!(∂x, x, y, Δy)
@@ -207,8 +205,7 @@ function _norm2_forward(x, Δx, y)
207205
return ∂y
208206
end
209207
function _norm2_back(x, y, Δy)
210-
∂x = similar(x)
211-
∂x .= x .* (real(Δy) * pinv(y))
208+
∂x = x .* (real(Δy) * pinv(y))
212209
return ∂x
213210
end
214211
function _norm2_back!(∂x, x, y, Δy)

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
# frule_test(fnorm, (xp, ẋ))
5454
rrule_test(fnorm, ȳ, (xp, x̄))
5555
end
56+
T == Float64 && ndims(x) == 1 && @testset "Integer input" begin
57+
x = [1,2,3]
58+
_, int_back = rrule(fnorm, x)
59+
_, float_back = rrule(fnorm, float(x))
60+
@test unthunk(int_back(1.0)[2]) unthunk(float_back(1.0)[2])
61+
end
5662
end
5763
@testset "norm(x::Array{$T,$(length(sz))})" for
5864
T in (Float64, ComplexF64),
@@ -127,6 +133,12 @@
127133
end
128134
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
129135
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
136+
T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin
137+
x = [1,2,3]
138+
_, int_back = rrule(fnorm, x, p)
139+
_, float_back = rrule(fnorm, float(x), p)
140+
@test unthunk(unthunk(int_back(1.0)[2])) unthunk(unthunk(float_back(1.0)[2]))
141+
end
130142
end
131143
@testset "norm($fdual(::Vector{$T}), p)" for
132144
T in (Float64, ComplexF64),

0 commit comments

Comments
 (0)