Skip to content

Commit ee4cece

Browse files
author
Val Brodsky
committed
Create model chat evaluation project
1 parent 3ee25d1 commit ee4cece

File tree

5 files changed

+85
-193
lines changed

5 files changed

+85
-193
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 63 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from labelbox.adv_client import AdvClient
2222
from labelbox.orm import query
2323
from labelbox.orm.db_object import DbObject
24-
from labelbox.orm.model import Entity
24+
from labelbox.orm.model import Entity, Field
2525
from labelbox.pagination import PaginatedCollection
2626
from labelbox.schema import role
2727
from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy
@@ -52,7 +52,8 @@
5252
from labelbox.schema.slice import CatalogSlice, ModelSlice
5353
from labelbox.schema.task import Task
5454
from labelbox.schema.user import User
55-
from labelbox.schema.editor_task_type import EditorTaskType
55+
from labelbox.schema.ontology_kind import (OntologyKind, EditorTaskTypeMapper,
56+
EditorTaskType)
5657

5758
logger = logging.getLogger(__name__)
5859

@@ -564,7 +565,7 @@ def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]:
564565
"""
565566
return self._get_all(Entity.LabelingFrontend, where)
566567

567-
def _create(self, db_object_type, data):
568+
def _create(self, db_object_type, data, extra_params={}):
568569
""" Creates an object on the server. Attribute values are
569570
passed as keyword arguments:
570571
@@ -580,12 +581,14 @@ def _create(self, db_object_type, data):
580581
"""
581582
# Convert string attribute names to Field or Relationship objects.
582583
# Also convert Labelbox object values to their UIDs.
584+
583585
data = {
584586
db_object_type.attribute(attr) if isinstance(attr, str) else attr:
585587
value.uid if isinstance(value, DbObject) else value
586588
for attr, value in data.items()
587589
}
588590

591+
data = {**data, **extra_params}
589592
query_string, params = query.create(db_object_type, data)
590593
res = self.execute(query_string, params)
591594
res = res["create%s" % db_object_type.type_name()]
@@ -700,18 +703,26 @@ def create_project(self, **kwargs) -> Project:
700703
)
701704

702705
media_type = kwargs.get("media_type")
703-
if media_type:
704-
if MediaType.is_supported(media_type):
705-
media_type = media_type.value
706-
else:
707-
raise TypeError(f"{media_type} is not a valid media type. Use"
708-
f" any of {MediaType.get_supported_members()}"
709-
" from MediaType. Example: MediaType.Image.")
706+
if media_type and MediaType.is_supported(media_type):
707+
media_type_value = media_type.value
708+
elif media_type:
709+
raise TypeError(f"{media_type} is not a valid media type. Use"
710+
f" any of {MediaType.get_supported_members()}"
711+
" from MediaType. Example: MediaType.Image.")
710712
else:
711713
logger.warning(
712714
"Creating a project without specifying media_type"
713715
" through this method will soon no longer be supported.")
714716

717+
ontology_kind = kwargs.pop("ontology_kind", None)
718+
if ontology_kind and OntologyKind.is_supported(ontology_kind):
719+
editor_task_type_value = EditorTaskTypeMapper.to_editor_task_type(
720+
ontology_kind, media_type).value
721+
elif ontology_kind:
722+
raise OntologyKind.get_ontology_kind_validation_error(ontology_kind)
723+
else:
724+
editor_task_type_value = None
725+
715726
quality_mode = kwargs.get("quality_mode")
716727
if not quality_mode:
717728
logger.info("Defaulting quality mode to Benchmark.")
@@ -729,16 +740,34 @@ def create_project(self, **kwargs) -> Project:
729740
else:
730741
raise ValueError(f"{quality_mode} is not a valid quality mode.")
731742

732-
return self._create(Entity.Project, {
733-
**data,
734-
**({
735-
"media_type": media_type
736-
} if media_type else {})
737-
})
743+
params = {**data}
744+
if media_type_value:
745+
params["media_type"] = media_type_value
746+
if editor_task_type_value:
747+
params["editor_task_type"] = editor_task_type_value
748+
749+
extra_params = {
750+
Field.String("dataset_name_or_id"):
751+
params.pop("dataset_name_or_id", None),
752+
Field.Boolean("append_to_existing_dataset"):
753+
params.pop("append_to_existing_dataset", None),
754+
Field.Int("data_row_count"):
755+
params.pop("data_row_count", None),
756+
}
757+
extra_params = {k: v for k, v in extra_params.items() if v is not None}
738758

739-
def create_model_chat_project(self, **kwargs) -> Project:
740-
kwargs["media_type"] = media_type.MediaType.Conversational
741-
kwargs["editor_task_type"] = editor_task_type.EditorTaskType.ModelChatEvaluation
759+
return self._create(Entity.Project, params, extra_params)
760+
761+
def create_model_evalution_project(self,
762+
dataset_name_or_id: str,
763+
append_to_existing_dataset: bool = False,
764+
data_row_count: int = 100,
765+
**kwargs) -> Project:
766+
kwargs["media_type"] = MediaType.Conversational
767+
kwargs["ontology_kind"] = OntologyKind.ModelEvaluation
768+
kwargs["dataset_name_or_id"] = dataset_name_or_id
769+
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
770+
kwargs["data_row_count"] = data_row_count
742771

743772
return self.create_project(**kwargs)
744773

@@ -950,7 +979,7 @@ def create_ontology_from_feature_schemas(
950979
name,
951980
feature_schema_ids,
952981
media_type: MediaType = None,
953-
editor_task_type: EditorTaskType = None) -> Ontology:
982+
ontology_kind: OntologyKind = None) -> Ontology:
954983
"""
955984
Creates an ontology from a list of feature schema ids
956985
@@ -988,10 +1017,11 @@ def create_ontology_from_feature_schemas(
9881017
"Neither `tool` or `classification` found in the normalized feature schema"
9891018
)
9901019
normalized = {'tools': tools, 'classifications': classifications}
1020+
9911021
return self.create_ontology(name=name,
9921022
normalized=normalized,
9931023
media_type=media_type,
994-
editor_task_type=editor_task_type)
1024+
ontology_kind=ontology_kind)
9951025

9961026
def delete_unused_feature_schema(self, feature_schema_id: str) -> None:
9971027
"""
@@ -1182,7 +1212,7 @@ def create_ontology(self,
11821212
name,
11831213
normalized,
11841214
media_type: MediaType = None,
1185-
editor_task_type: EditorTaskType = None) -> Ontology:
1215+
ontology_kind: OntologyKind = None) -> Ontology:
11861216
"""
11871217
Creates an ontology from normalized data
11881218
>>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []}
@@ -1206,16 +1236,17 @@ def create_ontology(self,
12061236

12071237
if media_type:
12081238
if MediaType.is_supported(media_type):
1209-
media_type = media_type.value
1239+
media_type_value = media_type.value
12101240
else:
12111241
raise get_media_type_validation_error(media_type)
12121242

1213-
if editor_task_type:
1214-
if EditorTaskType.is_supported(editor_task_type):
1215-
editor_task_type = editor_task_type.value
1216-
else:
1217-
raise EditorTaskType.get_editor_task_type_validation_error(
1218-
editor_task_type)
1243+
if ontology_kind and OntologyKind.is_supported(ontology_kind):
1244+
editor_task_type_value = EditorTaskTypeMapper.to_editor_task_type(
1245+
ontology_kind, media_type).value
1246+
elif ontology_kind:
1247+
raise OntologyKind.get_ontology_kind_validation_error(ontology_kind)
1248+
else:
1249+
editor_task_type_value = None
12191250

12201251
query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertOntologyInput!){
12211252
upsertOntology(data: $data){ %s }
@@ -1224,68 +1255,15 @@ def create_ontology(self,
12241255
'data': {
12251256
'name': name,
12261257
'normalized': json.dumps(normalized),
1227-
'mediaType': media_type
1258+
'mediaType': media_type_value
12281259
}
12291260
}
1230-
if editor_task_type:
1231-
params['data']['editorTaskType'] = editor_task_type
1261+
if editor_task_type_value:
1262+
params['data']['editorTaskType'] = editor_task_type_value
12321263

12331264
res = self.execute(query_str, params)
12341265
return Entity.Ontology(self, res['upsertOntology'])
12351266

1236-
def create_model_chat_evaluation_ontology(self, name, normalized):
1237-
"""
1238-
Creates a model chat evalutation ontology from normalized data
1239-
>>> normalized = {"tools" : [{'tool': 'message-single-selection', 'name': 'model output single selection', 'color': '#ff0000',},
1240-
{'tool': 'message-multi-selection', 'name': 'model output multi selection', 'color': '#00ff00',},
1241-
{'tool': 'message-ranking', 'name': 'model output multi ranking', 'color': '#0000ff',}]
1242-
}
1243-
>>> ontology = client.create_ontology("ontology-name", normalized)
1244-
1245-
Or use the ontology builder
1246-
>>> ontology_builder = OntologyBuilder(tools=[
1247-
Tool(tool=Tool.Type.MESSAGE_SINGLE_SELECTION,
1248-
name="model output single selection"),
1249-
Tool(tool=Tool.Type.MESSAGE_MULTI_SELECTION,
1250-
name="model output multi selection"),
1251-
Tool(tool=Tool.Type.MESSAGE_RANKING,
1252-
name="model output multi ranking"),
1253-
],)
1254-
1255-
>>> ontology = client.create_model_chat_evaluation_ontology("Multi-chat ontology", ontology_builder.asdict())
1256-
1257-
Args:
1258-
name (str): Name of the ontology
1259-
normalized (dict): A normalized ontology payload. See above for details.
1260-
Returns:
1261-
The created Ontology
1262-
"""
1263-
1264-
return self.create_ontology(
1265-
name=name,
1266-
normalized=normalized,
1267-
media_type=MediaType.Conversational,
1268-
editor_task_type=EditorTaskType.ModelChatEvaluation)
1269-
1270-
def create_model_chat_evaluation_ontology_from_feature_schemas(
1271-
self, name, feature_schema_ids):
1272-
"""
1273-
Creates an ontology from a list of feature schema ids
1274-
1275-
Args:
1276-
name (str): Name of the ontology
1277-
feature_schema_ids (List[str]): List of feature schema ids corresponding to
1278-
top level tools and classifications to include in the ontology
1279-
Returns:
1280-
The created Ontology
1281-
"""
1282-
1283-
return self.create_ontology_from_feature_schemas(
1284-
name=name,
1285-
feature_schema_ids=feature_schema_ids,
1286-
media_type=MediaType.Conversational,
1287-
editor_task_type=EditorTaskType.ModelChatEvaluation)
1288-
12891267
def create_feature_schema(self, normalized):
12901268
"""
12911269
Creates a feature schema from normalized data.

libs/labelbox/src/labelbox/schema/editor_task_type.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from labelbox.schema.resource_tag import ResourceTag
3838
from labelbox.schema.task import Task
3939
from labelbox.schema.task_queue import TaskQueue
40+
from labelbox.schema.ontology_kind import (EditorTaskType)
4041

4142
if TYPE_CHECKING:
4243
from labelbox import BulkImportRequest
@@ -115,6 +116,7 @@ class Project(DbObject, Updateable, Deletable):
115116
auto_audit_percentage = Field.Float("auto_audit_percentage")
116117
# Bind data_type and allowedMediaTYpe using the GraphQL type MediaType
117118
media_type = Field.Enum(MediaType, "media_type", "allowedMediaType")
119+
editor_task_type = Field.Enum(EditorTaskType, "editor_task_type")
118120

119121
# Relationships
120122
created_by = Relationship.ToOne("User", False, "created_by")

libs/labelbox/tests/integration/conftest.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from labelbox.schema.queue_mode import QueueMode
2929
from labelbox.schema.user import User
3030
from labelbox import Client
31+
from labelbox.schema.ontology_kind import OntologyKind
3132

3233

3334
@pytest.fixture
@@ -340,8 +341,8 @@ def _upload_invalid_data_rows_for_dataset(dataset: Dataset):
340341

341342

342343
@pytest.fixture
343-
def model_chat_evaluation_ontology(client, rand_gen):
344-
ontology_name = f"test-model-chat-evaluation-ontology-{rand_gen(str)}"
344+
def chat_evaluation_ontology(client, rand_gen):
345+
ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}"
345346
ontology_builder = OntologyBuilder(tools=[
346347
Tool(tool=Tool.Type.MESSAGE_SINGLE_SELECTION,
347348
name="model output single selection"),
@@ -350,14 +351,29 @@ def model_chat_evaluation_ontology(client, rand_gen):
350351
Tool(tool=Tool.Type.MESSAGE_RANKING, name="model output multi ranking"),
351352
],)
352353

353-
ontology = client.create_model_chat_evaluation_ontology(
354-
ontology_name, ontology_builder.asdict())
354+
ontology = client.create_ontology(
355+
ontology_name,
356+
ontology_builder.asdict(),
357+
media_type=MediaType.Conversational,
358+
ontology_kind=OntologyKind.ModelEvaluation)
355359

356360
yield ontology
357361

358362
client.delete_unused_ontology(ontology.uid)
359363

360364

365+
@pytest.fixture
366+
def chat_evaluation_project(client, rand_gen):
367+
project_name = f"test-model-evaluation-project-{rand_gen(str)}"
368+
dataset_name_or_id = f"test-model-evaluation-dataset-{rand_gen(str)}"
369+
project = client.create_model_evalution_project(
370+
name=project_name, dataset_name_or_id=dataset_name_or_id)
371+
372+
yield project
373+
374+
project.delete()
375+
376+
361377
def pytest_configure():
362378
pytest.report = defaultdict(int)
363379

0 commit comments

Comments
 (0)