Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit a8f9cb9

Browse files
authored
Merge pull request #605 from JuliaGPU/ksh/storeplan
RFC: provide a function to generate a contraction plan, and let contraction use it
2 parents 8dd5dfe + 6b313fe commit a8f9cb9

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

src/tensor/wrappers.jl

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,10 @@ function contraction!(
207207
alpha::Number, A::CuArray, Ainds::ModeType, opA::cutensorOperator_t,
208208
B::CuArray, Binds::ModeType, opB::cutensorOperator_t,
209209
beta::Number, C::CuArray, Cinds::ModeType, opC::cutensorOperator_t,
210-
opOut::cutensorOperator_t,
210+
opOut::cutensorOperator_t;
211211
pref::cutensorWorksizePreference_t=CUTENSOR_WORKSPACE_RECOMMENDED,
212-
algo::cutensorAlgo_t=CUTENSOR_ALGO_DEFAULT, stream::CuStream=CuDefaultStream())
212+
algo::cutensorAlgo_t=CUTENSOR_ALGO_DEFAULT, stream::CuStream=CuDefaultStream(),
213+
compute_type::Type=eltype(C), plan::Union{cutensorContractionPlan_t, Nothing}=nothing)
213214

214215
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
215216
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
@@ -239,25 +240,73 @@ function contraction!(
239240
descC, modeC, alignmentRequirementC[],
240241
descC, modeC, alignmentRequirementC[],
241242
computeType)
242-
243243
find = Ref(cutensorContractionFind_t(ntuple(i->0, Val(64))))
244244
cutensorInitContractionFind(handle(), find, algo)
245245

246246
@workspace fallback=1<<27 size=@argout(
247247
cutensorContractionGetWorkspace(handle(), desc, find, pref,
248248
out(Ref{UInt64}(C_NULL)))
249249
)[] workspace->begin
250-
plan = Ref(cutensorContractionPlan_t(ntuple(i->0, Val(640))))
251-
cutensorInitContractionPlan(handle(), plan, desc, find, sizeof(workspace))
252-
253-
cutensorContraction(handle(), plan,
254-
T[alpha], A, B,
255-
T[beta], C, C,
250+
plan_ref = Ref(cutensorContractionPlan_t(ntuple(i->0, Val(640))))
251+
if isnothing(plan)
252+
cutensorInitContractionPlan(handle(), plan_ref, desc, find, sizeof(workspace))
253+
else
254+
plan_ref = Ref(plan)
255+
end
256+
cutensorContraction(handle(), plan_ref,
257+
T[convert(T, alpha)], A, B,
258+
T[convert(T, beta)], C, C,
256259
workspace, sizeof(workspace), stream)
257260
end
258261
return C
259262
end
260263

264+
function plan_contraction(
265+
A::CuArray, Ainds::ModeType, opA::cutensorOperator_t,
266+
B::CuArray, Binds::ModeType, opB::cutensorOperator_t,
267+
C::CuArray, Cinds::ModeType, opC::cutensorOperator_t,
268+
opOut::cutensorOperator_t;
269+
pref::cutensorWorksizePreference_t=CUTENSOR_WORKSPACE_RECOMMENDED,
270+
algo::cutensorAlgo_t=CUTENSOR_ALGO_DEFAULT, compute_type::Type=eltype(C))
271+
272+
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
273+
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
274+
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
275+
!is_unary(opOut) && throw(ArgumentError("opOut must be a unary op!"))
276+
descA = CuTensorDescriptor(A; op = opA)
277+
descB = CuTensorDescriptor(B; op = opB)
278+
descC = CuTensorDescriptor(C; op = opC)
279+
# for now, D must be identical to C (and thus, descD must be identical to descC)
280+
computeType = cutensorComputeType(compute_type)
281+
T = sizeof(compute_type) < sizeof(eltype(C)) ? eltype(C) : compute_type
282+
modeA = collect(Cint, Ainds)
283+
modeB = collect(Cint, Binds)
284+
modeC = collect(Cint, Cinds)
285+
286+
alignmentRequirementA = Ref{UInt32}(C_NULL)
287+
cutensorGetAlignmentRequirement(handle(), A, descA, alignmentRequirementA)
288+
alignmentRequirementB = Ref{UInt32}(C_NULL)
289+
cutensorGetAlignmentRequirement(handle(), B, descB, alignmentRequirementB)
290+
alignmentRequirementC = Ref{UInt32}(C_NULL)
291+
cutensorGetAlignmentRequirement(handle(), C, descC, alignmentRequirementC)
292+
desc = Ref(cutensorContractionDescriptor_t(ntuple(i->0, Val(256))))
293+
cutensorInitContractionDescriptor(handle(),
294+
desc,
295+
descA, modeA, alignmentRequirementA[],
296+
descB, modeB, alignmentRequirementB[],
297+
descC, modeC, alignmentRequirementC[],
298+
descC, modeC, alignmentRequirementC[],
299+
computeType)
300+
301+
find = Ref(cutensorContractionFind_t(ntuple(i->0, Val(64))))
302+
cutensorInitContractionFind(handle(), find, algo)
303+
plan = Ref(cutensorContractionPlan_t(ntuple(i->0, Val(640))))
304+
workspace_size = Ref{UInt64}(C_NULL)
305+
cutensorContractionGetWorkspace(handle(), desc, find, pref, workspace_size)
306+
cutensorInitContractionPlan(handle(), plan, desc, find, workspace_size[])
307+
return plan[]
308+
end
309+
261310
function reduction!(
262311
alpha::Number, A::CuArray, Ainds::ModeType, opA::cutensorOperator_t,
263312
beta::Number, C::CuArray, Cinds::ModeType, opC::cutensorOperator_t,

test/tensor.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,18 @@ end
452452
mC = reshape(permutedims(C, ipC), (loA, loB))
453453
@test mC mA * mB
454454

455+
# simple case with plan storage
456+
opA = CUTENSOR.CUTENSOR_OP_IDENTITY
457+
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
458+
opC = CUTENSOR.CUTENSOR_OP_IDENTITY
459+
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
460+
plan = CUTENSOR.plan_contraction(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut)
461+
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
462+
0, dC, indsC, opC, opOut, plan=plan)
463+
C = collect(dC)
464+
mC = reshape(permutedims(C, ipC), (loA, loB))
465+
@test mC mA * mB
466+
455467
# with non-trivial α
456468
α = rand(eltyC)
457469
dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,

0 commit comments

Comments
 (0)