File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -113,6 +113,10 @@ def __init__(
113113 def parameters (self ):
114114 return [self .concepts ]
115115
116+ @property
117+ def device (self ):
118+ return self .concepts .device
119+
116120 @beartype
117121 def forward (
118122 self ,
@@ -138,6 +142,7 @@ def forward(
138142 inferred_concept_id = self .concept_embed_ids [0 ]
139143
140144 x = self .tokenize (x )
145+ x = x .to (self .device )
141146
142147 superclass_mask = x == self .superclass_embed_id
143148 assert superclass_mask .any (dim = - 1 ).all (), 'superclass embed id must be present for all prompts'
Original file line number Diff line number Diff line change 33setup (
44 name = 'perfusion-pytorch' ,
55 packages = find_packages (exclude = []),
6- version = '0.1.20 ' ,
6+ version = '0.1.21 ' ,
77 license = 'MIT' ,
88 description = 'Perfusion - Pytorch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments