Skip to content

Commit 7bc38a0

Browse files
authored
hvncat: Better handling of 0- and 1-length dims/shape args (#41197)
1 parent 6240d35 commit 7bc38a0

File tree

2 files changed

+116
-15
lines changed

2 files changed

+116
-15
lines changed

base/abstractarray.jl

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,28 +2126,34 @@ julia> hvncat(((3, 3), (3, 3), (6,)), true, a, b, c, d, e, f)
21262126
4 = elements in each 4d slice (4,)
21272127
=> shape = ((2, 1, 1), (3, 1), (4,), (4,)) with `rowfirst` = true
21282128
"""
2129-
hvncat(::Tuple{}, ::Bool) = []
2130-
hvncat(::Tuple{}, ::Bool, xs...) = []
2131-
hvncat(::Tuple{Vararg{Any, 1}}, ::Bool, xs...) = vcat(xs...) # methods assume 2+ dimensions
21322129
hvncat(dimsshape::Tuple, row_first::Bool, xs...) = _hvncat(dimsshape, row_first, xs...)
21332130
hvncat(dim::Int, xs...) = _hvncat(dim, true, xs...)
21342131

2135-
_hvncat(::Union{Tuple, Int}, ::Bool) = []
2132+
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool) = _typed_hvncat(Any, dimsshape, row_first)
21362133
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs...) = _typed_hvncat(promote_eltypeof(xs...), dimsshape, row_first, xs...)
21372134
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::T...) where T<:Number = _typed_hvncat(T, dimsshape, row_first, xs...)
21382135
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::Number...) = _typed_hvncat(promote_typeof(xs...), dimsshape, row_first, xs...)
21392136
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray...) = _typed_hvncat(promote_eltype(xs...), dimsshape, row_first, xs...)
21402137
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray{T}...) where T = _typed_hvncat(T, dimsshape, row_first, xs...)
21412138

2142-
typed_hvncat(::Type{T}, ::Tuple{}, ::Bool) where T = Vector{T}()
2143-
typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs...) where T = Vector{T}()
2144-
typed_hvncat(T::Type, ::Tuple{Vararg{Any, 1}}, ::Bool, xs...) = typed_vcat(T, xs...) # methods assume 2+ dimensions
21452139
typed_hvncat(T::Type, dimsshape::Tuple, row_first::Bool, xs...) = _typed_hvncat(T, dimsshape, row_first, xs...)
21462140
typed_hvncat(T::Type, dim::Int, xs...) = _typed_hvncat(T, Val(dim), xs...)
21472141

2148-
_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool) where T = Vector{T}()
2149-
_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs...) where T = Vector{T}()
2150-
_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs::Number...) where T = Vector{T}()
2142+
# 1-dimensional hvncat methods
2143+
2144+
_typed_hvncat(::Type, ::Val{0}) = _typed_hvncat_0d_only_one()
2145+
_typed_hvncat(T::Type, ::Val{0}, x) = fill(convert(T, x))
2146+
_typed_hvncat(T::Type, ::Val{0}, x::Number) = fill(convert(T, x))
2147+
_typed_hvncat(T::Type, ::Val{0}, x::AbstractArray) = convert.(T, x)
2148+
_typed_hvncat(::Type, ::Val{0}, ::Any...) = _typed_hvncat_0d_only_one()
2149+
_typed_hvncat(::Type, ::Val{0}, ::Number...) = _typed_hvncat_0d_only_one()
2150+
_typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one()
2151+
2152+
_typed_hvncat_0d_only_one() =
2153+
throw(ArgumentError("a 0-dimensional array may only contain exactly one element"))
2154+
2155+
_typed_hvncat(::Type{T}, ::Val{N}) where {T, N} = Array{T, N}(undef, ntuple(x -> 0, Val(N)))
2156+
21512157
function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, xs::Number...) where {T, N}
21522158
A = Array{T, N}(undef, dims...)
21532159
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
@@ -2185,14 +2191,13 @@ function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
21852191
end
21862192

21872193
_typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters
2188-
_typed_hvncat(::Type{T}, ::Val) where T = Vector{T}()
2189-
_typed_hvncat(T::Type, ::Val{N}, xs::Number...) where N = _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(xs)), false, xs...)
21902194
function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
21912195
# optimization for arrays that can be concatenated by copying them linearly into the destination
21922196
# conditions: the elements must all have 1- or 0-length dimensions above N
21932197
for a as
21942198
ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) ||
2195-
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as)), false, as...)
2199+
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...)
2200+
# the extra 1 is to avoid an infinite cycle
21962201
end
21972202

21982203
nd = max(N, ndims(as[1]))
@@ -2246,6 +2251,31 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
22462251
return A
22472252
end
22482253

2254+
2255+
# 0-dimensional cases for balanced and unbalanced hvncat method
2256+
2257+
_typed_hvncat(T::Type, ::Tuple{}, ::Bool, x...) = _typed_hvncat(T, Val(0), x...)
2258+
_typed_hvncat(T::Type, ::Tuple{}, ::Bool, x::Number...) = _typed_hvncat(T, Val(0), x...)
2259+
2260+
2261+
# balanced dimensions hvncat methods
2262+
2263+
_typed_hvncat(T::Type, dims::Tuple{Int}, ::Bool, as...) = _typed_hvncat_1d(T, dims[1], Val(false), as...)
2264+
_typed_hvncat(T::Type, dims::Tuple{Int}, ::Bool, as::Number...) = _typed_hvncat_1d(T, dims[1], Val(false), as...)
2265+
2266+
function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T, row_first}
2267+
lengthas = length(as)
2268+
ds > 0 ||
2269+
throw(ArgumentError("`dimsshape` argument must consist of positive integers"))
2270+
lengthas == ds ||
2271+
throw(ArgumentError("number of elements does not match `dimshape` argument; expected $ds, got $lengthas"))
2272+
if row_first
2273+
return _typed_hvncat(T, Val(2), as...)
2274+
else
2275+
return _typed_hvncat(T, Val(1), as...)
2276+
end
2277+
end
2278+
22492279
function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, as...) where {T, N}
22502280
d1 = row_first ? 2 : 1
22512281
d2 = row_first ? 1 : 2
@@ -2308,7 +2338,16 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,
23082338
return A
23092339
end
23102340

2311-
function _typed_hvncat(::Type{T}, shape::Tuple{Vararg{Tuple, N}}, row_first::Bool, as...) where {T, N}
2341+
2342+
# unbalanced dimensions hvncat methods
2343+
2344+
function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...)
2345+
length(shape[1]) > 0 ||
2346+
throw(ArgumentError("each level of `shape` argument must have at least one value"))
2347+
return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...)
2348+
end
2349+
2350+
function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N}
23122351
d1 = row_first ? 2 : 1
23132352
d2 = row_first ? 1 : 2
23142353
shape = collect(shape) # saves allocations later

test/abstractarray.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,7 @@ end
13421342
end
13431343
end
13441344

1345+
using Base: typed_hvncat
13451346
@testset "hvncat" begin
13461347
a = fill(1, (2,3,2,4,5))
13471348
b = fill(2, (1,1,2,4,5))
@@ -1389,7 +1390,68 @@ end
13891390
@test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2)
13901391
end
13911392

1392-
@test_throws BoundsError hvncat(((1, 2), (3,)), false, zeros(Int, 0, 0, 0), 7, 8)
1393+
# 0-dimension behaviors
1394+
# exactly one argument, placed in an array
1395+
# if already an array, copy, with type conversion as necessary
1396+
@test_throws ArgumentError hvncat(0)
1397+
@test hvncat(0, 1) == fill(1)
1398+
@test hvncat(0, [1]) == [1]
1399+
@test_throws ArgumentError hvncat(0, 1, 1)
1400+
@test_throws ArgumentError typed_hvncat(Float64, 0)
1401+
@test typed_hvncat(Float64, 0, 1) == fill(1.0)
1402+
@test typed_hvncat(Float64, 0, [1]) == Float64[1.0]
1403+
@test_throws ArgumentError typed_hvncat(Float64, 0, 1, 1)
1404+
@test_throws ArgumentError hvncat((), true) == []
1405+
@test hvncat((), true, 1) == fill(1)
1406+
@test hvncat((), true, [1]) == [1]
1407+
@test_throws ArgumentError hvncat((), true, 1, 1)
1408+
@test_throws ArgumentError typed_hvncat(Float64, (), true) == Float64[]
1409+
@test typed_hvncat(Float64, (), true, 1) == fill(1.0)
1410+
@test typed_hvncat(Float64, (), true, [1]) == [1.0]
1411+
@test_throws ArgumentError typed_hvncat(Float64, (), true, 1, 1)
1412+
1413+
# 1-dimension behaviors
1414+
# int form
1415+
@test hvncat(1) == []
1416+
@test hvncat(1, 1) == [1]
1417+
@test hvncat(1, [1]) == [1]
1418+
@test hvncat(1, [1 2; 3 4]) == [1 2; 3 4]
1419+
@test hvncat(1, 1, 1) == [1 ; 1]
1420+
@test typed_hvncat(Float64, 1) == Float64[]
1421+
@test typed_hvncat(Float64, 1, 1) == Float64[1.0]
1422+
@test typed_hvncat(Float64, 1, [1]) == Float64[1.0]
1423+
@test typed_hvncat(Float64, 1, 1, 1) == Float64[1.0 ; 1.0]
1424+
# dims form
1425+
@test_throws ArgumentError hvncat((1,), true)
1426+
@test hvncat((2,), true, 1, 1) == [1; 1]
1427+
@test hvncat((2,), true, [1], [1]) == [1; 1]
1428+
@test_throws ArgumentError hvncat((2,), true, 1)
1429+
@test typed_hvncat(Float64, (2,), true, 1, 1) == Float64[1.0; 1.0]
1430+
@test typed_hvncat(Float64, (2,), true, [1], [1]) == Float64[1.0; 1.0]
1431+
@test_throws ArgumentError typed_hvncat(Float64, (2,), true, 1)
1432+
# row_first has no effect with just one dimension of the dims form
1433+
@test hvncat((2,), false, 1, 1) == [1; 1]
1434+
@test typed_hvncat(Float64, (2,), false, 1, 1) == Float64[1.0; 1.0]
1435+
# shape form
1436+
@test hvncat(((2,),), true, 1, 1) == [1 1]
1437+
@test hvncat(((2,),), true, [1], [1]) == [1 1]
1438+
@test_throws ArgumentError hvncat(((2,),), true, 1)
1439+
@test hvncat(((2,),), false, 1, 1) == [1; 1]
1440+
@test hvncat(((2,),), false, [1], [1]) == [1; 1]
1441+
@test typed_hvncat(Float64, ((2,),), true, 1, 1) == Float64[1.0 1.0]
1442+
@test typed_hvncat(Float64, ((2,),), true, [1], [1]) == Float64[1.0 1.0]
1443+
@test_throws ArgumentError typed_hvncat(Float64, ((2,),), true, 1)
1444+
@test typed_hvncat(Float64, ((2,),), false, 1, 1) == Float64[1.0; 1.0]
1445+
@test typed_hvncat(Float64, ((2,),), false, [1], [1]) == Float64[1.0; 1.0]
1446+
1447+
# zero-value behaviors for int form above dimension zero
1448+
# e.g. [;;], [;;;], though that isn't valid syntax
1449+
@test [] == hvncat(1) isa Array{Any, 1}
1450+
@test Array{Any, 2}(undef, 0, 0) == hvncat(2) isa Array{Any, 2}
1451+
@test Array{Any, 3}(undef, 0, 0, 0) == hvncat(3) isa Array{Any, 3}
1452+
@test Int[] == typed_hvncat(Int, 1) isa Array{Int, 1}
1453+
@test Array{Int, 2}(undef, 0, 0) == typed_hvncat(Int, 2) isa Array{Int, 2}
1454+
@test Array{Int, 3}(undef, 0, 0, 0) == typed_hvncat(Int, 3) isa Array{Int, 3}
13931455
end
13941456

13951457
@testset "keepat!" begin

0 commit comments

Comments
 (0)