Skip to content

Commit a6bc869

Browse files
authored
Merge pull request #183 from Labelbox/ms/mea-tests
MEA tests
2 parents 138e2bf + f7fd369 commit a6bc869

File tree

7 files changed

+269
-67
lines changed

7 files changed

+269
-67
lines changed

labelbox/schema/annotation_import.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def inputs(self) -> List[Dict[str, Any]]:
3838
"""
3939
Inputs for each individual annotation uploaded.
4040
This should match the ndjson annotations that you have uploaded.
41-
4241
Returns:
4342
Uploaded ndjson.
44-
4543
* This information will expire after 24 hours.
4644
"""
4745
return self._fetch_remote_ndjson(self.input_file_url)
@@ -50,11 +48,9 @@ def inputs(self) -> List[Dict[str, Any]]:
5048
def errors(self) -> List[Dict[str, Any]]:
5149
"""
5250
Errors for each individual annotation uploaded. This is a subset of statuses
53-
5451
Returns:
5552
List of dicts containing error messages. Empty list means there were no errors
5653
See `AnnotationImport.statuses` for more details.
57-
5854
* This information will expire after 24 hours.
5955
"""
6056
self.wait_until_done()
@@ -64,38 +60,32 @@ def errors(self) -> List[Dict[str, Any]]:
6460
def statuses(self) -> List[Dict[str, Any]]:
6561
"""
6662
Status for each individual annotation uploaded.
67-
6863
Returns:
6964
A status for each annotation if the upload is done running.
7065
See below table for more details
71-
7266
.. list-table::
7367
:widths: 15 150
7468
:header-rows: 1
75-
7669
* - Field
7770
- Description
7871
* - uuid
7972
- Specifies the annotation for the status row.
8073
* - dataRow
8174
- JSON object containing the Labelbox data row ID for the annotation.
8275
* - status
83-
- Indicates SUCCESS or FAILURE.
76+
- Indicates SUCCESS or FAILURE.
8477
* - errors
85-
- An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
86-
78+
- An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
8779
* This information will expire after 24 hours.
8880
"""
8981
self.wait_until_done()
9082
return self._fetch_remote_ndjson(self.status_file_url)
9183

9284
def wait_until_done(self, sleep_time_seconds: int = 10) -> None:
9385
"""Blocks import job until certain conditions are met.
94-
9586
Blocks until the AnnotationImport.state changes either to
9687
`AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`,
9788
periodically refreshing object's state.
98-
9989
Args:
10090
sleep_time_seconds (str): a time to block between subsequent API calls
10191
"""
@@ -117,7 +107,6 @@ def __backoff_refresh(self) -> None:
117107
def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
118108
"""
119109
Fetches the remote ndjson file and caches the results.
120-
121110
Args:
122111
url (str): Can be any url pointing to an ndjson file.
123112
Returns:
@@ -132,6 +121,7 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
132121

133122
@classmethod
134123
def _build_import_predictions_query(cls, file_args: str, vars: str):
124+
cls.validate_cls()
135125
query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, $predictionType : PredictionType!, %s) {
136126
createAnnotationImport(data: {
137127
%s : $parent_id
@@ -148,9 +138,17 @@ def _build_import_predictions_query(cls, file_args: str, vars: str):
148138
return query_str
149139

150140
@classmethod
151-
def _from_name(cls, client, parent_id, name: str, raw=False):
152-
query_str = """query
153-
getImportPyApi($parent_id : ID!, $name: String!) {
141+
def validate_cls(cls):
142+
supported_base_classes = {MALPredictionImport, MEAPredictionImport}
143+
if cls not in {MALPredictionImport, MEAPredictionImport}:
144+
raise TypeError(
145+
f"Can't directly use the base AnnotationImport class. Must use one of {supported_base_classes}"
146+
)
147+
148+
@classmethod
149+
def from_name(cls, client, parent_id, name: str, raw=False):
150+
cls.validate_cls()
151+
query_str = """query getImportPyApi($parent_id : ID!, $name: String!) {
154152
annotationImport(
155153
where: {%s: $parent_id, name: $name}){
156154
__typename
@@ -194,10 +192,10 @@ def refresh(self) -> None:
194192
"""Synchronizes values of all fields with the database.
195193
"""
196194
cls = type(self)
197-
res = cls._from_name(self.client,
198-
self.get_parent_id(),
199-
self.name,
200-
raw=True)
195+
res = cls.from_name(self.client,
196+
self.get_parent_id(),
197+
self.name,
198+
raw=True)
201199
self._set_field_values(res)
202200

203201
@classmethod

labelbox/schema/model_run.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class ModelRun(DbObject):
1616
created_by_id = Field.String("created_by_id", "createdBy")
1717
model_id = Field.String("model_id")
1818

19+
def __init__(self, client, field_values):
20+
field_values['createdBy'] = uuid_to_cuid(field_values['createdBy'])
21+
super().__init__(client, field_values)
22+
1923
def upsert_labels(self, label_ids):
2024

2125
if len(label_ids) < 1:
@@ -32,38 +36,34 @@ def upsert_labels(self, label_ids):
3236
return True
3337

3438
def add_predictions(
35-
self,
36-
name: str,
37-
annotations: Union[str, Path, Iterable[Dict]],
38-
validate: bool = True) -> 'MEAPredictionImport': # type: ignore
39-
""" Uploads annotations to a new Editor project.
39+
self,
40+
name: str,
41+
predictions: Union[str, Path, Iterable[Dict]],
42+
) -> 'MEAPredictionImport': # type: ignore
43+
""" Uploads predictions to a new Editor project.
4044
Args:
4145
name (str): name of the AnnotationImport job
42-
annotations (str or Path or Iterable):
46+
predictions (str or Path or Iterable):
4347
url that is publicly accessible by Labelbox containing an
4448
ndjson file
4549
OR local path to an ndjson file
4650
OR iterable of annotation rows
47-
validate (bool):
48-
Whether or not to validate the payload before uploading.
4951
Returns:
5052
AnnotationImport
5153
"""
52-
kwargs = dict(client=self.client,
53-
model_run_id=self.uid,
54-
name=name,
55-
predictions=annotations)
56-
if isinstance(annotations, str) or isinstance(annotations, Path):
57-
return MEAPredictionImport.create_from_file(**kwargs)
58-
elif isinstance(annotations, Iterable):
59-
return MEAPredictionImport.create_from_objects(**kwargs)
54+
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
55+
if isinstance(predictions, str) or isinstance(predictions, Path):
56+
return MEAPredictionImport.create_from_file(path=predictions,
57+
**kwargs)
58+
elif isinstance(predictions, Iterable):
59+
return MEAPredictionImport.create_from_objects(
60+
predictions=predictions, **kwargs)
6061
else:
6162
raise ValueError(
62-
f'Invalid annotations given of type: {type(annotations)}')
63+
f'Invalid predictions given of type: {type(predictions)}')
6364

6465
def annotation_groups(self):
65-
query_str = """
66-
query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
66+
query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
6767
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
6868
{nodes{%s},pageInfo{endCursor}}
6969
}
@@ -80,9 +80,8 @@ class AnnotationGroup(DbObject):
8080
model_run_id = Field.String("model_run_id")
8181
data_row = Relationship.ToOne("DataRow", False, cache=True)
8282

83-
def __init__(self, client, model_id, field_values):
84-
field_values['labelId'] = uuid_to_cuid(field_values['labelId'])
85-
super().__init__(client, field_values)
83+
def __init__(self, client, model_id, *args, **kwargs):
84+
super().__init__(client, *args, **kwargs)
8685
self.model_id = model_id
8786

8887
@property

tests/integration/bulk_import/conftest.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import uuid
22
import pytest
3+
import time
34

45
from labelbox.schema.labeling_frontend import LabelingFrontend
6+
from labelbox.schema.annotation_import import MALPredictionImport
57

68
IMG_URL = "https://picsum.photos/200/300"
79

@@ -100,7 +102,9 @@ def ontology():
100102

101103

102104
@pytest.fixture
103-
def configured_project(client, project, ontology, dataset):
105+
def configured_project(client, ontology, rand_gen):
106+
project = client.create_project(name=rand_gen(str))
107+
dataset = client.create_dataset(name=rand_gen(str), projects=project)
104108
editor = list(
105109
client.get_labeling_frontends(
106110
where=LabelingFrontend.name == "editor"))[0]
@@ -109,13 +113,14 @@ def configured_project(client, project, ontology, dataset):
109113
dataset.create_data_row(row_data=IMG_URL)
110114
project.datasets.connect(dataset)
111115
yield project
116+
project.delete()
117+
dataset.delete()
112118

113119

114120
@pytest.fixture
115121
def prediction_id_mapping(configured_project):
116122
#Maps tool types to feature schema ids
117123
ontology = configured_project.ontology().normalized
118-
inferences = []
119124
datarows = [d for d in list(configured_project.datasets())[0].data_rows()]
120125
result = {}
121126

@@ -266,11 +271,54 @@ def video_checklist_inference(prediction_id_mapping):
266271

267272

268273
@pytest.fixture
269-
def predictions(polygon_inference, rectangle_inference, line_inference,
270-
entity_inference, segmentation_inference, checklist_inference,
271-
text_inference):
274+
def model_run_predictions(polygon_inference, rectangle_inference,
275+
line_inference):
276+
# Not supporting mask since there isn't a signed url representing a seg mask to upload
277+
return [polygon_inference, rectangle_inference, line_inference]
278+
279+
280+
@pytest.fixture
281+
def object_predictions(polygon_inference, rectangle_inference, line_inference,
282+
entity_inference, segmentation_inference):
272283
return [
273284
polygon_inference, rectangle_inference, line_inference,
274-
entity_inference, segmentation_inference, checklist_inference,
275-
text_inference
285+
entity_inference, segmentation_inference
276286
]
287+
288+
289+
@pytest.fixture
290+
def classification_predictions(checklist_inference, text_inference):
291+
return [checklist_inference, text_inference]
292+
293+
294+
@pytest.fixture
295+
def predictions(object_predictions, classification_predictions):
296+
return object_predictions + classification_predictions
297+
298+
299+
@pytest.fixture
300+
def model_run(client, rand_gen, configured_project, annotation_submit_fn,
301+
model_run_predictions):
302+
configured_project.enable_model_assisted_labeling()
303+
ontology = configured_project.ontology()
304+
305+
upload_task = MALPredictionImport.create_from_objects(
306+
client, configured_project.uid, f'mal-import-{uuid.uuid4()}',
307+
model_run_predictions)
308+
upload_task.wait_until_done()
309+
310+
for data_row_id in {x['dataRow']['id'] for x in model_run_predictions}:
311+
annotation_submit_fn(configured_project.uid, data_row_id)
312+
313+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
314+
model = client.create_model(data["name"], data["ontology_id"])
315+
name = rand_gen(str)
316+
model_run_s = model.create_model_run(name)
317+
318+
time.sleep(3)
319+
model_run_s.upsert_labels(
320+
[label.uid for label in configured_project.labels()])
321+
time.sleep(3)
322+
323+
yield model_run_s
324+
# TODO: Delete resources when that is possible ..
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import uuid
2+
import ndjson
3+
import pytest
4+
import requests
5+
6+
from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport
7+
"""
8+
- Here we only want to check that the uploads are calling the validation
9+
- Then with unit tests we can check the types of errors raised
10+
11+
"""
12+
13+
14+
def check_running_state(req, name, url=None):
15+
assert req.name == name
16+
if url is not None:
17+
assert req.input_file_url == url
18+
assert req.error_file_url is None
19+
assert req.status_file_url is None
20+
assert req.state == AnnotationImportState.RUNNING
21+
22+
23+
def test_create_from_url(model_run):
24+
name = str(uuid.uuid4())
25+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
26+
annotation_import = model_run.add_predictions(name=name, predictions=url)
27+
assert annotation_import.model_run_id == model_run.uid
28+
check_running_state(annotation_import, name, url)
29+
30+
31+
def test_create_from_objects(model_run, object_predictions):
32+
name = str(uuid.uuid4())
33+
34+
annotation_import = model_run.add_predictions(
35+
name=name, predictions=object_predictions)
36+
37+
assert annotation_import.model_run_id == model_run.uid
38+
check_running_state(annotation_import, name)
39+
assert_file_content(annotation_import.input_file_url, object_predictions)
40+
41+
42+
def test_create_from_local_file(tmp_path, model_run, object_predictions):
43+
name = str(uuid.uuid4())
44+
file_name = f"{name}.ndjson"
45+
file_path = tmp_path / file_name
46+
with file_path.open("w") as f:
47+
ndjson.dump(object_predictions, f)
48+
49+
annotation_import = model_run.add_predictions(name=name,
50+
predictions=str(file_path))
51+
52+
assert annotation_import.model_run_id == model_run.uid
53+
check_running_state(annotation_import, name)
54+
assert_file_content(annotation_import.input_file_url, object_predictions)
55+
56+
57+
def test_get(client, model_run):
58+
name = str(uuid.uuid4())
59+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
60+
model_run.add_predictions(name=name, predictions=url)
61+
62+
annotation_import = MEAPredictionImport.from_name(client,
63+
parent_id=model_run.uid,
64+
name=name)
65+
66+
assert annotation_import.model_run_id == model_run.uid
67+
check_running_state(annotation_import, name, url)
68+
69+
70+
@pytest.mark.slow
71+
def test_wait_till_done(model_run_predictions, model_run):
72+
name = str(uuid.uuid4())
73+
annotation_import = model_run.add_predictions(
74+
name=name, predictions=model_run_predictions)
75+
76+
assert len(annotation_import.inputs) == len(model_run_predictions)
77+
annotation_import.wait_until_done()
78+
assert annotation_import.state == AnnotationImportState.FINISHED
79+
# Check that the status files are being returned as expected
80+
assert len(annotation_import.errors) == 0
81+
assert len(annotation_import.inputs) == len(model_run_predictions)
82+
input_uuids = [
83+
input_annot['uuid'] for input_annot in annotation_import.inputs
84+
]
85+
inference_uuids = [pred['uuid'] for pred in model_run_predictions]
86+
assert set(input_uuids) == set(inference_uuids)
87+
assert len(annotation_import.statuses) == len(model_run_predictions)
88+
for status in annotation_import.statuses:
89+
assert status['status'] == 'SUCCESS'
90+
status_uuids = [
91+
input_annot['uuid'] for input_annot in annotation_import.statuses
92+
]
93+
assert set(input_uuids) == set(status_uuids)
94+
95+
96+
def assert_file_content(url: str, predictions):
97+
response = requests.get(url)
98+
assert response.text == ndjson.dumps(predictions)

0 commit comments

Comments
 (0)