Skip to content

Commit b938167

Browse files
committed
make thresholds backwards compatible
1 parent 3acb748 commit b938167

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

examples/owl_predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
parser = argparse.ArgumentParser()
3232
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
3333
parser.add_argument("--prompt", type=str, default="an owl, a glove")
34-
parser.add_argument("--thresholds", type=str, default="0.1,0.1")
34+
parser.add_argument("--threshold", type=str, default="0.1,0.1")
3535
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
3636
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
3737
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
@@ -43,7 +43,7 @@
4343
text = prompt.split(',')
4444
print(text)
4545

46-
thresholds = args.thresholds.strip("][()")
46+
thresholds = args.threshold.strip("][()")
4747
thresholds = thresholds.split(',')
4848
thresholds = [float(x) for x in thresholds]
4949
print(thresholds)
@@ -62,7 +62,7 @@
6262
image=image,
6363
text=text,
6464
text_encodings=text_encodings,
65-
thresholds=thresholds,
65+
threshold=thresholds,
6666
pad_square=False
6767
)
6868

examples/tree_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
parser.add_argument("--threshold", type=float, default=0.1)
3434
parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg")
3535
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
36-
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
36+
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
3737
args = parser.parse_args()
3838

3939
predictor = TreePredictor(

nanoowl/owl_predictor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,12 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
274274
def decode(self,
275275
image_output: OwlEncodeImageOutput,
276276
text_output: OwlEncodeTextOutput,
277-
thresholds: List[float],
277+
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
278278
) -> OwlDecodeOutput:
279279

280+
if isinstance(threshold, (int, float)):
281+
threshold = [threshold]
282+
280283
num_input_images = image_output.image_class_embeds.shape[0]
281284

282285
image_class_embeds = image_output.image_class_embeds
@@ -291,9 +294,9 @@ def decode(self,
291294
labels = scores_max.indices
292295
scores = scores_max.values
293296
masks = []
294-
for i, threshold in enumerate(thresholds):
297+
for i, thresh in enumerate(threshold):
295298
label_mask = labels == i
296-
score_mask = scores > threshold
299+
score_mask = scores > thresh
297300
obj_mask = torch.logical_and(label_mask,score_mask)
298301
masks.append(obj_mask)
299302

@@ -455,7 +458,7 @@ def predict(self,
455458
image: PIL.Image,
456459
text: List[str],
457460
text_encodings: Optional[OwlEncodeTextOutput],
458-
thresholds: List[float],
461+
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
459462
pad_square: bool = True,
460463

461464
) -> OwlDecodeOutput:
@@ -469,5 +472,5 @@ def predict(self,
469472

470473
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
471474

472-
return self.decode(image_encodings, text_encodings, thresholds)
475+
return self.decode(image_encodings, text_encodings, threshold)
473476

0 commit comments

Comments
 (0)