Skip to content

Commit ce7a16e

Browse files
committed
some more improvements
- avoid assuming the tuples have the same eltype - coerce the inputs to simpler representations before printing, if unambiguous - update a couple more error messages with the same text - use string builder, instead of repeated concatenation
1 parent 74a626c commit ce7a16e

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

base/indices.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,33 @@ IndexStyle(::IndexStyle, ::IndexStyle) = IndexCartesian()
106106

107107
promote_shape(::Tuple{}, ::Tuple{}) = ()
108108

109-
# Consistent error message for promote_shape mismatch, hiding implementation details like
109+
# Consistent error message for promote_shape mismatch, hiding type details like
110110
# OneTo. When b ≡ nothing, it is omitted; i can be supplied for an index.
111-
function throw_promote_shape_mismatch(a::Tuple{T,Vararg{T}},
112-
b::Union{Nothing,Tuple{T,Vararg{T}}},
113-
i = nothing) where {T}
114-
_has_axes = T <: AbstractUnitRange
115-
_normalize(d) = map(x -> _has_axes ? (firstindex(x):lastindex(x)) : x, d)
116-
_things = _has_axes ? "axes" : "size"
117-
msg = "a has $(_things) $(_normalize(a))"
111+
function throw_promote_shape_mismatch(a::Tuple, b::Union{Nothing,Tuple}, i = nothing)
112+
if a isa Tuple{Vararg{Base.OneTo}} && (b === nothing || b isa Tuple{Vararg{Base.OneTo}})
113+
a = map(lastindex, a)::Dims
114+
b === nothing || (b = map(lastindex, b)::Dims)
115+
end
116+
_has_axes = !(a isa Dims && (b === nothing || b isa Dims))
117+
if _has_axes
118+
_normalize(d) = map(x -> firstindex(x):lastindex(x), d)
119+
a = _normalize(a)
120+
b === nothing || (b = _normalize(b))
121+
_things = "axes "
122+
else
123+
_things = "size "
124+
end
125+
msg = IOBuffer()
126+
print(msg, "a has ", _things)
127+
print(msg, a)
118128
if b nothing
119-
msg *= ", b has $(_things) $(_normalize(b))"
129+
print(msg, ", b has ", _things)
130+
print(msg, b)
120131
end
121132
if i nothing
122-
msg *= ", mismatch at $(i)"
133+
print(msg, ", mismatch at dim ", i)
123134
end
124-
throw(DimensionMismatch(msg))
135+
throw(DimensionMismatch(String(take!(msg))))
125136
end
126137

127138
function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, Ab
184184
end
185185
for d in ndims(Ay)+1:ndims(z)
186186
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
187-
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
187+
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
188188
axes(z), ", must have singleton at dim ", d)))
189189
end
190190
Ay .+ z
@@ -197,7 +197,7 @@ function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, Ab
197197
end
198198
for d in 3:ndims(z)
199199
# Similar error to (u*v) + z:
200-
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
200+
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
201201
axes(z), ", must have singleton at dim ", d)))
202202
end
203203
(u .* v) .+ z

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ end
251251
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
252252
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
253253
sz = size(A)
254-
all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match"))
254+
for B in Bs
255+
size(B) == sz || Base.throw_promote_shape_mismatch(sz, size(B))
256+
end
255257
return f.(A, Bs...)
256258
end

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ end
142142
@test map!(*, Z, X, Y) == broadcast(*, fX, fY)
143143
end
144144
end
145+
# these would be valid for broadcast, but not for map
146+
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)))
147+
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)), D)
148+
@test_throws DimensionMismatch map(+, D, D, Diagonal(rand(1)))
149+
@test_throws DimensionMismatch map(+, Diagonal(rand(1)), D, D)
145150
end
146151

147152
@testset "Issue #33397" begin

0 commit comments

Comments
 (0)