|
7 | 7 | einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
8 | 8 |
|
9 | 9 |
|
10 |
| -# def tensor_dot_tucker(tensor, tucker, modes): |
11 |
| -# modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes) |
12 |
| -# input_order = tensor.ndim |
13 |
| -# weight_order = tucker.order |
14 |
| - |
15 |
| -# sorted_modes_tucker = sorted(modes_tucker, reverse=True) |
16 |
| -# sorted_modes_tensor = sorted(modes_tensor, reverse=True) |
17 |
| - |
18 |
| -# # Symbol for dimensionality of the core |
19 |
| -# rank_sym = [einsum_symbols[i] for i in range(weight_order)] |
20 |
| - |
21 |
| -# # Symbols for tucker weight size |
22 |
| -# tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)] |
23 |
| - |
24 |
| -# # Symbolds for input tensor |
25 |
| -# tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)] |
26 |
| - |
27 |
| -# # Output: input + weights symbols after removing contraction symbols |
28 |
| -# output_sym = tensor_sym + tucker_sym |
29 |
| -# for m in sorted_modes_tucker: |
30 |
| -# output_sym.pop(m+input_order) |
31 |
| -# for m in sorted_modes_tensor: |
32 |
| -# output_sym.pop(m) |
33 |
| -# for i, e in enumerate(modes_tensor): |
34 |
| -# tensor_sym[e] = tucker_sym[modes_tucker[i]] |
35 |
| - |
36 |
| -# # Form the actual equation: tensor, core, factors -> output |
37 |
| -# eq = ''.join(tensor_sym) |
38 |
| -# eq += ',' + ''.join(rank_sym) |
39 |
| -# eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym)) |
40 |
| -# eq += '->' + ''.join(output_sym) |
| 10 | +def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()): |
| 11 | + """Batched tensor contraction between a dense tensor and a Tucker tensor on specified modes |
41 | 12 |
|
42 |
| -# return tl.einsum(eq, tensor, tucker.core, *tucker.factors) |
43 |
| - |
| 13 | + Parameters |
| 14 | + ---------- |
| 15 | + tensor : DenseTensor |
| 16 | + tucker : TuckerTensor |
| 17 | + modes : int list or int |
| 18 | + modes on which to contract tensor1 and tensor2 |
| 19 | + batched_modes : int or tuple[int] |
44 | 20 |
|
45 |
| -def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()): |
| 21 | + Returns |
| 22 | + ------- |
| 23 | + contraction : tensor contracted with cp on the specified modes |
| 24 | + """ |
46 | 25 | modes_tensor, modes_tucker = _validate_contraction_modes(
|
47 | 26 | tl.shape(tensor), tucker.tensor_shape, modes)
|
48 | 27 | input_order = tensor.ndim
|
|
0 commit comments