Skip to content

Commit 92d86f1

Browse files
maleadtChrisPsa
andauthored
cuTENSOR plan handling changes. (#2234)
Co-authored-by: Christos Psarras <cpsarras@nvidia.com>
1 parent 6923658 commit 92d86f1

File tree

8 files changed

+135
-82
lines changed

8 files changed

+135
-82
lines changed

lib/cutensor/src/interfaces.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ function Base.:(+)(A::CuTensor, B::CuTensor)
66
α = convert(eltype(A), 1.0)
77
γ = convert(eltype(B), 1.0)
88
C = similar(B)
9-
elementwise_binary!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
9+
elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
1010
C
1111
end
1212

1313
function Base.:(-)(A::CuTensor, B::CuTensor)
1414
α = convert(eltype(A), 1.0)
1515
γ = convert(eltype(B), -1.0)
1616
C = similar(B)
17-
elementwise_binary!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
17+
elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
1818
C
1919
end
2020

@@ -36,16 +36,16 @@ end
3636
using LinearAlgebra
3737

3838
function LinearAlgebra.axpy!(a, X::CuTensor, Y::CuTensor)
39-
elementwise_binary!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, one(eltype(Y)), Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
39+
elementwise_binary_execute!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, one(eltype(Y)), Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
4040
return Y
4141
end
4242

4343
function LinearAlgebra.axpby!(a, X::CuTensor, b, Y::CuTensor)
44-
elementwise_binary!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, b, Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
44+
elementwise_binary_execute!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, b, Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
4545
return Y
4646
end
4747

4848
function LinearAlgebra.mul!(C::CuTensor, A::CuTensor, B::CuTensor)
49-
contraction!(one(eltype(C)), A.data, A.inds, CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR_OP_IDENTITY, zero(eltype(C)), C.data, C.inds, CUTENSOR_OP_IDENTITY, CUTENSOR_OP_IDENTITY)
49+
contract!(one(eltype(C)), A.data, A.inds, CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR_OP_IDENTITY, zero(eltype(C)), C.data, C.inds, CUTENSOR_OP_IDENTITY, CUTENSOR_OP_IDENTITY)
5050
return C
5151
end

lib/cutensor/src/operations.jl

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
export elementwise_binary!, elementwise_trinary!,
2-
permutation!, contraction!, reduction!
3-
41
const ModeType = AbstractVector{<:Union{Char, Integer}}
52

63
# remove the CUTENSOR_ prefix from some common enums,
@@ -13,7 +10,7 @@ const ModeType = AbstractVector{<:Union{Char, Integer}}
1310
is_unary(op::cutensorOperator_t) = (op (OP_IDENTITY, OP_SQRT, OP_RELU, OP_CONJ, OP_RCP))
1411
is_binary(op::cutensorOperator_t) = (op (OP_ADD, OP_MUL, OP_MAX, OP_MIN))
1512

16-
function elementwise_trinary!(
13+
function elementwise_trinary_execute!(
1714
@nospecialize(alpha::Number),
1815
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
1916
@nospecialize(beta::Number),
@@ -43,12 +40,7 @@ function elementwise_trinary!(
4340
plan
4441
end
4542

46-
scalar_type = actual_plan.scalar_type
47-
cutensorElementwiseTrinaryExecute(handle(), actual_plan,
48-
Ref{scalar_type}(alpha), A,
49-
Ref{scalar_type}(beta), B,
50-
Ref{scalar_type}(gamma), C, D,
51-
stream())
43+
elementwise_trinary_execute!(actual_plan, alpha, A, beta, B, gamma, C, D)
5244

5345
if plan === nothing
5446
CUDA.unsafe_free!(actual_plan)
@@ -57,6 +49,23 @@ function elementwise_trinary!(
5749
return D
5850
end
5951

52+
function elementwise_trinary_execute!(plan::CuTensorPlan,
53+
@nospecialize(alpha::Number),
54+
@nospecialize(A::DenseCuArray),
55+
@nospecialize(beta::Number),
56+
@nospecialize(B::DenseCuArray),
57+
@nospecialize(gamma::Number),
58+
@nospecialize(C::DenseCuArray),
59+
@nospecialize(D::DenseCuArray))
60+
scalar_type = plan.scalar_type
61+
cutensorElementwiseTrinaryExecute(handle(), plan,
62+
Ref{scalar_type}(alpha), A,
63+
Ref{scalar_type}(beta), B,
64+
Ref{scalar_type}(gamma), C, D,
65+
stream())
66+
return D
67+
end
68+
6069
function plan_elementwise_trinary(
6170
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
6271
@nospecialize(B::DenseCuArray), Binds::ModeType, opB::cutensorOperator_t,
@@ -104,7 +113,7 @@ function plan_elementwise_trinary(
104113
CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
105114
end
106115

107-
function elementwise_binary!(
116+
function elementwise_binary_execute!(
108117
@nospecialize(alpha::Number),
109118
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
110119
@nospecialize(gamma::Number),
@@ -130,11 +139,7 @@ function elementwise_binary!(
130139
plan
131140
end
132141

133-
scalar_type = actual_plan.scalar_type
134-
cutensorElementwiseBinaryExecute(handle(), actual_plan,
135-
Ref{scalar_type}(alpha), A,
136-
Ref{scalar_type}(gamma), C, D,
137-
stream())
142+
elementwise_binary_execute!(actual_plan, alpha, A, gamma, C, D)
138143

139144
if plan === nothing
140145
CUDA.unsafe_free!(actual_plan)
@@ -143,6 +148,20 @@ function elementwise_binary!(
143148
return D
144149
end
145150

151+
function elementwise_binary_execute!(plan::CuTensorPlan,
152+
@nospecialize(alpha::Number),
153+
@nospecialize(A::DenseCuArray),
154+
@nospecialize(gamma::Number),
155+
@nospecialize(C::DenseCuArray),
156+
@nospecialize(D::DenseCuArray))
157+
scalar_type = plan.scalar_type
158+
cutensorElementwiseBinaryExecute(handle(), plan,
159+
Ref{scalar_type}(alpha), A,
160+
Ref{scalar_type}(gamma), C, D,
161+
stream())
162+
return D
163+
end
164+
146165
function plan_elementwise_binary(
147166
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
148167
@nospecialize(C::DenseCuArray), Cinds::ModeType, opC::cutensorOperator_t,
@@ -183,7 +202,7 @@ function plan_elementwise_binary(
183202
CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
184203
end
185204

186-
function permutation!(
205+
function permute!(
187206
@nospecialize(alpha::Number),
188207
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
189208
@nospecialize(B::DenseCuArray), Binds::ModeType;
@@ -206,10 +225,7 @@ function permutation!(
206225
plan
207226
end
208227

209-
scalar_type = actual_plan.scalar_type
210-
cutensorPermute(handle(), actual_plan,
211-
Ref{scalar_type}(alpha), A, B,
212-
stream())
228+
permute!(actual_plan, alpha, A, B)
213229

214230
if plan === nothing
215231
CUDA.unsafe_free!(actual_plan)
@@ -218,6 +234,17 @@ function permutation!(
218234
return B
219235
end
220236

237+
function permute!(plan::CuTensorPlan,
238+
@nospecialize(alpha::Number),
239+
@nospecialize(A::DenseCuArray),
240+
@nospecialize(B::DenseCuArray))
241+
scalar_type = plan.scalar_type
242+
cutensorPermute(handle(), plan,
243+
Ref{scalar_type}(alpha), A, B,
244+
stream())
245+
return B
246+
end
247+
221248
function plan_permutation(
222249
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
223250
@nospecialize(B::DenseCuArray), Binds::ModeType;
@@ -249,7 +276,7 @@ function plan_permutation(
249276
CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
250277
end
251278

252-
function contraction!(
279+
function contract!(
253280
@nospecialize(alpha::Number),
254281
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
255282
@nospecialize(B::DenseCuArray), Binds::ModeType, opB::cutensorOperator_t,
@@ -275,11 +302,7 @@ function contraction!(
275302
plan
276303
end
277304

278-
scalar_type = actual_plan.scalar_type
279-
cutensorContract(handle(), actual_plan,
280-
Ref{scalar_type}(alpha), A, B,
281-
Ref{scalar_type}(beta), C, C,
282-
actual_plan.workspace, sizeof(actual_plan.workspace), stream())
305+
contract!(actual_plan, alpha, A, B, beta, C)
283306

284307
if plan === nothing
285308
CUDA.unsafe_free!(actual_plan)
@@ -288,6 +311,20 @@ function contraction!(
288311
return C
289312
end
290313

314+
function contract!(plan::CuTensorPlan,
315+
@nospecialize(alpha::Number),
316+
@nospecialize(A::DenseCuArray),
317+
@nospecialize(B::DenseCuArray),
318+
@nospecialize(beta::Number),
319+
@nospecialize(C::DenseCuArray))
320+
scalar_type = plan.scalar_type
321+
cutensorContract(handle(), plan,
322+
Ref{scalar_type}(alpha), A, B,
323+
Ref{scalar_type}(beta), C, C,
324+
plan.workspace, sizeof(plan.workspace), stream())
325+
return C
326+
end
327+
291328
function plan_contraction(
292329
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
293330
@nospecialize(B::DenseCuArray), Binds::ModeType, opB::cutensorOperator_t,
@@ -330,7 +367,7 @@ function plan_contraction(
330367
CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
331368
end
332369

333-
function reduction!(
370+
function reduce!(
334371
@nospecialize(alpha::Number),
335372
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
336373
@nospecialize(beta::Number),
@@ -353,11 +390,7 @@ function reduction!(
353390
plan
354391
end
355392

356-
scalar_type = actual_plan.scalar_type
357-
cutensorReduce(handle(), actual_plan,
358-
Ref{scalar_type}(alpha), A,
359-
Ref{scalar_type}(beta), C, C,
360-
actual_plan.workspace, sizeof(actual_plan.workspace), stream())
393+
reduce!(actual_plan, alpha, A, beta, C)
361394

362395
if plan === nothing
363396
CUDA.unsafe_free!(actual_plan)
@@ -366,6 +399,19 @@ function reduction!(
366399
return C
367400
end
368401

402+
function reduce!(plan::CuTensorPlan,
403+
@nospecialize(alpha::Number),
404+
@nospecialize(A::DenseCuArray),
405+
@nospecialize(beta::Number),
406+
@nospecialize(C::DenseCuArray))
407+
scalar_type = plan.scalar_type
408+
cutensorReduce(handle(), plan,
409+
Ref{scalar_type}(alpha), A,
410+
Ref{scalar_type}(beta), C, C,
411+
plan.workspace, sizeof(plan.workspace), stream())
412+
return C
413+
end
414+
369415
function plan_reduction(
370416
@nospecialize(A::DenseCuArray), Ainds::ModeType, opA::cutensorOperator_t,
371417
@nospecialize(C::DenseCuArray), Cinds::ModeType, opC::cutensorOperator_t,

lib/cutensor/src/types.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
## data types
22

3-
export cutensorComputeDescriptorEnum
4-
53
@enum cutensorComputeDescriptorEnum begin
64
COMPUTE_DESC_16F = 1
75
COMPUTE_DESC_32F = 2

lib/cutensor/test/contractions.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
@testset "contractions" begin
22

3+
using cuTENSOR: contract!, plan_contraction
4+
35
using LinearAlgebra
46

57
eltypes = [(Float32, Float32, Float32, Float32),
@@ -52,7 +54,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
5254
opB = cuTENSOR.OP_IDENTITY
5355
opC = cuTENSOR.OP_IDENTITY
5456
opOut = cuTENSOR.OP_IDENTITY
55-
dC = contraction!(1, dA, indsA, opA, dB, indsB, opB, 0, dC, indsC, opC, opOut, compute_type=eltyCompute)
57+
dC = contract!(1, dA, indsA, opA, dB, indsB, opB, 0, dC, indsC, opC, opOut, compute_type=eltyCompute)
5658
C = collect(dC)
5759
mC = reshape(permutedims(C, ipC), (loA, loB))
5860
@test mC mA * mB rtol=compute_rtol
@@ -63,7 +65,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
6365
opC = cuTENSOR.OP_IDENTITY
6466
opOut = cuTENSOR.OP_IDENTITY
6567
plan = cuTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut)
66-
dC = contraction!(1, dA, indsA, opA, dB, indsB, opB, 0, dC, indsC, opC, opOut; plan)
68+
dC = cuTENSOR.contract!(plan, 1, dA, dB, 0, dC)
6769
C = collect(dC)
6870
mC = reshape(permutedims(C, ipC), (loA, loB))
6971
@test mC mA * mB
@@ -73,10 +75,9 @@ eltypes = [(Float32, Float32, Float32, Float32),
7375
opB = cuTENSOR.OP_IDENTITY
7476
opC = cuTENSOR.OP_IDENTITY
7577
opOut = cuTENSOR.OP_IDENTITY
76-
eltypComputeEnum = convert(cutensorComputeDescriptorEnum, eltyCompute)
78+
eltypComputeEnum = convert(cuTENSOR.cutensorComputeDescriptorEnum, eltyCompute)
7779
plan = cuTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut; compute_type=eltypComputeEnum)
78-
dC = contraction!(1, dA, indsA, opA, dB, indsB, opB,
79-
0, dC, indsC, opC, opOut, plan=plan, compute_type=eltypComputeEnum)
80+
dC = cuTENSOR.contract!(plan, 1, dA, dB, 0, dC)
8081
C = collect(dC)
8182
mC = reshape(permutedims(C, ipC), (loA, loB))
8283
@test mC mA * mB rtol=compute_rtol
@@ -87,14 +88,14 @@ eltypes = [(Float32, Float32, Float32, Float32),
8788
opC = cuTENSOR.OP_IDENTITY
8889
opOut = cuTENSOR.OP_IDENTITY
8990
plan = cuTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut; jit=cuTENSOR.JIT_MODE_DEFAULT)
90-
dC = contraction!(1, dA, indsA, opA, dB, indsB, opB, 0, dC, indsC, opC, opOut, plan=plan)
91+
dC = cuTENSOR.contract!(plan, 1, dA, dB, 0, dC)
9192
C = collect(dC)
9293
mC = reshape(permutedims(C, ipC), (loA, loB))
9394
@test mC mA * mB
9495

9596
# with non-trivial α
9697
α = rand(eltyCompute)
97-
dC = contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyCompute), dC, indsC, opC, opOut; compute_type=eltyCompute)
98+
dC = contract!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyCompute), dC, indsC, opC, opOut; compute_type=eltyCompute)
9899
C = collect(dC)
99100
mC = reshape(permutedims(C, ipC), (loA, loB))
100101
@test mC α * mA * mB rtol=compute_rtol
@@ -105,7 +106,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
105106
α = rand(eltyCompute)
106107
β = rand(eltyCompute)
107108
copyto!(dC, C)
108-
dD = contraction!(α, dA, indsA, opA, dB, indsB, opB, β, dC, indsC, opC, opOut; compute_type=eltyCompute)
109+
dD = contract!(α, dA, indsA, opA, dB, indsB, opB, β, dC, indsC, opC, opOut; compute_type=eltyCompute)
109110
D = collect(dD)
110111
mC = reshape(permutedims(C, ipC), (loA, loB))
111112
mD = reshape(permutedims(D, ipC), (loA, loB))
@@ -133,7 +134,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
133134
opA = cuTENSOR.OP_CONJ
134135
opB = cuTENSOR.OP_IDENTITY
135136
opOut = cuTENSOR.OP_IDENTITY
136-
dC = contraction!(complex(1.0, 0.0), dA, indsA, opA, dB, indsB, opB,
137+
dC = contract!(complex(1.0, 0.0), dA, indsA, opA, dB, indsB, opB,
137138
0, dC, indsC, opC, opOut; compute_type=eltyCompute)
138139
C = collect(dC)
139140
mC = reshape(permutedims(C, ipC), (loA, loB))
@@ -143,8 +144,8 @@ eltypes = [(Float32, Float32, Float32, Float32),
143144
opA = cuTENSOR.OP_IDENTITY
144145
opB = cuTENSOR.OP_CONJ
145146
opOut = cuTENSOR.OP_IDENTITY
146-
dC = contraction!(complex(1.0, 0.0), dA, indsA, opA, dB, indsB, opB,
147-
complex(0.0, 0.0), dC, indsC, opC, opOut; compute_type=eltyCompute)
147+
dC = contract!(complex(1.0, 0.0), dA, indsA, opA, dB, indsB, opB,
148+
complex(0.0, 0.0), dC, indsC, opC, opOut; compute_type=eltyCompute)
148149
C = collect(dC)
149150
mC = reshape(permutedims(C, ipC), (loA, loB))
150151
@test mC mA*conj(mB) rtol=compute_rtol
@@ -153,7 +154,7 @@ eltypes = [(Float32, Float32, Float32, Float32),
153154
opA = cuTENSOR.OP_CONJ
154155
opB = cuTENSOR.OP_CONJ
155156
opOut = cuTENSOR.OP_IDENTITY
156-
dC = contraction!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB,
157+
dC = contract!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB,
157158
zero(eltyCompute), dC, indsC, opC, opOut; compute_type=eltyCompute)
158159
C = collect(dC)
159160
mC = reshape(permutedims(C, ipC), (loA, loB))

0 commit comments

Comments
 (0)