Skip to content

Commit c5fcd73

Browse files
vpuri3maleadt
andauthored
Fix broadcast defaulting to Mem.Unified() (#2327)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 9c24777 commit c5fcd73

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/broadcast.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ BroadcastStyle(::CUDA.CuArrayStyle{N1, B1},
1616
::CUDA.CuArrayStyle{N2, B2}) where {N1,N2,B1,B2} =
1717
CuArrayStyle{max(N1,N2), Mem.Unified}()
1818

19+
# resolve ambiguity: different N, same buffer
20+
BroadcastStyle(::CUDA.CuArrayStyle{N1, B},
21+
::CUDA.CuArrayStyle{N2, B}) where {N1,N2,B} =
22+
CuArrayStyle{max(N1,N2), B}()
23+
1924
# allocation of output arrays
2025
Base.similar(bc::Broadcasted{CuArrayStyle{N,B}}, ::Type{T}, dims) where {T,N,B} =
2126
similar(CuArray{T,length(dims),B}, dims)

test/base/broadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,10 @@ end
6868
g = cu([1 2]; host=true)
6969
h = f .+ g
7070
@test is_unified(h)
71+
72+
# however, differences in only shape shouldn't change the buffer type
73+
i = cu([1]; device=true)
74+
j = cu([1 2]; device=true)
75+
k = i .+ j
76+
@test !is_unified(k)
7177
end

0 commit comments

Comments
 (0)