Skip to content

Commit 2ecd820

Browse files
authored
Merge pull request #2228 from JuliaGPU/tb/cutensor
cuTensor fixes
2 parents 2b97ab2 + 533e9c3 commit 2ecd820

File tree

4 files changed

+52
-21
lines changed

4 files changed

+52
-21
lines changed

lib/cutensor/src/operations.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function elementwise_trinary!(
2424
opABC::cutensorOperator_t;
2525
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
2626
algo::cutensorAlgo_t=ALGO_DEFAULT,
27-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing,
27+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
2828
plan::Union{CuTensorPlan, Nothing}=nothing)
2929

3030
actual_compute_type = if compute_type === nothing
@@ -66,7 +66,7 @@ function plan_elementwise_trinary(
6666
jit::cutensorJitMode_t=JIT_MODE_NONE,
6767
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
6868
algo::cutensorAlgo_t=ALGO_DEFAULT,
69-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing)
69+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)
7070
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
7171
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
7272
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
@@ -96,7 +96,7 @@ function plan_elementwise_trinary(
9696
descC, modeC, opC,
9797
descD, modeD,
9898
opAB, opABC,
99-
compute_type)
99+
actual_compute_type)
100100

101101
plan_pref = Ref{cutensorPlanPreference_t}()
102102
cutensorCreatePlanPreference(handle(), plan_pref, algo, jit)
@@ -112,7 +112,7 @@ function elementwise_binary!(
112112
@nospecialize(D::DenseCuArray), Dinds::ModeType, opAC::cutensorOperator_t;
113113
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
114114
algo::cutensorAlgo_t=ALGO_DEFAULT,
115-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing,
115+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
116116
plan::Union{CuTensorPlan, Nothing}=nothing)
117117

118118
actual_compute_type = if compute_type === nothing
@@ -150,7 +150,7 @@ function plan_elementwise_binary(
150150
jit::cutensorJitMode_t=JIT_MODE_NONE,
151151
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
152152
algo::cutensorAlgo_t=ALGO_DEFAULT,
153-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=eltype(C))
153+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=eltype(C))
154154
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
155155
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
156156
!is_binary(opAC) && throw(ArgumentError("opAC must be a binary op!"))
@@ -189,7 +189,7 @@ function permutation!(
189189
@nospecialize(B::DenseCuArray), Binds::ModeType;
190190
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
191191
algo::cutensorAlgo_t=ALGO_DEFAULT,
192-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing,
192+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
193193
plan::Union{CuTensorPlan, Nothing}=nothing)
194194

195195
actual_compute_type = if compute_type === nothing
@@ -224,8 +224,7 @@ function plan_permutation(
224224
jit::cutensorJitMode_t=JIT_MODE_NONE,
225225
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
226226
algo::cutensorAlgo_t=ALGO_DEFAULT,
227-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing)
228-
#!is_unary(opPsi) && throw(ArgumentError("opPsi must be a unary op!"))
227+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)
229228
descA = CuTensorDescriptor(A)
230229
descB = CuTensorDescriptor(B)
231230

@@ -260,7 +259,7 @@ function contraction!(
260259
jit::cutensorJitMode_t=JIT_MODE_NONE,
261260
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
262261
algo::cutensorAlgo_t=ALGO_DEFAULT,
263-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing,
262+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
264263
plan::Union{CuTensorPlan, Nothing}=nothing)
265264

266265
actual_compute_type = if compute_type === nothing
@@ -269,7 +268,6 @@ function contraction!(
269268
compute_type
270269
end
271270

272-
# XXX: save these as parameters of the plan?
273271
actual_plan = if plan === nothing
274272
plan_contraction(A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut;
275273
jit, workspace, algo, compute_type=actual_compute_type)
@@ -298,7 +296,7 @@ function plan_contraction(
298296
jit::cutensorJitMode_t=JIT_MODE_NONE,
299297
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
300298
algo::cutensorAlgo_t=ALGO_DEFAULT,
301-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing)
299+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)
302300
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
303301
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
304302
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
@@ -340,7 +338,7 @@ function reduction!(
340338
opReduce::cutensorOperator_t;
341339
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
342340
algo::cutensorAlgo_t=ALGO_DEFAULT,
343-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing,
341+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
344342
plan::Union{CuTensorPlan, Nothing}=nothing)
345343

346344
actual_compute_type = if compute_type === nothing
@@ -375,7 +373,7 @@ function plan_reduction(
375373
jit::cutensorJitMode_t=JIT_MODE_NONE,
376374
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
377375
algo::cutensorAlgo_t=ALGO_DEFAULT,
378-
compute_type::Union{Type, cutensorComputeDescriptor_t, Nothing}=nothing)
376+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)
379377
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
380378
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
381379
!is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!"))

lib/cutensor/src/types.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
## data types
22

3+
export cutensorComputeDescriptorEnum
4+
5+
@enum cutensorComputeDescriptorEnum begin
6+
COMPUTE_DESC_16F = 1
7+
COMPUTE_DESC_32F = 2
8+
COMPUTE_DESC_TF32 = 3
9+
COMPUTE_DESC_3XTF32 = 4
10+
COMPUTE_DESC_64F = 5
11+
end
12+
313
const contraction_compute_types = Dict(
414
# typeA, typeB, typeC => typeCompute
515
(Float16, Float16, Float16) => Float32,
@@ -55,18 +65,38 @@ const reduction_compute_types = Dict(
5565
(ComplexF32, ComplexF32) => Float32,
5666
(ComplexF64, ComplexF64) => Float64)
5767

58-
function Base.cconvert(::Type{cutensorComputeDescriptor_t}, T::DataType)
59-
if T == Float16 || T == ComplexF16
60-
return CUTENSOR_COMPUTE_DESC_16F()
68+
# we have our own enum to represent cutensorComputeDescriptor_t values
69+
function Base.convert(::Type{cutensorComputeDescriptorEnum}, T::DataType)
70+
if T == Float16
71+
return COMPUTE_DESC_16F
6172
elseif T == Float32 || T == ComplexF32
62-
return CUTENSOR_COMPUTE_DESC_32F()
73+
return COMPUTE_DESC_32F
6374
elseif T == Float64 || T == ComplexF64
75+
return COMPUTE_DESC_64F
76+
else
77+
throw(ArgumentError("cutensorComputeDescriptor equivalent for input type $T does not exist!"))
78+
end
79+
end
80+
Base.cconvert(::Type{cutensorComputeDescriptor_t}, T::DataType) =
81+
Base.cconvert(cutensorComputeDescriptor_t, convert(cutensorComputeDescriptorEnum, T))
82+
83+
function Base.cconvert(::Type{cutensorComputeDescriptor_t}, T::cutensorComputeDescriptorEnum)
84+
if T == COMPUTE_DESC_16F
85+
return CUTENSOR_COMPUTE_DESC_16F()
86+
elseif T == COMPUTE_DESC_32F
87+
return CUTENSOR_COMPUTE_DESC_32F()
88+
elseif T == COMPUTE_DESC_TF32
89+
return CUTENSOR_COMPUTE_DESC_TF32()
90+
elseif T == COMPUTE_DESC_3XTF32
91+
return CUTENSOR_COMPUTE_DESC_3XTF32()
92+
elseif T == COMPUTE_DESC_64F
6493
return CUTENSOR_COMPUTE_DESC_64F()
6594
else
66-
throw(ArgumentError("cutensorComputeType equivalent for input type $T does not exist!"))
95+
throw(ArgumentError("cutensorComputeDescriptor equivalent for input enum value $T does not exist!"))
6796
end
6897
end
6998

99+
70100
function Base.convert(::Type{cutensorDataType_t}, T::DataType)
71101
if T == Float16
72102
return CUTENSOR_R_16F

lib/cutensor/test/contractions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
3131
ipB = invperm(pB)
3232
pC = randperm(NoA + NoB)
3333
ipC = invperm(pC)
34-
compute_rtol = (real(eltyCompute) == Float16 || real(eltyC) == Float16) ? 1e-2 : (real(eltyCompute) == Float32 ? 1e-4 : 1e-6)
34+
compute_rtol = (eltyCompute == Float16 || eltyC == Float16) ? 1e-2 : (eltyCompute == Float32 ? 1e-4 : 1e-6)
3535
dimsA = [dimsoA; dimsc][pA]
3636
indsA = [indsoA; indsc][pA]
3737
dimsB = [dimsc; dimsoB][pB]
@@ -73,9 +73,10 @@ eltypes = [(Float32, Float32, Float32, Float32),
7373
opB = cuTENSOR.OP_IDENTITY
7474
opC = cuTENSOR.OP_IDENTITY
7575
opOut = cuTENSOR.OP_IDENTITY
76-
plan = cuTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut; compute_type=eltyCompute)
76+
eltypComputeEnum = convert(cutensorComputeDescriptorEnum, eltyCompute)
77+
plan = cuTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut; compute_type=eltypComputeEnum)
7778
dC = contraction!(1, dA, indsA, opA, dB, indsB, opB,
78-
0, dC, indsC, opC, opOut, plan=plan, compute_type=eltyCompute)
79+
0, dC, indsC, opC, opOut, plan=plan, compute_type=eltypComputeEnum)
7980
C = collect(dC)
8081
mC = reshape(permutedims(C, ipC), (loA, loB))
8182
@test mC mA * mB rtol=compute_rtol

lib/cutensor/test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("contractions.jl")
1818
include("reductions.jl")
1919

2020
# we should have some kernels in the cache after this
21+
if CUDA.runtime_version() >= v"11.8" && capability(device()) >= v"8.0"
2122
@testset "kernel cache" begin
2223
mktempdir() do dir
2324
cd(dir) do
@@ -27,5 +28,6 @@ include("reductions.jl")
2728
end
2829
end
2930
end
31+
end
3032

3133
end

0 commit comments

Comments
 (0)