Skip to content

Commit d58e420

Browse files
authored
Remove piracy of diagm (#475)
1 parent 8c7f7ed commit d58e420

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.7.0"
3+
version = "1.7.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/thunks.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,17 @@ LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a))
5959
LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo)
6060
LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo)
6161

62-
function LinearAlgebra.diagm(kv::Pair{<:Integer,<:AbstractThunk}...)
63-
return diagm((k => unthunk(v) for (k, v) in kv)...)
62+
function LinearAlgebra.diagm(
63+
kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}...
64+
)
65+
return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...)
6466
end
65-
function LinearAlgebra.diagm(m, n, kv::Pair{<:Integer,<:AbstractThunk}...)
66-
return diagm(m, n, (k => unthunk(v) for (k, v) in kv)...)
67+
function LinearAlgebra.diagm(
68+
m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}...
69+
)
70+
return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...)
6771
end
72+
6873
LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a))
6974
LinearAlgebra.tril(a::AbstractThunk, k) = tril(unthunk(a), k)
7075
LinearAlgebra.triu(a::AbstractThunk) = triu(unthunk(a))

test/tangent_types/thunks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@
138138
if VERSION >= v"1.2"
139139
@test diagm(0 => v) == diagm(0 => tv)
140140
@test diagm(3, 4, 0 => v) == diagm(3, 4, 0 => tv)
141+
# Check against accidential type piracy
142+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472
143+
@test Base.which(diagm, Tuple{}()).module != ChainRulesCore
144+
@test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore
141145
end
142146
@test tril(a) == tril(t)
143147
@test tril(a, 1) == tril(t, 1)

0 commit comments

Comments
 (0)