diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index 497ac899d..c12d3b250 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -587,6 +587,20 @@ def parent_id(self) -> str: """ return self.project().uid + def delete(self) -> None: + """ + Deletes a MALPredictionImport job + """ + + query_string = """ + mutation deleteModelAssistedLabelingPredictionImportPyApi($id: ID!) { + deleteModelAssistedLabelingPredictionImport(where: { id: $id }) { + id + } + } + """ + self.client.execute(query_string, {"id": self.uid}) + @classmethod def create_from_file( cls, client: "labelbox.Client", project_id: str, name: str, path: str diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 2589e6ab1..1916d3e8c 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -16,6 +16,7 @@ get_args, ) +from labelbox.schema.annotation_import import LabelImport, MALPredictionImport from lbox.exceptions import ( InvalidQueryError, LabelboxError, @@ -710,6 +711,56 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict): }, ) + def get_mal_prediction_imports(self) -> PaginatedCollection: + """Returns mal prediction import objects which are used in model-assisted labeling associated with the project. + + Returns: + PaginatedCollection + """ + + id_param = "projectId" + query_str = """ + query getModelAssistedLabelingPredictionImportsPyApi($%s: ID!) { + modelAssistedLabelingPredictionImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }} + """ % ( + id_param, + id_param, + query.results_query_part(MALPredictionImport), + ) + + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["modelAssistedLabelingPredictionImports"], + MALPredictionImport, + ) + + def get_label_imports(self) -> PaginatedCollection: + """Returns label import objects associated with the project. + + Returns: + PaginatedCollection + """ + + id_param = "projectId" + query_str = """ + query getLabelImportsPyApi($%s: ID!) { + labelImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }} + """ % ( + id_param, + id_param, + query.results_query_part(LabelImport), + ) + + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["labelImports"], + LabelImport, + ) + def create_batch( self, name: str, diff --git a/libs/labelbox/tests/data/annotation_import/test_label_import.py b/libs/labelbox/tests/data/annotation_import/test_label_import.py index 5576025fd..87741c1b7 100644 --- a/libs/labelbox/tests/data/annotation_import/test_label_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_label_import.py @@ -136,6 +136,19 @@ def test_get(client, module_project, annotation_import_test_helpers): annotation_import_test_helpers.check_running_state(label_import, name, url) +def test_get_import_jobs_from_project(client, configured_project): + name = str(uuid.uuid4()) + url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" + label_import = LabelImport.create_from_url( + client=client, project_id=configured_project.uid, name=name, url=url + ) + label_import.wait_until_done() + + label_imports = list(configured_project.get_label_imports()) + assert len(label_imports) == 1 + assert label_imports[0].input_file_url == url + + @pytest.mark.slow def test_wait_till_done(client, module_project, predictions): name = str(uuid.uuid4()) diff --git a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py index 3ffd6bfc1..fe1634ba2 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py @@ -65,3 +65,20 @@ def test_create_with_path_arg( annotation_import_test_helpers.assert_file_content( label_import.input_file_url, object_predictions ) + + +def test_get_mal_import_jobs_from_project(client, configured_project): + name = str(uuid.uuid4()) + url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" + label_import = MALPredictionImport.create( + client=client, id=configured_project.uid, name=name, url=url + ) + label_import.wait_until_done() + + label_imports = list(configured_project.get_mal_prediction_imports()) + assert len(label_imports) == 1 + assert label_imports[0].input_file_url == url + + label_imports[0].delete() + label_imports = list(configured_project.get_mal_prediction_imports()) + assert len(label_imports) == 0