@@ -24,7 +24,7 @@ function elementwise_trinary!(
24
24
opABC:: cutensorOperator_t ;
25
25
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
26
26
algo:: cutensorAlgo_t = ALGO_DEFAULT,
27
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
27
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
28
28
plan:: Union{CuTensorPlan, Nothing} = nothing )
29
29
30
30
actual_compute_type = if compute_type === nothing
@@ -66,7 +66,7 @@ function plan_elementwise_trinary(
66
66
jit:: cutensorJitMode_t = JIT_MODE_NONE,
67
67
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
68
68
algo:: cutensorAlgo_t = ALGO_DEFAULT,
69
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
69
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
70
70
! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
71
71
! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
72
72
! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
@@ -96,7 +96,7 @@ function plan_elementwise_trinary(
96
96
descC, modeC, opC,
97
97
descD, modeD,
98
98
opAB, opABC,
99
- compute_type )
99
+ actual_compute_type )
100
100
101
101
plan_pref = Ref {cutensorPlanPreference_t} ()
102
102
cutensorCreatePlanPreference (handle (), plan_pref, algo, jit)
@@ -112,7 +112,7 @@ function elementwise_binary!(
112
112
@nospecialize (D:: DenseCuArray ), Dinds:: ModeType , opAC:: cutensorOperator_t ;
113
113
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
114
114
algo:: cutensorAlgo_t = ALGO_DEFAULT,
115
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
115
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
116
116
plan:: Union{CuTensorPlan, Nothing} = nothing )
117
117
118
118
actual_compute_type = if compute_type === nothing
@@ -150,7 +150,7 @@ function plan_elementwise_binary(
150
150
jit:: cutensorJitMode_t = JIT_MODE_NONE,
151
151
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
152
152
algo:: cutensorAlgo_t = ALGO_DEFAULT,
153
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = eltype (C))
153
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = eltype (C))
154
154
! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
155
155
! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
156
156
! is_binary (opAC) && throw (ArgumentError (" opAC must be a binary op!" ))
@@ -189,7 +189,7 @@ function permutation!(
189
189
@nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
190
190
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
191
191
algo:: cutensorAlgo_t = ALGO_DEFAULT,
192
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
192
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
193
193
plan:: Union{CuTensorPlan, Nothing} = nothing )
194
194
195
195
actual_compute_type = if compute_type === nothing
@@ -224,8 +224,7 @@ function plan_permutation(
224
224
jit:: cutensorJitMode_t = JIT_MODE_NONE,
225
225
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
226
226
algo:: cutensorAlgo_t = ALGO_DEFAULT,
227
- compute_type:: Union{Type, cutensorComputeDescriptor_t, Nothing} = nothing )
228
- # !is_unary(opPsi) && throw(ArgumentError("opPsi must be a unary op!"))
227
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum, Nothing} = nothing )
229
228
descA = CuTensorDescriptor (A)
230
229
descB = CuTensorDescriptor (B)
231
230
@@ -260,7 +259,7 @@ function contraction!(
260
259
jit:: cutensorJitMode_t = JIT_MODE_NONE,
261
260
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
262
261
algo:: cutensorAlgo_t = ALGO_DEFAULT,
263
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
262
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
264
263
plan:: Union{CuTensorPlan, Nothing} = nothing )
265
264
266
265
actual_compute_type = if compute_type === nothing
@@ -269,7 +268,6 @@ function contraction!(
269
268
compute_type
270
269
end
271
270
272
- # XXX : save these as parameters of the plan?
273
271
actual_plan = if plan === nothing
274
272
plan_contraction (A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut;
275
273
jit, workspace, algo, compute_type= actual_compute_type)
@@ -298,7 +296,7 @@ function plan_contraction(
298
296
jit:: cutensorJitMode_t = JIT_MODE_NONE,
299
297
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
300
298
algo:: cutensorAlgo_t = ALGO_DEFAULT,
301
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
299
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
302
300
! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
303
301
! is_unary (opB) && throw (ArgumentError (" opB must be a unary op!" ))
304
302
! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
@@ -340,7 +338,7 @@ function reduction!(
340
338
opReduce:: cutensorOperator_t ;
341
339
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
342
340
algo:: cutensorAlgo_t = ALGO_DEFAULT,
343
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing ,
341
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing ,
344
342
plan:: Union{CuTensorPlan, Nothing} = nothing )
345
343
346
344
actual_compute_type = if compute_type === nothing
@@ -375,7 +373,7 @@ function plan_reduction(
375
373
jit:: cutensorJitMode_t = JIT_MODE_NONE,
376
374
workspace:: cutensorWorksizePreference_t = WORKSPACE_DEFAULT,
377
375
algo:: cutensorAlgo_t = ALGO_DEFAULT,
378
- compute_type:: Union{Type, cutensorComputeDescriptor_t , Nothing} = nothing )
376
+ compute_type:: Union{DataType, cutensorComputeDescriptorEnum , Nothing} = nothing )
379
377
! is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
380
378
! is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
381
379
! is_binary (opReduce) && throw (ArgumentError (" opReduce must be a binary op!" ))
0 commit comments