Skip to content

Commit f693ecd

Browse files
committed
Tucker_tensordot: Support batching
1 parent 525efb9 commit f693ecd

File tree

1 file changed

+53
-7
lines changed

1 file changed

+53
-7
lines changed

tltorch/functional/factorized_tensordot.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,77 @@
77
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
88

99

10-
def tensor_dot_tucker(tensor, tucker, modes):
11-
modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
10+
# def tensor_dot_tucker(tensor, tucker, modes):
11+
# modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
12+
# input_order = tensor.ndim
13+
# weight_order = tucker.order
14+
15+
# sorted_modes_tucker = sorted(modes_tucker, reverse=True)
16+
# sorted_modes_tensor = sorted(modes_tensor, reverse=True)
17+
18+
# # Symbol for dimensionality of the core
19+
# rank_sym = [einsum_symbols[i] for i in range(weight_order)]
20+
21+
# # Symbols for tucker weight size
22+
# tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]
23+
24+
# # Symbolds for input tensor
25+
# tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]
26+
27+
# # Output: input + weights symbols after removing contraction symbols
28+
# output_sym = tensor_sym + tucker_sym
29+
# for m in sorted_modes_tucker:
30+
# output_sym.pop(m+input_order)
31+
# for m in sorted_modes_tensor:
32+
# output_sym.pop(m)
33+
# for i, e in enumerate(modes_tensor):
34+
# tensor_sym[e] = tucker_sym[modes_tucker[i]]
35+
36+
# # Form the actual equation: tensor, core, factors -> output
37+
# eq = ''.join(tensor_sym)
38+
# eq += ',' + ''.join(rank_sym)
39+
# eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
40+
# eq += '->' + ''.join(output_sym)
41+
42+
# return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
43+
44+
45+
def tensor_dot_tucker(tensor, tucker, modes, batched_modes):
46+
modes_tensor, modes_tucker = _validate_contraction_modes(
47+
tl.shape(tensor), tucker.tensor_shape, modes)
1248
input_order = tensor.ndim
1349
weight_order = tucker.order
50+
51+
batched_modes_tensor, batched_modes_tucker = _validate_contraction_modes(
52+
tl.shape(tensor), tucker.tensor_shape, batched_modes)
1453

15-
sorted_modes_tucker = sorted(modes_tucker, reverse=True)
16-
sorted_modes_tensor = sorted(modes_tensor, reverse=True)
54+
sorted_modes_tucker = sorted(modes_tucker+batched_modes_tucker, reverse=True)
55+
sorted_modes_tensor = sorted(modes_tensor+batched_modes_tensor, reverse=True)
1756

1857
# Symbol for dimensionality of the core
1958
rank_sym = [einsum_symbols[i] for i in range(weight_order)]
2059

2160
# Symbols for tucker weight size
2261
tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]
2362

24-
# Symbolds for input tensor
63+
# Symbols for input tensor
2564
tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]
2665

2766
# Output: input + weights symbols after removing contraction symbols
2867
output_sym = tensor_sym + tucker_sym
2968
for m in sorted_modes_tucker:
30-
output_sym.pop(m+input_order)
69+
if m in modes_tucker: #not batched
70+
output_sym.pop(m+input_order)
3171
for m in sorted_modes_tensor:
72+
# It's batched, always remove
3273
output_sym.pop(m)
74+
75+
# print(tensor_sym, tucker_sym, modes_tensor, batched_modes_tensor)
3376
for i, e in enumerate(modes_tensor):
3477
tensor_sym[e] = tucker_sym[modes_tucker[i]]
35-
78+
for i, e in enumerate(batched_modes_tensor):
79+
tensor_sym[e] = tucker_sym[batched_modes_tucker[i]]
80+
3681
# Form the actual equation: tensor, core, factors -> output
3782
eq = ''.join(tensor_sym)
3883
eq += ',' + ''.join(rank_sym)
@@ -42,6 +87,7 @@ def tensor_dot_tucker(tensor, tucker, modes):
4287
return tl.einsum(eq, tensor, tucker.core, *tucker.factors)
4388

4489

90+
4591
def tensor_dot_cp(tensor, cp, modes):
4692
"""Contracts a to CP tensors in factorized form
4793

0 commit comments

Comments
 (0)