Skip to content

Commit 82a00db

Browse files
author
Matt Sokoloff
committed
add tests for MEA
1 parent faa0bd6 commit 82a00db

File tree

6 files changed

+171
-67
lines changed

6 files changed

+171
-67
lines changed

labelbox/schema/annotation_import.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def inputs(self) -> List[Dict[str, Any]]:
3838
"""
3939
Inputs for each individual annotation uploaded.
4040
This should match the ndjson annotations that you have uploaded.
41-
4241
Returns:
4342
Uploaded ndjson.
44-
4543
* This information will expire after 24 hours.
4644
"""
4745
return self._fetch_remote_ndjson(self.input_file_url)
@@ -50,11 +48,9 @@ def inputs(self) -> List[Dict[str, Any]]:
5048
def errors(self) -> List[Dict[str, Any]]:
5149
"""
5250
Errors for each individual annotation uploaded. This is a subset of statuses
53-
5451
Returns:
5552
List of dicts containing error messages. Empty list means there were no errors
5653
See `AnnotationImport.statuses` for more details.
57-
5854
* This information will expire after 24 hours.
5955
"""
6056
self.wait_until_done()
@@ -64,38 +60,32 @@ def errors(self) -> List[Dict[str, Any]]:
6460
def statuses(self) -> List[Dict[str, Any]]:
6561
"""
6662
Status for each individual annotation uploaded.
67-
6863
Returns:
6964
A status for each annotation if the upload is done running.
7065
See below table for more details
71-
7266
.. list-table::
7367
:widths: 15 150
7468
:header-rows: 1
75-
7669
* - Field
7770
- Description
7871
* - uuid
7972
- Specifies the annotation for the status row.
8073
* - dataRow
8174
- JSON object containing the Labelbox data row ID for the annotation.
8275
* - status
83-
- Indicates SUCCESS or FAILURE.
76+
- Indicates SUCCESS or FAILURE.
8477
* - errors
85-
- An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
86-
78+
- An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
8779
* This information will expire after 24 hours.
8880
"""
8981
self.wait_until_done()
9082
return self._fetch_remote_ndjson(self.status_file_url)
9183

9284
def wait_until_done(self, sleep_time_seconds: int = 10) -> None:
9385
"""Blocks import job until certain conditions are met.
94-
9586
Blocks until the AnnotationImport.state changes either to
9687
`AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`,
9788
periodically refreshing object's state.
98-
9989
Args:
10090
sleep_time_seconds (str): a time to block between subsequent API calls
10191
"""
@@ -117,7 +107,6 @@ def __backoff_refresh(self) -> None:
117107
def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
118108
"""
119109
Fetches the remote ndjson file and caches the results.
120-
121110
Args:
122111
url (str): Can be any url pointing to an ndjson file.
123112
Returns:
@@ -132,6 +121,7 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
132121

133122
@classmethod
134123
def _build_import_predictions_query(cls, file_args: str, vars: str):
124+
cls.validate_cls()
135125
query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, $predictionType : PredictionType!, %s) {
136126
createAnnotationImport(data: {
137127
%s : $parent_id
@@ -148,9 +138,17 @@ def _build_import_predictions_query(cls, file_args: str, vars: str):
148138
return query_str
149139

150140
@classmethod
151-
def _from_name(cls, client, parent_id, name: str, raw=False):
152-
query_str = """query
153-
getImportPyApi($parent_id : ID!, $name: String!) {
141+
def validate_cls(cls):
142+
supported_base_classes = {MALPredictionImport, MEAPredictionImport}
143+
if cls not in {MALPredictionImport, MEAPredictionImport}:
144+
raise TypeError(
145+
f"Can't directly use the base AnnotationImport class. Must use one of {supported_base_classes}"
146+
)
147+
148+
@classmethod
149+
def from_name(cls, client, parent_id, name: str, raw=False):
150+
cls.validate_cls()
151+
query_str = """query getImportPyApi($parent_id : ID!, $name: String!) {
154152
annotationImport(
155153
where: {%s: $parent_id, name: $name}){
156154
__typename
@@ -194,10 +192,10 @@ def refresh(self) -> None:
194192
"""Synchronizes values of all fields with the database.
195193
"""
196194
cls = type(self)
197-
res = cls._from_name(self.client,
198-
self.get_parent_id(),
199-
self.name,
200-
raw=True)
195+
res = cls.from_name(self.client,
196+
self.get_parent_id(),
197+
self.name,
198+
raw=True)
201199
self._set_field_values(res)
202200

203201
@classmethod

labelbox/schema/model_run.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class ModelRun(DbObject):
1616
created_by_id = Field.String("created_by_id", "createdBy")
1717
model_id = Field.String("model_id")
1818

19+
def __init__(self, client, field_values):
20+
field_values['createdBy'] = uuid_to_cuid(field_values['createdBy'])
21+
super().__init__(client, field_values)
22+
1923
def upsert_labels(self, label_ids):
2024

2125
if len(label_ids) < 1:
@@ -32,38 +36,34 @@ def upsert_labels(self, label_ids):
3236
return True
3337

3438
def add_predictions(
35-
self,
36-
name: str,
37-
annotations: Union[str, Path, Iterable[Dict]],
38-
validate: bool = True) -> 'MEAPredictionImport': # type: ignore
39-
""" Uploads annotations to a new Editor project.
39+
self,
40+
name: str,
41+
predictions: Union[str, Path, Iterable[Dict]],
42+
) -> 'MEAPredictionImport': # type: ignore
43+
""" Uploads predictions to a new Editor project.
4044
Args:
4145
name (str): name of the AnnotationImport job
42-
annotations (str or Path or Iterable):
46+
predictions (str or Path or Iterable):
4347
url that is publicly accessible by Labelbox containing an
4448
ndjson file
4549
OR local path to an ndjson file
4650
OR iterable of annotation rows
47-
validate (bool):
48-
Whether or not to validate the payload before uploading.
4951
Returns:
5052
AnnotationImport
5153
"""
52-
kwargs = dict(client=self.client,
53-
model_run_id=self.uid,
54-
name=name,
55-
predictions=annotations)
56-
if isinstance(annotations, str) or isinstance(annotations, Path):
57-
return MEAPredictionImport.create_from_file(**kwargs)
58-
elif isinstance(annotations, Iterable):
59-
return MEAPredictionImport.create_from_objects(**kwargs)
54+
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
55+
if isinstance(predictions, str) or isinstance(predictions, Path):
56+
return MEAPredictionImport.create_from_file(path=predictions,
57+
**kwargs)
58+
elif isinstance(predictions, Iterable):
59+
return MEAPredictionImport.create_from_objects(
60+
predictions=predictions, **kwargs)
6061
else:
6162
raise ValueError(
62-
f'Invalid annotations given of type: {type(annotations)}')
63+
f'Invalid predictions given of type: {type(predictions)}')
6364

6465
def annotation_groups(self):
65-
query_str = """
66-
query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
66+
query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
6767
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
6868
{nodes{%s},pageInfo{endCursor}}
6969
}
@@ -80,9 +80,8 @@ class AnnotationGroup(DbObject):
8080
model_run_id = Field.String("model_run_id")
8181
data_row = Relationship.ToOne("DataRow", False, cache=True)
8282

83-
def __init__(self, client, model_id, field_values):
84-
field_values['labelId'] = uuid_to_cuid(field_values['labelId'])
85-
super().__init__(client, field_values)
83+
def __init__(self, client, model_id, *args, **kwargs):
84+
super().__init__(client, *args, **kwargs)
8685
self.model_id = model_id
8786

8887
@property

tests/integration/bulk_import/conftest.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import uuid
22
import pytest
3+
import time
34

45
from labelbox.schema.labeling_frontend import LabelingFrontend
6+
from labelbox.schema.annotation_import import MALPredictionImport
57

68
IMG_URL = "https://picsum.photos/200/300"
79

@@ -100,7 +102,9 @@ def ontology():
100102

101103

102104
@pytest.fixture
103-
def configured_project(client, project, ontology, dataset):
105+
def configured_project(client, ontology, rand_gen):
106+
project = client.create_project(name=rand_gen(str))
107+
dataset = client.create_dataset(name=rand_gen(str), projects=project)
104108
editor = list(
105109
client.get_labeling_frontends(
106110
where=LabelingFrontend.name == "editor"))[0]
@@ -109,13 +113,14 @@ def configured_project(client, project, ontology, dataset):
109113
dataset.create_data_row(row_data=IMG_URL)
110114
project.datasets.connect(dataset)
111115
yield project
116+
project.delete()
117+
dataset.delete()
112118

113119

114120
@pytest.fixture
115121
def prediction_id_mapping(configured_project):
116122
#Maps tool types to feature schema ids
117123
ontology = configured_project.ontology().normalized
118-
inferences = []
119124
datarows = [d for d in list(configured_project.datasets())[0].data_rows()]
120125
result = {}
121126

@@ -266,11 +271,54 @@ def video_checklist_inference(prediction_id_mapping):
266271

267272

268273
@pytest.fixture
269-
def predictions(polygon_inference, rectangle_inference, line_inference,
270-
entity_inference, segmentation_inference, checklist_inference,
271-
text_inference):
274+
def model_run_predictions(polygon_inference, rectangle_inference,
275+
line_inference):
276+
# Not supporting mask since there isn't a signed url representing a seg mask to upload
277+
return [polygon_inference, rectangle_inference, line_inference]
278+
279+
280+
@pytest.fixture
281+
def object_predictions(polygon_inference, rectangle_inference, line_inference,
282+
entity_inference, segmentation_inference):
272283
return [
273284
polygon_inference, rectangle_inference, line_inference,
274-
entity_inference, segmentation_inference, checklist_inference,
275-
text_inference
285+
entity_inference, segmentation_inference
276286
]
287+
288+
289+
@pytest.fixture
290+
def classification_predictions(checklist_inference, text_inference):
291+
return [checklist_inference, text_inference]
292+
293+
294+
@pytest.fixture
295+
def predictions(object_predictions, classification_predictions):
296+
return object_predictions + classification_predictions
297+
298+
299+
@pytest.fixture
300+
def model_run(client, rand_gen, configured_project, annotation_submit_fn,
301+
model_run_predictions):
302+
configured_project.enable_model_assisted_labeling()
303+
ontology = configured_project.ontology()
304+
305+
upload_task = MALPredictionImport.create_from_objects(
306+
client, configured_project.uid, f'mal-import-{uuid.uuid4()}',
307+
model_run_predictions)
308+
upload_task.wait_until_done()
309+
310+
for data_row_id in {x['dataRow']['id'] for x in model_run_predictions}:
311+
annotation_submit_fn(configured_project.uid, data_row_id)
312+
313+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
314+
model = client.create_model(data["name"], data["ontology_id"])
315+
name = rand_gen(str)
316+
model_run_s = model.create_model_run(name)
317+
318+
time.sleep(3)
319+
model_run_s.upsert_labels(
320+
[label.uid for label in configured_project.labels()])
321+
time.sleep(3)
322+
323+
yield model_run_s
324+
# TODO: Delete resources when that is possible ..

tests/integration/conftest.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from random import randint
55
from string import ascii_letters
66
from types import SimpleNamespace
7+
import uuid
78
import os
89
import re
910

@@ -16,6 +17,9 @@
1617
from labelbox import LabelingFrontend
1718
from labelbox import Client
1819

20+
from labelbox.schema.ontology import OntologyBuilder, Tool
21+
from labelbox import LabelingFrontend, MALPredictionImport
22+
1923
IMG_URL = "https://picsum.photos/200/300"
2024

2125

@@ -110,8 +114,8 @@ def __init__(self, environ: str) -> None:
110114

111115
self.queries = []
112116

113-
def execute(self, query, params=None, check_naming=True, **kwargs):
114-
if check_naming:
117+
def execute(self, query=None, params=None, check_naming=True, **kwargs):
118+
if check_naming and query is not None:
115119
assert re.match(r"(?:query|mutation) \w+PyApi", query) is not None
116120
self.queries.append((query, params))
117121
return super().execute(query, params, **kwargs)
@@ -241,3 +245,68 @@ def configured_project(project, client, rand_gen):
241245
yield project
242246
dataset.delete()
243247
project.delete()
248+
249+
250+
@pytest.fixture
251+
def annotation_submit_fn(client):
252+
253+
def submit(project_id, data_row_id):
254+
feature_result = client.execute(
255+
"""query featuresPyApi ($project_id : ID!, $datarow_id: ID!
256+
) {project(where: { id: $project_id }) {
257+
featuresForDataRow(where: {dataRow: { id: $datarow_id }}) {id}}}
258+
""", {
259+
"project_id": project_id,
260+
"datarow_id": data_row_id
261+
})
262+
features = feature_result['project']['featuresForDataRow']
263+
feature_ids = [feature['id'] for feature in features]
264+
client.execute(
265+
"""mutation createLabelPyApi ($project_id : ID!,$datarow_id: ID!,$feature_ids: [ID!]!,$time_seconds : Float!) {
266+
createLabelFromFeatures(data: {dataRow: { id: $datarow_id },project: { id: $project_id },
267+
featureIds: $feature_ids,secondsSpent: $time_seconds}) {id}}""",
268+
{
269+
"project_id": project_id,
270+
"datarow_id": data_row_id,
271+
"feature_ids": feature_ids,
272+
"time_seconds": 10
273+
})
274+
275+
return submit
276+
277+
278+
@pytest.fixture
279+
def configured_project_with_label(client, rand_gen, annotation_submit_fn):
280+
project = client.create_project(name=rand_gen(str))
281+
dataset = client.create_dataset(name=rand_gen(str), projects=project)
282+
data_row = dataset.create_data_row(row_data=IMG_URL)
283+
editor = list(
284+
project.client.get_labeling_frontends(
285+
where=LabelingFrontend.name == "editor"))[0]
286+
287+
ontology_builder = OntologyBuilder(tools=[
288+
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
289+
])
290+
project.setup(editor, ontology_builder.asdict())
291+
project.enable_model_assisted_labeling()
292+
ontology = ontology_builder.from_project(project)
293+
predictions = [{
294+
"uuid": str(uuid.uuid4()),
295+
"schemaId": ontology.tools[0].feature_schema_id,
296+
"dataRow": {
297+
"id": data_row.uid
298+
},
299+
"bbox": {
300+
"top": 20,
301+
"left": 20,
302+
"height": 50,
303+
"width": 50
304+
}
305+
}]
306+
upload_task = MALPredictionImport.create_from_objects(
307+
client, project.uid, f'mal-import-{uuid.uuid4()}', predictions)
308+
upload_task.wait_until_done()
309+
annotation_submit_fn(project.uid, data_row.uid)
310+
yield project
311+
dataset.delete()
312+
project.delete()

0 commit comments

Comments
 (0)