Skip to content

Commit 8ec4df1

Browse files
authored
Make v and hcat with numbers work. (#379)
1 parent 6481ae2 commit 8ec4df1

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/host/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,13 @@ function PermutedDimsArrays._copy!(dest::PermutedDimsArray{T,N,<:Any,<:Any,<:Abs
5050
dest .= src
5151
dest
5252
end
53+
54+
## concatenation
55+
56+
# hacky overloads to make simple vcat and hcat with numbers work as expected.
57+
# we can't really make this work in general without Base providing
58+
# a dispatch mechanism for output container type.
59+
@inline Base.vcat(a::Number, b::AbstractGPUArray) =
60+
vcat(fill!(similar(b, typeof(a), (1,size(b)[2:end]...)), a), b)
61+
@inline Base.hcat(a::Number, b::AbstractGPUArray) =
62+
hcat(fill!(similar(b, typeof(a), (size(b)[1:end-1]...,1)), a), b)

test/testsuite/base.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,24 @@ end
115115
end
116116
end
117117

118-
@testset "vcat + hcat" begin
119-
@test compare(vcat, AT, fill(0f0, (10, 10)), rand(Float32, 20, 10))
120-
@test compare(hcat, AT, fill(0f0, (10, 10)), rand(Float32, 10, 10))
121-
122-
@test compare(hcat, AT, rand(Float32, 3, 3), rand(Float32, 3, 3))
123-
@test compare(vcat, AT, rand(Float32, 3, 3), rand(Float32, 3, 3))
124-
@test compare((a,b) -> cat(a, b; dims=4), AT, rand(Float32, 3, 4), rand(Float32, 3, 4))
118+
@testset "cat" begin
119+
@test compare(hcat, AT, rand(3), rand(3))
120+
@test compare(hcat, AT, rand(), rand(1, 3))
121+
@test compare(hcat, AT, rand(1, 3), rand())
122+
@test compare(hcat, AT, rand(3), rand(3, 3))
123+
@test compare(hcat, AT, rand(3, 3), rand(3))
124+
@test compare(hcat, AT, rand(3, 3), rand(3, 3))
125+
#@test compare(hcat, AT, rand(), rand(3, 3))
126+
#@test compare(hcat, AT, rand(3, 3), rand())
127+
128+
@test compare(vcat, AT, rand(3), rand(3))
129+
@test compare(vcat, AT, rand(3, 3), rand(3, 3))
130+
@test compare(vcat, AT, rand(), rand(3))
131+
@test compare(vcat, AT, rand(3), rand())
132+
@test compare(vcat, AT, rand(), rand(3, 3))
133+
#@test compare(vcat, AT, rand(3, 3), rand())
134+
135+
@test compare((a,b) -> cat(a, b; dims=4), AT, rand(3, 4), rand(3, 4))
125136
end
126137

127138
@testset "reshape" begin

0 commit comments

Comments
 (0)