1
- export elementwise_binary!, elementwise_trinary!,
2
- permutation!, contraction!, reduction!
3
-
4
1
const ModeType = AbstractVector{<: Union{Char, Integer} }
5
2
6
3
# remove the CUTENSOR_ prefix from some common enums,
@@ -13,7 +10,7 @@ const ModeType = AbstractVector{<:Union{Char, Integer}}
13
10
is_unary (op:: cutensorOperator_t ) = (op ∈ (OP_IDENTITY, OP_SQRT, OP_RELU, OP_CONJ, OP_RCP))
14
11
is_binary (op:: cutensorOperator_t ) = (op ∈ (OP_ADD, OP_MUL, OP_MAX, OP_MIN))
15
12
16
- function elementwise_trinary ! (
13
+ function elementwise_trinary_execute ! (
17
14
@nospecialize (alpha:: Number ),
18
15
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
19
16
@nospecialize (beta:: Number ),
@@ -43,12 +40,7 @@ function elementwise_trinary!(
43
40
plan
44
41
end
45
42
46
- scalar_type = actual_plan. scalar_type
47
- cutensorElementwiseTrinaryExecute (handle (), actual_plan,
48
- Ref {scalar_type} (alpha), A,
49
- Ref {scalar_type} (beta), B,
50
- Ref {scalar_type} (gamma), C, D,
51
- stream ())
43
+ elementwise_trinary_execute! (actual_plan, alpha, A, beta, B, gamma, C, D)
52
44
53
45
if plan === nothing
54
46
CUDA. unsafe_free! (actual_plan)
@@ -57,6 +49,23 @@ function elementwise_trinary!(
57
49
return D
58
50
end
59
51
52
+ function elementwise_trinary_execute! (plan:: CuTensorPlan ,
53
+ @nospecialize (alpha:: Number ),
54
+ @nospecialize (A:: DenseCuArray ),
55
+ @nospecialize (beta:: Number ),
56
+ @nospecialize (B:: DenseCuArray ),
57
+ @nospecialize (gamma:: Number ),
58
+ @nospecialize (C:: DenseCuArray ),
59
+ @nospecialize (D:: DenseCuArray ))
60
+ scalar_type = plan. scalar_type
61
+ cutensorElementwiseTrinaryExecute (handle (), plan,
62
+ Ref {scalar_type} (alpha), A,
63
+ Ref {scalar_type} (beta), B,
64
+ Ref {scalar_type} (gamma), C, D,
65
+ stream ())
66
+ return D
67
+ end
68
+
60
69
function plan_elementwise_trinary (
61
70
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
62
71
@nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -104,7 +113,7 @@ function plan_elementwise_trinary(
104
113
CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
105
114
end
106
115
107
- function elementwise_binary ! (
116
+ function elementwise_binary_execute ! (
108
117
@nospecialize (alpha:: Number ),
109
118
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
110
119
@nospecialize (gamma:: Number ),
@@ -130,11 +139,7 @@ function elementwise_binary!(
130
139
plan
131
140
end
132
141
133
- scalar_type = actual_plan. scalar_type
134
- cutensorElementwiseBinaryExecute (handle (), actual_plan,
135
- Ref {scalar_type} (alpha), A,
136
- Ref {scalar_type} (gamma), C, D,
137
- stream ())
142
+ elementwise_binary_execute! (actual_plan, alpha, A, gamma, C, D)
138
143
139
144
if plan === nothing
140
145
CUDA. unsafe_free! (actual_plan)
@@ -143,6 +148,20 @@ function elementwise_binary!(
143
148
return D
144
149
end
145
150
151
+ function elementwise_binary_execute! (plan:: CuTensorPlan ,
152
+ @nospecialize (alpha:: Number ),
153
+ @nospecialize (A:: DenseCuArray ),
154
+ @nospecialize (gamma:: Number ),
155
+ @nospecialize (C:: DenseCuArray ),
156
+ @nospecialize (D:: DenseCuArray ))
157
+ scalar_type = plan. scalar_type
158
+ cutensorElementwiseBinaryExecute (handle (), plan,
159
+ Ref {scalar_type} (alpha), A,
160
+ Ref {scalar_type} (gamma), C, D,
161
+ stream ())
162
+ return D
163
+ end
164
+
146
165
function plan_elementwise_binary (
147
166
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
148
167
@nospecialize (C:: DenseCuArray ), Cinds:: ModeType , opC:: cutensorOperator_t ,
@@ -183,7 +202,7 @@ function plan_elementwise_binary(
183
202
CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
184
203
end
185
204
186
- function permutation ! (
205
+ function permute ! (
187
206
@nospecialize (alpha:: Number ),
188
207
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
189
208
@nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
@@ -206,10 +225,7 @@ function permutation!(
206
225
plan
207
226
end
208
227
209
- scalar_type = actual_plan. scalar_type
210
- cutensorPermute (handle (), actual_plan,
211
- Ref {scalar_type} (alpha), A, B,
212
- stream ())
228
+ permute! (actual_plan, alpha, A, B)
213
229
214
230
if plan === nothing
215
231
CUDA. unsafe_free! (actual_plan)
@@ -218,6 +234,17 @@ function permutation!(
218
234
return B
219
235
end
220
236
237
+ function permute! (plan:: CuTensorPlan ,
238
+ @nospecialize (alpha:: Number ),
239
+ @nospecialize (A:: DenseCuArray ),
240
+ @nospecialize (B:: DenseCuArray ))
241
+ scalar_type = plan. scalar_type
242
+ cutensorPermute (handle (), plan,
243
+ Ref {scalar_type} (alpha), A, B,
244
+ stream ())
245
+ return B
246
+ end
247
+
221
248
function plan_permutation (
222
249
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
223
250
@nospecialize (B:: DenseCuArray ), Binds:: ModeType ;
@@ -249,7 +276,7 @@ function plan_permutation(
249
276
CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
250
277
end
251
278
252
- function contraction ! (
279
+ function contract ! (
253
280
@nospecialize (alpha:: Number ),
254
281
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
255
282
@nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -275,11 +302,7 @@ function contraction!(
275
302
plan
276
303
end
277
304
278
- scalar_type = actual_plan. scalar_type
279
- cutensorContract (handle (), actual_plan,
280
- Ref {scalar_type} (alpha), A, B,
281
- Ref {scalar_type} (beta), C, C,
282
- actual_plan. workspace, sizeof (actual_plan. workspace), stream ())
305
+ contract! (actual_plan, alpha, A, B, beta, C)
283
306
284
307
if plan === nothing
285
308
CUDA. unsafe_free! (actual_plan)
@@ -288,6 +311,20 @@ function contraction!(
288
311
return C
289
312
end
290
313
314
+ function contract! (plan:: CuTensorPlan ,
315
+ @nospecialize (alpha:: Number ),
316
+ @nospecialize (A:: DenseCuArray ),
317
+ @nospecialize (B:: DenseCuArray ),
318
+ @nospecialize (beta:: Number ),
319
+ @nospecialize (C:: DenseCuArray ))
320
+ scalar_type = plan. scalar_type
321
+ cutensorContract (handle (), plan,
322
+ Ref {scalar_type} (alpha), A, B,
323
+ Ref {scalar_type} (beta), C, C,
324
+ plan. workspace, sizeof (plan. workspace), stream ())
325
+ return C
326
+ end
327
+
291
328
function plan_contraction (
292
329
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
293
330
@nospecialize (B:: DenseCuArray ), Binds:: ModeType , opB:: cutensorOperator_t ,
@@ -330,7 +367,7 @@ function plan_contraction(
330
367
CuTensorPlan (desc[], plan_pref[]; workspacePref= workspace)
331
368
end
332
369
333
- function reduction ! (
370
+ function reduce ! (
334
371
@nospecialize (alpha:: Number ),
335
372
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
336
373
@nospecialize (beta:: Number ),
@@ -353,11 +390,7 @@ function reduction!(
353
390
plan
354
391
end
355
392
356
- scalar_type = actual_plan. scalar_type
357
- cutensorReduce (handle (), actual_plan,
358
- Ref {scalar_type} (alpha), A,
359
- Ref {scalar_type} (beta), C, C,
360
- actual_plan. workspace, sizeof (actual_plan. workspace), stream ())
393
+ reduce! (actual_plan, alpha, A, beta, C)
361
394
362
395
if plan === nothing
363
396
CUDA. unsafe_free! (actual_plan)
@@ -366,6 +399,19 @@ function reduction!(
366
399
return C
367
400
end
368
401
402
+ function reduce! (plan:: CuTensorPlan ,
403
+ @nospecialize (alpha:: Number ),
404
+ @nospecialize (A:: DenseCuArray ),
405
+ @nospecialize (beta:: Number ),
406
+ @nospecialize (C:: DenseCuArray ))
407
+ scalar_type = plan. scalar_type
408
+ cutensorReduce (handle (), plan,
409
+ Ref {scalar_type} (alpha), A,
410
+ Ref {scalar_type} (beta), C, C,
411
+ plan. workspace, sizeof (plan. workspace), stream ())
412
+ return C
413
+ end
414
+
369
415
function plan_reduction (
370
416
@nospecialize (A:: DenseCuArray ), Ainds:: ModeType , opA:: cutensorOperator_t ,
371
417
@nospecialize (C:: DenseCuArray ), Cinds:: ModeType , opC:: cutensorOperator_t ,
0 commit comments