diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 3e2cca9d6..0daf3af10 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -13,7 +13,7 @@ Optional, Tuple, Union, - overload, + get_args, ) from lbox.exceptions import ( @@ -40,7 +40,11 @@ from labelbox.schema.export_task import ExportTask from labelbox.schema.id_type import IdType from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId -from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds +from labelbox.schema.identifiables import ( + DataRowIdentifiers, + GlobalKeys, + UniqueIds, +) from labelbox.schema.labeling_service import ( LabelingService, LabelingServiceStatus, @@ -67,9 +71,7 @@ DataRowPriority = int -LabelingParameterOverrideInput = Tuple[ - Union[DataRow, DataRowIdentifier], DataRowPriority -] +LabelingParameterOverrideInput = Tuple[DataRowIdentifier, DataRowPriority] logger = logging.getLogger(__name__) MAX_SYNC_BATCH_ROW_COUNT = 1_000 @@ -79,23 +81,18 @@ def validate_labeling_parameter_overrides( data: List[LabelingParameterOverrideInput], ) -> None: for idx, row in enumerate(data): - if len(row) < 2: - raise TypeError( - f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" - ) data_row_identifier = row[0] priority = row[1] - valid_types = (Entity.DataRow, UniqueId, GlobalKey) - if not isinstance(data_row_identifier, valid_types): + if not isinstance(data_row_identifier, get_args(DataRowIdentifier)): raise TypeError( - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}" + f"Data row identifier should be of type DataRowIdentifier. Found {type(data_row_identifier)}." + ) + if len(row) < 2: + raise TypeError( + f"Data must be a list of tuples each containing two elements: a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" ) - if not isinstance(priority, int): - if isinstance(data_row_identifier, Entity.DataRow): - id = data_row_identifier.uid - else: - id = data_row_identifier + id = data_row_identifier.key raise TypeError( f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}" ) @@ -1046,57 +1043,6 @@ def _create_batch_async( return self.client.get_batch(self.uid, batch_id) - def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": - """ - Updates the queueing mode of this project. - - Deprecation notice: This method is deprecated. Going forward, projects must - go through a migration to have the queue mode changed. Users should specify the - queue mode for a project during creation if a non-default mode is desired. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Args: - mode: the specified queue mode - - Returns: the updated queueing mode of this project - - """ - - logger.warning( - "Updating the queue_mode for a project will soon no longer be supported." - ) - - if self.queue_mode == mode: - return mode - - if mode == QueueMode.Batch: - status = "ENABLED" - elif mode == QueueMode.Dataset: - status = "DISABLED" - else: - raise ValueError( - "Must provide either `BATCH` or `DATASET` as a mode" - ) - - query_str = ( - """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { - project(where: {id: $projectId}) { - setTagSetStatus(input: {tagSetStatus: $status}) { - tagSetStatus - } - } - } - """ - % "setTagSetStatusPyApi" - ) - - self.client.execute( - query_str, {"projectId": self.uid, "status": status} - ) - - return mode - def get_label_count(self) -> int: """ Returns: the total number of labels in this project. @@ -1111,46 +1057,6 @@ def get_label_count(self) -> int: res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["labelCount"] - def get_queue_mode(self) -> "QueueMode": - """ - Provides the queue mode used for this project. - - Deprecation notice: This method is deprecated and will be removed in - a future version. To obtain the queue mode of a project, simply refer - to the queue_mode attribute of a Project. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Returns: the QueueMode for this project - - """ - - logger.warning( - "Obtaining the queue_mode for a project through this method will soon" - " no longer be supported." - ) - - query_str = ( - """query %s($projectId: ID!) { - project(where: {id: $projectId}) { - tagSetStatus - } - } - """ - % "GetTagSetStatusPyApi" - ) - - status = self.client.execute(query_str, {"projectId": self.uid})[ - "project" - ]["tagSetStatus"] - - if status == "ENABLED": - return QueueMode.Batch - elif status == "DISABLED": - return QueueMode.Dataset - else: - raise ValueError("Status not known") - def add_model_config(self, model_config_id: str) -> str: """Adds a model config to this project. @@ -1243,18 +1149,13 @@ def set_labeling_parameter_overrides( See information on priority here: https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system - >>> project.set_labeling_parameter_overrides([ - >>> (data_row_id1, 2), (data_row_id2, 1)]) - or >>> project.set_labeling_parameter_overrides([ >>> (data_row_gk1, 2), (data_row_gk2, 1)]) Args: data (iterable): An iterable of tuples. Each tuple must contain - either (DataRow, DataRowPriority) - or (DataRowIdentifier, priority) for the new override. + (DataRowIdentifier, priority) for the new override. DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead. Priority: * Data will be labeled in priority order. @@ -1283,16 +1184,7 @@ def set_labeling_parameter_overrides( data_rows_with_identifiers = "" for data_row, priority in data: - if isinstance(data_row, DataRow): - data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},' - elif isinstance(data_row, UniqueId) or isinstance( - data_row, GlobalKey - ): - data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' - else: - raise TypeError( - f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." - ) + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' query_str = template.substitute( dataWithDataRowIdentifiers=data_rows_with_identifiers @@ -1300,26 +1192,10 @@ def set_labeling_parameter_overrides( res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] - @overload def update_data_row_labeling_priority( self, data_rows: DataRowIdentifiers, priority: int, - ) -> bool: - pass - - @overload - def update_data_row_labeling_priority( - self, - data_rows: List[str], - priority: int, - ) -> bool: - pass - - def update_data_row_labeling_priority( - self, - data_rows, - priority: int, ) -> bool: """ Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be @@ -1329,7 +1205,7 @@ def update_data_row_labeling_priority( https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system Args: - data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object + data_rows: data row identifiers object to update priorities. DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. priority (int): Priority for the new override. See above for more information. @@ -1337,8 +1213,8 @@ def update_data_row_labeling_priority( bool, indicates if the operation was a success. """ - if isinstance(data_rows, list): - data_rows = UniqueIds(data_rows) + if not isinstance(data_rows, get_args(DataRowIdentifiers)): + raise TypeError("data_rows must be a DataRowIdentifiers object") method = "createQueuePriorityUpdateTask" priority_param = "priority" @@ -1481,25 +1357,15 @@ def task_queues(self) -> List[TaskQueue]: for field_values in task_queue_values ] - @overload def move_data_rows_to_task_queue( self, data_row_ids: DataRowIdentifiers, task_queue_id: str ): - pass - - @overload - def move_data_rows_to_task_queue( - self, data_row_ids: List[str], task_queue_id: str - ): - pass - - def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): """ Moves data rows to the specified task queue. Args: - data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object + data_row_ids: a list of data row ids to be moved. This should be a DataRowIdentifiers object DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue @@ -1507,8 +1373,9 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): None if successful, or a raised error on failure """ - if isinstance(data_row_ids, list): - data_row_ids = UniqueIds(data_row_ids) + + if not isinstance(data_row_ids, get_args(DataRowIdentifiers)): + raise TypeError("data_rows must be a DataRowIdentifiers object") method = "createBulkAddRowsToQueueTask" query_str = ( diff --git a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py index 63423202a..597c529aa 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py @@ -9,6 +9,7 @@ from labelbox import Project, Dataset from labelbox.schema.data_row import DataRow from labelbox.schema.label import Label +from labelbox import UniqueIds IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" @@ -128,7 +129,9 @@ def test_with_date_filters( review_queue = next( tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" ) - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) export_task = project_export( project, task_name, filters=filters, params=params ) diff --git a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py index bd14040de..afa038482 100644 --- a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py @@ -23,7 +23,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): data_rows[2].uid, } - data = [(data_rows[0], 4, 2), (data_rows[1], 3)] + data = [(UniqueId(data_rows[0].uid), 4, 2), (UniqueId(data_rows[1].uid), 3)] success = project.set_labeling_parameter_overrides(data) assert success @@ -60,7 +60,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): assert {o.priority for o in updated_overrides} == {2, 3, 4} with pytest.raises(TypeError) as exc_info: - data = [(data_rows[2], "a_string", 3)] + data = [(UniqueId(data_rows[2].uid), "a_string", 3)] project.set_labeling_parameter_overrides(data) assert ( str(exc_info.value) @@ -72,7 +72,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): project.set_labeling_parameter_overrides(data) assert ( str(exc_info.value) - == f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + == "Data row identifier should be of type DataRowIdentifier. Found ." ) @@ -85,13 +85,6 @@ def test_set_labeling_priority(consensus_project_with_batch): assert len(init_labeling_parameter_overrides) == 3 assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - data = [data_row.uid for data_row in data_rows] - success = project.update_data_row_labeling_priority(data, 1) - lo = list(project.labeling_parameter_overrides()) - assert success - assert len(lo) == 3 - assert {o.priority for o in lo} == {1, 1, 1} - data = [data_row.uid for data_row in data_rows] success = project.update_data_row_labeling_priority(UniqueIds(data), 2) lo = list(project.labeling_parameter_overrides()) diff --git a/libs/labelbox/tests/integration/test_task_queue.py b/libs/labelbox/tests/integration/test_task_queue.py index 835f67219..0cd66cb62 100644 --- a/libs/labelbox/tests/integration/test_task_queue.py +++ b/libs/labelbox/tests/integration/test_task_queue.py @@ -68,7 +68,9 @@ def test_move_to_task(configured_batch_project_with_label): review_queue = next( tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" ) - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) review_queue = next( diff --git a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py index 7f6d29d5a..8e6c38559 100644 --- a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py @@ -6,19 +6,6 @@ from labelbox.schema.project import validate_labeling_parameter_overrides -def test_validate_labeling_parameter_overrides_valid_data(): - mock_data_row = MagicMock(spec=DataRow) - mock_data_row.uid = "abc" - data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - validate_labeling_parameter_overrides(data) - - -def test_validate_labeling_parameter_overrides_invalid_data(): - data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - with pytest.raises(TypeError): - validate_labeling_parameter_overrides(data) - - def test_validate_labeling_parameter_overrides_invalid_priority(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc"