Skip to content

[PLT-1490] Removed data row ids list on some project methods and removed get queue_modes #1852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 23 additions & 156 deletions libs/labelbox/src/labelbox/schema/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Optional,
Tuple,
Union,
overload,
get_args,
)

from lbox.exceptions import (
Expand All @@ -40,7 +40,11 @@
from labelbox.schema.export_task import ExportTask
from labelbox.schema.id_type import IdType
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
from labelbox.schema.identifiables import (
DataRowIdentifiers,
GlobalKeys,
UniqueIds,
)
from labelbox.schema.labeling_service import (
LabelingService,
LabelingServiceStatus,
Expand All @@ -67,9 +71,7 @@


DataRowPriority = int
LabelingParameterOverrideInput = Tuple[
Union[DataRow, DataRowIdentifier], DataRowPriority
]
LabelingParameterOverrideInput = Tuple[DataRowIdentifier, DataRowPriority]

logger = logging.getLogger(__name__)
MAX_SYNC_BATCH_ROW_COUNT = 1_000
Expand All @@ -79,23 +81,18 @@ def validate_labeling_parameter_overrides(
data: List[LabelingParameterOverrideInput],
) -> None:
for idx, row in enumerate(data):
if len(row) < 2:
raise TypeError(
f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
)
data_row_identifier = row[0]
priority = row[1]
valid_types = (Entity.DataRow, UniqueId, GlobalKey)
if not isinstance(data_row_identifier, valid_types):
if not isinstance(data_row_identifier, get_args(DataRowIdentifier)):
raise TypeError(
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
f"Data row identifier should be of type DataRowIdentifier. Found {type(data_row_identifier)}."
)
if len(row) < 2:
raise TypeError(
f"Data must be a list of tuples each containing two elements: a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
)

if not isinstance(priority, int):
if isinstance(data_row_identifier, Entity.DataRow):
id = data_row_identifier.uid
else:
id = data_row_identifier
id = data_row_identifier.key
raise TypeError(
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
)
Expand Down Expand Up @@ -1046,57 +1043,6 @@ def _create_batch_async(

return self.client.get_batch(self.uid, batch_id)

def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
"""
Updates the queueing mode of this project.

Deprecation notice: This method is deprecated. Going forward, projects must
go through a migration to have the queue mode changed. Users should specify the
queue mode for a project during creation if a non-default mode is desired.

For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes

Args:
mode: the specified queue mode

Returns: the updated queueing mode of this project

"""

logger.warning(
"Updating the queue_mode for a project will soon no longer be supported."
)

if self.queue_mode == mode:
return mode

if mode == QueueMode.Batch:
status = "ENABLED"
elif mode == QueueMode.Dataset:
status = "DISABLED"
else:
raise ValueError(
"Must provide either `BATCH` or `DATASET` as a mode"
)

query_str = (
"""mutation %s($projectId: ID!, $status: TagSetStatusInput!) {
project(where: {id: $projectId}) {
setTagSetStatus(input: {tagSetStatus: $status}) {
tagSetStatus
}
}
}
"""
% "setTagSetStatusPyApi"
)

self.client.execute(
query_str, {"projectId": self.uid, "status": status}
)

return mode

def get_label_count(self) -> int:
"""
Returns: the total number of labels in this project.
Expand All @@ -1111,46 +1057,6 @@ def get_label_count(self) -> int:
res = self.client.execute(query_str, {"projectId": self.uid})
return res["project"]["labelCount"]

def get_queue_mode(self) -> "QueueMode":
"""
Provides the queue mode used for this project.

Deprecation notice: This method is deprecated and will be removed in
a future version. To obtain the queue mode of a project, simply refer
to the queue_mode attribute of a Project.

For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes

Returns: the QueueMode for this project

"""

logger.warning(
"Obtaining the queue_mode for a project through this method will soon"
" no longer be supported."
)

query_str = (
"""query %s($projectId: ID!) {
project(where: {id: $projectId}) {
tagSetStatus
}
}
"""
% "GetTagSetStatusPyApi"
)

status = self.client.execute(query_str, {"projectId": self.uid})[
"project"
]["tagSetStatus"]

if status == "ENABLED":
return QueueMode.Batch
elif status == "DISABLED":
return QueueMode.Dataset
else:
raise ValueError("Status not known")

def add_model_config(self, model_config_id: str) -> str:
"""Adds a model config to this project.

Expand Down Expand Up @@ -1243,18 +1149,13 @@ def set_labeling_parameter_overrides(
See information on priority here:
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system

>>> project.set_labeling_parameter_overrides([
>>> (data_row_id1, 2), (data_row_id2, 1)])
or
>>> project.set_labeling_parameter_overrides([
>>> (data_row_gk1, 2), (data_row_gk2, 1)])

Args:
data (iterable): An iterable of tuples. Each tuple must contain
either (DataRow, DataRowPriority<int>)
or (DataRowIdentifier, priority<int>) for the new override.
(DataRowIdentifier, priority<int>) for the new override.
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.

Priority:
* Data will be labeled in priority order.
Expand Down Expand Up @@ -1283,43 +1184,18 @@ def set_labeling_parameter_overrides(

data_rows_with_identifiers = ""
for data_row, priority in data:
if isinstance(data_row, DataRow):
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},'
elif isinstance(data_row, UniqueId) or isinstance(
data_row, GlobalKey
):
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'
else:
raise TypeError(
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
)
data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'

query_str = template.substitute(
dataWithDataRowIdentifiers=data_rows_with_identifiers
)
res = self.client.execute(query_str, {"projectId": self.uid})
return res["project"]["setLabelingParameterOverrides"]["success"]

@overload
def update_data_row_labeling_priority(
self,
data_rows: DataRowIdentifiers,
priority: int,
) -> bool:
pass

@overload
def update_data_row_labeling_priority(
self,
data_rows: List[str],
priority: int,
) -> bool:
pass

def update_data_row_labeling_priority(
self,
data_rows,
priority: int,
) -> bool:
"""
Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be
Expand All @@ -1329,16 +1205,16 @@ def update_data_row_labeling_priority(
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system

Args:
data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object
data_rows: data row identifiers object to update priorities.
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
priority (int): Priority for the new override. See above for more information.

Returns:
bool, indicates if the operation was a success.
"""

if isinstance(data_rows, list):
data_rows = UniqueIds(data_rows)
if not isinstance(data_rows, get_args(DataRowIdentifiers)):
raise TypeError("data_rows must be a DataRowIdentifiers object")

method = "createQueuePriorityUpdateTask"
priority_param = "priority"
Expand Down Expand Up @@ -1481,34 +1357,25 @@ def task_queues(self) -> List[TaskQueue]:
for field_values in task_queue_values
]

@overload
def move_data_rows_to_task_queue(
self, data_row_ids: DataRowIdentifiers, task_queue_id: str
):
pass

@overload
def move_data_rows_to_task_queue(
self, data_row_ids: List[str], task_queue_id: str
):
pass

def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
"""

Moves data rows to the specified task queue.

Args:
data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object
data_row_ids: a list of data row ids to be moved. This should be a DataRowIdentifiers object
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue

Returns:
None if successful, or a raised error on failure

"""
if isinstance(data_row_ids, list):
data_row_ids = UniqueIds(data_row_ids)

if not isinstance(data_row_ids, get_args(DataRowIdentifiers)):
raise TypeError("data_rows must be a DataRowIdentifiers object")

method = "createBulkAddRowsToQueueTask"
query_str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from labelbox import Project, Dataset
from labelbox.schema.data_row import DataRow
from labelbox.schema.label import Label
from labelbox import UniqueIds

IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg"

Expand Down Expand Up @@ -128,7 +129,9 @@ def test_with_date_filters(
review_queue = next(
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
)
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
project.move_data_rows_to_task_queue(
UniqueIds([data_row.uid]), review_queue.uid
)
export_task = project_export(
project, task_name, filters=filters, params=params
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
data_rows[2].uid,
}

data = [(data_rows[0], 4, 2), (data_rows[1], 3)]
data = [(UniqueId(data_rows[0].uid), 4, 2), (UniqueId(data_rows[1].uid), 3)]
success = project.set_labeling_parameter_overrides(data)
assert success

Expand Down Expand Up @@ -60,7 +60,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
assert {o.priority for o in updated_overrides} == {2, 3, 4}

with pytest.raises(TypeError) as exc_info:
data = [(data_rows[2], "a_string", 3)]
data = [(UniqueId(data_rows[2].uid), "a_string", 3)]
project.set_labeling_parameter_overrides(data)
assert (
str(exc_info.value)
Expand All @@ -72,7 +72,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
project.set_labeling_parameter_overrides(data)
assert (
str(exc_info.value)
== f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
== "Data row identifier should be of type DataRowIdentifier. Found <class 'str'>."
)


Expand All @@ -85,13 +85,6 @@ def test_set_labeling_priority(consensus_project_with_batch):
assert len(init_labeling_parameter_overrides) == 3
assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5}

data = [data_row.uid for data_row in data_rows]
success = project.update_data_row_labeling_priority(data, 1)
lo = list(project.labeling_parameter_overrides())
assert success
assert len(lo) == 3
assert {o.priority for o in lo} == {1, 1, 1}

data = [data_row.uid for data_row in data_rows]
success = project.update_data_row_labeling_priority(UniqueIds(data), 2)
lo = list(project.labeling_parameter_overrides())
Expand Down
4 changes: 3 additions & 1 deletion libs/labelbox/tests/integration/test_task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_move_to_task(configured_batch_project_with_label):
review_queue = next(
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
)
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
project.move_data_rows_to_task_queue(
UniqueIds([data_row.uid]), review_queue.uid
)
_validate_moved(project, "MANUAL_REVIEW_QUEUE", 1)

review_queue = next(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@
from labelbox.schema.project import validate_labeling_parameter_overrides


def test_validate_labeling_parameter_overrides_valid_data():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed anymore since you only can use DataRowIdentifiers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No overrides

mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
validate_labeling_parameter_overrides(data)


def test_validate_labeling_parameter_overrides_invalid_data():
data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
with pytest.raises(TypeError):
validate_labeling_parameter_overrides(data)


def test_validate_labeling_parameter_overrides_invalid_priority():
mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
Expand Down
Loading