Skip to content

Commit d69a92e

Browse files
authored
Updated model to include trained_slice_id (#421)
1 parent e453db6 commit d69a92e

File tree

6 files changed

+93
-5
lines changed

6 files changed

+93
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## [0.16.18](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.18) - 2024-02-06
10+
11+
### Added
12+
- Add the ability to add and remove `trained_slice_id` to a model
13+
914
## [0.16.17](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.17) - 2024-01-29
1015

1116
### Fixes

nucleus/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
MESSAGE_KEY,
112112
MODEL_RUN_ID_KEY,
113113
MODEL_TAGS_KEY,
114+
MODEL_TRAINED_SLICE_IDS_KEY,
114115
NAME_KEY,
115116
NUCLEUS_ENDPOINT,
116117
POINTS_KEY,
@@ -247,6 +248,7 @@ def models(self) -> List[Model]:
247248
metadata=model["metadata"] or None,
248249
client=self,
249250
tags=model.get(MODEL_TAGS_KEY, []),
251+
trained_slice_ids=model.get(MODEL_TRAINED_SLICE_IDS_KEY, None),
250252
)
251253
for model in model_objects["models"]
252254
]
@@ -560,6 +562,7 @@ def create_model(
560562
metadata: Optional[Dict] = None,
561563
bundle_name: Optional[str] = None,
562564
tags: Optional[List[str]] = None,
565+
trained_slice_ids: Optional[List[str]] = None,
563566
) -> Model:
564567
"""Adds a :class:`Model` to Nucleus.
565568
@@ -579,7 +582,12 @@ def create_model(
579582
"""
580583
response = self.make_request(
581584
construct_model_creation_payload(
582-
name, reference_id, metadata, bundle_name, tags
585+
name,
586+
reference_id,
587+
metadata,
588+
bundle_name,
589+
tags,
590+
trained_slice_ids,
583591
),
584592
"models/add",
585593
)
@@ -595,6 +603,7 @@ def create_model(
595603
bundle_name=bundle_name,
596604
client=self,
597605
tags=tags,
606+
trained_slice_ids=trained_slice_ids,
598607
)
599608

600609
def create_launch_model(
@@ -603,6 +612,7 @@ def create_launch_model(
603612
reference_id: str,
604613
bundle_args: Dict[str, Any],
605614
metadata: Optional[Dict] = None,
615+
trained_slice_ids: Optional[List[str]] = None,
606616
) -> Model:
607617
"""
608618
Adds a :class:`Model` to Nucleus, as well as a Launch bundle from a given function.
@@ -694,6 +704,7 @@ def create_launch_model(
694704
reference_id,
695705
metadata,
696706
bundle_name,
707+
trained_slice_ids=trained_slice_ids,
697708
)
698709

699710
def create_launch_model_from_dir(
@@ -702,6 +713,7 @@ def create_launch_model_from_dir(
702713
reference_id: str,
703714
bundle_from_dir_args: Dict[str, Any],
704715
metadata: Optional[Dict] = None,
716+
trained_slice_ids: Optional[List[str]] = None,
705717
) -> Model:
706718
"""
707719
Adds a :class:`Model` to Nucleus, as well as a Launch bundle from a directory.
@@ -816,6 +828,7 @@ def create_launch_model_from_dir(
816828
reference_id,
817829
metadata,
818830
bundle_name,
831+
trained_slice_ids=trained_slice_ids,
819832
)
820833

821834
@deprecated(

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
METADATA_KEY = "metadata"
107107
MODEL_BUNDLE_NAME_KEY = "bundle_name"
108108
MODEL_TAGS_KEY = "tags"
109+
MODEL_TRAINED_SLICE_IDS_KEY = "trained_slice_ids"
109110
MODEL_ID_KEY = "model_id"
110111
MODEL_RUN_ID_KEY = "model_run_id"
111112
MODEL_PREDICTION_ID_KEY = "model_prediction_id"

nucleus/model.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import requests
44

55
from .async_job import AsyncJob
6-
from .constants import METADATA_KEY, MODEL_TAGS_KEY, NAME_KEY, REFERENCE_ID_KEY
6+
from .constants import (
7+
METADATA_KEY,
8+
MODEL_TAGS_KEY,
9+
MODEL_TRAINED_SLICE_IDS_KEY,
10+
NAME_KEY,
11+
REFERENCE_ID_KEY,
12+
)
713
from .dataset import Dataset
814
from .model_run import ModelRun
915
from .prediction import (
@@ -101,6 +107,7 @@ def __init__(
101107
client,
102108
bundle_name=None,
103109
tags=None,
110+
trained_slice_ids=None,
104111
):
105112
self.id = model_id
106113
self.name = name
@@ -109,9 +116,10 @@ def __init__(
109116
self.bundle_name = bundle_name
110117
self.tags = tags if tags else []
111118
self._client = client
119+
self.trained_slice_ids = trained_slice_ids if trained_slice_ids else []
112120

113121
def __repr__(self):
114-
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, bundle_name={self.bundle_name}, tags={self.tags}, client={self._client})"
122+
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, bundle_name={self.bundle_name}, tags={self.tags}, client={self._client}, trained_slice_ids={self.trained_slice_ids})"
115123

116124
def __eq__(self, other):
117125
return (
@@ -120,6 +128,7 @@ def __eq__(self, other):
120128
and (self.metadata == other.metadata)
121129
and (self._client == other._client)
122130
and (self.bundle_name == other.bundle_name)
131+
and (self.trained_slice_ids == other.trained_slice_ids)
123132
)
124133

125134
def __hash__(self):
@@ -134,6 +143,8 @@ def from_json(cls, payload: dict, client):
134143
reference_id=payload["ref_id"],
135144
metadata=payload["metadata"] or None,
136145
client=client,
146+
tags=payload.get(MODEL_TAGS_KEY, None),
147+
trained_slice_ids=payload.get(MODEL_TRAINED_SLICE_IDS_KEY, None),
137148
)
138149

139150
def create_run(
@@ -242,7 +253,9 @@ def add_tags(self, tags: List[str]):
242253
)
243254

244255
if response.ok:
245-
self.tags.extend(tags)
256+
for tag in tags:
257+
if tag not in self.tags:
258+
self.tags.append(tag)
246259

247260
return response.json()
248261

@@ -269,3 +282,55 @@ def remove_tags(self, tags: List[str]):
269282
self.tags = list(filter(lambda t: t not in tags, self.tags))
270283

271284
return response.json()
285+
286+
def add_trained_slice_ids(self, slice_ids: List[str]):
287+
"""Add trained slice id(s) to the model. ::
288+
289+
import nucleus
290+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
291+
model = client.list_models()[0]
292+
293+
model.add_trained_slice_ids(["slc_...", "slc_..."])
294+
295+
Args:
296+
slice_ids: list of trained slice ids
297+
"""
298+
response: requests.Response = self._client.make_request(
299+
{MODEL_TRAINED_SLICE_IDS_KEY: slice_ids},
300+
f"model/{self.id}/trainedSliceId",
301+
requests_command=requests.post,
302+
return_raw_response=True,
303+
)
304+
305+
if response.ok:
306+
for slice_id in slice_ids:
307+
if slice_id not in self.trained_slice_ids:
308+
self.trained_slice_ids.append(slice_id)
309+
310+
return response.json()
311+
312+
def remove_trained_slice_ids(self, slide_ids: List[str]):
313+
"""Remove trained slice id(s) from the model. ::
314+
315+
import nucleus
316+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
317+
model = client.list_models()[0]
318+
319+
model.remove_trained_slice_ids(["slc_...", "slc_..."])
320+
321+
Args:
322+
slice_ids: list of trained slice ids to remove
323+
"""
324+
response: requests.Response = self._client.make_request(
325+
{MODEL_TRAINED_SLICE_IDS_KEY: slide_ids},
326+
f"model/{self.id}/trainedSliceId",
327+
requests_command=requests.delete,
328+
return_raw_response=True,
329+
)
330+
331+
if response.ok:
332+
self.trained_slice_ids = list(
333+
filter(lambda t: t not in slide_ids, self.trained_slice_ids)
334+
)
335+
336+
return response.json()

nucleus/payload_constructor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MODEL_BUNDLE_NAME_KEY,
2020
MODEL_ID_KEY,
2121
MODEL_TAGS_KEY,
22+
MODEL_TRAINED_SLICE_IDS_KEY,
2223
NAME_KEY,
2324
REFERENCE_ID_KEY,
2425
SCENES_KEY,
@@ -137,13 +138,16 @@ def construct_model_creation_payload(
137138
metadata: Optional[Dict],
138139
bundle_name: Optional[str],
139140
tags: Optional[List[str]],
141+
trained_slice_ids: Optional[List[str]],
140142
) -> dict:
141143
payload = {
142144
NAME_KEY: name,
143145
REFERENCE_ID_KEY: reference_id,
144146
METADATA_KEY: metadata if metadata else {},
145147
}
146148

149+
if trained_slice_ids:
150+
payload[MODEL_TRAINED_SLICE_IDS_KEY] = trained_slice_ids
147151
if bundle_name:
148152
payload[MODEL_BUNDLE_NAME_KEY] = bundle_name
149153
if tags:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running
2525

2626
[tool.poetry]
2727
name = "scale-nucleus"
28-
version = "0.16.17"
28+
version = "0.16.18"
2929
description = "The official Python client library for Nucleus, the Data Platform for AI"
3030
license = "MIT"
3131
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)