Skip to content

Commit 0a46a13

Browse files
authored
Docstring for cu (#1493)
1 parent 3f10276 commit 0a46a13

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/array.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,50 @@ Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Compl
546546
Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Union{Float16,BFloat16},N,B} =
547547
isbits(xs) ? xs : CuArray{T,N,B}(xs)
548548

549+
"""
550+
cu(A; unified=false)
551+
552+
Opinionated GPU array adaptor, which may alter the element type `T` of arrays:
553+
* For `T<:AbstractFloat`, it makes a `CuArray{Float32}` for performance reasons.
554+
(Except that `Float16` and `BFloat16` element types are not changed.)
555+
* For `T<:Complex{<:AbstractFloat}` it makes a `CuArray{ComplexF32}`.
556+
* For other `isbitstype(T)`, it makes a `CuArray{T}`.
557+
558+
By contrast, `CuArray(A)` never changes the element type.
559+
560+
Uses Adapt.jl to act inside some wrapper structs.
561+
562+
# Examples
563+
564+
```
565+
julia> cu(ones(3)')
566+
1×3 adjoint(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}) with eltype Float32:
567+
1.0 1.0 1.0
568+
569+
julia> cu(zeros(1, 3); unified=true)
570+
1×3 CuArray{Float32, 2, CUDA.Mem.UnifiedBuffer}:
571+
0.0 0.0 0.0
572+
573+
julia> cu(1:3)
574+
1:3
575+
576+
julia> CuArray(ones(3)') # ignores Adjoint, preserves Float64
577+
1×3 CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}:
578+
1.0 1.0 1.0
579+
580+
julia> adapt(CuArray, ones(3)') # this restores Adjoint wrapper
581+
1×3 adjoint(::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}) with eltype Float64:
582+
1.0 1.0 1.0
583+
584+
julia> CuArray(1:3)
585+
3-element CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}:
586+
1
587+
2
588+
3
589+
```
590+
"""
549591
@inline cu(xs; unified::Bool=false) = adapt(CuArrayAdaptor{unified ? Mem.UnifiedBuffer : Mem.DeviceBuffer}(), xs)
592+
550593
Base.getindex(::typeof(cu), xs...) = CuArray([xs...])
551594

552595

0 commit comments

Comments
 (0)