Skip to content

Commit c7f0f0a

Browse files
committed
pytests working
1 parent 7e13c1e commit c7f0f0a

File tree

7 files changed

+66
-75
lines changed

7 files changed

+66
-75
lines changed

nucleus/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -572,28 +572,24 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
572572
def predict(
573573
self,
574574
model_run_id: str,
575-
payload: Dict[str, List[Union[BoxPrediction, PolygonPrediction]]],
575+
annotations: List[Union[BoxPrediction, PolygonPrediction]],
576+
update: bool,
576577
batch_size: int = 100,
577578
):
578579
"""
579580
Uploads model outputs as predictions for a model_run. Returns info about the upload.
580-
:param payload:
581-
{
582-
"annotations": List[Union[Box2DPrediction, Polygon2DPrediction]],
583-
}
581+
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
582+
:param update: bool
584583
:return:
585584
{
586585
"dataset_id": str,
587586
"model_run_id": str,
588587
"annotations_processed: int,
589588
}
590589
"""
591-
predictions: List[Union[BoxPrediction, PolygonPrediction]] = payload[
592-
ANNOTATIONS_KEY
593-
]
594590
batches = [
595-
predictions[i : i + batch_size]
596-
for i in range(0, len(predictions), batch_size)
591+
annotations[i : i + batch_size]
592+
for i in range(0, len(annotations), batch_size)
597593
]
598594

599595
agg_response = {
@@ -604,7 +600,11 @@ def predict(
604600
tqdm_batches = self.tqdm_bar(batches)
605601

606602
for batch in tqdm_batches:
607-
batch_payload = {ANNOTATIONS_KEY: batch}
603+
batch_payload = construct_box_predictions_payload(
604+
annotations,
605+
update,
606+
)
607+
print(batch_payload)
608608
response = self._make_request(
609609
batch_payload, f"modelRun/{model_run_id}/predict"
610610
)
@@ -701,7 +701,7 @@ def predictions_ref_id(self, model_run_id: str, ref_id: str):
701701
:param reference_id: reference_id of a dataset item.
702702
:return:
703703
{
704-
"annotations": List[Box2DPrediction],
704+
"annotations": List[BoxPrediction],
705705
}
706706
"""
707707
return self._make_request(
@@ -726,7 +726,7 @@ def predictions_iloc(self, model_run_id: str, i: int):
726726
:param i: absolute number of Dataset Item for a dataset corresponding to the model run.
727727
:return:
728728
{
729-
"annotations": List[Box2DPrediction],
729+
"annotations": List[BoxPrediction],
730730
}
731731
"""
732732
return self._make_request(
@@ -755,7 +755,7 @@ def predictions_loc(self, model_run_id: str, dataset_item_id: str):
755755
:param dataset_item_id: dataset_item_id of a dataset item.
756756
:return:
757757
{
758-
"annotations": List[Box2DPrediction],
758+
"annotations": List[BoxPrediction],
759759
}
760760
"""
761761
return self._make_request(

nucleus/model_run.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional, List, Dict, Any, Union
2-
from .constants import ANNOTATIONS_KEY
2+
from .constants import ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE
33
from .prediction import BoxPrediction, PolygonPrediction
44
from .payload_constructor import construct_box_predictions_payload
55

@@ -61,30 +61,26 @@ def commit(self, payload: Optional[dict] = None) -> dict:
6161
return self._client.commit_model_run(self.model_run_id, payload)
6262

6363
def predict(
64-
self, annotations: List[Union[BoxPrediction, PolygonPrediction]]
64+
self,
65+
annotations: List[Union[BoxPrediction, PolygonPrediction]],
66+
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
6567
) -> dict:
6668
"""
6769
Uploads model outputs as predictions for a model_run. Returns info about the upload.
6870
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
6971
:return:
7072
{
71-
"dataset_id": str,
7273
"model_run_id": str,
73-
"annotations_processed: int,
74+
"predictions_processed: int,
7475
}
7576
"""
76-
payload: Dict[str, List[Any]] = construct_box_predictions_payload(
77-
annotations
78-
)
79-
return self._client.predict(self.model_run_id, payload)
77+
return self._client.predict(self.model_run_id, annotations, update)
8078

8179
def iloc(self, i: int):
8280
"""
8381
Returns Model Run Info For Dataset Item by its number.
8482
:param i: absolute number of Dataset Item for a dataset corresponding to the model run.
85-
:return:
86-
{
87-
"annotations": List[Union[BoxPrediction, PolygonPrediction]],
83+
:return: List[Union[BoxPrediction, PolygonPrediction]],
8884
}
8985
"""
9086
response = self._client.predictions_iloc(self.model_run_id, i)
@@ -94,10 +90,7 @@ def refloc(self, reference_id: str):
9490
"""
9591
Returns Model Run Info For Dataset Item by its reference_id.
9692
:param reference_id: reference_id of a dataset item.
97-
:return:
98-
{
99-
"annotations": List[Union[BoxPrediction, PolygonPrediction]],
100-
}
93+
:return: List[Union[BoxPrediction, PolygonPrediction]],
10194
"""
10295
response = self._client.predictions_ref_id(
10396
self.model_run_id, reference_id

nucleus/payload_constructor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ def construct_annotation_payload(
4141

4242
def construct_box_predictions_payload(
4343
box_predictions: List[Union[BoxPrediction, PolygonPrediction]],
44+
update: bool,
4445
) -> dict:
4546
predictions = []
4647
for prediction in box_predictions:
4748
predictions.append(prediction.to_payload())
4849

49-
return {ANNOTATIONS_KEY: predictions}
50+
return {ANNOTATIONS_KEY: predictions, ANNOTATION_UPDATE_KEY: update}
5051

5152

5253
def construct_model_creation_payload(

nucleus/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737

3838
def to_payload(self) -> dict:
3939
payload = super().to_payload()
40-
if self.confidence:
40+
if self.confidence is not None:
4141
payload[CONFIDENCE_KEY] = self.confidence
4242

4343
return payload
@@ -78,7 +78,7 @@ def __init__(
7878

7979
def to_payload(self) -> dict:
8080
payload = super().to_payload()
81-
if self.confidence:
81+
if self.confidence is not None:
8282
payload[CONFIDENCE_KEY] = self.confidence
8383

8484
return payload

tests/test_annotation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def test_polygon_gt_upload(dataset):
5454
response = dataset.refloc(annotation.reference_id)['annotations']
5555
assert len(response) == 1
5656
response_annotation = response[0]
57-
print(response_annotation)
5857
assert_polygon_annotation_matches_dict(response_annotation, TEST_POLYGON_ANNOTATIONS[0])
5958

6059

tests/test_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def test_slice_create_and_delete(dataset):
114114
assert len(response["reference_ids"]) == 2
115115
for item in ds_items[:2]:
116116
assert item.reference_id in response["reference_ids"]
117-
print(response)
118117

119118

120119
def test_slice_append(dataset):

tests/test_prediction.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
TEST_DATASET_NAME,
55
TEST_MODEL_NAME,
66
TEST_MODEL_REFERENCE,
7+
TEST_MODEL_RUN,
78
TEST_IMG_URLS,
89
TEST_BOX_PREDICTIONS,
910
TEST_POLYGON_PREDICTIONS,
@@ -16,7 +17,7 @@
1617
from nucleus.constants import ERROR_PAYLOAD
1718

1819
@pytest.fixture()
19-
def dataset(CLIENT):
20+
def model_run(CLIENT):
2021
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
2122
ds_items = []
2223
for url in TEST_IMG_URLS:
@@ -30,46 +31,48 @@ def dataset(CLIENT):
3031

3132
model = CLIENT.add_model(
3233
name=TEST_MODEL_NAME,
33-
reference_id=TEST_MODEL_NAME
34+
reference_id=TEST_MODEL_REFERENCE
3435
)
3536

37+
run = model.create_run(
38+
name=TEST_MODEL_RUN,
39+
dataset=ds,
40+
predictions=[])
3641

37-
yield ds
42+
yield run
3843

3944
response = CLIENT.delete_dataset(ds.id)
4045
assert response == {}
46+
response = CLIENT.delete_model(model.id)
47+
assert response == {}
4148

42-
43-
def test_box_pred_upload(dataset):
49+
def test_box_pred_upload(model_run):
4450
prediction = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
45-
response = dataset.annotate(predictions=[prediction])
51+
response = model_run.predict(annotations=[prediction])
4652

47-
assert response['dataset_id'] == dataset.id
53+
assert response['model_run_id'] == model_run.model_run_id
4854
assert response['predictions_processed'] == 1
4955

50-
response = dataset.refloc(prediction.reference_id)['predictions']
56+
response = model_run.refloc(prediction.reference_id)
5157
assert len(response) == 1
52-
response_prediction = response[0]
53-
assert_box_prediction_matches_dict(response_prediction, TEST_BOX_PREDICTIONS[0])
58+
assert_box_prediction_matches_dict(response[0], TEST_BOX_PREDICTIONS[0])
5459

5560

56-
def test_polygon_pred_upload(dataset):
61+
def test_polygon_pred_upload(model_run):
5762
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
58-
response = dataset.annotate(predictions=[prediction])
63+
response = model_run.predict(annotations=[prediction])
5964

60-
assert response['dataset_id'] == dataset.id
65+
assert response['model_run_id'] == model_run.model_run_id
6166
assert response['predictions_processed'] == 1
6267

63-
response = dataset.refloc(prediction.reference_id)['predictions']
68+
response = model_run.refloc(prediction.reference_id)
6469
assert len(response) == 1
65-
response_prediction = response[0]
66-
print(response_prediction)
67-
assert_polygon_prediction_matches_dict(response_prediction, TEST_POLYGON_PREDICTIONS[0])
70+
assert_polygon_prediction_matches_dict(response[0], TEST_POLYGON_PREDICTIONS[0])
6871

6972

70-
def test_box_pred_upload_update(dataset):
73+
def test_box_pred_upload_update(model_run):
7174
prediction = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
72-
response = dataset.annotate(predictions=[prediction])
75+
response = model_run.predict(annotations=[prediction])
7376

7477
assert response['predictions_processed'] == 1
7578

@@ -79,19 +82,18 @@ def test_box_pred_upload_update(dataset):
7982
prediction_update_params['reference_id'] = TEST_BOX_PREDICTIONS[0]['reference_id']
8083

8184
prediction_update = BoxPrediction(**prediction_update_params)
82-
response = dataset.annotate(predictions=[prediction_update], update=True)
85+
response = model_run.predict(annotations=[prediction_update], update=True)
8386

8487
assert response['predictions_processed'] == 1
8588

86-
response = dataset.refloc(prediction.reference_id)['predictions']
89+
response = model_run.refloc(prediction.reference_id)
8790
assert len(response) == 1
88-
response_prediction = response[0]
89-
assert_box_prediction_matches_dict(response_prediction, prediction_update_params)
91+
assert_box_prediction_matches_dict(response[0], prediction_update_params)
9092

9193

92-
def test_box_pred_upload_ignore(dataset):
94+
def test_box_pred_upload_ignore(model_run):
9395
prediction = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
94-
response = dataset.annotate(predictions=[prediction])
96+
response = model_run.predict(annotations=[prediction])
9597

9698
assert response['predictions_processed'] == 1
9799

@@ -101,19 +103,18 @@ def test_box_pred_upload_ignore(dataset):
101103
prediction_update_params['reference_id'] = TEST_BOX_PREDICTIONS[0]['reference_id']
102104
prediction_update = BoxPrediction(**prediction_update_params)
103105
# Default behavior is ignore.
104-
response = dataset.annotate(predictions=[prediction_update])
106+
response = model_run.predict(annotations=[prediction_update])
105107

106108
assert response['predictions_processed'] == 1
107109

108-
response = dataset.refloc(prediction.reference_id)['predictions']
110+
response = model_run.refloc(prediction.reference_id)
109111
assert len(response) == 1
110-
response_prediction = response[0]
111-
assert_box_prediction_matches_dict(response_prediction, TEST_BOX_PREDICTIONS[0])
112+
assert_box_prediction_matches_dict(response[0], TEST_BOX_PREDICTIONS[0])
112113

113114

114-
def test_polygon_pred_upload_update(dataset):
115+
def test_polygon_pred_upload_update(model_run):
115116
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
116-
response = dataset.annotate(predictions=[prediction])
117+
response = model_run.predict(annotations=[prediction])
117118

118119
assert response['predictions_processed'] == 1
119120

@@ -123,19 +124,18 @@ def test_polygon_pred_upload_update(dataset):
123124
prediction_update_params['reference_id'] = TEST_POLYGON_PREDICTIONS[0]['reference_id']
124125

125126
prediction_update = PolygonPrediction(**prediction_update_params)
126-
response = dataset.annotate(predictions=[prediction_update], update=True)
127+
response = model_run.predict(annotations=[prediction_update], update=True)
127128

128129
assert response['predictions_processed'] == 1
129130

130-
response = dataset.refloc(prediction.reference_id)['predictions']
131+
response = model_run.refloc(prediction.reference_id)
131132
assert len(response) == 1
132-
response_prediction = response[0]
133-
assert_polygon_prediction_matches_dict(response_prediction, prediction_update_params)
133+
assert_polygon_prediction_matches_dict(response[0], prediction_update_params)
134134

135135

136-
def test_polygon_pred_upload_ignore(dataset):
136+
def test_polygon_pred_upload_ignore(model_run):
137137
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
138-
response = dataset.annotate(predictions=[prediction])
138+
response = model_run.predict(annotations=[prediction])
139139

140140
assert response['predictions_processed'] == 1
141141

@@ -146,11 +146,10 @@ def test_polygon_pred_upload_ignore(dataset):
146146

147147
prediction_update = PolygonPrediction(**prediction_update_params)
148148
# Default behavior is ignore.
149-
response = dataset.annotate(predictions=[prediction_update])
149+
response = model_run.predict(annotations=[prediction_update])
150150

151151
assert response['predictions_processed'] == 1
152152

153-
response = dataset.refloc(prediction.reference_id)['predictions']
153+
response = model_run.refloc(prediction.reference_id)
154154
assert len(response) == 1
155-
response_prediction = response[0]
156-
assert_polygon_prediction_matches_dict(response_prediction, TEST_POLYGON_PREDICTIONS[0])
155+
assert_polygon_prediction_matches_dict(response[0], TEST_POLYGON_PREDICTIONS[0])

0 commit comments

Comments
 (0)