Skip to content

Commit ae80326

Browse files
GabefireVal Brodsky
authored andcommitted
[PLT-1490] Removed data row ids list on some project methods and removed get queue_modes (#1852)
1 parent aa5c5ee commit ae80326

File tree

5 files changed

+33
-181
lines changed

5 files changed

+33
-181
lines changed

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 23 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Optional,
1414
Tuple,
1515
Union,
16-
overload,
16+
get_args,
1717
)
1818

1919
from lbox.exceptions import (
@@ -40,7 +40,11 @@
4040
from labelbox.schema.export_task import ExportTask
4141
from labelbox.schema.id_type import IdType
4242
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
43-
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
43+
from labelbox.schema.identifiables import (
44+
DataRowIdentifiers,
45+
GlobalKeys,
46+
UniqueIds,
47+
)
4448
from labelbox.schema.labeling_service import (
4549
LabelingService,
4650
LabelingServiceStatus,
@@ -67,9 +71,7 @@
6771

6872

6973
DataRowPriority = int
70-
LabelingParameterOverrideInput = Tuple[
71-
Union[DataRow, DataRowIdentifier], DataRowPriority
72-
]
74+
LabelingParameterOverrideInput = Tuple[DataRowIdentifier, DataRowPriority]
7375

7476
logger = logging.getLogger(__name__)
7577
MAX_SYNC_BATCH_ROW_COUNT = 1_000
@@ -79,23 +81,18 @@ def validate_labeling_parameter_overrides(
7981
data: List[LabelingParameterOverrideInput],
8082
) -> None:
8183
for idx, row in enumerate(data):
82-
if len(row) < 2:
83-
raise TypeError(
84-
f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
85-
)
8684
data_row_identifier = row[0]
8785
priority = row[1]
88-
valid_types = (Entity.DataRow, UniqueId, GlobalKey)
89-
if not isinstance(data_row_identifier, valid_types):
86+
if not isinstance(data_row_identifier, get_args(DataRowIdentifier)):
9087
raise TypeError(
91-
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
88+
f"Data row identifier should be of type DataRowIdentifier. Found {type(data_row_identifier)}."
89+
)
90+
if len(row) < 2:
91+
raise TypeError(
92+
f"Data must be a list of tuples each containing two elements: a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
9293
)
93-
9494
if not isinstance(priority, int):
95-
if isinstance(data_row_identifier, Entity.DataRow):
96-
id = data_row_identifier.uid
97-
else:
98-
id = data_row_identifier
95+
id = data_row_identifier.key
9996
raise TypeError(
10097
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
10198
)
@@ -1048,57 +1045,6 @@ def _create_batch_async(
10481045

10491046
return self.client.get_batch(self.uid, batch_id)
10501047

1051-
def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
1052-
"""
1053-
Updates the queueing mode of this project.
1054-
1055-
Deprecation notice: This method is deprecated. Going forward, projects must
1056-
go through a migration to have the queue mode changed. Users should specify the
1057-
queue mode for a project during creation if a non-default mode is desired.
1058-
1059-
For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes
1060-
1061-
Args:
1062-
mode: the specified queue mode
1063-
1064-
Returns: the updated queueing mode of this project
1065-
1066-
"""
1067-
1068-
logger.warning(
1069-
"Updating the queue_mode for a project will soon no longer be supported."
1070-
)
1071-
1072-
if self.queue_mode == mode:
1073-
return mode
1074-
1075-
if mode == QueueMode.Batch:
1076-
status = "ENABLED"
1077-
elif mode == QueueMode.Dataset:
1078-
status = "DISABLED"
1079-
else:
1080-
raise ValueError(
1081-
"Must provide either `BATCH` or `DATASET` as a mode"
1082-
)
1083-
1084-
query_str = (
1085-
"""mutation %s($projectId: ID!, $status: TagSetStatusInput!) {
1086-
project(where: {id: $projectId}) {
1087-
setTagSetStatus(input: {tagSetStatus: $status}) {
1088-
tagSetStatus
1089-
}
1090-
}
1091-
}
1092-
"""
1093-
% "setTagSetStatusPyApi"
1094-
)
1095-
1096-
self.client.execute(
1097-
query_str, {"projectId": self.uid, "status": status}
1098-
)
1099-
1100-
return mode
1101-
11021048
def get_label_count(self) -> int:
11031049
"""
11041050
Returns: the total number of labels in this project.
@@ -1113,46 +1059,6 @@ def get_label_count(self) -> int:
11131059
res = self.client.execute(query_str, {"projectId": self.uid})
11141060
return res["project"]["labelCount"]
11151061

1116-
def get_queue_mode(self) -> "QueueMode":
1117-
"""
1118-
Provides the queue mode used for this project.
1119-
1120-
Deprecation notice: This method is deprecated and will be removed in
1121-
a future version. To obtain the queue mode of a project, simply refer
1122-
to the queue_mode attribute of a Project.
1123-
1124-
For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes
1125-
1126-
Returns: the QueueMode for this project
1127-
1128-
"""
1129-
1130-
logger.warning(
1131-
"Obtaining the queue_mode for a project through this method will soon"
1132-
" no longer be supported."
1133-
)
1134-
1135-
query_str = (
1136-
"""query %s($projectId: ID!) {
1137-
project(where: {id: $projectId}) {
1138-
tagSetStatus
1139-
}
1140-
}
1141-
"""
1142-
% "GetTagSetStatusPyApi"
1143-
)
1144-
1145-
status = self.client.execute(query_str, {"projectId": self.uid})[
1146-
"project"
1147-
]["tagSetStatus"]
1148-
1149-
if status == "ENABLED":
1150-
return QueueMode.Batch
1151-
elif status == "DISABLED":
1152-
return QueueMode.Dataset
1153-
else:
1154-
raise ValueError("Status not known")
1155-
11561062
def add_model_config(self, model_config_id: str) -> str:
11571063
"""Adds a model config to this project.
11581064
@@ -1245,18 +1151,13 @@ def set_labeling_parameter_overrides(
12451151
See information on priority here:
12461152
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
12471153
1248-
>>> project.set_labeling_parameter_overrides([
1249-
>>> (data_row_id1, 2), (data_row_id2, 1)])
1250-
or
12511154
>>> project.set_labeling_parameter_overrides([
12521155
>>> (data_row_gk1, 2), (data_row_gk2, 1)])
12531156
12541157
Args:
12551158
data (iterable): An iterable of tuples. Each tuple must contain
1256-
either (DataRow, DataRowPriority<int>)
1257-
or (DataRowIdentifier, priority<int>) for the new override.
1159+
(DataRowIdentifier, priority<int>) for the new override.
12581160
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
1259-
NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.
12601161
12611162
Priority:
12621163
* Data will be labeled in priority order.
@@ -1285,43 +1186,18 @@ def set_labeling_parameter_overrides(
12851186

12861187
data_rows_with_identifiers = ""
12871188
for data_row, priority in data:
1288-
if isinstance(data_row, DataRow):
1289-
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},'
1290-
elif isinstance(data_row, UniqueId) or isinstance(
1291-
data_row, GlobalKey
1292-
):
1293-
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'
1294-
else:
1295-
raise TypeError(
1296-
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
1297-
)
1189+
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'
12981190

12991191
query_str = template.substitute(
13001192
dataWithDataRowIdentifiers=data_rows_with_identifiers
13011193
)
13021194
res = self.client.execute(query_str, {"projectId": self.uid})
13031195
return res["project"]["setLabelingParameterOverrides"]["success"]
13041196

1305-
@overload
13061197
def update_data_row_labeling_priority(
13071198
self,
13081199
data_rows: DataRowIdentifiers,
13091200
priority: int,
1310-
) -> bool:
1311-
pass
1312-
1313-
@overload
1314-
def update_data_row_labeling_priority(
1315-
self,
1316-
data_rows: List[str],
1317-
priority: int,
1318-
) -> bool:
1319-
pass
1320-
1321-
def update_data_row_labeling_priority(
1322-
self,
1323-
data_rows,
1324-
priority: int,
13251201
) -> bool:
13261202
"""
13271203
Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be
@@ -1331,16 +1207,16 @@ def update_data_row_labeling_priority(
13311207
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
13321208
13331209
Args:
1334-
data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object
1210+
data_rows: data row identifiers object to update priorities.
13351211
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
13361212
priority (int): Priority for the new override. See above for more information.
13371213
13381214
Returns:
13391215
bool, indicates if the operation was a success.
13401216
"""
13411217

1342-
if isinstance(data_rows, list):
1343-
data_rows = UniqueIds(data_rows)
1218+
if not isinstance(data_rows, get_args(DataRowIdentifiers)):
1219+
raise TypeError("data_rows must be a DataRowIdentifiers object")
13441220

13451221
method = "createQueuePriorityUpdateTask"
13461222
priority_param = "priority"
@@ -1483,34 +1359,25 @@ def task_queues(self) -> List[TaskQueue]:
14831359
for field_values in task_queue_values
14841360
]
14851361

1486-
@overload
14871362
def move_data_rows_to_task_queue(
14881363
self, data_row_ids: DataRowIdentifiers, task_queue_id: str
14891364
):
1490-
pass
1491-
1492-
@overload
1493-
def move_data_rows_to_task_queue(
1494-
self, data_row_ids: List[str], task_queue_id: str
1495-
):
1496-
pass
1497-
1498-
def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
14991365
"""
15001366
15011367
Moves data rows to the specified task queue.
15021368
15031369
Args:
1504-
data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object
1370+
data_row_ids: a list of data row ids to be moved. This should be a DataRowIdentifiers object
15051371
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
15061372
task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue
15071373
15081374
Returns:
15091375
None if successful, or a raised error on failure
15101376
15111377
"""
1512-
if isinstance(data_row_ids, list):
1513-
data_row_ids = UniqueIds(data_row_ids)
1378+
1379+
if not isinstance(data_row_ids, get_args(DataRowIdentifiers)):
1380+
raise TypeError("data_rows must be a DataRowIdentifiers object")
15141381

15151382
method = "createBulkAddRowsToQueueTask"
15161383
query_str = (

libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from labelbox import Project, Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label
12+
from labelbox import UniqueIds
1213

1314
IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg"
1415

@@ -128,7 +129,9 @@ def test_with_date_filters(
128129
review_queue = next(
129130
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
130131
)
131-
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
132+
project.move_data_rows_to_task_queue(
133+
UniqueIds([data_row.uid]), review_queue.uid
134+
)
132135
export_task = project_export(
133136
project, task_name, filters=filters, params=params
134137
)

libs/labelbox/tests/integration/test_labeling_parameter_overrides.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
2323
data_rows[2].uid,
2424
}
2525

26-
data = [(data_rows[0], 4, 2), (data_rows[1], 3)]
26+
data = [(UniqueId(data_rows[0].uid), 4, 2), (UniqueId(data_rows[1].uid), 3)]
2727
success = project.set_labeling_parameter_overrides(data)
2828
assert success
2929

@@ -60,7 +60,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
6060
assert {o.priority for o in updated_overrides} == {2, 3, 4}
6161

6262
with pytest.raises(TypeError) as exc_info:
63-
data = [(data_rows[2], "a_string", 3)]
63+
data = [(UniqueId(data_rows[2].uid), "a_string", 3)]
6464
project.set_labeling_parameter_overrides(data)
6565
assert (
6666
str(exc_info.value)
@@ -72,7 +72,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
7272
project.set_labeling_parameter_overrides(data)
7373
assert (
7474
str(exc_info.value)
75-
== f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
75+
== "Data row identifier should be of type DataRowIdentifier. Found <class 'str'>."
7676
)
7777

7878

@@ -85,13 +85,6 @@ def test_set_labeling_priority(consensus_project_with_batch):
8585
assert len(init_labeling_parameter_overrides) == 3
8686
assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5}
8787

88-
data = [data_row.uid for data_row in data_rows]
89-
success = project.update_data_row_labeling_priority(data, 1)
90-
lo = list(project.labeling_parameter_overrides())
91-
assert success
92-
assert len(lo) == 3
93-
assert {o.priority for o in lo} == {1, 1, 1}
94-
9588
data = [data_row.uid for data_row in data_rows]
9689
success = project.update_data_row_labeling_priority(UniqueIds(data), 2)
9790
lo = list(project.labeling_parameter_overrides())

libs/labelbox/tests/integration/test_task_queue.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def test_move_to_task(configured_batch_project_with_label):
6868
review_queue = next(
6969
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
7070
)
71-
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
71+
project.move_data_rows_to_task_queue(
72+
UniqueIds([data_row.uid]), review_queue.uid
73+
)
7274
_validate_moved(project, "MANUAL_REVIEW_QUEUE", 1)
7375

7476
review_queue = next(

libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@
66
from labelbox.schema.project import validate_labeling_parameter_overrides
77

88

9-
def test_validate_labeling_parameter_overrides_valid_data():
10-
mock_data_row = MagicMock(spec=DataRow)
11-
mock_data_row.uid = "abc"
12-
data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
13-
validate_labeling_parameter_overrides(data)
14-
15-
16-
def test_validate_labeling_parameter_overrides_invalid_data():
17-
data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
18-
with pytest.raises(TypeError):
19-
validate_labeling_parameter_overrides(data)
20-
21-
229
def test_validate_labeling_parameter_overrides_invalid_priority():
2310
mock_data_row = MagicMock(spec=DataRow)
2411
mock_data_row.uid = "abc"

0 commit comments

Comments
 (0)