@@ -162,14 +162,15 @@ def __init__(
162
162
except BaseException as e :
163
163
print (e )
164
164
165
- assert device in ("cuda" , "cpu" )
165
+ assert device in ("cuda" , "cpu" , "mps" )
166
166
if device == "cpu" :
167
167
self .device = torch .device ("cpu" )
168
- if torch .cuda .is_available () and device == "cuda" :
168
+ elif torch .cuda .is_available () and device == "cuda" :
169
169
self .device = torch .device ("cuda" )
170
-
171
170
if torch .cuda .device_count () > 1 and n_devices > 1 :
172
171
self .model = nn .DataParallel (self .model , device_ids = range (n_devices ))
172
+ elif torch .backends .mps .is_available () and device == "mps" :
173
+ self .device = torch .device ("mps" )
173
174
174
175
self .model .to (self .device )
175
176
self .model .eval ()
@@ -245,6 +246,14 @@ def infer(self) -> None:
245
246
for n , m in zip (names , soft_masks ):
246
247
self .soft_masks [n ] = m
247
248
249
+ # Quick kludge to add soft type and sem to seg_results
250
+ for soft , seg in zip (soft_masks , seg_results ):
251
+ if "type" in soft .keys ():
252
+ seg ["soft_type" ] = soft ["type" ]
253
+ if "sem" in soft .keys ():
254
+ seg ["soft_sem" ] = soft ["sem" ]
255
+
256
+ # save to cache or disk
248
257
if self .save_dir is None :
249
258
for n , m in zip (names , seg_results ):
250
259
self .out_masks [n ] = m
0 commit comments