Skip to content

Commit 11cd9e4

Browse files
committed
tensordot: more doc
1 parent 10d7af9 commit 11cd9e4

File tree

1 file changed

+13
-34
lines changed

1 file changed

+13
-34
lines changed

tltorch/functional/factorized_tensordot.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,21 @@
77
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
88

99

10-
# def tensor_dot_tucker(tensor, tucker, modes):
11-
# modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
12-
# input_order = tensor.ndim
13-
# weight_order = tucker.order
14-
15-
# sorted_modes_tucker = sorted(modes_tucker, reverse=True)
16-
# sorted_modes_tensor = sorted(modes_tensor, reverse=True)
17-
18-
# # Symbol for dimensionality of the core
19-
# rank_sym = [einsum_symbols[i] for i in range(weight_order)]
20-
21-
# # Symbols for tucker weight size
22-
# tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]
23-
24-
# # Symbolds for input tensor
25-
# tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]
26-
27-
# # Output: input + weights symbols after removing contraction symbols
28-
# output_sym = tensor_sym + tucker_sym
29-
# for m in sorted_modes_tucker:
30-
# output_sym.pop(m+input_order)
31-
# for m in sorted_modes_tensor:
32-
# output_sym.pop(m)
33-
# for i, e in enumerate(modes_tensor):
34-
# tensor_sym[e] = tucker_sym[modes_tucker[i]]
35-
36-
# # Form the actual equation: tensor, core, factors -> output
37-
# eq = ''.join(tensor_sym)
38-
# eq += ',' + ''.join(rank_sym)
39-
# eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
40-
# eq += '->' + ''.join(output_sym)
10+
def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()):
11+
"""Batched tensor contraction between a dense tensor and a Tucker tensor on specified modes
4112
42-
# return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
43-
13+
Parameters
14+
----------
15+
tensor : DenseTensor
16+
tucker : TuckerTensor
17+
modes : int list or int
18+
modes on which to contract tensor1 and tensor2
19+
batched_modes : int or tuple[int]
4420
45-
def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()):
21+
Returns
22+
-------
23+
contraction : tensor contracted with cp on the specified modes
24+
"""
4625
modes_tensor, modes_tucker = _validate_contraction_modes(
4726
tl.shape(tensor), tucker.tensor_shape, modes)
4827
input_order = tensor.ndim

0 commit comments

Comments
 (0)