Skip to content

Commit eccb221

Browse files
committed
take care of automatically returning embeds with superclass, if superclass id was given to the embed wrapper and it is detected to be in training mode. also allow it to be turned off during forward
1 parent 318497a commit eccb221

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

perfusion_pytorch/embedding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def __init__(
3636
self.num_concepts = num_concepts
3737
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
3838

39+
self.superclass_embed_id = superclass_embed_id
40+
assert not (exists(superclass_embed_id) and num_concepts > 1), 'cannot do multi concept with superclass embed id given'
41+
3942
if exists(superclass_embed_id):
4043
# author had better results initializing the concept embed to the super class embed, allow for that option
4144

@@ -59,6 +62,7 @@ def forward(
5962
self,
6063
x,
6164
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
65+
return_embed_with_superclass = True
6266
):
6367
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
6468

@@ -88,6 +92,17 @@ def forward(
8892
embeds
8993
)
9094

95+
# if training, and superclass embed id given
96+
# also return embeddings with superclass, for deriving superclass_text_enc
97+
98+
if self.training and exists(self.superclass_embed_id) and return_embed_with_superclass:
99+
x = x.masked_fill(concept_masks[0], self.superclass_embed_id)
100+
101+
with torch.no_grad():
102+
superclass_embeds = self.embed(x)
103+
104+
return embeds, superclass_embeds
105+
91106
return embeds
92107

93108
@beartype

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.2',
6+
version = '0.1.4',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)