Skip to content

Commit 14d3971

Browse files
Move fixtures & test cases
1 parent 49b03c1 commit 14d3971

File tree

8 files changed

+86
-70
lines changed

8 files changed

+86
-70
lines changed

tests/integration/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def get_project_invites(client, project_id):
8585
query_str, {id_param: project_id},
8686
['project', 'invites', 'nodes'],
8787
Invite,
88-
cursor_path=['project', 'invites', 'nextCursor'],
88+
cursor_path=['project',
89+
'invites', 'nextCursor'],
8990
experimental=True)
9091

9192

@@ -134,7 +135,9 @@ def client(environ: str):
134135

135136

136137
@pytest.fixture(scope="session")
137-
def image_url(client):
138+
def image_url(client, environ: str):
139+
if environ == Environ.LOCAL:
140+
return IMG_URL
138141
return client.upload_data(requests.get(IMG_URL).content, sign=True)
139142

140143

tests/integration/bulk_import/conftest.py renamed to tests/integration/mal_and_mea/conftest.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def configured_project(client, ontology, rand_gen, image_url):
120120

121121
@pytest.fixture
122122
def prediction_id_mapping(configured_project):
123-
#Maps tool types to feature schema ids
123+
# Maps tool types to feature schema ids
124124
ontology = configured_project.ontology().normalized
125125
result = {}
126126

@@ -179,7 +179,7 @@ def rectangle_inference(prediction_id_mapping):
179179
"schemaId":
180180
rectangle['tool']['classifications'][0]['options'][0]
181181
['featureSchemaId']
182-
}
182+
}
183183
}]
184184
})
185185
del rectangle['tool']
@@ -297,11 +297,23 @@ def predictions(object_predictions, classification_predictions):
297297

298298

299299
@pytest.fixture
300-
def model_run(client, rand_gen, configured_project, annotation_submit_fn,
301-
model_run_predictions):
302-
configured_project.enable_model_assisted_labeling()
300+
def model(client, rand_gen, configured_project):
303301
ontology = configured_project.ontology()
304302

303+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
304+
return client.create_model(data["name"], data["ontology_id"])
305+
306+
307+
@pytest.fixture
308+
def model_run(rand_gen, model):
309+
name = rand_gen(str)
310+
return model.create_model_run(name)
311+
312+
313+
@pytest.fixture
314+
def model_run_annotation_groups(client, configured_project, annotation_submit_fn, model_run_predictions, model_run):
315+
configured_project.enable_model_assisted_labeling()
316+
305317
upload_task = MALPredictionImport.create_from_objects(
306318
client, configured_project.uid, f'mal-import-{uuid.uuid4()}',
307319
model_run_predictions)
@@ -310,15 +322,10 @@ def model_run(client, rand_gen, configured_project, annotation_submit_fn,
310322
for data_row_id in {x['dataRow']['id'] for x in model_run_predictions}:
311323
annotation_submit_fn(configured_project.uid, data_row_id)
312324

313-
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
314-
model = client.create_model(data["name"], data["ontology_id"])
315-
name = rand_gen(str)
316-
model_run_s = model.create_model_run(name)
317-
318325
time.sleep(3)
319326
labels = configured_project.export_labels(download=True)
320-
model_run_s.upsert_labels([label['ID'] for label in labels])
327+
model_run.upsert_labels([label['ID'] for label in labels])
321328
time.sleep(3)
322329

323-
yield model_run_s
330+
yield model_run
324331
# TODO: Delete resources when that is possible ..

tests/integration/bulk_import/test_mea_annotation_import.py renamed to tests/integration/mal_and_mea/test_mea_annotation_import.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,56 +20,57 @@ def check_running_state(req, name, url=None):
2020
assert req.state == AnnotationImportState.RUNNING
2121

2222

23-
def test_create_from_url(model_run):
23+
def test_create_from_url(model_run_annotation_groups):
2424
name = str(uuid.uuid4())
2525
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
26-
annotation_import = model_run.add_predictions(name=name, predictions=url)
27-
assert annotation_import.model_run_id == model_run.uid
26+
annotation_import = model_run_annotation_groups.add_predictions(
27+
name=name, predictions=url)
28+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
2829
check_running_state(annotation_import, name, url)
2930

3031

31-
def test_create_from_objects(model_run, object_predictions):
32+
def test_create_from_objects(model_run_annotation_groups, object_predictions):
3233
name = str(uuid.uuid4())
3334

34-
annotation_import = model_run.add_predictions(
35+
annotation_import = model_run_annotation_groups.add_predictions(
3536
name=name, predictions=object_predictions)
3637

37-
assert annotation_import.model_run_id == model_run.uid
38+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
3839
check_running_state(annotation_import, name)
3940
assert_file_content(annotation_import.input_file_url, object_predictions)
4041

4142

42-
def test_create_from_local_file(tmp_path, model_run, object_predictions):
43+
def test_create_from_local_file(tmp_path, model_run_annotation_groups, object_predictions):
4344
name = str(uuid.uuid4())
4445
file_name = f"{name}.ndjson"
4546
file_path = tmp_path / file_name
4647
with file_path.open("w") as f:
4748
ndjson.dump(object_predictions, f)
4849

49-
annotation_import = model_run.add_predictions(name=name,
50-
predictions=str(file_path))
50+
annotation_import = model_run_annotation_groups.add_predictions(name=name,
51+
predictions=str(file_path))
5152

52-
assert annotation_import.model_run_id == model_run.uid
53+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
5354
check_running_state(annotation_import, name)
5455
assert_file_content(annotation_import.input_file_url, object_predictions)
5556

5657

57-
def test_get(client, model_run):
58+
def test_get(client, model_run_annotation_groups):
5859
name = str(uuid.uuid4())
5960
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
60-
model_run.add_predictions(name=name, predictions=url)
61+
model_run_annotation_groups.add_predictions(name=name, predictions=url)
6162

6263
annotation_import = MEAPredictionImport.from_name(
63-
client, model_run_id=model_run.uid, name=name)
64+
client, model_run_id=model_run_annotation_groups.uid, name=name)
6465

65-
assert annotation_import.model_run_id == model_run.uid
66+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
6667
check_running_state(annotation_import, name, url)
6768

6869

6970
@pytest.mark.slow
70-
def test_wait_till_done(model_run_predictions, model_run):
71+
def test_wait_till_done(model_run_predictions, model_run_annotation_groups):
7172
name = str(uuid.uuid4())
72-
annotation_import = model_run.add_predictions(
73+
annotation_import = model_run_annotation_groups.add_predictions(
7374
name=name, predictions=model_run_predictions)
7475

7576
assert len(annotation_import.inputs) == len(model_run_predictions)

tests/integration/test_model.py renamed to tests/integration/mal_and_mea/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_model(client, configured_project, rand_gen):
2020
assert model.name == data["name"]
2121

2222

23-
def test_model_delete(client):
23+
def test_model_delete(client, model):
2424
before = list(client.get_models())
2525

2626
model = before[0]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import time
2+
3+
4+
def test_model_run(client, configured_project_with_label, rand_gen):
5+
project = configured_project_with_label
6+
ontology = project.ontology()
7+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
8+
model = client.create_model(data["name"], data["ontology_id"])
9+
10+
name = rand_gen(str)
11+
model_run = model.create_model_run(name)
12+
assert model_run.name == name
13+
assert model_run.model_id == model.uid
14+
assert model_run.created_by_id == client.get_user().uid
15+
16+
label = project.export_labels(download=True)[0]
17+
model_run.upsert_labels([label['ID']])
18+
time.sleep(3)
19+
20+
annotation_group = next(model_run.annotation_groups())
21+
assert annotation_group.label_id == label['ID']
22+
assert annotation_group.model_run_id == model_run.uid
23+
assert annotation_group.data_row().uid == next(
24+
next(project.datasets()).data_rows()).uid
25+
26+
27+
def test_model_run_delete(client, model_run):
28+
models_before = list(client.get_models())
29+
model_before = models_before[0]
30+
before = list(model_before.model_runs())
31+
32+
model_run = before[0]
33+
model_run.delete_model_run()
34+
35+
models_after = list(client.get_models())
36+
model_after = models_after[0]
37+
after = list(model_after.model_runs())
38+
39+
assert len(before) == len(after) + 1
40+
41+
42+
def test_model_run_delete(client, model_run_annotation_groups):
43+
# TODO
44+
pass

tests/integration/test_model_run.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)