Skip to content

Commit 3acb748

Browse files
committed
add thresholds per object for owl_predictor
1 parent cca8017 commit 3acb748

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

examples/owl_predict.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,25 @@
3030

3131
parser = argparse.ArgumentParser()
3232
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
33-
parser.add_argument("--prompt", type=str, default="")
34-
parser.add_argument("--threshold", type=float, default=0.1)
33+
parser.add_argument("--prompt", type=str, default="an owl, a glove")
34+
parser.add_argument("--thresholds", type=str, default="0.1,0.1")
3535
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
3636
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
37-
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
37+
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
3838
parser.add_argument("--profile", action="store_true")
3939
parser.add_argument("--num_profiling_runs", type=int, default=30)
4040
args = parser.parse_args()
4141

4242
prompt = args.prompt.strip("][()")
43-
4443
text = prompt.split(',')
45-
4644
print(text)
4745

46+
thresholds = args.thresholds.strip("][()")
47+
thresholds = thresholds.split(',')
48+
thresholds = [float(x) for x in thresholds]
49+
print(thresholds)
50+
51+
4852
predictor = OwlPredictor(
4953
args.model,
5054
image_encoder_engine=args.image_encoder_engine
@@ -58,7 +62,7 @@
5862
image=image,
5963
text=text,
6064
text_encodings=text_encodings,
61-
threshold=args.threshold,
65+
thresholds=thresholds,
6266
pad_square=False
6367
)
6468

@@ -70,7 +74,7 @@
7074
image=image,
7175
text=text,
7276
text_encodings=text_encodings,
73-
threshold=args.threshold,
77+
thresholds=thresholds,
7478
pad_square=False
7579
)
7680
torch.cuda.current_stream().synchronize()

nanoowl/owl_predictor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
274274
def decode(self,
275275
image_output: OwlEncodeImageOutput,
276276
text_output: OwlEncodeTextOutput,
277-
threshold: float = 0.1
277+
thresholds: List[float],
278278
) -> OwlDecodeOutput:
279279

280280
num_input_images = image_output.image_class_embeds.shape[0]
@@ -290,8 +290,16 @@ def decode(self,
290290
scores_max = scores_sigmoid.max(dim=-1)
291291
labels = scores_max.indices
292292
scores = scores_max.values
293-
294-
mask = (scores > threshold)
293+
masks = []
294+
for i, threshold in enumerate(thresholds):
295+
label_mask = labels == i
296+
score_mask = scores > threshold
297+
obj_mask = torch.logical_and(label_mask,score_mask)
298+
masks.append(obj_mask)
299+
300+
mask = masks[0]
301+
for mask_t in masks[1:]:
302+
mask = torch.logical_or(mask, mask_t)
295303

296304
input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device)
297305
input_indices = input_indices[:, None].repeat(1, self.num_patches)
@@ -447,8 +455,9 @@ def predict(self,
447455
image: PIL.Image,
448456
text: List[str],
449457
text_encodings: Optional[OwlEncodeTextOutput],
458+
thresholds: List[float],
450459
pad_square: bool = True,
451-
threshold: float = 0.1
460+
452461
) -> OwlDecodeOutput:
453462

454463
image_tensor = self.image_preprocessor.preprocess_pil_image(image)
@@ -460,5 +469,5 @@ def predict(self,
460469

461470
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
462471

463-
return self.decode(image_encodings, text_encodings, threshold)
472+
return self.decode(image_encodings, text_encodings, thresholds)
464473

0 commit comments

Comments
 (0)