Skip to content

Commit 8b7c380

Browse files
authored
Allow Users to Specify Trained Slice for Model Runs (#417)
1 parent 3fa9810 commit 8b7c380

File tree

5 files changed

+40
-13
lines changed

5 files changed

+40
-13
lines changed

CHANGELOG.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [0.16.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.12) - 2023-11-29
8+
## [0.16.13](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.13) - 2023-12-13
99

1010
### Added
11+
- Added `trained_slice_id` parameter to `dataset.upload_predictions()` to specify the slice ID used to train the model.
12+
13+
### Fixes
14+
- Fix offset generation for image chips in `dataset.items_and_annotation_chip_generator()`
1115

16+
## [0.16.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.12) - 2023-11-29
17+
18+
### Added
1219
- Added tag support for slices.
1320

1421
Example:
@@ -21,7 +28,6 @@ Example:
2128
## [0.16.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.11) - 2023-11-22
2229

2330
### Added
24-
2531
- Added `num_processes` parameter to `dataset.items_and_annotation_chip_generator()` to specify parallel processing.
2632
- Method to allow for concurrent task fetches for pointcloud data
2733

nucleus/annotation_uploader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def upload(
5757
update: bool = False,
5858
remote_files_per_upload_request: int = 20,
5959
local_files_per_upload_request: int = 10,
60+
trained_slice_id: Optional[str] = None,
6061
):
6162
"""For more details on parameters and functionality, see dataset.annotate."""
6263
if local_files_per_upload_request > 10:
@@ -95,6 +96,7 @@ def upload(
9596
update,
9697
batch_size=remote_files_per_upload_request,
9798
segmentation=True,
99+
trained_slice_id=trained_slice_id,
98100
)
99101
)
100102
if annotations_without_files:
@@ -104,6 +106,7 @@ def upload(
104106
update,
105107
batch_size=batch_size,
106108
segmentation=False,
109+
trained_slice_id=trained_slice_id,
107110
)
108111
)
109112

@@ -115,6 +118,7 @@ def make_batched_requests(
115118
update: bool,
116119
batch_size: int,
117120
segmentation: bool,
121+
trained_slice_id: Optional[str],
118122
):
119123
batches = [
120124
annotations[i : i + batch_size]
@@ -125,7 +129,9 @@ def make_batched_requests(
125129
"Segmentation batches" if segmentation else "Annotation batches"
126130
)
127131
for batch in self._client.tqdm_bar(batches, desc=progress_bar_name):
128-
payload = construct_annotation_payload(batch, update)
132+
payload = construct_annotation_payload(
133+
batch, update, trained_slice_id
134+
)
129135
responses.append(
130136
self._client.make_request(payload, route=self._route)
131137
)
@@ -234,9 +240,11 @@ def __init__(
234240
dataset_id: Optional[str] = None,
235241
model_id: Optional[str] = None,
236242
model_run_id: Optional[str] = None,
243+
trained_slice_id: Optional[str] = None,
237244
):
238245
super().__init__(dataset_id, client)
239246
self._client = client
247+
self.trained_slice_id = trained_slice_id
240248
if model_run_id is not None:
241249
assert model_id is None and dataset_id is None
242250
self._route = f"modelRun/{model_run_id}/predict"

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
TRACK_REFERENCE_ID_KEY = "track_reference_id"
150150
TRACK_REFERENCE_IDS_KEY = "track_reference_ids"
151151
TRACKS_KEY = "tracks"
152+
TRAINED_SLICE_ID_KEY = "trained_slice_id"
152153
TRUE_POSITIVE_KEY = "true_positive"
153154
TYPE_KEY = "type"
154155
UPDATED_ITEMS = "updated_items"

nucleus/dataset.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
SLICE_ID_KEY,
6767
TRACK_REFERENCE_IDS_KEY,
6868
TRACKS_KEY,
69+
TRAINED_SLICE_ID_KEY,
6970
UPDATE_KEY,
7071
VIDEO_URL_KEY,
7172
)
@@ -1793,6 +1794,7 @@ def upload_predictions(
17931794
batch_size: int = 5000,
17941795
remote_files_per_upload_request: int = 20,
17951796
local_files_per_upload_request: int = 10,
1797+
trained_slice_id: Optional[str] = None,
17961798
):
17971799
"""Uploads predictions and associates them with an existing :class:`Model`.
17981800
@@ -1841,19 +1843,20 @@ def upload_predictions(
18411843
you can try lowering this batch size. This is only relevant for
18421844
asynchronous=False
18431845
remote_files_per_upload_request: Number of remote files to upload in each
1844-
request. Segmentations have either local or remote files, if you are
1845-
getting timeouts while uploading segmentations with remote urls, you
1846-
should lower this value from its default of 20. This is only relevant for
1847-
asynchronous=False.
1846+
request. Segmentations have either local or remote files, if you are
1847+
getting timeouts while uploading segmentations with remote urls, you
1848+
should lower this value from its default of 20. This is only relevant for
1849+
asynchronous=False.
18481850
local_files_per_upload_request: Number of local files to upload in each
1849-
request. Segmentations have either local or remote files, if you are
1850-
getting timeouts while uploading segmentations with local files, you
1851-
should lower this value from its default of 10. The maximum is 10.
1852-
This is only relevant for asynchronous=False
1851+
request. Segmentations have either local or remote files, if you are
1852+
getting timeouts while uploading segmentations with local files, you
1853+
should lower this value from its default of 10. The maximum is 10.
1854+
This is only relevant for asynchronous=False
1855+
trained_slice_id: Nucleus-generated slice ID (starts with ``slc_``) which was used
1856+
to train the model.
18531857
18541858
Returns:
18551859
Payload describing the synchronous upload::
1856-
18571860
{
18581861
"dataset_id": str,
18591862
"model_run_id": str,
@@ -1876,7 +1879,11 @@ def upload_predictions(
18761879
predictions, self.id, self._client
18771880
)
18781881
response = self._client.make_request(
1879-
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
1882+
payload={
1883+
REQUEST_ID_KEY: request_id,
1884+
UPDATE_KEY: update,
1885+
TRAINED_SLICE_ID_KEY: trained_slice_id,
1886+
},
18801887
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
18811888
)
18821889
return AsyncJob.from_json(response, self._client)
@@ -1887,6 +1894,7 @@ def upload_predictions(
18871894
update=update,
18881895
remote_files_per_upload_request=remote_files_per_upload_request,
18891896
local_files_per_upload_request=local_files_per_upload_request,
1897+
trained_slice_id=trained_slice_id,
18901898
)
18911899

18921900
def predictions_iloc(self, model, index):

nucleus/payload_constructor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SCENES_KEY,
2525
SEGMENTATIONS_KEY,
2626
TAXONOMY_NAME_KEY,
27+
TRAINED_SLICE_ID_KEY,
2728
TYPE_KEY,
2829
UPDATE_KEY,
2930
)
@@ -76,6 +77,7 @@ def construct_annotation_payload(
7677
]
7778
],
7879
update: bool,
80+
trained_slice_id: Optional[str],
7981
) -> dict:
8082
annotations = [
8183
annotation.to_payload()
@@ -92,6 +94,8 @@ def construct_annotation_payload(
9294
payload[ANNOTATIONS_KEY] = annotations
9395
if segmentations:
9496
payload[SEGMENTATIONS_KEY] = segmentations
97+
if trained_slice_id:
98+
payload[TRAINED_SLICE_ID_KEY] = trained_slice_id
9599
return payload
96100

97101

0 commit comments

Comments
 (0)