Skip to content

Commit f885295

Browse files
committed
Matrix(Diagonal)
1 parent 7dd6ede commit f885295

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/lib/array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ end
420420
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
421421
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
422422

423+
# fix Matrix(Diagonal(param([1,2,3]))) after https://github.com/JuliaLang/julia/pull/44615
424+
(::Type{Matrix})(d::Diagonal{<:Any,<:TrackedArray}) = diagm(0 => d.diag)
425+
423426
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
424427
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
425428
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)

test/tracker.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ end
172172
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
173173

174174
@test gradtest(x -> diagm(0 => x), rand(3))
175+
@test gradtest(x -> Matrix(Diagonal(x)), rand(3))
175176

176177
@test gradtest(W -> inv(log.(W * W)), (5,5))
177178
@test gradtest((A, B) -> A / B , (1,5), (5,5))

0 commit comments

Comments
 (0)