Skip to content

Commit ebd828c

Browse files
authored
Converted TypedDict to BaseModel for client.send_to_annotate_from_catelog method (#1608)
1 parent 6936d35 commit ebd828c

File tree

2 files changed

+33
-30
lines changed

2 files changed

+33
-30
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,7 +2002,7 @@ def send_to_annotate_from_catalog(self, destination_project_id: str,
20022002
task_queue_id: Optional[str],
20032003
batch_name: str,
20042004
data_rows: Union[DataRowIds, GlobalKeys],
2005-
params: SendToAnnotateFromCatalogParams):
2005+
params: Dict[str, Any]):
20062006
"""
20072007
Sends data rows from catalog to a specified project for annotation.
20082008
@@ -2033,6 +2033,8 @@ def send_to_annotate_from_catalog(self, destination_project_id: str,
20332033
20342034
"""
20352035

2036+
validated_params = SendToAnnotateFromCatalogParams(**params)
2037+
20362038
mutation_str = """mutation SendToAnnotateFromCatalogPyApi($input: SendToAnnotateFromCatalogInput!) {
20372039
sendToAnnotateFromCatalog(input: $input) {
20382040
taskId
@@ -2044,26 +2046,14 @@ def send_to_annotate_from_catalog(self, destination_project_id: str,
20442046
task_queue_id)
20452047
data_rows_query = self.build_catalog_query(data_rows)
20462048

2047-
source_model_run_id = params.get("source_model_run_id", None)
2048-
predictions_ontology_mapping = params.get(
2049-
"predictions_ontology_mapping", None)
20502049
predictions_input = build_predictions_input(
2051-
predictions_ontology_mapping,
2052-
source_model_run_id) if source_model_run_id else None
2050+
validated_params.predictions_ontology_mapping,
2051+
validated_params.source_model_run_id
2052+
) if validated_params.source_model_run_id else None
20532053

2054-
source_project_id = params.get("source_project_id", None)
2055-
annotations_ontology_mapping = params.get(
2056-
"annotations_ontology_mapping", None)
20572054
annotations_input = build_annotations_input(
2058-
annotations_ontology_mapping,
2059-
source_project_id) if source_project_id else None
2060-
2061-
batch_priority = params.get("batch_priority", 5)
2062-
exclude_data_rows_in_project = params.get(
2063-
"exclude_data_rows_in_project", False)
2064-
override_existing_annotations_rule = params.get(
2065-
"override_existing_annotations_rule",
2066-
ConflictResolutionStrategy.KeepExisting)
2055+
validated_params.annotations_ontology_mapping, validated_params.
2056+
source_project_id) if validated_params.source_project_id else None
20672057

20682058
res = self.execute(
20692059
mutation_str, {
@@ -2072,18 +2062,18 @@ def send_to_annotate_from_catalog(self, destination_project_id: str,
20722062
destination_project_id,
20732063
"batchInput": {
20742064
"batchName": batch_name,
2075-
"batchPriority": batch_priority
2065+
"batchPriority": validated_params.batch_priority
20762066
},
20772067
"destinationTaskQueue":
20782068
destination_task_queue,
20792069
"excludeDataRowsInProject":
2080-
exclude_data_rows_in_project,
2070+
validated_params.exclude_data_rows_in_project,
20812071
"annotationsInput":
20822072
annotations_input,
20832073
"predictionsInput":
20842074
predictions_input,
20852075
"conflictLabelsResolutionStrategy":
2086-
override_existing_annotations_rule,
2076+
validated_params.override_existing_annotations_rule,
20872077
"searchQuery": {
20882078
"scope": None,
20892079
"query": [data_rows_query]

libs/labelbox/src/labelbox/schema/send_to_annotate_params.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from typing import Optional, Dict
44

55
from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy
6+
from labelbox import pydantic_compat
67

78
if sys.version_info >= (3, 8):
89
from typing import TypedDict
910
else:
1011
from typing_extensions import TypedDict
1112

1213

13-
class SendToAnnotateFromCatalogParams(TypedDict):
14+
class SendToAnnotateFromCatalogParams(pydantic_compat.BaseModel):
1415
"""
1516
Extra parameters for sending data rows to a project through catalog. At least one of source_model_run_id or
1617
source_project_id must be provided.
@@ -30,14 +31,26 @@ class SendToAnnotateFromCatalogParams(TypedDict):
3031
:param batch_priority: Optional[int] - The priority of the batch. Defaults to 5.
3132
"""
3233

33-
source_model_run_id: Optional[str]
34-
predictions_ontology_mapping: Optional[Dict[str, str]]
35-
source_project_id: Optional[str]
36-
annotations_ontology_mapping: Optional[Dict[str, str]]
37-
exclude_data_rows_in_project: Optional[bool]
38-
override_existing_annotations_rule: Optional[ConflictResolutionStrategy]
39-
batch_priority: Optional[int]
40-
34+
source_model_run_id: Optional[str] = None
35+
source_project_id: Optional[str] = None
36+
predictions_ontology_mapping: Optional[Dict[str, str]] = {}
37+
annotations_ontology_mapping: Optional[Dict[str, str]] = {}
38+
exclude_data_rows_in_project: Optional[bool] = False
39+
override_existing_annotations_rule: Optional[
40+
ConflictResolutionStrategy] = ConflictResolutionStrategy.KeepExisting
41+
batch_priority: Optional[int] = 5
42+
43+
@pydantic_compat.root_validator
44+
def check_project_id_or_model_run_id(cls, values):
45+
if not values.get("source_model_run_id") and not values.get("source_project_id"):
46+
raise ValueError(
47+
'Either source_project_id or source_model_id are required'
48+
)
49+
if values.get("source_model_run_id") and values.get("source_project_id"):
50+
raise ValueError(
51+
'Provide only a source_project_id or source_model_id not both'
52+
)
53+
return values
4154

4255
class SendToAnnotateFromModelParams(TypedDict):
4356
"""

0 commit comments

Comments
 (0)