Skip to content

Commit 89d39d8

Browse files
committed
address #15
1 parent 75bdb98 commit 89d39d8

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

perfusion_pytorch/embedding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def __init__(
113113
def parameters(self):
114114
return [self.concepts]
115115

116+
@property
117+
def device(self):
118+
return self.concepts.device
119+
116120
@beartype
117121
def forward(
118122
self,
@@ -138,6 +142,7 @@ def forward(
138142
inferred_concept_id = self.concept_embed_ids[0]
139143

140144
x = self.tokenize(x)
145+
x = x.to(self.device)
141146

142147
superclass_mask = x == self.superclass_embed_id
143148
assert superclass_mask.any(dim = -1).all(), 'superclass embed id must be present for all prompts'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.20',
6+
version = '0.1.21',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)