Skip to content

Commit c4622a7

Browse files
author
Diego Ardila
committed
Passes local tests
1 parent ef23e95 commit c4622a7

File tree

5 files changed

+33
-20
lines changed

5 files changed

+33
-20
lines changed

nucleus/annotation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,6 @@ def to_payload(self) -> dict:
174174
}
175175

176176

177-
# TODO: Add Generic type for 2D point
178-
179-
180177
@dataclass
181178
class Point:
182179
x: float
@@ -206,9 +203,15 @@ def __post_init__(self):
206203
if not hasattr(self.vertices[0], X_KEY) or not hasattr(
207204
self.vertices[0], "to_payload"
208205
):
209-
raise ValueError(
210-
"Use the Point object, not a dictionary for vertices"
211-
)
206+
try:
207+
self.vertices = [
208+
Point(x=vertex[X_KEY], y=vertex[Y_KEY])
209+
for vertex in self.vertices
210+
]
211+
except KeyError as ke:
212+
raise ValueError(
213+
"Use a point object to pass in vertices. For example, vertices=[nucleus.Point(x=1, y=2)]"
214+
) from ke
212215

213216
@classmethod
214217
def from_json(cls, payload: dict):

nucleus/prediction.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Dict, Optional, List, Any
1+
from typing import Dict, Optional, List
22
from .annotation import (
33
BoxAnnotation,
4+
Point,
45
PolygonAnnotation,
56
Segment,
67
SegmentationAnnotation,
@@ -101,7 +102,7 @@ class PolygonPrediction(PolygonAnnotation):
101102
def __init__(
102103
self,
103104
label: str,
104-
vertices: List[Any],
105+
vertices: List[Point],
105106
reference_id: Optional[str] = None,
106107
item_id: Optional[str] = None,
107108
confidence: Optional[float] = None,
@@ -133,7 +134,9 @@ def from_json(cls, payload: dict):
133134
geometry = payload.get(GEOMETRY_KEY, {})
134135
return cls(
135136
label=payload.get(LABEL_KEY, 0),
136-
vertices=geometry.get(VERTICES_KEY, []),
137+
vertices=[
138+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
139+
],
137140
reference_id=payload.get(REFERENCE_ID_KEY, None),
138141
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
139142
confidence=payload.get(CONFIDENCE_KEY, None),

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.8"
24+
version = "0.1.10"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
@@ -36,7 +36,7 @@ python = "^3.6.2"
3636
requests = "^2.23.0"
3737
tqdm = "^4.41.0"
3838
dataclasses = { version = "^0.7", python = "^3.6.1, <3.7" }
39-
aiohttp = "^3.7.4"
39+
aiohttp = "3.7.4.post0"
4040

4141
[tool.poetry.dev-dependencies]
4242
poetry = "^1.1.5"

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def reference_id_from_url(url):
112112
{
113113
**TEST_POLYGON_ANNOTATIONS[i],
114114
"confidence": 0.10 * i,
115-
"class_pdf": TEST_POLYGON_MODEL_PDF,
115+
"class_pdf": None,
116116
}
117117
for i in range(len(TEST_POLYGON_ANNOTATIONS))
118118
]

tests/test_prediction.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DatasetItem,
2323
Segment,
2424
ModelRun,
25+
Point,
2526
)
2627
from nucleus.constants import ERROR_PAYLOAD
2728

@@ -89,7 +90,7 @@ def test_box_pred_upload(model_run):
8990

9091

9192
def test_polygon_pred_upload(model_run):
92-
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
93+
prediction = PolygonPrediction.from_json(TEST_POLYGON_PREDICTIONS[0])
9394
response = model_run.predict(annotations=[prediction])
9495

9596
assert response["model_run_id"] == model_run.model_run_id
@@ -189,7 +190,7 @@ def test_box_pred_upload_ignore(model_run):
189190

190191

191192
def test_polygon_pred_upload_update(model_run):
192-
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
193+
prediction = PolygonPrediction.from_json(TEST_POLYGON_PREDICTIONS[0])
193194
response = model_run.predict(annotations=[prediction])
194195

195196
assert response["predictions_processed"] == 1
@@ -203,7 +204,7 @@ def test_polygon_pred_upload_update(model_run):
203204
"reference_id"
204205
]
205206

206-
prediction_update = PolygonPrediction(**prediction_update_params)
207+
prediction_update = PolygonPrediction.from_json(prediction_update_params)
207208
response = model_run.predict(annotations=[prediction_update], update=True)
208209

209210
assert response["predictions_processed"] == 1
@@ -217,7 +218,7 @@ def test_polygon_pred_upload_update(model_run):
217218

218219

219220
def test_polygon_pred_upload_ignore(model_run):
220-
prediction = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
221+
prediction = PolygonPrediction.from_json(TEST_POLYGON_PREDICTIONS[0])
221222
response = model_run.predict(annotations=[prediction])
222223

223224
assert response["predictions_processed"] == 1
@@ -231,7 +232,7 @@ def test_polygon_pred_upload_ignore(model_run):
231232
"reference_id"
232233
]
233234

234-
prediction_update = PolygonPrediction(**prediction_update_params)
235+
prediction_update = PolygonPrediction.from_json(prediction_update_params)
235236
# Default behavior is ignore.
236237
response = model_run.predict(annotations=[prediction_update])
237238

@@ -249,7 +250,9 @@ def test_mixed_pred_upload(model_run):
249250
prediction_semseg = SegmentationPrediction.from_json(
250251
TEST_SEGMENTATION_PREDICTIONS[0]
251252
)
252-
prediction_polygon = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
253+
prediction_polygon = PolygonPrediction.from_json(
254+
TEST_POLYGON_PREDICTIONS[0]
255+
)
253256
prediction_bbox = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
254257
response = model_run.predict(
255258
annotations=[prediction_semseg, prediction_polygon, prediction_bbox]
@@ -276,7 +279,9 @@ def test_mixed_pred_upload_async(model_run: ModelRun):
276279
prediction_semseg = SegmentationPrediction.from_json(
277280
TEST_SEGMENTATION_PREDICTIONS[0]
278281
)
279-
prediction_polygon = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
282+
prediction_polygon = PolygonPrediction.from_json(
283+
TEST_POLYGON_PREDICTIONS[0]
284+
)
280285
prediction_bbox = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
281286
job: AsyncJob = model_run.predict(
282287
annotations=[prediction_semseg, prediction_polygon, prediction_bbox],
@@ -310,7 +315,9 @@ def test_mixed_pred_upload_async_with_error(model_run: ModelRun):
310315
prediction_semseg = SegmentationPrediction.from_json(
311316
TEST_SEGMENTATION_PREDICTIONS[0]
312317
)
313-
prediction_polygon = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
318+
prediction_polygon = PolygonPrediction.from_json(
319+
TEST_POLYGON_PREDICTIONS[0]
320+
)
314321
prediction_bbox = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
315322
prediction_bbox.reference_id = "fake_garbage"
316323

0 commit comments

Comments
 (0)