Skip to content

Commit 60c5d9f

Browse files
authored
Merge pull request #9 from ssmmoo1/main
Add thresholds per object for owl_predictor
2 parents cca8017 + d8fa78b commit 60c5d9f

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

examples/owl_predict.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,25 @@
3030

3131
parser = argparse.ArgumentParser()
3232
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
33-
parser.add_argument("--prompt", type=str, default="")
34-
parser.add_argument("--threshold", type=float, default=0.1)
33+
parser.add_argument("--prompt", type=str, default="an owl, a glove")
34+
parser.add_argument("--threshold", type=str, default="0.1,0.1")
3535
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
3636
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
37-
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
37+
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
3838
parser.add_argument("--profile", action="store_true")
3939
parser.add_argument("--num_profiling_runs", type=int, default=30)
4040
args = parser.parse_args()
4141

4242
prompt = args.prompt.strip("][()")
43-
4443
text = prompt.split(',')
45-
4644
print(text)
4745

46+
thresholds = args.threshold.strip("][()")
47+
thresholds = thresholds.split(',')
48+
thresholds = [float(x) for x in thresholds]
49+
print(thresholds)
50+
51+
4852
predictor = OwlPredictor(
4953
args.model,
5054
image_encoder_engine=args.image_encoder_engine
@@ -58,7 +62,7 @@
5862
image=image,
5963
text=text,
6064
text_encodings=text_encodings,
61-
threshold=args.threshold,
65+
threshold=thresholds,
6266
pad_square=False
6367
)
6468

@@ -70,7 +74,7 @@
7074
image=image,
7175
text=text,
7276
text_encodings=text_encodings,
73-
threshold=args.threshold,
77+
threshold=thresholds,
7478
pad_square=False
7579
)
7680
torch.cuda.current_stream().synchronize()

examples/tree_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
parser.add_argument("--threshold", type=float, default=0.1)
3434
parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg")
3535
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
36-
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
36+
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
3737
args = parser.parse_args()
3838

3939
predictor = TreePredictor(

nanoowl/owl_predictor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,12 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
274274
def decode(self,
275275
image_output: OwlEncodeImageOutput,
276276
text_output: OwlEncodeTextOutput,
277-
threshold: float = 0.1
277+
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
278278
) -> OwlDecodeOutput:
279279

280+
if isinstance(threshold, (int, float)):
281+
threshold = [threshold]
282+
280283
num_input_images = image_output.image_class_embeds.shape[0]
281284

282285
image_class_embeds = image_output.image_class_embeds
@@ -290,8 +293,16 @@ def decode(self,
290293
scores_max = scores_sigmoid.max(dim=-1)
291294
labels = scores_max.indices
292295
scores = scores_max.values
293-
294-
mask = (scores > threshold)
296+
masks = []
297+
for i, thresh in enumerate(threshold):
298+
label_mask = labels == i
299+
score_mask = scores > thresh
300+
obj_mask = torch.logical_and(label_mask,score_mask)
301+
masks.append(obj_mask)
302+
303+
mask = masks[0]
304+
for mask_t in masks[1:]:
305+
mask = torch.logical_or(mask, mask_t)
295306

296307
input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device)
297308
input_indices = input_indices[:, None].repeat(1, self.num_patches)
@@ -447,8 +458,9 @@ def predict(self,
447458
image: PIL.Image,
448459
text: List[str],
449460
text_encodings: Optional[OwlEncodeTextOutput],
461+
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
450462
pad_square: bool = True,
451-
threshold: float = 0.1
463+
452464
) -> OwlDecodeOutput:
453465

454466
image_tensor = self.image_preprocessor.preprocess_pil_image(image)

0 commit comments

Comments
 (0)