1
- import Base . Broadcast : Broadcasted, Extruded, BroadcastStyle, ArrayStyle
1
+ # broadcasting
2
2
3
- BroadcastStyle ( :: Type{<:CuArray} ) = ArrayStyle {CuArray} ()
3
+ using Base . Broadcast : BroadcastStyle, Broadcasted
4
4
5
- function Base. similar (bc:: Broadcasted{ArrayStyle{CuArray}} , :: Type{T} ) where T
5
+ struct CuArrayStyle{N} <: AbstractGPUArrayStyle{N} end
6
+ CuArrayStyle (:: Val{N} ) where N = CuArrayStyle {N} ()
7
+ CuArrayStyle {M} (:: Val{N} ) where {N,M} = CuArrayStyle {N} ()
8
+
9
+ BroadcastStyle (:: Type{<:CuArray{T,N}} ) where {T,N} = CuArrayStyle {N} ()
10
+
11
+ Base. similar (bc:: Broadcasted{CuArrayStyle{N}} , :: Type{T} ) where {N,T} =
6
12
similar (CuArray{T}, axes (bc))
7
- end
8
13
9
- function Base. similar (bc:: Broadcasted{ArrayStyle{CuArray }} , :: Type{T} , dims... ) where {T}
10
- similar ( CuArray{T}, dims... )
11
- end
14
+ Base. similar (bc:: Broadcasted{CuArrayStyle{N }} , :: Type{T} , dims... ) where {N,T} =
15
+ CuArray {T} (undef , dims... )
16
+
12
17
13
- # replace base functions with libdevice alternatives
14
- # TODO : do this with Cassette.jl
18
+ # # replace base functions with libdevice alternatives
15
19
16
20
cufunc (f) = f
17
21
cufunc (:: Type{T} ) where T = (x... ) -> T (x... ) # broadcasting type ctors isn't GPU compatible
18
22
19
- Broadcast. broadcasted (:: ArrayStyle{CuArray } , f, args... ) =
20
- Broadcasted {ArrayStyle{CuArray }} (cufunc (f), args, nothing )
23
+ Broadcast. broadcasted (:: CuArrayStyle{N } , f, args... ) where {N} =
24
+ Broadcasted {CuArrayStyle{N }} (cufunc (f), args, nothing )
21
25
22
- libdevice = :[
26
+ const libdevice = :[
23
27
cos, cospi, sin, sinpi, tan, acos, asin, atan,
24
28
cosh, sinh, tanh, acosh, asinh, atanh,
25
29
log, log10, log1p, log2, logb, ilogb,
@@ -40,7 +44,8 @@ for f in libdevice
40
44
@eval cufunc (:: typeof (Base.$ f)) = CUDAnative.$ f
41
45
end
42
46
43
- # broadcast ^
47
+ # broadcast ^
48
+
44
49
culiteral_pow (:: typeof (^ ), x:: T , :: Val{0} ) where {T<: Real } = one (x)
45
50
culiteral_pow (:: typeof (^ ), x:: T , :: Val{1} ) where {T<: Real } = x
46
51
culiteral_pow (:: typeof (^ ), x:: T , :: Val{2} ) where {T<: Real } = x * x
0 commit comments