7
7
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
8
8
9
9
10
- def tensor_dot_tucker (tensor , tucker , modes ):
11
- modes_tensor , modes_tucker = _validate_contraction_modes (tl .shape (tensor ), tucker .tensor_shape , modes )
10
+ # def tensor_dot_tucker(tensor, tucker, modes):
11
+ # modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
12
+ # input_order = tensor.ndim
13
+ # weight_order = tucker.order
14
+
15
+ # sorted_modes_tucker = sorted(modes_tucker, reverse=True)
16
+ # sorted_modes_tensor = sorted(modes_tensor, reverse=True)
17
+
18
+ # # Symbol for dimensionality of the core
19
+ # rank_sym = [einsum_symbols[i] for i in range(weight_order)]
20
+
21
+ # # Symbols for tucker weight size
22
+ # tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]
23
+
24
+ # # Symbolds for input tensor
25
+ # tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]
26
+
27
+ # # Output: input + weights symbols after removing contraction symbols
28
+ # output_sym = tensor_sym + tucker_sym
29
+ # for m in sorted_modes_tucker:
30
+ # output_sym.pop(m+input_order)
31
+ # for m in sorted_modes_tensor:
32
+ # output_sym.pop(m)
33
+ # for i, e in enumerate(modes_tensor):
34
+ # tensor_sym[e] = tucker_sym[modes_tucker[i]]
35
+
36
+ # # Form the actual equation: tensor, core, factors -> output
37
+ # eq = ''.join(tensor_sym)
38
+ # eq += ',' + ''.join(rank_sym)
39
+ # eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
40
+ # eq += '->' + ''.join(output_sym)
41
+
42
+ # return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
43
+
44
+
45
+ def tensor_dot_tucker (tensor , tucker , modes , batched_modes ):
46
+ modes_tensor , modes_tucker = _validate_contraction_modes (
47
+ tl .shape (tensor ), tucker .tensor_shape , modes )
12
48
input_order = tensor .ndim
13
49
weight_order = tucker .order
50
+
51
+ batched_modes_tensor , batched_modes_tucker = _validate_contraction_modes (
52
+ tl .shape (tensor ), tucker .tensor_shape , batched_modes )
14
53
15
- sorted_modes_tucker = sorted (modes_tucker , reverse = True )
16
- sorted_modes_tensor = sorted (modes_tensor , reverse = True )
54
+ sorted_modes_tucker = sorted (modes_tucker + batched_modes_tucker , reverse = True )
55
+ sorted_modes_tensor = sorted (modes_tensor + batched_modes_tensor , reverse = True )
17
56
18
57
# Symbol for dimensionality of the core
19
58
rank_sym = [einsum_symbols [i ] for i in range (weight_order )]
20
59
21
60
# Symbols for tucker weight size
22
61
tucker_sym = [einsum_symbols [i + weight_order ] for i in range (weight_order )]
23
62
24
- # Symbolds for input tensor
63
+ # Symbols for input tensor
25
64
tensor_sym = [einsum_symbols [i + 2 * weight_order ] for i in range (tensor .ndim )]
26
65
27
66
# Output: input + weights symbols after removing contraction symbols
28
67
output_sym = tensor_sym + tucker_sym
29
68
for m in sorted_modes_tucker :
30
- output_sym .pop (m + input_order )
69
+ if m in modes_tucker : #not batched
70
+ output_sym .pop (m + input_order )
31
71
for m in sorted_modes_tensor :
72
+ # It's batched, always remove
32
73
output_sym .pop (m )
74
+
75
+ # print(tensor_sym, tucker_sym, modes_tensor, batched_modes_tensor)
33
76
for i , e in enumerate (modes_tensor ):
34
77
tensor_sym [e ] = tucker_sym [modes_tucker [i ]]
35
-
78
+ for i , e in enumerate (batched_modes_tensor ):
79
+ tensor_sym [e ] = tucker_sym [batched_modes_tucker [i ]]
80
+
36
81
# Form the actual equation: tensor, core, factors -> output
37
82
eq = '' .join (tensor_sym )
38
83
eq += ',' + '' .join (rank_sym )
@@ -42,6 +87,7 @@ def tensor_dot_tucker(tensor, tucker, modes):
42
87
return tl .einsum (eq , tensor , tucker .core , * tucker .factors )
43
88
44
89
90
+
45
91
def tensor_dot_cp (tensor , cp , modes ):
46
92
"""Contracts a to CP tensors in factorized form
47
93
0 commit comments