Skip to content

Commit a4970e7

Browse files
[PLT-773] Allow users to create configs and associate them with projects through SDK. (#1591)
1 parent 9e5c5fd commit a4970e7

File tree

12 files changed

+233
-0
lines changed

12 files changed

+233
-0
lines changed

docs/labelbox/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ Labelbox Python SDK Documentation
2929
labeling-frontend-options
3030
labeling-parameter-override
3131
model
32+
model-config
3233
model-run
3334
ontology
3435
ontology_kind
3536
organization
3637
pagination
3738
project
39+
project-model-config
3840
quality-mode
3941
resource-tag
4042
review

docs/labelbox/model-config.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Model Config
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.model_config
5+
:members:
6+
:show-inheritance:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Project Model Config
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.project_model_config
5+
:members:
6+
:show-inheritance:

libs/labelbox/src/labelbox/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from labelbox.client import Client
66
from labelbox.schema.project import Project
77
from labelbox.schema.model import Model
8+
from labelbox.schema.model_config import ModelConfig
89
from labelbox.schema.bulk_import_request import BulkImportRequest
910
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport
1011
from labelbox.schema.dataset import Dataset
@@ -29,6 +30,7 @@
2930
from labelbox.schema.benchmark import Benchmark
3031
from labelbox.schema.iam_integration import IAMIntegration
3132
from labelbox.schema.resource_tag import ResourceTag
33+
from labelbox.schema.project_model_config import ProjectModelConfig
3234
from labelbox.schema.project_resource_tag import ProjectResourceTag
3335
from labelbox.schema.media_type import MediaType
3436
from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice

libs/labelbox/src/labelbox/client.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from labelbox.schema.labeling_frontend import LabelingFrontend
3939
from labelbox.schema.media_type import MediaType, get_media_type_validation_error
4040
from labelbox.schema.model import Model
41+
from labelbox.schema.model_config import ModelConfig
4142
from labelbox.schema.model_run import ModelRun
4243
from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult
4344
from labelbox.schema.ontology import Tool, Classification, FeatureSchema
@@ -595,6 +596,56 @@ def _create(self, db_object_type, data, extra_params={}):
595596
res = res["create%s" % db_object_type.type_name()]
596597
return db_object_type(self, res)
597598

599+
def create_model_config(self, name: str, model_id: str, inference_params: dict) -> ModelConfig:
600+
""" Creates a new model config with the given params.
601+
Model configs are scoped to organizations, and can be reused between projects.
602+
603+
Args:
604+
name (str): Name of the model config
605+
model_id (str): ID of model to configure
606+
inference_params (dict): JSON of model configuration parameters.
607+
608+
Returns:
609+
str, id of the created model config
610+
"""
611+
612+
query = """mutation CreateModelConfigPyApi($modelId: ID!, $inferenceParams: Json!, $name: String!) {
613+
createModelConfig(input: {modelId: $modelId, inferenceParams: $inferenceParams, name: $name}) {
614+
modelId
615+
inferenceParams
616+
id
617+
name
618+
}
619+
}"""
620+
params = {
621+
"modelId": model_id,
622+
"inferenceParams": inference_params,
623+
"name": name
624+
}
625+
result = self.execute(query, params)
626+
return ModelConfig(self, result['createModelConfig'])
627+
628+
def delete_model_config(self, id: str) -> bool:
629+
""" Deletes an existing model config with the given id
630+
631+
Args:
632+
id (str): ID of existing model config
633+
634+
Returns:
635+
bool, indicates if the operation was a success.
636+
"""
637+
638+
query = """mutation DeleteModelConfigPyApi($id: ID!) {
639+
deleteModelConfig(input: {id: $id}) {
640+
success
641+
}
642+
}"""
643+
params = {
644+
"id": id
645+
}
646+
result = self.execute(query, params)
647+
return result['deleteModelConfig']['success']
648+
598649
def create_dataset(self,
599650
iam_integration=IAMIntegration._DEFAULT,
600651
**kwargs) -> Dataset:

libs/labelbox/src/labelbox/orm/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class Entity(metaclass=EntityMeta):
362362
Task: Type[labelbox.Task]
363363
AssetAttachment: Type[labelbox.AssetAttachment]
364364
ModelRun: Type[labelbox.ModelRun]
365+
ModelConfig: Type[labelbox.ModelConfig]
365366
Review: Type[labelbox.Review]
366367
User: Type[labelbox.User]
367368
LabelingFrontend: Type[labelbox.LabelingFrontend]
@@ -375,6 +376,7 @@ class Entity(metaclass=EntityMeta):
375376
Invite: Type[labelbox.Invite]
376377
InviteLimit: Type[labelbox.InviteLimit]
377378
ProjectRole: Type[labelbox.ProjectRole]
379+
ProjectModelConfig: Type[labelbox.ProjectModelConfig]
378380
Project: Type[labelbox.Project]
379381
Batch: Type[labelbox.Batch]
380382
CatalogSlice: Type[labelbox.CatalogSlice]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field
3+
4+
5+
class ModelConfig(DbObject):
6+
""" A ModelConfig represents a set of inference params configured for a model
7+
8+
Attributes:
9+
inference_params (JSON): Dict of inference params
10+
model_id (str): ID of the model to configure
11+
name (str): Name of config
12+
"""
13+
14+
inference_params = Field.Json("inference_params", "inferenceParams")
15+
model_id = Field.String("model_id", "modelId")
16+
name = Field.String("name", "name")

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
3434
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
3535
from labelbox.schema.media_type import MediaType
36+
from labelbox.schema.project_model_config import ProjectModelConfig
3637
from labelbox.schema.queue_mode import QueueMode
3738
from labelbox.schema.resource_tag import ResourceTag
3839
from labelbox.schema.task import Task
@@ -136,6 +137,28 @@ class Project(DbObject, Updateable, Deletable):
136137
def is_chat_evaluation(self) -> bool:
137138
return self.media_type == MediaType.Conversational and self.editor_task_type == EditorTaskType.ModelChatEvaluation
138139

140+
def project_model_configs(self):
141+
query_str = """query ProjectModelConfigsPyApi($id: ID!) {
142+
project(where: {id : $id}) {
143+
projectModelConfigs {
144+
id
145+
modelConfigId
146+
modelConfig {
147+
id
148+
modelId
149+
inferenceParams
150+
}
151+
projectId
152+
}
153+
}
154+
}"""
155+
data = {"id": self.uid}
156+
res = self.client.execute(query_str, data)
157+
return [
158+
ProjectModelConfig(self.client, projectModelConfig)
159+
for projectModelConfig in res["project"]["projectModelConfigs"]
160+
]
161+
139162
def update(self, **kwargs):
140163
""" Updates this project with the specified attributes
141164
@@ -1226,6 +1249,50 @@ def get_queue_mode(self) -> "QueueMode":
12261249
else:
12271250
raise ValueError("Status not known")
12281251

1252+
def add_model_config(self, model_config_id: str) -> str:
1253+
""" Adds a model config to this project.
1254+
1255+
Args:
1256+
model_config_id (str): ID of a model config to add to this project.
1257+
1258+
Returns:
1259+
str, ID of the project model config association. This is needed for updating and deleting associations.
1260+
"""
1261+
1262+
query = """mutation CreateProjectModelConfigPyApi($projectId: ID!, $modelConfigId: ID!) {
1263+
createProjectModelConfig(input: {projectId: $projectId, modelConfigId: $modelConfigId}) {
1264+
projectModelConfigId
1265+
}
1266+
}"""
1267+
1268+
params = {
1269+
"projectId": self.uid,
1270+
"modelConfigId": model_config_id,
1271+
}
1272+
result = self.client.execute(query, params)
1273+
return result["createProjectModelConfig"]["projectModelConfigId"]
1274+
1275+
def delete_project_model_config(self, project_model_config_id: str) -> bool:
1276+
""" Deletes the association between a model config and this project.
1277+
1278+
Args:
1279+
project_model_config_id (str): ID of a project model config association to delete for this project.
1280+
1281+
Returns:
1282+
bool, indicates if the operation was a success.
1283+
"""
1284+
query = """mutation DeleteProjectModelConfigPyApi($id: ID!) {
1285+
deleteProjectModelConfig(input: {id: $id}) {
1286+
success
1287+
}
1288+
}"""
1289+
1290+
params = {
1291+
"id": project_model_config_id,
1292+
}
1293+
result = self.client.execute(query, params)
1294+
return result["deleteProjectModelConfig"]["success"]
1295+
12291296
def set_labeling_parameter_overrides(
12301297
self, data: List[LabelingParameterOverrideInput]) -> bool:
12311298
""" Adds labeling parameter overrides to this project.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from labelbox.orm.db_object import DbObject, Deletable
2+
from labelbox.orm.model import Field, Relationship
3+
4+
5+
class ProjectModelConfig(DbObject):
6+
""" A ProjectModelConfig represents an association between a project and a single model config.
7+
8+
Attributes:
9+
project_id (str): ID of project to associate
10+
model_config_id (str): ID of the model configuration
11+
model_config (ModelConfig): Configuration for model
12+
"""
13+
14+
project_id = Field.String("project_id", "projectId")
15+
model_config_id = Field.String("model_config_id", "modelConfigId")
16+
model_config = Relationship.ToOne("ModelConfig", False, "model_config")
17+
18+
def delete(self) -> bool:
19+
""" Deletes this association between a model config and this project.
20+
21+
Returns:
22+
bool, indicates if the operation was a success.
23+
"""
24+
query = """mutation DeleteProjectModelConfigPyApi($id: ID!) {
25+
deleteProjectModelConfig(input: {id: $id}) {
26+
success
27+
}
28+
}"""
29+
30+
params = {
31+
"id": self.uid,
32+
}
33+
result = self.client.execute(query, params)
34+
return result["deleteProjectModelConfig"]["success"]

libs/labelbox/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@ def consensus_project(client, rand_gen):
428428
project.delete()
429429

430430

431+
@pytest.fixture
432+
def model_config(client, rand_gen, valid_model_id):
433+
model_config = client.create_model_config(name=rand_gen(str), model_id=valid_model_id, inference_params = {"param": "value"})
434+
yield model_config
435+
client.delete_model_config(model_config.uid)
436+
431437
@pytest.fixture
432438
def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen,
433439
image_url):
@@ -1043,3 +1049,7 @@ def embedding(client: Client):
10431049
embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8)
10441050
yield embedding
10451051
embedding.delete()
1052+
1053+
@pytest.fixture
1054+
def valid_model_id():
1055+
return "2c903542-d1da-48fd-9db1-8c62571bd3d2"

0 commit comments

Comments
 (0)