74
74
PolygonAnnotation ,
75
75
SegmentationAnnotation ,
76
76
)
77
- from .prediction import BoxPrediction , PolygonPrediction
77
+ from .prediction import (
78
+ BoxPrediction ,
79
+ PolygonPrediction ,
80
+ SegmentationPrediction ,
81
+ )
78
82
from .model_run import ModelRun
79
83
from .slice import Slice
80
84
from .upload_response import UploadResponse
@@ -622,7 +626,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
622
626
def predict (
623
627
self ,
624
628
model_run_id : str ,
625
- annotations : List [Union [BoxPrediction , PolygonPrediction ]],
629
+ annotations : List [
630
+ Union [BoxPrediction , PolygonPrediction , SegmentationPrediction ]
631
+ ],
626
632
update : bool ,
627
633
batch_size : int = 100 ,
628
634
):
@@ -638,9 +644,26 @@ def predict(
638
644
"predictions_ignored": int,
639
645
}
640
646
"""
647
+ segmentations = [
648
+ ann
649
+ for ann in annotations
650
+ if isinstance (ann , SegmentationPrediction )
651
+ ]
652
+
653
+ other_predictions = [
654
+ ann
655
+ for ann in annotations
656
+ if not isinstance (ann , SegmentationPrediction )
657
+ ]
658
+
659
+ s_batches = [
660
+ segmentations [i : i + batch_size ]
661
+ for i in range (0 , len (segmentations ), batch_size )
662
+ ]
663
+
641
664
batches = [
642
- annotations [i : i + batch_size ]
643
- for i in range (0 , len (annotations ), batch_size )
665
+ other_predictions [i : i + batch_size ]
666
+ for i in range (0 , len (other_predictions ), batch_size )
644
667
]
645
668
646
669
agg_response = {
@@ -669,8 +692,23 @@ def predict(
669
692
PREDICTIONS_IGNORED_KEY
670
693
]
671
694
695
+ for s_batch in s_batches :
696
+ payload = construct_segmentation_payload (s_batch , update )
697
+ response = self ._make_request (
698
+ payload , f"modelRun/{ model_run_id } /predict_segmentation"
699
+ )
700
+ # pbar.update(1)
701
+ if STATUS_CODE_KEY in response :
702
+ agg_response [ERRORS_KEY ] = response
703
+ else :
704
+ agg_response [PREDICTIONS_PROCESSED_KEY ] += response [
705
+ PREDICTIONS_PROCESSED_KEY
706
+ ]
707
+ agg_response [PREDICTIONS_IGNORED_KEY ] += response [
708
+ PREDICTIONS_IGNORED_KEY
709
+ ]
710
+
672
711
return agg_response
673
- # return self._make_request(payload, f"modelRun/{model_run_id}/predict")
674
712
675
713
def commit_model_run (
676
714
self , model_run_id : str , payload : Optional [dict ] = None
0 commit comments