@@ -546,7 +546,50 @@ Adapt.adapt_storage(::CuArrayAdaptor{B}, xs::AbstractArray{T,N}) where {T<:Compl
546
546
Adapt. adapt_storage (:: CuArrayAdaptor{B} , xs:: AbstractArray{T,N} ) where {T<: Union{Float16,BFloat16} ,N,B} =
547
547
isbits (xs) ? xs : CuArray {T,N,B} (xs)
548
548
549
+ """
550
+ cu(A; unified=false)
551
+
552
+ Opinionated GPU array adaptor, which may alter the element type `T` of arrays:
553
+ * For `T<:AbstractFloat`, it makes a `CuArray{Float32}` for performance reasons.
554
+ (Except that `Float16` and `BFloat16` element types are not changed.)
555
+ * For `T<:Complex{<:AbstractFloat}` it makes a `CuArray{ComplexF32}`.
556
+ * For other `isbitstype(T)`, it makes a `CuArray{T}`.
557
+
558
+ By contrast, `CuArray(A)` never changes the element type.
559
+
560
+ Uses Adapt.jl to act inside some wrapper structs.
561
+
562
+ # Examples
563
+
564
+ ```
565
+ julia> cu(ones(3)')
566
+ 1×3 adjoint(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}) with eltype Float32:
567
+ 1.0 1.0 1.0
568
+
569
+ julia> cu(zeros(1, 3); unified=true)
570
+ 1×3 CuArray{Float32, 2, CUDA.Mem.UnifiedBuffer}:
571
+ 0.0 0.0 0.0
572
+
573
+ julia> cu(1:3)
574
+ 1:3
575
+
576
+ julia> CuArray(ones(3)') # ignores Adjoint, preserves Float64
577
+ 1×3 CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}:
578
+ 1.0 1.0 1.0
579
+
580
+ julia> adapt(CuArray, ones(3)') # this restores Adjoint wrapper
581
+ 1×3 adjoint(::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}) with eltype Float64:
582
+ 1.0 1.0 1.0
583
+
584
+ julia> CuArray(1:3)
585
+ 3-element CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}:
586
+ 1
587
+ 2
588
+ 3
589
+ ```
590
+ """
549
591
@inline cu (xs; unified:: Bool = false ) = adapt (CuArrayAdaptor {unified ? Mem.UnifiedBuffer : Mem.DeviceBuffer} (), xs)
592
+
550
593
Base. getindex (:: typeof (cu), xs... ) = CuArray ([xs... ])
551
594
552
595
0 commit comments