Skip to content

Commit 58a3f4c

Browse files
authored
[PLT-1614] Support data row / batch for live mmc projects (#1856)
1 parent 994b6da commit 58a3f4c

12 files changed

+183
-107
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Datarow payload templates
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.data_row_payload_templates
5+
:members:
6+
:show-inheritance:

libs/labelbox/src/labelbox/client.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import time
99
import urllib.parse
10+
import warnings
1011
from collections import defaultdict
1112
from datetime import datetime, timezone
1213
from types import MappingProxyType
@@ -910,11 +911,21 @@ def create_model_evaluation_project(
910911
) -> Project:
911912
pass
912913

914+
@overload
913915
def create_model_evaluation_project(
914916
self,
915917
dataset_id: Optional[str] = None,
916918
dataset_name: Optional[str] = None,
917-
data_row_count: int = 100,
919+
data_row_count: Optional[int] = None,
920+
**kwargs,
921+
) -> Project:
922+
pass
923+
924+
def create_model_evaluation_project(
925+
self,
926+
dataset_id: Optional[str] = None,
927+
dataset_name: Optional[str] = None,
928+
data_row_count: Optional[int] = None,
918929
**kwargs,
919930
) -> Project:
920931
"""
@@ -940,26 +951,38 @@ def create_model_evaluation_project(
940951
>>> client.create_model_evaluation_project(name=project_name, dataset_id="clr00u8j0j0j0", data_row_count=10)
941952
>>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created 10 data rows to the project.
942953
954+
>>> client.create_model_evaluation_project(name=project_name)
955+
>>> This creates a new project with no data rows.
943956
944957
"""
945-
if not dataset_id and not dataset_name:
946-
raise ValueError(
947-
"dataset_name or data_set_id must be present and not be an empty string."
948-
)
949-
if data_row_count <= 0:
950-
raise ValueError("data_row_count must be a positive integer.")
958+
autogenerate_data_rows = False
959+
dataset_name_or_id = None
960+
append_to_existing_dataset = None
961+
962+
if dataset_id or dataset_name:
963+
autogenerate_data_rows = True
951964

952965
if dataset_id:
953966
append_to_existing_dataset = True
954967
dataset_name_or_id = dataset_id
955-
else:
968+
elif dataset_name:
956969
append_to_existing_dataset = False
957970
dataset_name_or_id = dataset_name
958971

972+
if autogenerate_data_rows:
973+
kwargs["dataset_name_or_id"] = dataset_name_or_id
974+
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
975+
if data_row_count is None:
976+
data_row_count = 100
977+
if data_row_count < 0:
978+
raise ValueError("data_row_count must be a positive integer.")
979+
kwargs["data_row_count"] = data_row_count
980+
warnings.warn(
981+
"Automatic generation of data rows of live model evaluation projects is deprecated. dataset_name_or_id, append_to_existing_dataset, data_row_count will be removed in a future version.",
982+
DeprecationWarning,
983+
)
984+
959985
kwargs["media_type"] = MediaType.Conversational
960-
kwargs["dataset_name_or_id"] = dataset_name_or_id
961-
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
962-
kwargs["data_row_count"] = data_row_count
963986
kwargs["editor_task_type"] = EditorTaskType.ModelChatEvaluation.value
964987

965988
return self._create_project(**kwargs)

libs/labelbox/src/labelbox/data/annotation_types/collection.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import logging
2-
from concurrent.futures import ThreadPoolExecutor, as_completed
3-
from typing import Callable, Generator, Iterable, Union, Optional
4-
from uuid import uuid4
52
import warnings
3+
from typing import Callable, Generator, Iterable, Union
64

7-
from tqdm import tqdm
8-
9-
from labelbox.schema import ontology
105
from labelbox.orm.model import Entity
11-
from ..ontology import get_classifications, get_tools
6+
from labelbox.schema import ontology
7+
128
from ..generator import PrefetchGenerator
139
from .label import Label
1410

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Dict, List
2+
3+
from pydantic import BaseModel, Field
4+
5+
from labelbox.schema.data_row import DataRowMetadataField
6+
7+
8+
class ModelEvalutationTemlateRowData(BaseModel):
9+
type: str = Field(
10+
default="application/vnd.labelbox.conversational.model-chat-evaluation",
11+
frozen=True,
12+
)
13+
draft: bool = Field(default=True, frozen=True)
14+
rootMessageIds: List[str] = Field(default=[])
15+
actors: Dict = Field(default={})
16+
version: int = Field(default=2, frozen=True)
17+
messages: Dict = Field(default={})
18+
19+
20+
class ModelEvaluationTemplate(BaseModel):
21+
"""
22+
Use this class to create a model evaluation data row.
23+
24+
Examples:
25+
>>> data = ModelEvaluationTemplate()
26+
>>> data.row_data.rootMessageIds = ["root1"]
27+
>>> vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
28+
>>> data.embeddings = [...]
29+
>>> data.metadata_fields = [...]
30+
>>> data.attachments = [...]
31+
>>> content = data.model_dump()
32+
>>> task = dataset.create_data_rows([content])
33+
"""
34+
35+
row_data: ModelEvalutationTemlateRowData = Field(
36+
default=ModelEvalutationTemlateRowData()
37+
)
38+
attachments: List[Dict] = Field(default=[])
39+
embeddings: List[Dict] = Field(default=[])
40+
metadata_fields: List[DataRowMetadataField] = Field(default=[])

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
22
import logging
3-
from string import Template
43
import time
54
import warnings
65
from collections import namedtuple
76
from datetime import datetime, timezone
87
from pathlib import Path
8+
from string import Template
99
from typing import (
1010
TYPE_CHECKING,
1111
Any,
@@ -14,28 +14,18 @@
1414
List,
1515
Optional,
1616
Tuple,
17-
TypeVar,
1817
Union,
1918
overload,
2019
)
2120
from urllib.parse import urlparse
2221

23-
from labelbox.schema.labeling_service import (
24-
LabelingService,
25-
LabelingServiceStatus,
26-
)
27-
from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard
28-
import requests
29-
30-
from labelbox import parser
3122
from labelbox import utils
32-
from labelbox.exceptions import error_message_for_unparsed_graphql_error
3323
from labelbox.exceptions import (
3424
InvalidQueryError,
3525
LabelboxError,
3626
ProcessingWaitTimeout,
37-
ResourceConflict,
3827
ResourceNotFoundError,
28+
error_message_for_unparsed_graphql_error,
3929
)
4030
from labelbox.orm import query
4131
from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental
@@ -46,30 +36,33 @@
4636
from labelbox.schema.data_row import DataRow
4737
from labelbox.schema.export_filters import (
4838
ProjectExportFilters,
49-
validate_datetime,
5039
build_filters,
5140
)
5241
from labelbox.schema.export_params import ProjectExportParams
5342
from labelbox.schema.export_task import ExportTask
5443
from labelbox.schema.id_type import IdType
5544
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
5645
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
46+
from labelbox.schema.labeling_service import (
47+
LabelingService,
48+
LabelingServiceStatus,
49+
)
50+
from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard
5751
from labelbox.schema.media_type import MediaType
5852
from labelbox.schema.model_config import ModelConfig
59-
from labelbox.schema.project_model_config import ProjectModelConfig
60-
from labelbox.schema.queue_mode import QueueMode
61-
from labelbox.schema.resource_tag import ResourceTag
62-
from labelbox.schema.task import Task
63-
from labelbox.schema.task_queue import TaskQueue
6453
from labelbox.schema.ontology_kind import (
6554
EditorTaskType,
66-
OntologyKind,
6755
UploadType,
6856
)
57+
from labelbox.schema.project_model_config import ProjectModelConfig
6958
from labelbox.schema.project_overview import (
7059
ProjectOverview,
7160
ProjectOverviewDetailed,
7261
)
62+
from labelbox.schema.queue_mode import QueueMode
63+
from labelbox.schema.resource_tag import ResourceTag
64+
from labelbox.schema.task import Task
65+
from labelbox.schema.task_queue import TaskQueue
7366

7467
if TYPE_CHECKING:
7568
from labelbox import BulkImportRequest
@@ -579,7 +572,7 @@ def upsert_instructions(self, instructions_file: str) -> None:
579572

580573
if frontend.name != "Editor":
581574
logger.warning(
582-
f"This function has only been tested to work with the Editor front end. Found %s",
575+
"This function has only been tested to work with the Editor front end. Found %s",
583576
frontend.name,
584577
)
585578

@@ -788,7 +781,9 @@ def create_batch(
788781
if self.queue_mode != QueueMode.Batch:
789782
raise ValueError("Project must be in batch mode")
790783

791-
if self.is_auto_data_generation():
784+
if (
785+
self.is_auto_data_generation() and not self.is_chat_evaluation()
786+
): # NOTE live chat evaluatiuon projects in sdk do not pre-generate data rows, but use batch as all other projects
792787
raise ValueError(
793788
"Cannot create batches for auto data generation projects"
794789
)
@@ -814,7 +809,7 @@ def create_batch(
814809

815810
if row_count > 100_000:
816811
raise ValueError(
817-
f"Batch exceeds max size, break into smaller batches"
812+
"Batch exceeds max size, break into smaller batches"
818813
)
819814
if not row_count:
820815
raise ValueError("You need at least one data row in a batch")
@@ -1088,8 +1083,7 @@ def _create_batch_async(
10881083
task = self._wait_for_task(task_id)
10891084
if task.status != "COMPLETE":
10901085
raise LabelboxError(
1091-
f"Batch was not created successfully: "
1092-
+ json.dumps(task.errors)
1086+
"Batch was not created successfully: " + json.dumps(task.errors)
10931087
)
10941088

10951089
return self.client.get_batch(self.uid, batch_id)
@@ -1436,7 +1430,7 @@ def update_data_row_labeling_priority(
14361430
task = self._wait_for_task(task_id)
14371431
if task.status != "COMPLETE":
14381432
raise LabelboxError(
1439-
f"Priority was not updated successfully: "
1433+
"Priority was not updated successfully: "
14401434
+ json.dumps(task.errors)
14411435
)
14421436
return True
@@ -1629,7 +1623,7 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
16291623
task = self._wait_for_task(task_id)
16301624
if task.status != "COMPLETE":
16311625
raise LabelboxError(
1632-
f"Data rows were not moved successfully: "
1626+
"Data rows were not moved successfully: "
16331627
+ json.dumps(task.errors)
16341628
)
16351629

libs/labelbox/tests/integration/conftest.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,28 @@ def chat_evaluation_ontology(client, rand_gen):
646646

647647

648648
@pytest.fixture
649-
def live_chat_evaluation_project_with_new_dataset(client, rand_gen):
649+
def live_chat_evaluation_project(client, rand_gen):
650650
project_name = f"test-model-evaluation-project-{rand_gen(str)}"
651-
dataset_name = f"test-model-evaluation-dataset-{rand_gen(str)}"
652-
project = client.create_model_evaluation_project(
653-
name=project_name, dataset_name=dataset_name, data_row_count=1
651+
project = client.create_model_evaluation_project(name=project_name)
652+
653+
yield project
654+
655+
project.delete()
656+
657+
658+
@pytest.fixture
659+
def live_chat_evaluation_project_with_batch(
660+
client,
661+
rand_gen,
662+
live_chat_evaluation_project,
663+
offline_conversational_data_row,
664+
):
665+
project_name = f"test-model-evaluation-project-{rand_gen(str)}"
666+
project = client.create_model_evaluation_project(name=project_name)
667+
668+
project.create_batch(
669+
rand_gen(str),
670+
[offline_conversational_data_row.uid], # sample of data row objects
654671
)
655672

656673
yield project

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import pytest
2-
from unittest.mock import patch
32

43
from labelbox import MediaType
54
from labelbox.schema.ontology_kind import OntologyKind
6-
from labelbox.exceptions import MalformedQueryException
75

86

97
def test_create_chat_evaluation_ontology_project(
10-
client,
118
chat_evaluation_ontology,
12-
live_chat_evaluation_project_with_new_dataset,
9+
live_chat_evaluation_project,
1310
offline_conversational_data_row,
1411
rand_gen,
1512
):
@@ -28,36 +25,19 @@ def test_create_chat_evaluation_ontology_project(
2825
assert classification.schema_id
2926
assert classification.feature_schema_id
3027

31-
project = live_chat_evaluation_project_with_new_dataset
28+
project = live_chat_evaluation_project
3229
assert project.model_setup_complete is None
3330

3431
project.connect_ontology(ontology)
3532

3633
assert project.labeling_frontend().name == "Editor"
3734
assert project.ontology().name == ontology.name
3835

39-
with pytest.raises(
40-
ValueError,
41-
match="Cannot create batches for auto data generation projects",
42-
):
43-
project.create_batch(
44-
rand_gen(str),
45-
[offline_conversational_data_row.uid], # sample of data row objects
46-
)
47-
48-
with pytest.raises(
49-
ValueError,
50-
match="Cannot create batches for auto data generation projects",
51-
):
52-
with patch(
53-
"labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0
54-
): # force to async
55-
project.create_batch(
56-
rand_gen(str),
57-
[
58-
offline_conversational_data_row.uid
59-
], # sample of data row objects
60-
)
36+
batch = project.create_batch(
37+
rand_gen(str),
38+
[offline_conversational_data_row.uid], # sample of data row objects
39+
)
40+
assert batch
6141

6242

6343
def test_create_chat_evaluation_ontology_project_existing_dataset(

libs/labelbox/tests/integration/test_data_rows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_create_data_row_with_metadata_dict(
405405
row_data=image_url, metadata_fields=make_metadata_fields_dict
406406
)
407407

408-
assert len(list(dataset.data_rows())) == 1
408+
assert len([dr for dr in dataset.data_rows()]) == 1
409409
assert data_row.dataset() == dataset
410410
assert data_row.created_by() == client.get_user()
411411
assert data_row.organization() == client.get_organization()

0 commit comments

Comments
 (0)