Skip to content

Commit 3fe1968

Browse files
committed
just do all einsum
1 parent 53aa6b0 commit 3fe1968

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from beartype import beartype
2+
from beartype.typing import Union
3+
14
import torch
25
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
36
from torch.nn import Module
47
import torch.nn.functional as F
58

6-
from beartype import beartype
7-
from beartype.typing import Union
89
from einops import rearrange
910

11+
from opt_einsum import contract as opt_einsum
12+
1013
# helpers
1114

1215
def exists(val):
@@ -108,10 +111,10 @@ def forward(
108111

109112
# main contribution eq (3)
110113

111-
i_energy = einsum('b d, b d -> b', i @ Ci, i)
114+
i_energy = opt_einsum('b o, o i, b i -> b', i, Ci, i)
112115
i_energy = rearrange(i_energy, '... -> ... 1 1')
113116

114-
sim = einsum('b n d, b d -> b n', text_enc, i @ Ci)
117+
sim = opt_einsum('b n o, o i, b i -> b n', text_enc, Ci, i)
115118
sim = rearrange(sim, '... -> ... 1')
116119

117120
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.4',
6+
version = '0.0.5',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',
@@ -19,6 +19,7 @@
1919
install_requires=[
2020
'beartype',
2121
'einops>=0.6.1',
22+
'opt-einsum',
2223
'torch>=2.0'
2324
],
2425
classifiers=[

0 commit comments

Comments
 (0)