Skip to content

Commit c55bc6a

Browse files
committed
use namedtuple for embedding wrapper return, for clarity
1 parent 142ed1a commit c55bc6a

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

perfusion_pytorch/embedding.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,20 @@
22
from torch import nn
33
from torch.nn import Module
44

5+
from collections import namedtuple
6+
57
from beartype import beartype
68
from beartype.typing import Optional, Tuple, Union
79

810
from einops import rearrange
911

12+
# constants
13+
14+
EmbeddingReturn = namedtuple('EmbeddingReturn', [
15+
'embed_with_concept',
16+
'embed_with_superclass'
17+
])
18+
1019
# helper functions
1120

1221
def exists(val):
@@ -63,7 +72,7 @@ def forward(
6372
x,
6473
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
6574
return_embed_with_superclass = True
66-
):
75+
) -> EmbeddingReturn:
6776
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
6877

6978
if exists(concept_id):
@@ -101,9 +110,9 @@ def forward(
101110
with torch.no_grad():
102111
superclass_embeds = self.embed(x)
103112

104-
return embeds, superclass_embeds
113+
return EmbeddingReturn(embeds, superclass_embeds)
105114

106-
return embeds
115+
return EmbeddingReturn(embeds, None)
107116

108117
@beartype
109118
def merge_embedding_wrappers(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.8',
6+
version = '0.1.9',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)