Skip to content

Commit 299c1bf

Browse files
committed
Work around type-instability in broadcast with 32-bit axes.
1 parent ac0f7d7 commit 299c1bf

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/broadcast.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,14 @@ BroadcastStyle(::CUDA.CuArrayStyle{N, B1},
1919
# allocation of output arrays
2020
Base.similar(bc::Broadcasted{CuArrayStyle{N,B}}, ::Type{T}, dims) where {T,N,B} =
2121
similar(CuArray{T,length(dims),B}, dims)
22+
23+
# Base.Broadcast can't handle Int32 axes
24+
# XXX: not using a quirk, as constprop/irinterpret is crucial here
25+
# XXX: 1.11 uses to_index i nstead of CartesianIndex
26+
Base.@propagate_inbounds Broadcast.newindex(arg::AnyCuDeviceArray, I::CartesianIndex) = CartesianIndex(_newindex(axes(arg), I.I))
27+
Base.@propagate_inbounds Broadcast.newindex(arg::AnyCuDeviceArray, I::Integer) = CartesianIndex(_newindex(axes(arg), (I,)))
28+
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = # XXX: upstream this?
29+
(ifelse(length(ax[1]) == 1, promote(ax[1][1], I[1])...), _newindex(Base.tail(ax), Base.tail(I))...)
30+
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = ()
31+
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(Base.tail(ax), ())...)
32+
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()

src/device/array.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ end
4343
const CuDeviceVector = CuDeviceArray{T,1,A} where {T,A}
4444
const CuDeviceMatrix = CuDeviceArray{T,2,A} where {T,A}
4545

46+
# anything that's (secretly) backed by a CuArray
47+
const AnyCuDeviceArray{T,N} = Union{CuDeviceArray{T,N}, WrappedArray{T,N,CuDeviceArray,CuDeviceArray{T,N}}}
48+
const AnyCuDeviceVector{T} = AnyCuDeviceArray{T,1}
49+
const AnyCuDeviceMatrix{T} = AnyCuDeviceArray{T,2}
50+
const AnyCuDeviceVecOrMat{T} = Union{AnyCuDeviceVector{T}, AnyCuDeviceMatrix{T}}
51+
4652

4753
## array interface
4854

0 commit comments

Comments
 (0)