Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 9080140

Browse files
committed
Adapt to broadcast changes in GPUArrays.
1 parent c108dde commit 9080140

File tree

3 files changed

+24
-27
lines changed

3 files changed

+24
-27
lines changed

Manifest.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ version = "5.0.1"
4040

4141
[[CUDAnative]]
4242
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
43-
git-tree-sha1 = "e0c2805c9a7d338823c0d8f574242e284410fa61"
43+
git-tree-sha1 = "f38cf81a0c6c08cb532b7324eb85f3e2d438bed7"
44+
repo-rev = "0f65bc6ebc9c490c068d65436dc42b93c7650fce"
45+
repo-url = "https://github.com/JuliaGPU/CUDAnative.jl.git"
4446
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
45-
version = "2.9.1"
47+
version = "2.10.0"
4648

4749
[[DataStructures]]
4850
deps = ["InteractiveUtils", "OrderedCollections"]
@@ -60,8 +62,8 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6062

6163
[[GPUArrays]]
6264
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
63-
git-tree-sha1 = "64bf1d9f7634b35e1383955213cafbda3a9ac184"
64-
repo-rev = "a89fa3c84e97488f88449036d6a860db689669f2"
65+
git-tree-sha1 = "1bc994886f5c6fa82e60da92debfc9be79400993"
66+
repo-rev = "cc7bb3030ed18c8e545539c540b675171ff1b73e"
6567
repo-url = "https://github.com/JuliaGPU/GPUArrays.jl.git"
6668
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6769
version = "2.0.1"

src/broadcast.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1-
import Base.Broadcast: Broadcasted, Extruded, BroadcastStyle, ArrayStyle
1+
# broadcasting
22

3-
BroadcastStyle(::Type{<:CuArray}) = ArrayStyle{CuArray}()
3+
using Base.Broadcast: BroadcastStyle, Broadcasted
44

5-
function Base.similar(bc::Broadcasted{ArrayStyle{CuArray}}, ::Type{T}) where T
5+
struct CuArrayStyle{N} <: AbstractGPUArrayStyle{N} end
6+
CuArrayStyle(::Val{N}) where N = CuArrayStyle{N}()
7+
CuArrayStyle{M}(::Val{N}) where {N,M} = CuArrayStyle{N}()
8+
9+
BroadcastStyle(::Type{<:CuArray{T,N}}) where {T,N} = CuArrayStyle{N}()
10+
11+
Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =
612
similar(CuArray{T}, axes(bc))
7-
end
813

9-
function Base.similar(bc::Broadcasted{ArrayStyle{CuArray}}, ::Type{T}, dims...) where {T}
10-
similar(CuArray{T}, dims...)
11-
end
14+
Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
15+
CuArray{T}(undef, dims...)
16+
1217

13-
# replace base functions with libdevice alternatives
14-
# TODO: do this with Cassette.jl
18+
## replace base functions with libdevice alternatives
1519

1620
cufunc(f) = f
1721
cufunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible
1822

19-
Broadcast.broadcasted(::ArrayStyle{CuArray}, f, args...) =
20-
Broadcasted{ArrayStyle{CuArray}}(cufunc(f), args, nothing)
23+
Broadcast.broadcasted(::CuArrayStyle{N}, f, args...) where {N} =
24+
Broadcasted{CuArrayStyle{N}}(cufunc(f), args, nothing)
2125

22-
libdevice = :[
26+
const libdevice = :[
2327
cos, cospi, sin, sinpi, tan, acos, asin, atan,
2428
cosh, sinh, tanh, acosh, asinh, atanh,
2529
log, log10, log1p, log2, logb, ilogb,
@@ -40,7 +44,8 @@ for f in libdevice
4044
@eval cufunc(::typeof(Base.$f)) = CUDAnative.$f
4145
end
4246

43-
#broadcast ^
47+
# broadcast ^
48+
4449
culiteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x)
4550
culiteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x
4651
culiteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x

test/base.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,6 @@ end
170170
@test testf(x -> reduce(|, x, init=false), map(t -> t > 0.5, cu(rand(2, 3))))
171171
end
172172

173-
@testset "0D" begin
174-
x = CuArray{Float64}(undef)
175-
x .= 1
176-
@test collect(x)[] == 1
177-
# broken test that throws
178-
# https://github.com/JuliaGPU/GPUArrays.jl/issues/204
179-
@test_throws ErrorException x /= 2
180-
#@test collect(x)[] == 0.5
181-
end
182-
183173
@testset "SubArray" begin
184174
@test testf(rand(5)) do x
185175
y = x[2:4]

0 commit comments

Comments
 (0)