Skip to content

Commit f26397c

Browse files
author
Val Brodsky
committed
Add project validation input
1 parent 29e1567 commit f26397c

File tree

3 files changed

+104
-108
lines changed

3 files changed

+104
-108
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 6 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from labelbox.orm.db_object import DbObject
2424
from labelbox.orm.model import Entity, Field
2525
from labelbox.pagination import PaginatedCollection
26+
from labelbox.project_validation import _CoreProjectInput
2627
from labelbox.schema import role
2728
from labelbox.schema.catalog import Catalog
2829
from labelbox.schema.data_row import DataRow
@@ -632,7 +633,8 @@ def create_project(self, **kwargs) -> Project:
632633
kwargs.pop("append_to_existing_dataset", None)
633634
kwargs.pop("data_row_count", None)
634635
kwargs.pop("editor_task_type", None)
635-
return self._create_project(**kwargs)
636+
input = _CoreProjectInput(**kwargs)
637+
return self._create_project(input)
636638

637639
@overload
638640
def create_model_evaluation_project(
@@ -820,103 +822,10 @@ def create_response_creation_project(self, **kwargs) -> Project:
820822

821823
return self._create_project(**kwargs)
822824

823-
def _create_project(self, **kwargs) -> Project:
824-
auto_audit_percentage = kwargs.get("auto_audit_percentage")
825-
auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels")
826-
if (
827-
auto_audit_percentage is not None
828-
or auto_audit_number_of_labels is not None
829-
):
830-
raise ValueError(
831-
"quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels."
832-
)
833-
834-
name = kwargs.get("name")
835-
if name is None or not name.strip():
836-
raise ValueError("project name must be a valid string.")
837-
838-
queue_mode = kwargs.get("queue_mode")
839-
if queue_mode is QueueMode.Dataset:
840-
raise ValueError(
841-
"Dataset queue mode is deprecated. Please prefer Batch queue mode."
842-
)
843-
elif queue_mode is QueueMode.Batch:
844-
logger.warning(
845-
"Passing a queue mode of batch is redundant and will soon no longer be supported."
846-
)
847-
848-
media_type = kwargs.get("media_type")
849-
if media_type and MediaType.is_supported(media_type):
850-
media_type_value = media_type.value
851-
elif media_type:
852-
raise TypeError(
853-
f"{media_type} is not a valid media type. Use"
854-
f" any of {MediaType.get_supported_members()}"
855-
" from MediaType. Example: MediaType.Image."
856-
)
857-
else:
858-
logger.warning(
859-
"Creating a project without specifying media_type"
860-
" through this method will soon no longer be supported."
861-
)
862-
media_type_value = None
863-
864-
quality_modes = kwargs.get("quality_modes")
865-
quality_mode = kwargs.get("quality_mode")
866-
if quality_mode:
867-
logger.warning(
868-
"Passing quality_mode is deprecated and will soon no longer be supported. Use quality_modes instead."
869-
)
870-
871-
if quality_modes and quality_mode:
872-
raise ValueError(
873-
"Cannot use both quality_modes and quality_mode at the same time. Use one or the other."
874-
)
875-
876-
if not quality_modes and not quality_mode:
877-
logger.info("Defaulting quality modes to Benchmark and Consensus.")
878-
879-
data = kwargs
880-
data.pop("quality_modes", None)
881-
data.pop("quality_mode", None)
882-
883-
# check if quality_modes is a set, if not, convert to set
884-
quality_modes_set = quality_modes
885-
if quality_modes and not isinstance(quality_modes, set):
886-
quality_modes_set = set(quality_modes)
887-
if quality_mode:
888-
quality_modes_set = {quality_mode}
889-
890-
if (
891-
quality_modes_set is None
892-
or len(quality_modes_set) == 0
893-
or quality_modes_set
894-
== {QualityMode.Benchmark, QualityMode.Consensus}
895-
):
896-
data["auto_audit_number_of_labels"] = (
897-
CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS
898-
)
899-
data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE
900-
data["is_benchmark_enabled"] = True
901-
data["is_consensus_enabled"] = True
902-
elif quality_modes_set == {QualityMode.Benchmark}:
903-
data["auto_audit_number_of_labels"] = (
904-
BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS
905-
)
906-
data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE
907-
data["is_benchmark_enabled"] = True
908-
elif quality_modes_set == {QualityMode.Consensus}:
909-
data["auto_audit_number_of_labels"] = (
910-
CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS
911-
)
912-
data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE
913-
data["is_consensus_enabled"] = True
914-
else:
915-
raise ValueError(
916-
f"{quality_modes_set} is not a valid quality modes set. Allowed values are [Benchmark, Consensus]"
917-
)
825+
def _create_project(self, input: _CoreProjectInput) -> Project:
826+
media_type_value = input.media_type.value
918827

919-
params = {**data}
828+
params = input.model_dump(exclude_none=True)
920829
if media_type_value:
921830
params["media_type"] = media_type_value
922831

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Optional, Set
2+
3+
from pydantic import BaseModel, ConfigDict, Field, model_validator
4+
5+
from labelbox.schema.media_type import MediaType
6+
from labelbox.schema.quality_mode import (
7+
BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS,
8+
BENCHMARK_AUTO_AUDIT_PERCENTAGE,
9+
CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
10+
CONSENSUS_AUTO_AUDIT_PERCENTAGE,
11+
QualityMode,
12+
)
13+
from labelbox.schema.queue_mode import QueueMode
14+
15+
16+
class _CoreProjectInput(BaseModel):
17+
name: str
18+
description: Optional[str] = None
19+
media_type: MediaType
20+
queue_mode: QueueMode = Field(default=QueueMode.Batch, frozen=True)
21+
auto_audit_percentage: Optional[float] = None
22+
auto_audit_number_of_labels: Optional[int] = None
23+
quality_modes: Optional[Set[QualityMode]] = Field(
24+
default={QualityMode.Benchmark, QualityMode.Consensus}, exclude=True
25+
)
26+
is_benchmark_enabled: Optional[bool] = None
27+
is_consensus_enabled: Optional[bool] = None
28+
dataset_name_or_id: Optional[str] = None
29+
append_to_existing_dataset: Optional[bool] = None
30+
31+
model_config = ConfigDict(extra="forbid")
32+
33+
@model_validator(mode="after")
34+
def validate_fields(self):
35+
if (
36+
self.auto_audit_percentage is not None
37+
and self.auto_audit_number_of_labels is not None
38+
):
39+
raise ValueError(
40+
"quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels."
41+
)
42+
43+
if not self.name.strip():
44+
raise ValueError("project name must be a valid string.")
45+
46+
if self.quality_modes == {
47+
QualityMode.Benchmark,
48+
QualityMode.Consensus,
49+
}:
50+
self._set_quality_mode_attributes(
51+
CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
52+
CONSENSUS_AUTO_AUDIT_PERCENTAGE,
53+
is_benchmark_enabled=True,
54+
is_consensus_enabled=True,
55+
)
56+
elif self.quality_modes == {QualityMode.Benchmark}:
57+
self._set_quality_mode_attributes(
58+
BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS,
59+
BENCHMARK_AUTO_AUDIT_PERCENTAGE,
60+
is_benchmark_enabled=True,
61+
)
62+
elif self.quality_modes == {QualityMode.Consensus}:
63+
self._set_quality_mode_attributes(
64+
data,
65+
CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
66+
CONSENSUS_AUTO_AUDIT_PERCENTAGE,
67+
is_consensus_enabled=True,
68+
)
69+
70+
return self
71+
72+
def _set_quality_mode_attributes(
73+
self,
74+
number_of_labels,
75+
percentage,
76+
is_benchmark_enabled=False,
77+
is_consensus_enabled=False,
78+
):
79+
self.auto_audit_number_of_labels = number_of_labels
80+
self.auto_audit_percentage = percentage
81+
self.is_benchmark_enabled = is_benchmark_enabled
82+
self.is_consensus_enabled = is_consensus_enabled

libs/labelbox/tests/integration/test_project.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lbox.exceptions import InvalidQueryError
88

99
from labelbox import Dataset, LabelingFrontend, Project
10+
from labelbox.schema import media_type
1011
from labelbox.schema.media_type import MediaType
1112
from labelbox.schema.quality_mode import QualityMode
1213
from labelbox.schema.queue_mode import QueueMode
@@ -51,7 +52,7 @@ def data_for_project_test(client, rand_gen):
5152
def _create_project(name: str = None):
5253
if name is None:
5354
name = rand_gen(str)
54-
project = client.create_project(name=name)
55+
project = client.create_project(name=name, media_type=MediaType.Image)
5556
projects.append(project)
5657
return project
5758

@@ -140,10 +141,6 @@ def test_extend_reservations(project):
140141
project.extend_reservations("InvalidQueueType")
141142

142143

143-
@pytest.mark.skipif(
144-
condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem",
145-
reason="new mutation does not work for onprem",
146-
)
147144
def test_attach_instructions(client, project):
148145
with pytest.raises(ValueError) as execinfo:
149146
project.upsert_instructions("tests/integration/media/sample_pdf.pdf")
@@ -248,9 +245,11 @@ def test_media_type(client, project: Project, rand_gen):
248245
assert isinstance(project.media_type, MediaType)
249246

250247
# Update test
251-
project = client.create_project(name=rand_gen(str))
252-
project.update(media_type=MediaType.Image)
253-
assert project.media_type == MediaType.Image
248+
project = client.create_project(
249+
name=rand_gen(str), media_type=MediaType.Image
250+
)
251+
project.update(media_type=MediaType.Text)
252+
assert project.media_type == MediaType.Text
254253
project.delete()
255254

256255
for media_type in MediaType.get_supported_members():
@@ -271,27 +270,33 @@ def test_media_type(client, project: Project, rand_gen):
271270

272271
def test_queue_mode(client, rand_gen):
273272
project = client.create_project(
274-
name=rand_gen(str)
273+
name=rand_gen(str),
274+
media_type=MediaType.Image,
275275
) # defaults to benchmark and consensus
276276
assert project.auto_audit_number_of_labels == 3
277277
assert project.auto_audit_percentage == 0
278278

279279
project = client.create_project(
280-
name=rand_gen(str), quality_modes=[QualityMode.Benchmark]
280+
name=rand_gen(str),
281+
quality_modes=[QualityMode.Benchmark],
282+
media_type=MediaType.Image,
281283
)
282284
assert project.auto_audit_number_of_labels == 1
283285
assert project.auto_audit_percentage == 1
284286

285287
project = client.create_project(
286288
name=rand_gen(str),
287289
quality_modes=[QualityMode.Benchmark, QualityMode.Consensus],
290+
media_type=MediaType.Image,
288291
)
289292
assert project.auto_audit_number_of_labels == 3
290293
assert project.auto_audit_percentage == 0
291294

292295

293296
def test_label_count(client, configured_batch_project_with_label):
294-
project = client.create_project(name="test label count")
297+
project = client.create_project(
298+
name="test label count", media_type=MediaType.Image
299+
)
295300
assert project.get_label_count() == 0
296301
project.delete()
297302

0 commit comments

Comments
 (0)