Skip to content

Commit 59d323f

Browse files
committed
feat: add mps support
1 parent 912e0e1 commit 59d323f

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,15 @@ def __init__(
162162
except BaseException as e:
163163
print(e)
164164

165-
assert device in ("cuda", "cpu")
165+
assert device in ("cuda", "cpu", "mps")
166166
if device == "cpu":
167167
self.device = torch.device("cpu")
168-
if torch.cuda.is_available() and device == "cuda":
168+
elif torch.cuda.is_available() and device == "cuda":
169169
self.device = torch.device("cuda")
170-
171170
if torch.cuda.device_count() > 1 and n_devices > 1:
172171
self.model = nn.DataParallel(self.model, device_ids=range(n_devices))
172+
elif torch.backends.mps.is_available() and device == "mps":
173+
self.device = torch.device("mps")
173174

174175
self.model.to(self.device)
175176
self.model.eval()
@@ -245,6 +246,14 @@ def infer(self) -> None:
245246
for n, m in zip(names, soft_masks):
246247
self.soft_masks[n] = m
247248

249+
# Quick kludge to add soft type and sem to seg_results
250+
for soft, seg in zip(soft_masks, seg_results):
251+
if "type" in soft.keys():
252+
seg["soft_type"] = soft["type"]
253+
if "sem" in soft.keys():
254+
seg["soft_sem"] = soft["sem"]
255+
256+
# save to cache or disk
248257
if self.save_dir is None:
249258
for n, m in zip(names, seg_results):
250259
self.out_masks[n] = m

0 commit comments

Comments
 (0)