@@ -207,9 +207,10 @@ function contraction!(
207
207
alpha:: Number , A:: CuArray , Ainds:: ModeType , opA:: cutensorOperator_t ,
208
208
B:: CuArray , Binds:: ModeType , opB:: cutensorOperator_t ,
209
209
beta:: Number , C:: CuArray , Cinds:: ModeType , opC:: cutensorOperator_t ,
210
- opOut:: cutensorOperator_t ,
210
+ opOut:: cutensorOperator_t ;
211
211
pref:: cutensorWorksizePreference_t = CUTENSOR_WORKSPACE_RECOMMENDED,
212
- algo:: cutensorAlgo_t = CUTENSOR_ALGO_DEFAULT, stream:: CuStream = CuDefaultStream ())
212
+ algo:: cutensorAlgo_t = CUTENSOR_ALGO_DEFAULT, stream:: CuStream = CuDefaultStream (),
213
+ compute_type:: Type = eltype (C), plan:: Union{cutensorContractionPlan_t, Nothing} = nothing )
213
214
214
215
! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
215
216
! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
@@ -239,25 +240,73 @@ function contraction!(
239
240
descC, modeC, alignmentRequirementC[],
240
241
descC, modeC, alignmentRequirementC[],
241
242
computeType)
242
-
243
243
find = Ref (cutensorContractionFind_t (ntuple (i-> 0 , Val (64 ))))
244
244
cutensorInitContractionFind (handle (), find, algo)
245
245
246
246
@workspace fallback= 1 << 27 size= @argout (
247
247
cutensorContractionGetWorkspace (handle (), desc, find, pref,
248
248
out (Ref {UInt64} (C_NULL )))
249
249
)[] workspace-> begin
250
- plan = Ref (cutensorContractionPlan_t (ntuple (i-> 0 , Val (640 ))))
251
- cutensorInitContractionPlan (handle (), plan, desc, find, sizeof (workspace))
252
-
253
- cutensorContraction (handle (), plan,
254
- T[alpha], A, B,
255
- T[beta], C, C,
250
+ plan_ref = Ref (cutensorContractionPlan_t (ntuple (i-> 0 , Val (640 ))))
251
+ if isnothing (plan)
252
+ cutensorInitContractionPlan (handle (), plan_ref, desc, find, sizeof (workspace))
253
+ else
254
+ plan_ref = Ref (plan)
255
+ end
256
+ cutensorContraction (handle (), plan_ref,
257
+ T[convert (T, alpha)], A, B,
258
+ T[convert (T, beta)], C, C,
256
259
workspace, sizeof (workspace), stream)
257
260
end
258
261
return C
259
262
end
260
263
264
+ function plan_contraction (
265
+ A:: CuArray , Ainds:: ModeType , opA:: cutensorOperator_t ,
266
+ B:: CuArray , Binds:: ModeType , opB:: cutensorOperator_t ,
267
+ C:: CuArray , Cinds:: ModeType , opC:: cutensorOperator_t ,
268
+ opOut:: cutensorOperator_t ;
269
+ pref:: cutensorWorksizePreference_t = CUTENSOR_WORKSPACE_RECOMMENDED,
270
+ algo:: cutensorAlgo_t = CUTENSOR_ALGO_DEFAULT, compute_type:: Type = eltype (C))
271
+
272
+ ! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
273
+ ! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
274
+ ! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
275
+ ! is_unary (opOut) && throw (ArgumentError (" opOut must be a unary op!" ))
276
+ descA = CuTensorDescriptor (A; op = opA)
277
+ descB = CuTensorDescriptor (B; op = opB)
278
+ descC = CuTensorDescriptor (C; op = opC)
279
+ # for now, D must be identical to C (and thus, descD must be identical to descC)
280
+ computeType = cutensorComputeType (compute_type)
281
+ T = sizeof (compute_type) < sizeof (eltype (C)) ? eltype (C) : compute_type
282
+ modeA = collect (Cint, Ainds)
283
+ modeB = collect (Cint, Binds)
284
+ modeC = collect (Cint, Cinds)
285
+
286
+ alignmentRequirementA = Ref {UInt32} (C_NULL )
287
+ cutensorGetAlignmentRequirement (handle (), A, descA, alignmentRequirementA)
288
+ alignmentRequirementB = Ref {UInt32} (C_NULL )
289
+ cutensorGetAlignmentRequirement (handle (), B, descB, alignmentRequirementB)
290
+ alignmentRequirementC = Ref {UInt32} (C_NULL )
291
+ cutensorGetAlignmentRequirement (handle (), C, descC, alignmentRequirementC)
292
+ desc = Ref (cutensorContractionDescriptor_t (ntuple (i-> 0 , Val (256 ))))
293
+ cutensorInitContractionDescriptor (handle (),
294
+ desc,
295
+ descA, modeA, alignmentRequirementA[],
296
+ descB, modeB, alignmentRequirementB[],
297
+ descC, modeC, alignmentRequirementC[],
298
+ descC, modeC, alignmentRequirementC[],
299
+ computeType)
300
+
301
+ find = Ref (cutensorContractionFind_t (ntuple (i-> 0 , Val (64 ))))
302
+ cutensorInitContractionFind (handle (), find, algo)
303
+ plan = Ref (cutensorContractionPlan_t (ntuple (i-> 0 , Val (640 ))))
304
+ workspace_size = Ref {UInt64} (C_NULL )
305
+ cutensorContractionGetWorkspace (handle (), desc, find, pref, workspace_size)
306
+ cutensorInitContractionPlan (handle (), plan, desc, find, workspace_size[])
307
+ return plan[]
308
+ end
309
+
261
310
function reduction! (
262
311
alpha:: Number , A:: CuArray , Ainds:: ModeType , opA:: cutensorOperator_t ,
263
312
beta:: Number , C:: CuArray , Cinds:: ModeType , opC:: cutensorOperator_t ,
0 commit comments