Skip to content

Commit 0044ea2

Browse files
authored
Fix mixed-buffer/mixed-shape broadcasts. (#2290)
1 parent 2ae7d33 commit 0044ea2

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ BroadcastStyle(W::Type{<:AnyCuArray{T,N}}) where {T,N} =
1212

1313
# when we are dealing with different buffer styles, we cannot know
1414
# which one is better, so use unified memory
15-
BroadcastStyle(::CUDA.CuArrayStyle{N, B1},
16-
::CUDA.CuArrayStyle{N, B2}) where {N,B1,B2} =
17-
CuArrayStyle{N, Mem.Unified}()
15+
BroadcastStyle(::CUDA.CuArrayStyle{N1, B1},
16+
::CUDA.CuArrayStyle{N2, B2}) where {N1,N2,B1,B2} =
17+
CuArrayStyle{max(N1,N2), Mem.Unified}()
1818

1919
# allocation of output arrays
2020
Base.similar(bc::Broadcasted{CuArrayStyle{N,B}}, ::Type{T}, dims) where {T,N,B} =

test/base/broadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,10 @@ end
6262
d = cu([1]; host=true)
6363
e = c .+ d
6464
@test is_unified(e)
65+
66+
# this should also work with differently-sized inputs
67+
f = cu([1]; device=true)
68+
g = cu([1 2]; host=true)
69+
h = f .+ g
70+
@test is_unified(h)
6571
end

0 commit comments

Comments
 (0)