|
| 1 | +from beartype import beartype |
| 2 | +from beartype.typing import Union |
| 3 | + |
1 | 4 | import torch |
2 | 5 | from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor |
3 | 6 | from torch.nn import Module |
4 | 7 | import torch.nn.functional as F |
5 | 8 |
|
6 | | -from beartype import beartype |
7 | | -from beartype.typing import Union |
8 | 9 | from einops import rearrange |
9 | 10 |
|
| 11 | +from opt_einsum import contract as opt_einsum |
| 12 | + |
10 | 13 | # helpers |
11 | 14 |
|
12 | 15 | def exists(val): |
@@ -108,10 +111,10 @@ def forward( |
108 | 111 |
|
109 | 112 | # main contribution eq (3) |
110 | 113 |
|
111 | | - i_energy = einsum('b d, b d -> b', i @ Ci, i) |
| 114 | + i_energy = opt_einsum('b o, o i, b i -> b', i, Ci, i) |
112 | 115 | i_energy = rearrange(i_energy, '... -> ... 1 1') |
113 | 116 |
|
114 | | - sim = einsum('b n d, b d -> b n', text_enc, i @ Ci) |
| 117 | + sim = opt_einsum('b n o, o i, b i -> b n', text_enc, Ci, i) |
115 | 118 | sim = rearrange(sim, '... -> ... 1') |
116 | 119 |
|
117 | 120 | sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid() |
|
0 commit comments