@@ -274,7 +274,7 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
274
274
def decode (self ,
275
275
image_output : OwlEncodeImageOutput ,
276
276
text_output : OwlEncodeTextOutput ,
277
- threshold : float = 0.1
277
+ thresholds : List [ float ],
278
278
) -> OwlDecodeOutput :
279
279
280
280
num_input_images = image_output .image_class_embeds .shape [0 ]
@@ -290,8 +290,16 @@ def decode(self,
290
290
scores_max = scores_sigmoid .max (dim = - 1 )
291
291
labels = scores_max .indices
292
292
scores = scores_max .values
293
-
294
- mask = (scores > threshold )
293
+ masks = []
294
+ for i , threshold in enumerate (thresholds ):
295
+ label_mask = labels == i
296
+ score_mask = scores > threshold
297
+ obj_mask = torch .logical_and (label_mask ,score_mask )
298
+ masks .append (obj_mask )
299
+
300
+ mask = masks [0 ]
301
+ for mask_t in masks [1 :]:
302
+ mask = torch .logical_or (mask , mask_t )
295
303
296
304
input_indices = torch .arange (0 , num_input_images , dtype = labels .dtype , device = labels .device )
297
305
input_indices = input_indices [:, None ].repeat (1 , self .num_patches )
@@ -447,8 +455,9 @@ def predict(self,
447
455
image : PIL .Image ,
448
456
text : List [str ],
449
457
text_encodings : Optional [OwlEncodeTextOutput ],
458
+ thresholds : List [float ],
450
459
pad_square : bool = True ,
451
- threshold : float = 0.1
460
+
452
461
) -> OwlDecodeOutput :
453
462
454
463
image_tensor = self .image_preprocessor .preprocess_pil_image (image )
@@ -460,5 +469,5 @@ def predict(self,
460
469
461
470
image_encodings = self .encode_rois (image_tensor , rois , pad_square = pad_square )
462
471
463
- return self .decode (image_encodings , text_encodings , threshold )
472
+ return self .decode (image_encodings , text_encodings , thresholds )
464
473
0 commit comments