Skip to content

Commit c3a8032

Browse files
author
Matt Sokoloff
committed
in-line validation
1 parent b42b6bd commit c3a8032

File tree

4 files changed

+82
-21
lines changed

4 files changed

+82
-21
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -781,33 +781,58 @@ def is_valid_location(cls, v):
781781
return v
782782

783783

784-
class MaskFeatures(BaseModel):
785-
instanceURI: str
786-
colorRGB: Union[List[int], Tuple[int, int, int]]
784+
class RLEMaskFeatures(BaseModel):
785+
counts: List[int]
786+
size: List[int]
787787

788+
@validator('counts')
789+
def validate_counts(cls, counts):
790+
if not all([count >= 0 for count in counts]):
791+
raise ValueError("Found negative value for counts. They should all be zero or positive")
792+
return counts
788793

789-
class NDMask(NDBaseTool):
790-
ontology_type: Literal["superpixel"] = "superpixel"
791-
mask: MaskFeatures = pydantic.Field(determinant=True)
794+
@validator('size')
795+
def validate_size(cls, size):
796+
if len(size) != 2:
797+
raise ValueError(f"Mask `size` should have two ints representing height and with. Found : {size}")
798+
if not all([count > 0 for count in size]):
799+
raise ValueError(f"Mask `size` should be a postitive int. Found : {size}" )
800+
return size
792801

793-
@validator('mask')
794-
def is_valid_mask(cls, v):
795-
if isinstance(v, BaseModel):
796-
v = v.dict()
797802

798-
colors = v['colorRGB']
803+
804+
class PNGMaskFeatures(BaseModel):
805+
# base64 encoded png bytes
806+
png: str
807+
808+
809+
810+
class URIMaskFeatures(BaseModel):
811+
instanceURI: str
812+
colorRGB: Union[List[int], Tuple[int, int, int]]
813+
814+
@validator('colorRGB')
815+
def validate_color(cls, colorRGB):
799816
#Does the dtype matter? Can it be a float?
800-
if not isinstance(colors, (tuple, list)):
817+
if not isinstance(colorRGB, (tuple, list)):
801818
raise ValueError(
802-
f"Received color that is not a list or tuple. Found : {colors}")
803-
elif len(colors) != 3:
819+
f"Received color that is not a list or tuple. Found : {colorRGB}")
820+
elif len(colorRGB) != 3:
804821
raise ValueError(
805-
f"Must provide RGB values for segmentation colors. Found : {colors}"
822+
f"Must provide RGB values for segmentation colors. Found : {colorRGB}"
806823
)
807-
elif not all([0 <= color <= 255 for color in colors]):
824+
elif not all([0 <= color <= 255 for color in colorRGB]):
808825
raise ValueError(
809-
f"All rgb colors must be between 0 and 255. Found : {colors}")
810-
return v
826+
f"All rgb colors must be between 0 and 255. Found : {colorRGB}")
827+
return colorRGB
828+
829+
830+
831+
832+
class NDMask(NDBaseTool):
833+
ontology_type: Literal["superpixel"] = "superpixel"
834+
mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = pydantic.Field(determinant=True)
835+
811836

812837

813838
#A union with custom construction logic to improve error messages

tests/integration/annotation_import/conftest.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,30 @@ def segmentation_inference(prediction_id_mapping):
231231
del segmentation['tool']
232232
return segmentation
233233

234+
@pytest.fixture
235+
def segmentation_inference_rle(prediction_id_mapping):
236+
segmentation = prediction_id_mapping['superpixel'].copy()
237+
segmentation.update(
238+
{
239+
'uuid' : str(uuid.uuid4()),
240+
'mask': {
241+
'size': [10,10],
242+
'counts': [1, 0, 10,100]
243+
}})
244+
del segmentation['tool']
245+
return segmentation
246+
247+
@pytest.fixture
248+
def segmentation_inference_png(prediction_id_mapping):
249+
segmentation = prediction_id_mapping['superpixel'].copy()
250+
segmentation.update(
251+
{
252+
'uuid' : str(uuid.uuid4()),
253+
'mask': {
254+
'png': "somedata",
255+
}})
256+
del segmentation['tool']
257+
return segmentation
234258

235259
@pytest.fixture
236260
def checklist_inference(prediction_id_mapping):
@@ -283,10 +307,10 @@ def model_run_predictions(polygon_inference, rectangle_inference,
283307
# also used for label imports
284308
@pytest.fixture
285309
def object_predictions(polygon_inference, rectangle_inference, line_inference,
286-
entity_inference, segmentation_inference):
310+
entity_inference, segmentation_inference, segmentation_inference_rle, segmentation_inference_png):
287311
return [
288312
polygon_inference, rectangle_inference, line_inference,
289-
entity_inference, segmentation_inference
313+
entity_inference, segmentation_inference, segmentation_inference_rle, segmentation_inference_png
290314
]
291315

292316

tests/integration/annotation_import/test_ndjson_validation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def test_subclassification_construction(rectangle_inference):
2828
(fixture_ref('rectangle_inference'), NDRectangle),
2929
(fixture_ref('line_inference'), NDPolyline),
3030
(fixture_ref('entity_inference'), NDTextEntity),
31-
(fixture_ref('segmentation_inference'), NDMask)])
31+
(fixture_ref('segmentation_inference'), NDMask),
32+
(fixture_ref('segmentation_inference_rle'), NDMask),
33+
(fixture_ref('segmentation_inference_png'), NDMask)])
3234
def test_tool_construction(inference, expected_type):
3335
assert isinstance(NDTool.build(inference), expected_type)
3436

@@ -131,6 +133,14 @@ def test_incorrect_mask(segmentation_inference, configured_project):
131133
with pytest.raises(MALValidationError):
132134
_validate_ndjson([seg], configured_project)
133135

136+
seg['mask'] = {'counts' : [0], 'size' : [0,1]}
137+
with pytest.raises(MALValidationError):
138+
_validate_ndjson([seg], configured_project)
139+
140+
seg['mask'] = {'counts' : [-1], 'size' : [1,1]}
141+
with pytest.raises(MALValidationError):
142+
_validate_ndjson([seg], configured_project)
143+
134144

135145
def test_all_validate_json(configured_project, predictions):
136146
#Predictions contains one of each type of prediction.

tests/integration/test_labeling_frontend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def test_get_labeling_frontends(client):
1212
where=LabelingFrontend.iframe_url_path ==
1313
target_frontend.iframe_url_path)
1414
for frontend in filtered_frontends:
15+
if frontend.name != 'Editor':
16+
continue
1517
assert target_frontend == frontend
1618

1719

0 commit comments

Comments
 (0)