Skip to content

Commit a61b4e0

Browse files
committed
OffsetArray support for cat/vcat/hcat
1 parent c2fd49e commit a61b4e0

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

base/abstractarray.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,13 +1442,13 @@ function _typed_vcat(::Type{T}, V::AbstractVecOrTuple{AbstractVector}) where T
14421442
for Vk in V
14431443
n += Int(length(Vk))::Int
14441444
end
1445-
a = similar(V[1], T, n)
1446-
pos = 1
1447-
for k=1:Int(length(V))::Int
1445+
a = similar(first(V), T, n)
1446+
pos = first(axes(a, 1))
1447+
for k = eachindex(V)
14481448
Vk = V[k]
1449-
p1 = pos + Int(length(Vk))::Int - 1
1450-
a[pos:p1] = Vk
1451-
pos = p1+1
1449+
n = length(Vk)
1450+
copyto!(a, pos, Vk, first(axes(Vk, 1)), n)
1451+
pos += n
14521452
end
14531453
a
14541454
end
@@ -1459,11 +1459,10 @@ hcat(A::AbstractVecOrMat...) = typed_hcat(promote_eltype(A...), A...)
14591459
hcat(A::AbstractVecOrMat{T}...) where {T} = typed_hcat(T, A...)
14601460

14611461
function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
1462-
nargs = length(A)
1463-
nrows = size(A[1], 1)
1462+
nrows = size(first(A), 1)
14641463
ncols = 0
14651464
dense = true
1466-
for j = 1:nargs
1465+
for j = eachindex(A)
14671466
Aj = A[j]
14681467
if size(Aj, 1) != nrows
14691468
throw(ArgumentError("number of rows of each array must match (got $(map(x->size(x,1), A)))"))
@@ -1472,17 +1471,17 @@ function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
14721471
nd = ndims(Aj)
14731472
ncols += (nd==2 ? size(Aj,2) : 1)
14741473
end
1475-
B = similar(A[1], T, nrows, ncols)
1476-
pos = 1
1474+
B = similar(first(A), T, nrows, ncols)
1475+
pos = first(axes(B, 1))
14771476
if dense
1478-
for k=1:nargs
1477+
for k=eachindex(A)
14791478
Ak = A[k]
14801479
n = length(Ak)
14811480
copyto!(B, pos, Ak, 1, n)
14821481
pos += n
14831482
end
14841483
else
1485-
for k=1:nargs
1484+
for k=eachindex(A)
14861485
Ak = A[k]
14871486
p1 = pos+(isa(Ak,AbstractMatrix) ? size(Ak, 2) : 1)-1
14881487
B[:, pos:p1] = Ak
@@ -1496,17 +1495,16 @@ vcat(A::AbstractVecOrMat...) = typed_vcat(promote_eltype(A...), A...)
14961495
vcat(A::AbstractVecOrMat{T}...) where {T} = typed_vcat(T, A...)
14971496

14981497
function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
1499-
nargs = length(A)
15001498
nrows = sum(a->size(a, 1), A)::Int
1501-
ncols = size(A[1], 2)
1502-
for j = 2:nargs
1499+
ncols = size(first(A), 2)
1500+
for j = first(axes(A))[2:end]
15031501
if size(A[j], 2) != ncols
15041502
throw(ArgumentError("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
15051503
end
15061504
end
1507-
B = similar(A[1], T, nrows, ncols)
1508-
pos = 1
1509-
for k=1:nargs
1505+
B = similar(first(A), T, nrows, ncols)
1506+
pos = first(axes(B, 1))
1507+
for k=eachindex(A)
15101508
Ak = A[k]
15111509
p1 = pos+size(Ak,1)::Int-1
15121510
B[pos:p1, :] = Ak
@@ -1589,7 +1587,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
15891587
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
15901588
catdims = dims2cat(dims)
15911589
shape = cat_shape(catdims, map(cat_size, X)::Tuple{Vararg{Union{Int,Dims}}})::Dims
1592-
A = cat_similar(X[1], T, shape)
1590+
A = cat_similar(first(X), T, shape)
15931591
if count(!iszero, catdims)::Int > 1
15941592
fill!(A, zero(T))
15951593
end
@@ -1604,7 +1602,7 @@ function __cat(A, shape::NTuple{M,Int}, catdims, X...) where M
16041602
for x in X
16051603
for i = 1:N
16061604
if concat[i]
1607-
inds[i] = offsets[i] .+ cat_indices(x, i)
1605+
inds[i] = offsets[i] .+ parent(cat_indices(x, i))
16081606
offsets[i] += cat_size(x, i)
16091607
else
16101608
inds[i] = 1:shape[i]

test/abstractarray.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,52 @@ function test_cat(::Type{TestAbstractArray})
616616
@test cat([1], [1], dims=[1, 2]) == I(2)
617617
end
618618

619+
module TestOffsetArraysCats
620+
using Test
621+
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath(@__DIR__, "testhelpers", "OffsetArrays.jl"))
622+
using .Main.OffsetArrays
623+
624+
# `cat`s on OffsetArrays ignore their offsets and treat them as normal list
625+
626+
# 1d
627+
v1 = collect(1:4)
628+
v2 = collect(5:8)
629+
ov1 = OffsetArray(v1, -1)
630+
ov2 = OffsetArray(v2, 1)
631+
@test hcat(ov1, v1, ov2, v2) == hcat(v1, v1, v2, v2)
632+
@test vcat(ov1, v1, ov2, v2) == vcat(v1, v1, v2, v2)
633+
@test hvcat((2, 2), ov1, v2, v1, ov2) == hvcat((2, 2), v1, v2, v1, v2)
634+
# 37628
635+
@test reduce(hcat, (v1, v2)) == hcat(v1, v2)
636+
@test reduce(vcat, (v1, v2)) == vcat(v1, v2)
637+
@test reduce(hcat, OffsetVector([1:2, 1:2],10)) == [1 1;2 2]
638+
639+
# 2d
640+
a1 = reshape(collect(1:6), 2, 3)
641+
a2 = reshape(collect(7:12), 2, 3)
642+
oa1 = OffsetArray(a1, -1, -1)
643+
oa2 = OffsetArray(a2, 1, 1)
644+
@test hcat(oa1, a1, oa2, a2) == hcat(a1, a1, a2, a2)
645+
@test vcat(oa1, a1, oa2, a2) == vcat(a1, a1, a2, a2)
646+
@test hvcat((2, 2), oa1, a2, a1, oa2) == hvcat((2, 2), a1, a2, a1, a2)
647+
648+
# 3d
649+
a1 = reshape(collect(1:12), 2, 3, 2)
650+
a2 = reshape(collect(13:24), 2, 3, 2)
651+
oa1 = OffsetArray(a1, -1, -1, -1)
652+
oa2 = OffsetArray(a2, 1, 1, 1)
653+
@test hcat(oa1, a1, oa2, a2) == hcat(a1, a1, a2, a2)
654+
@test vcat(oa1, a1, oa2, a2) == vcat(a1, a1, a2, a2)
655+
@test hvcat((2, 2), oa1, a2, a1, oa2) == hvcat((2, 2), a1, a2, a1, a2)
656+
# https://github.com/JuliaArrays/OffsetArrays.jl/issues/63
657+
form=OffsetArray(reshape(zeros(Int8,0),0,0,2),0:-1,0:-1,0:1)
658+
exp=OffsetArray(reshape(zeros(Int8,0),0,16,2),0:-1,0:15,0:1)
659+
@test size(hcat(form,exp)) == (0, 16, 2)
660+
# 37493
661+
@test hcat(zeros(2, 1:1, 2), zeros(2, 2:3, 2)) == zeros(2, 3, 2)
662+
@test vcat(zeros(1:1, 2, 2), zeros(2:3, 2, 2)) == zeros(3, 2, 2)
663+
end
664+
619665
function test_ind2sub(::Type{TestAbstractArray})
620666
n = rand(2:5)
621667
dims = tuple(rand(1:5, n)...)

0 commit comments

Comments
 (0)