Skip to content

Commit c6fe4ae

Browse files
authored
Clean up deprecations warnings in test suite and add parallel workers to CircleCI (#177)
1 parent 5f9799e commit c6fe4ae

File tree

11 files changed

+71
-38
lines changed

11 files changed

+71
-38
lines changed

.circleci/config.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ jobs:
99
docker:
1010
- image: python:3.6-buster
1111
resource_class: small
12+
parallelism: 6
1213
steps:
1314
- checkout # checkout source code to working directory
1415
- run:
@@ -43,7 +44,9 @@ jobs:
4344
name: Pytest Test Cases
4445
command: | # Run test suite, uses NUCLEUS_TEST_API_KEY env variable
4546
mkdir test_results
46-
poetry run coverage run --include=nucleus/* -m pytest -s -v --junitxml=test_results/junit.xml
47+
set -e
48+
TEST_FILES=$(circleci tests glob "tests/**/test_*.py" | circleci tests split --split-by=timings)
49+
poetry run coverage run --include=nucleus/* -m pytest -s -v --junitxml=test_results/junit.xml $TEST_FILES
4750
poetry run coverage report
4851
poetry run coverage html
4952
- store_test_results:

nucleus/__init__.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,16 @@ def dataitem_ref_id(self, dataset_id: str, reference_id: str):
683683

684684
@deprecated("Prefer calling Dataset.predictions_refloc instead.")
685685
@sanitize_string_args
686-
def predictions_ref_id(self, model_run_id: str, ref_id: str):
686+
def predictions_ref_id(
687+
self, model_run_id: str, ref_id: str, dataset_id: Optional[str] = None
688+
):
689+
if dataset_id:
690+
raise RuntimeError(
691+
"Need to pass a dataset id. Or use Dataset.predictions_refloc."
692+
)
687693
# TODO: deprecate ModelRun
688-
return self.make_request(
689-
{}, f"modelRun/{model_run_id}/refloc/{ref_id}", requests.get
690-
)
694+
m_run = self.get_model_run(model_run_id, dataset_id)
695+
return m_run.refloc(ref_id)
691696

692697
@deprecated("Prefer calling Dataset.iloc instead.")
693698
def dataitem_iloc(self, dataset_id: str, i: int):
@@ -720,10 +725,8 @@ def predictions_loc(self, model_run_id: str, dataset_item_id: str):
720725
@deprecated("Prefer calling Dataset.create_slice instead.")
721726
def create_slice(self, dataset_id: str, payload: dict) -> Slice:
722727
# TODO: deprecate in favor of Dataset.create_slice
723-
response = self.make_request(
724-
payload, f"dataset/{dataset_id}/create_slice"
725-
)
726-
return Slice(response[SLICE_ID_KEY], self)
728+
dataset = self.get_dataset(dataset_id)
729+
return dataset.create_slice(payload["name"], payload["reference_ids"])
727730

728731
def get_slice(self, slice_id: str) -> Slice:
729732
# TODO: migrate to Dataset method and deprecate
@@ -839,13 +842,9 @@ def create_custom_index(
839842
self, dataset_id: str, embeddings_urls: list, embedding_dim: int
840843
):
841844
# TODO: deprecate in favor of Dataset.create_custom_index invocation
842-
return self.make_request(
843-
{
844-
EMBEDDINGS_URL_KEY: embeddings_urls,
845-
EMBEDDING_DIMENSION_KEY: embedding_dim,
846-
},
847-
f"indexing/{dataset_id}",
848-
requests_command=requests.post,
845+
dataset = self.get_dataset(dataset_id)
846+
return dataset.create_custom_index(
847+
embeddings_urls=embeddings_urls, embedding_dim=embedding_dim
849848
)
850849

851850
@deprecated("Prefer calling Dataset.delete_custom_index instead.")
@@ -891,6 +890,18 @@ def create_object_index(
891890
requests_command=requests.post,
892891
)
893892

893+
def delete(self, route: str):
894+
return self._connection.delete(route)
895+
896+
def get(self, route: str):
897+
return self._connection.get(route)
898+
899+
def post(self, payload: dict, route: str):
900+
return self._connection.post(payload, route)
901+
902+
def put(self, payload: dict, route: str):
903+
return self._connection.put(payload, route)
904+
894905
# TODO: Fix return type, can be a list as well. Brings on a lot of mypy errors ...
895906
def make_request(
896907
self,

nucleus/dataset.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@
3737
DATASET_ID_KEY,
3838
DATASET_IS_SCENE_KEY,
3939
DEFAULT_ANNOTATION_UPDATE_MODE,
40+
EMBEDDING_DIMENSION_KEY,
41+
EMBEDDINGS_URL_KEY,
4042
EXPORTED_ROWS,
4143
KEEP_HISTORY_KEY,
4244
MESSAGE_KEY,
4345
NAME_KEY,
4446
REFERENCE_IDS_KEY,
4547
REQUEST_ID_KEY,
48+
SLICE_ID_KEY,
4649
UPDATE_KEY,
4750
)
4851
from .data_transfer_object.dataset_info import DatasetInfo
@@ -61,6 +64,7 @@
6164
construct_taxonomy_payload,
6265
)
6366
from .scene import LidarScene, Scene, check_all_scene_paths_remote
67+
from .slice import Slice
6468
from .upload_response import UploadResponse
6569

6670
# TODO: refactor to reduce this file to under 1000 lines.
@@ -502,7 +506,7 @@ def append(
502506
asynchronous
503507
), "In order to avoid timeouts, you must set asynchronous=True when uploading scenes."
504508

505-
return self.append_scenes(scenes, update, asynchronous)
509+
return self._append_scenes(scenes, update, asynchronous)
506510

507511
check_for_duplicate_reference_ids(dataset_items)
508512

@@ -537,6 +541,14 @@ def append_scenes(
537541
scenes: List[LidarScene],
538542
update: Optional[bool] = False,
539543
asynchronous: Optional[bool] = False,
544+
) -> Union[dict, AsyncJob]:
545+
return self._append_scenes(scenes, update, asynchronous)
546+
547+
def _append_scenes(
548+
self,
549+
scenes: List[LidarScene],
550+
update: Optional[bool] = False,
551+
asynchronous: Optional[bool] = False,
540552
) -> Union[dict, AsyncJob]:
541553
# TODO: make private in favor of Dataset.append invocation
542554
if not self.is_scene:
@@ -682,7 +694,7 @@ def create_slice(
682694
self,
683695
name: str,
684696
reference_ids: List[str],
685-
):
697+
) -> Slice:
686698
"""Creates a :class:`Slice` of dataset items within a dataset.
687699
688700
Parameters:
@@ -692,9 +704,11 @@ def create_slice(
692704
Returns:
693705
:class:`Slice`: The newly constructed slice item.
694706
"""
695-
return self._client.create_slice(
696-
self.id, {NAME_KEY: name, REFERENCE_IDS_KEY: reference_ids}
707+
payload = {NAME_KEY: name, REFERENCE_IDS_KEY: reference_ids}
708+
response = self._client.make_request(
709+
payload, f"dataset/{self.id}/create_slice"
697710
)
711+
return Slice(response[SLICE_ID_KEY], self._client)
698712

699713
@sanitize_string_args
700714
def delete_item(self, reference_id: str) -> dict:
@@ -785,12 +799,15 @@ def create_custom_index(
785799
Returns:
786800
:class:`AsyncJob`: Asynchronous job object to track processing status.
787801
"""
802+
res = self._client.post(
803+
{
804+
EMBEDDINGS_URL_KEY: embeddings_urls,
805+
EMBEDDING_DIMENSION_KEY: embedding_dim,
806+
},
807+
f"indexing/{self.id}",
808+
)
788809
return AsyncJob.from_json(
789-
self._client.create_custom_index(
790-
self.id,
791-
embeddings_urls,
792-
embedding_dim,
793-
),
810+
res,
794811
self._client,
795812
)
796813

nucleus/model_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def refloc(self, reference_id: str):
161161
:param reference_id: reference_id of a dataset item.
162162
:return: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
163163
"""
164-
response = self._client.predictions_ref_id(
165-
self.model_run_id, reference_id
164+
response = self._client.get(
165+
f"modelRun/{self.model_run_id}/refloc/{reference_id}"
166166
)
167167
return format_prediction_response(response)
168168

nucleus/url_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import urllib.request
2+
from functools import wraps
23

34

45
def sanitize_field(field):
@@ -8,6 +9,7 @@ def sanitize_field(field):
89
def sanitize_string_args(function):
910
"""Helper decorator that ensures that all arguments passed are url-safe."""
1011

12+
@wraps(function)
1113
def sanitized_function(*args, **kwargs):
1214
sanitized_args = []
1315
sanitized_kwargs = {}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Sphinx = "^4.2.0"
5757
sphinx-autobuild = "^2021.3.14"
5858
furo = "^2021.10.9"
5959
sphinx-autoapi = "^1.8.4"
60+
pytest-xdist = "^2.5.0"
6061

6162
[tool.pytest.ini_options]
6263
markers = [

scripts/load_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def upload_annotations(dataset: Dataset):
186186

187187

188188
def upload_predictions(dataset: Dataset):
189-
model = client().add_model(
189+
model = client().create_model(
190190
name="Load test model", reference_id="model_" + str(time.time())
191191
)
192192
run = model.create_run(

tests/test_autocurate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def model_run(CLIENT):
3232

3333
assert ERROR_PAYLOAD not in response.json()
3434

35-
model = CLIENT.add_model(
35+
model = CLIENT.create_model(
3636
name=TEST_MODEL_NAME, reference_id="model_" + str(time.time())
3737
)
3838

tests/test_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TEST_MODEL_RUN,
2828
TEST_PREDS,
2929
assert_box_prediction_matches_dict,
30+
get_uuid,
3031
)
3132

3233

@@ -55,12 +56,12 @@ def test_repr(test_object: any):
5556

5657

5758
def test_model_creation_and_listing(CLIENT, dataset):
58-
models_before = CLIENT.models
59-
6059
model_reference = "model_" + str(time.time())
6160
# Creation
62-
model = CLIENT.add_model(TEST_MODEL_NAME, model_reference)
63-
m_run = model.create_run(TEST_MODEL_RUN, dataset, TEST_PREDS)
61+
model_name = TEST_MODEL_NAME + get_uuid()
62+
model = CLIENT.create_model(model_name, model_reference)
63+
model_run = TEST_MODEL_RUN + get_uuid()
64+
m_run = model.create_run(model_run, dataset, TEST_PREDS)
6465
m_run.commit()
6566

6667
assert isinstance(model, Model)
@@ -74,20 +75,18 @@ def test_model_creation_and_listing(CLIENT, dataset):
7475
assert m == model
7576

7677
assert model in ms
77-
assert list(set(ms) - set(models_before))[0] == model
7878

7979
# Delete the model
8080
CLIENT.delete_model(model.id)
8181
ms = CLIENT.models
8282

8383
assert model not in ms
84-
assert ms == models_before
8584

8685

8786
# Until we fully remove the other endpoints (and then migrate those tests) just quickly test the basics of the new ones since they are basically just simple wrappers around the old ones.
8887
def test_new_model_endpoints(CLIENT, dataset: Dataset):
8988
model_reference = "model_" + str(time.time())
90-
model = CLIENT.add_model(TEST_MODEL_NAME, model_reference)
89+
model = CLIENT.create_model(TEST_MODEL_NAME, model_reference)
9190
predictions = [BoxPrediction(**TEST_BOX_PREDICTIONS[0])]
9291

9392
dataset.upload_predictions(model, predictions=predictions)

tests/test_prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def model_run(CLIENT):
8383
[f"[Pytest] Category Label ${i}" for i in range((len(TEST_IMG_URLS)))],
8484
)
8585

86-
model = CLIENT.add_model(
86+
model = CLIENT.create_model(
8787
name=TEST_MODEL_NAME, reference_id="model_" + str(time.time())
8888
)
8989

tests/test_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_slice_create_and_delete_and_list(dataset):
6262

6363
dataset_slices = dataset.slices
6464
assert len(dataset_slices) == 1
65-
assert slc.slice_id == dataset_slices[0]
65+
assert slc.id == dataset_slices[0]
6666

6767
response = slc.info()
6868
assert response["name"] == TEST_SLICE_NAME

0 commit comments

Comments
 (0)