Skip to content

Commit 320b659

Browse files
authored
Merge branch 'NVIDIA-AI-IOT:main' into main
2 parents 8f1946c + cfef75a commit 320b659

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

examples/owl_predict.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
parser = argparse.ArgumentParser()
3232
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
33-
parser.add_argument("--prompt", type=str, default="an owl, a glove")
33+
parser.add_argument("--prompt", type=str, default="[an owl, a glove]")
3434
parser.add_argument("--threshold", type=str, default="0.1,0.1")
3535
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
3636
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
@@ -45,7 +45,10 @@
4545

4646
thresholds = args.threshold.strip("][()")
4747
thresholds = thresholds.split(',')
48-
thresholds = [float(x) for x in thresholds]
48+
if len(thresholds) == 1:
49+
thresholds = float(thresholds[0])
50+
else:
51+
thresholds = [float(x) for x in thresholds]
4952
print(thresholds)
5053

5154

nanoowl/owl_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def decode(self,
278278
) -> OwlDecodeOutput:
279279

280280
if isinstance(threshold, (int, float)):
281-
threshold = [threshold]
281+
threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels
282282

283283
num_input_images = image_output.image_class_embeds.shape[0]
284284

@@ -468,7 +468,7 @@ def predict(self,
468468
if text_encodings is None:
469469
text_encodings = self.encode_text(text)
470470

471-
rois = torch.tensor([[0, 0, image.height, image.width]], dtype=image_tensor.dtype, device=image_tensor.device)
471+
rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
472472

473473
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
474474

0 commit comments

Comments
 (0)