Skip to content

Commit 63aba36

Browse files
vbrodskyVal Brodsky
authored andcommitted
[PLT-1614] Support data row / batch for live mmc projects (#1856)
1 parent 61dc169 commit 63aba36

12 files changed

+173
-88
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Datarow payload templates
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.data_row_payload_templates
5+
:members:
6+
:show-inheritance:

libs/labelbox/src/labelbox/client.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import time
88
import urllib.parse
9+
import warnings
910
from collections import defaultdict
1011
from datetime import datetime, timezone
1112
from types import MappingProxyType
@@ -637,6 +638,7 @@ def create_project(
637638
}
638639
return self._create_project(_CoreProjectInput(**input))
639640

641+
@overload
640642
def create_model_evaluation_project(
641643
self,
642644
name: str,
@@ -649,7 +651,17 @@ def create_model_evaluation_project(
649651
is_consensus_enabled: Optional[bool] = None,
650652
dataset_id: Optional[str] = None,
651653
dataset_name: Optional[str] = None,
652-
data_row_count: int = 100,
654+
data_row_count: Optional[int] = None,
655+
**kwargs,
656+
) -> Project:
657+
pass
658+
659+
def create_model_evaluation_project(
660+
self,
661+
dataset_id: Optional[str] = None,
662+
dataset_name: Optional[str] = None,
663+
data_row_count: Optional[int] = None,
664+
**kwargs,
653665
) -> Project:
654666
"""
655667
Use this method exclusively to create a chat model evaluation project.
@@ -674,22 +686,39 @@ def create_model_evaluation_project(
674686
>>> client.create_model_evaluation_project(name=project_name, dataset_id="clr00u8j0j0j0", data_row_count=10)
675687
>>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created 10 data rows to the project.
676688
689+
>>> client.create_model_evaluation_project(name=project_name)
690+
>>> This creates a new project with no data rows.
677691
678692
"""
679-
if not dataset_id and not dataset_name:
680-
raise ValueError(
681-
"dataset_name or data_set_id must be present and not be an empty string."
682-
)
693+
autogenerate_data_rows = False
694+
dataset_name_or_id = None
695+
append_to_existing_dataset = None
696+
697+
if dataset_id or dataset_name:
698+
autogenerate_data_rows = True
683699

684700
if dataset_id:
685701
append_to_existing_dataset = True
686702
dataset_name_or_id = dataset_id
687-
else:
703+
elif dataset_name:
688704
append_to_existing_dataset = False
689705
dataset_name_or_id = dataset_name
690706

691-
media_type = MediaType.Conversational
692-
editor_task_type = EditorTaskType.ModelChatEvaluation
707+
if autogenerate_data_rows:
708+
kwargs["dataset_name_or_id"] = dataset_name_or_id
709+
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
710+
if data_row_count is None:
711+
data_row_count = 100
712+
if data_row_count < 0:
713+
raise ValueError("data_row_count must be a positive integer.")
714+
kwargs["data_row_count"] = data_row_count
715+
warnings.warn(
716+
"Automatic generation of data rows of live model evaluation projects is deprecated. dataset_name_or_id, append_to_existing_dataset, data_row_count will be removed in a future version.",
717+
DeprecationWarning,
718+
)
719+
720+
kwargs["media_type"] = MediaType.Conversational
721+
kwargs["editor_task_type"] = EditorTaskType.ModelChatEvaluation.value
693722

694723
input = {
695724
"name": name,

libs/labelbox/src/labelbox/data/annotation_types/collection.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import logging
2-
from concurrent.futures import ThreadPoolExecutor, as_completed
3-
from typing import Callable, Generator, Iterable, Union, Optional
4-
from uuid import uuid4
52
import warnings
3+
from typing import Callable, Generator, Iterable, Union
64

7-
from tqdm import tqdm
8-
9-
from labelbox.schema import ontology
105
from labelbox.orm.model import Entity
11-
from ..ontology import get_classifications, get_tools
6+
from labelbox.schema import ontology
7+
128
from ..generator import PrefetchGenerator
139
from .label import Label
1410

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Dict, List
2+
3+
from pydantic import BaseModel, Field
4+
5+
from labelbox.schema.data_row import DataRowMetadataField
6+
7+
8+
class ModelEvalutationTemlateRowData(BaseModel):
9+
type: str = Field(
10+
default="application/vnd.labelbox.conversational.model-chat-evaluation",
11+
frozen=True,
12+
)
13+
draft: bool = Field(default=True, frozen=True)
14+
rootMessageIds: List[str] = Field(default=[])
15+
actors: Dict = Field(default={})
16+
version: int = Field(default=2, frozen=True)
17+
messages: Dict = Field(default={})
18+
19+
20+
class ModelEvaluationTemplate(BaseModel):
21+
"""
22+
Use this class to create a model evaluation data row.
23+
24+
Examples:
25+
>>> data = ModelEvaluationTemplate()
26+
>>> data.row_data.rootMessageIds = ["root1"]
27+
>>> vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
28+
>>> data.embeddings = [...]
29+
>>> data.metadata_fields = [...]
30+
>>> data.attachments = [...]
31+
>>> content = data.model_dump()
32+
>>> task = dataset.create_data_rows([content])
33+
"""
34+
35+
row_data: ModelEvalutationTemlateRowData = Field(
36+
default=ModelEvalutationTemlateRowData()
37+
)
38+
attachments: List[Dict] = Field(default=[])
39+
embeddings: List[Dict] = Field(default=[])
40+
metadata_fields: List[DataRowMetadataField] = Field(default=[])

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
List,
1414
Optional,
1515
Tuple,
16-
TypeVar,
1716
Union,
1817
get_args,
1918
)
@@ -23,7 +22,6 @@
2322
InvalidQueryError,
2423
LabelboxError,
2524
ProcessingWaitTimeout,
26-
ResourceConflict,
2725
ResourceNotFoundError,
2826
error_message_for_unparsed_graphql_error,
2927
) # type: ignore
@@ -59,7 +57,6 @@
5957
from labelbox.schema.model_config import ModelConfig
6058
from labelbox.schema.ontology_kind import (
6159
EditorTaskType,
62-
OntologyKind,
6360
UploadType,
6461
)
6562
from labelbox.schema.project_model_config import ProjectModelConfig
@@ -577,7 +574,7 @@ def upsert_instructions(self, instructions_file: str) -> None:
577574

578575
if frontend.name != "Editor":
579576
logger.warning(
580-
f"This function has only been tested to work with the Editor front end. Found %s",
577+
"This function has only been tested to work with the Editor front end. Found %s",
581578
frontend.name,
582579
)
583580

@@ -745,7 +742,9 @@ def create_batch(
745742
lbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows
746743
"""
747744

748-
if self.is_auto_data_generation():
745+
if (
746+
self.is_auto_data_generation() and not self.is_chat_evaluation()
747+
): # NOTE live chat evaluatiuon projects in sdk do not pre-generate data rows, but use batch as all other projects
749748
raise ValueError(
750749
"Cannot create batches for auto data generation projects"
751750
)
@@ -771,7 +770,7 @@ def create_batch(
771770

772771
if row_count > 100_000:
773772
raise ValueError(
774-
f"Batch exceeds max size, break into smaller batches"
773+
"Batch exceeds max size, break into smaller batches"
775774
)
776775
if not row_count:
777776
raise ValueError("You need at least one data row in a batch")
@@ -1039,8 +1038,7 @@ def _create_batch_async(
10391038
task = self._wait_for_task(task_id)
10401039
if task.status != "COMPLETE":
10411040
raise LabelboxError(
1042-
f"Batch was not created successfully: "
1043-
+ json.dumps(task.errors)
1041+
"Batch was not created successfully: " + json.dumps(task.errors)
10441042
)
10451043

10461044
return self.client.get_batch(self.uid, batch_id)
@@ -1262,7 +1260,7 @@ def update_data_row_labeling_priority(
12621260
task = self._wait_for_task(task_id)
12631261
if task.status != "COMPLETE":
12641262
raise LabelboxError(
1265-
f"Priority was not updated successfully: "
1263+
"Priority was not updated successfully: "
12661264
+ json.dumps(task.errors)
12671265
)
12681266
return True
@@ -1442,7 +1440,7 @@ def move_data_rows_to_task_queue(
14421440
task = self._wait_for_task(task_id)
14431441
if task.status != "COMPLETE":
14441442
raise LabelboxError(
1445-
f"Data rows were not moved successfully: "
1443+
"Data rows were not moved successfully: "
14461444
+ json.dumps(task.errors)
14471445
)
14481446

libs/labelbox/tests/integration/conftest.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,28 @@ def chat_evaluation_ontology(client, rand_gen):
631631

632632

633633
@pytest.fixture
634-
def live_chat_evaluation_project_with_new_dataset(client, rand_gen):
634+
def live_chat_evaluation_project(client, rand_gen):
635635
project_name = f"test-model-evaluation-project-{rand_gen(str)}"
636-
dataset_name = f"test-model-evaluation-dataset-{rand_gen(str)}"
637-
project = client.create_model_evaluation_project(
638-
name=project_name, dataset_name=dataset_name, data_row_count=1
636+
project = client.create_model_evaluation_project(name=project_name)
637+
638+
yield project
639+
640+
project.delete()
641+
642+
643+
@pytest.fixture
644+
def live_chat_evaluation_project_with_batch(
645+
client,
646+
rand_gen,
647+
live_chat_evaluation_project,
648+
offline_conversational_data_row,
649+
):
650+
project_name = f"test-model-evaluation-project-{rand_gen(str)}"
651+
project = client.create_model_evaluation_project(name=project_name)
652+
653+
project.create_batch(
654+
rand_gen(str),
655+
[offline_conversational_data_row.uid], # sample of data row objects
639656
)
640657

641658
yield project

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
from unittest.mock import patch
2-
31
import pytest
42

53
from labelbox import MediaType
64
from labelbox.schema.ontology_kind import OntologyKind
75

86

97
def test_create_chat_evaluation_ontology_project(
10-
client,
118
chat_evaluation_ontology,
12-
live_chat_evaluation_project_with_new_dataset,
9+
live_chat_evaluation_project,
1310
offline_conversational_data_row,
1411
rand_gen,
1512
):
@@ -28,36 +25,19 @@ def test_create_chat_evaluation_ontology_project(
2825
assert classification.schema_id
2926
assert classification.feature_schema_id
3027

31-
project = live_chat_evaluation_project_with_new_dataset
28+
project = live_chat_evaluation_project
3229
assert project.model_setup_complete is None
3330

3431
project.connect_ontology(ontology)
3532

3633
assert project.labeling_frontend().name == "Editor"
3734
assert project.ontology().name == ontology.name
3835

39-
with pytest.raises(
40-
ValueError,
41-
match="Cannot create batches for auto data generation projects",
42-
):
43-
project.create_batch(
44-
rand_gen(str),
45-
[offline_conversational_data_row.uid], # sample of data row objects
46-
)
47-
48-
with pytest.raises(
49-
ValueError,
50-
match="Cannot create batches for auto data generation projects",
51-
):
52-
with patch(
53-
"labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0
54-
): # force to async
55-
project.create_batch(
56-
rand_gen(str),
57-
[
58-
offline_conversational_data_row.uid
59-
], # sample of data row objects
60-
)
36+
batch = project.create_batch(
37+
rand_gen(str),
38+
[offline_conversational_data_row.uid], # sample of data row objects
39+
)
40+
assert batch
6141

6242

6343
def test_create_chat_evaluation_ontology_project_existing_dataset(

libs/labelbox/tests/integration/test_data_rows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_create_data_row_with_metadata_dict(
405405
row_data=image_url, metadata_fields=make_metadata_fields_dict
406406
)
407407

408-
assert len(list(dataset.data_rows())) == 1
408+
assert len([dr for dr in dataset.data_rows()]) == 1
409409
assert data_row.dataset() == dataset
410410
assert data_row.created_by() == client.get_user()
411411
assert data_row.organization() == client.get_organization()

libs/labelbox/tests/integration/test_labeling_service.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ def test_request_labeling_service_moe_offline_project(
4343

4444

4545
def test_request_labeling_service_moe_project(
46-
rand_gen,
47-
live_chat_evaluation_project_with_new_dataset,
46+
live_chat_evaluation_project_with_batch,
4847
chat_evaluation_ontology,
4948
model_config,
5049
):
51-
project = live_chat_evaluation_project_with_new_dataset
50+
project = live_chat_evaluation_project_with_batch
5251
project.connect_ontology(chat_evaluation_ontology)
5352

5453
project.upsert_instructions("tests/integration/media/sample_pdf.pdf")

0 commit comments

Comments
 (0)