Skip to content

Commit 076bb6e

Browse files
authored
Add tags to Models (#316)
1 parent 3e799a7 commit 076bb6e

File tree

7 files changed

+132
-5
lines changed

7 files changed

+132
-5
lines changed

CHANGELOG.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,30 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
9+
## [0.14.0](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.0) - 2022-06-16
10+
11+
### Added
12+
13+
- Allow creation/deletion of model tags on new and existing models, eg:
14+
```python
15+
# on model creation
16+
model = client.create_model(name="foo_model", reference_id="foo-model-ref", tags=["some tag"])
17+
18+
# on existing models
19+
existing_model = client.models[0]
20+
existing_model.add_tags(['tag a', 'tag b'])
21+
22+
# remove tag
23+
existing_model.remove_tags(['tag a'])
24+
```
25+
826
## [0.13.5](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.13.4) - 2022-06-15
927

1028
### Fixed
1129
- Guard against invalid skeleton indexes in KeypointsAnnotation
1230

31+
1332
## [0.13.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.13.4) - 2022-06-09
1433

1534
### Fixed
@@ -37,7 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3756
### Added
3857

3958
- Segmentation functions to Validate API
40-
59+
4160
## [0.12.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.12.4) - 2022-06-02
4261

4362
### Fixed

nucleus/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
KEEP_HISTORY_KEY,
9292
MESSAGE_KEY,
9393
MODEL_RUN_ID_KEY,
94+
MODEL_TAGS_KEY,
9495
NAME_KEY,
9596
NUCLEUS_ENDPOINT,
9697
PREDICTIONS_IGNORED_KEY,
@@ -218,6 +219,7 @@ def models(self) -> List[Model]:
218219
reference_id=model["ref_id"],
219220
metadata=model["metadata"] or None,
220221
client=self,
222+
tags=model.get(MODEL_TAGS_KEY, []),
221223
)
222224
for model in model_objects["models"]
223225
]
@@ -484,6 +486,7 @@ def create_model(
484486
reference_id: str,
485487
metadata: Optional[Dict] = None,
486488
bundle_name: Optional[str] = None,
489+
tags: Optional[List[str]] = None,
487490
) -> Model:
488491
"""Adds a :class:`Model` to Nucleus.
489492
@@ -495,13 +498,15 @@ def create_model(
495498
metadata: An arbitrary dictionary of additional data about this model
496499
that can be stored and retrieved. For example, you can store information
497500
about the hyperparameters used in training this model.
501+
bundle_name: Optional name of bundle attached to this model
502+
tags: Optional list of tags to attach to this model
498503
499504
Returns:
500505
:class:`Model`: The newly created model as an object.
501506
"""
502507
response = self.make_request(
503508
construct_model_creation_payload(
504-
name, reference_id, metadata, bundle_name
509+
name, reference_id, metadata, bundle_name, tags
505510
),
506511
"models/add",
507512
)
@@ -516,6 +521,7 @@ def create_model(
516521
metadata=metadata,
517522
bundle_name=bundle_name,
518523
client=self,
524+
tags=tags,
519525
)
520526

521527
def create_launch_model(

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
MESSAGE_KEY = "message"
9191
METADATA_KEY = "metadata"
9292
MODEL_BUNDLE_NAME_KEY = "bundle_name"
93+
MODEL_TAGS_KEY = "tags"
9394
MODEL_ID_KEY = "model_id"
9495
MODEL_RUN_ID_KEY = "model_run_id"
9596
NAME_KEY = "name"

nucleus/model.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import requests
44

5-
from .constants import METADATA_KEY, NAME_KEY, REFERENCE_ID_KEY
5+
from .constants import METADATA_KEY, MODEL_TAGS_KEY, NAME_KEY, REFERENCE_ID_KEY
66
from .dataset import Dataset
77
from .job import AsyncJob
88
from .model_run import ModelRun
@@ -93,13 +93,21 @@ class Model:
9393
"""
9494

9595
def __init__(
96-
self, model_id, name, reference_id, metadata, client, bundle_name=None
96+
self,
97+
model_id,
98+
name,
99+
reference_id,
100+
metadata,
101+
client,
102+
bundle_name=None,
103+
tags: List[str] = None,
97104
):
98105
self.id = model_id
99106
self.name = name
100107
self.reference_id = reference_id
101108
self.metadata = metadata
102109
self.bundle_name = bundle_name
110+
self.tags = tags if tags else []
103111
self._client = client
104112

105113
def __repr__(self):
@@ -213,3 +221,49 @@ def run(self, dataset_id: str, slice_id: Optional[str]) -> str:
213221
)
214222

215223
return response
224+
225+
def add_tags(self, tags: List[str]):
226+
"""Tag the model with custom tag names. ::
227+
228+
import nucleus
229+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
230+
model = client.list_models()[0]
231+
232+
model.add_tags(["tag_A", "tag_B"])
233+
234+
Args:
235+
tags: list of tag names
236+
"""
237+
response = self._client.make_request(
238+
{MODEL_TAGS_KEY: tags},
239+
f"model/{self.id}/tag",
240+
requests_command=requests.post,
241+
)
242+
243+
if response.get("msg", False):
244+
self.tags.extend(tags)
245+
246+
return response
247+
248+
def remove_tags(self, tags: List[str]):
249+
"""Remove tag(s) from the model. ::
250+
251+
import nucleus
252+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
253+
model = client.list_models()[0]
254+
255+
model.remove_tags(["tag_x"])
256+
257+
Args:
258+
tags: list of tag names to remove
259+
"""
260+
response = self._client.make_request(
261+
{MODEL_TAGS_KEY: tags},
262+
f"model/{self.id}/tag",
263+
requests_command=requests.delete,
264+
)
265+
266+
if response.get("msg", False):
267+
self.tags = list(filter(lambda t: t not in tags, self.tags))
268+
269+
return response

nucleus/payload_constructor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
METADATA_KEY,
1818
MODEL_BUNDLE_NAME_KEY,
1919
MODEL_ID_KEY,
20+
MODEL_TAGS_KEY,
2021
NAME_KEY,
2122
REFERENCE_ID_KEY,
2223
SCENES_KEY,
@@ -127,6 +128,7 @@ def construct_model_creation_payload(
127128
reference_id: str,
128129
metadata: Optional[Dict],
129130
bundle_name: Optional[str],
131+
tags: Optional[List[str]],
130132
) -> dict:
131133
payload = {
132134
NAME_KEY: name,
@@ -136,6 +138,8 @@ def construct_model_creation_payload(
136138

137139
if bundle_name:
138140
payload[MODEL_BUNDLE_NAME_KEY] = bundle_name
141+
if tags:
142+
payload[MODEL_TAGS_KEY] = tags
139143

140144
return payload
141145

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.13.5"
24+
version = "0.14.0"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_models.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,46 @@ def test_new_model_endpoints(CLIENT, dataset: Dataset):
115115
model, predictions[0].reference_id, predictions[0].annotation_id
116116
)
117117
assert_box_prediction_matches_dict(prediction_loc, TEST_BOX_PREDICTIONS[0])
118+
119+
120+
def test_tag_model(CLIENT, dataset: Dataset):
121+
def testing_model(ref_id):
122+
models_from_backend = list(
123+
filter(lambda m: m.reference_id == ref_id, CLIENT.models)
124+
)
125+
assert len(models_from_backend) == 1
126+
return models_from_backend[0]
127+
128+
model_reference = "model_" + str(time.time())
129+
model = CLIENT.create_model(
130+
TEST_MODEL_NAME, model_reference, tags=["first_tag"]
131+
)
132+
133+
model.add_tags(["single tag"])
134+
model.add_tags(["tag_a", "tag_b"])
135+
136+
backend_model = testing_model(model_reference)
137+
assert sorted(backend_model.tags) == sorted(
138+
["first_tag", "single tag", "tag_a", "tag_b"]
139+
)
140+
141+
model.remove_tags(["tag_a"])
142+
model.remove_tags(["first_tag", "tag_b"])
143+
144+
backend_model = testing_model(model_reference)
145+
assert backend_model.tags == ["single tag"]
146+
147+
148+
def test_remove_invalid_tag_from_model(CLIENT, dataset: Dataset):
149+
150+
model_reference = "model_" + str(time.time())
151+
model = CLIENT.create_model(TEST_MODEL_NAME, model_reference)
152+
153+
model.add_tags(["single tag"])
154+
155+
response = model.remove_tags(["tag_a"])
156+
assert "error" in response
157+
assert (
158+
response["error"]
159+
== "Deleted 0 tags from model. Either the tag(s) did not exist, or you are not the owner of this model project."
160+
)

0 commit comments

Comments
 (0)