File tree Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Original file line number Diff line number Diff line change 22from torch import nn
33from torch .nn import Module
44
5+ from collections import namedtuple
6+
57from beartype import beartype
68from beartype .typing import Optional , Tuple , Union
79
810from einops import rearrange
911
12+ # constants
13+
14+ EmbeddingReturn = namedtuple ('EmbeddingReturn' , [
15+ 'embed_with_concept' ,
16+ 'embed_with_superclass'
17+ ])
18+
1019# helper functions
1120
1221def exists (val ):
@@ -63,7 +72,7 @@ def forward(
6372 x ,
6473 concept_id : Optional [Union [int , Tuple [int , ...]]] = None ,
6574 return_embed_with_superclass = True
66- ):
75+ ) -> EmbeddingReturn :
6776 concept_masks = tuple (concept_id == x for concept_id in self .concept_embed_ids )
6877
6978 if exists (concept_id ):
@@ -101,9 +110,9 @@ def forward(
101110 with torch .no_grad ():
102111 superclass_embeds = self .embed (x )
103112
104- return embeds , superclass_embeds
113+ return EmbeddingReturn ( embeds , superclass_embeds )
105114
106- return embeds
115+ return EmbeddingReturn ( embeds , None )
107116
108117@beartype
109118def merge_embedding_wrappers (
Original file line number Diff line number Diff line change 33setup (
44 name = 'perfusion-pytorch' ,
55 packages = find_packages (exclude = []),
6- version = '0.1.8 ' ,
6+ version = '0.1.9 ' ,
77 license = 'MIT' ,
88 description = 'Perfusion - Pytorch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments