Skip to content

Commit f7fd369

Browse files
author
Matt Sokoloff
committed
add mea import tests
1 parent 82a00db commit f7fd369

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import uuid
2+
import ndjson
3+
import pytest
4+
import requests
5+
6+
from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport
7+
"""
8+
- Here we only want to check that the uploads are calling the validation
9+
- Then with unit tests we can check the types of errors raised
10+
11+
"""
12+
13+
14+
def check_running_state(req, name, url=None):
15+
assert req.name == name
16+
if url is not None:
17+
assert req.input_file_url == url
18+
assert req.error_file_url is None
19+
assert req.status_file_url is None
20+
assert req.state == AnnotationImportState.RUNNING
21+
22+
23+
def test_create_from_url(model_run):
24+
name = str(uuid.uuid4())
25+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
26+
annotation_import = model_run.add_predictions(name=name, predictions=url)
27+
assert annotation_import.model_run_id == model_run.uid
28+
check_running_state(annotation_import, name, url)
29+
30+
31+
def test_create_from_objects(model_run, object_predictions):
32+
name = str(uuid.uuid4())
33+
34+
annotation_import = model_run.add_predictions(
35+
name=name, predictions=object_predictions)
36+
37+
assert annotation_import.model_run_id == model_run.uid
38+
check_running_state(annotation_import, name)
39+
assert_file_content(annotation_import.input_file_url, object_predictions)
40+
41+
42+
def test_create_from_local_file(tmp_path, model_run, object_predictions):
43+
name = str(uuid.uuid4())
44+
file_name = f"{name}.ndjson"
45+
file_path = tmp_path / file_name
46+
with file_path.open("w") as f:
47+
ndjson.dump(object_predictions, f)
48+
49+
annotation_import = model_run.add_predictions(name=name,
50+
predictions=str(file_path))
51+
52+
assert annotation_import.model_run_id == model_run.uid
53+
check_running_state(annotation_import, name)
54+
assert_file_content(annotation_import.input_file_url, object_predictions)
55+
56+
57+
def test_get(client, model_run):
58+
name = str(uuid.uuid4())
59+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
60+
model_run.add_predictions(name=name, predictions=url)
61+
62+
annotation_import = MEAPredictionImport.from_name(client,
63+
parent_id=model_run.uid,
64+
name=name)
65+
66+
assert annotation_import.model_run_id == model_run.uid
67+
check_running_state(annotation_import, name, url)
68+
69+
70+
@pytest.mark.slow
71+
def test_wait_till_done(model_run_predictions, model_run):
72+
name = str(uuid.uuid4())
73+
annotation_import = model_run.add_predictions(
74+
name=name, predictions=model_run_predictions)
75+
76+
assert len(annotation_import.inputs) == len(model_run_predictions)
77+
annotation_import.wait_until_done()
78+
assert annotation_import.state == AnnotationImportState.FINISHED
79+
# Check that the status files are being returned as expected
80+
assert len(annotation_import.errors) == 0
81+
assert len(annotation_import.inputs) == len(model_run_predictions)
82+
input_uuids = [
83+
input_annot['uuid'] for input_annot in annotation_import.inputs
84+
]
85+
inference_uuids = [pred['uuid'] for pred in model_run_predictions]
86+
assert set(input_uuids) == set(inference_uuids)
87+
assert len(annotation_import.statuses) == len(model_run_predictions)
88+
for status in annotation_import.statuses:
89+
assert status['status'] == 'SUCCESS'
90+
status_uuids = [
91+
input_annot['uuid'] for input_annot in annotation_import.statuses
92+
]
93+
assert set(input_uuids) == set(status_uuids)
94+
95+
96+
def assert_file_content(url: str, predictions):
97+
response = requests.get(url)
98+
assert response.text == ndjson.dumps(predictions)

0 commit comments

Comments
 (0)