Skip to content

Commit b5b6e38

Browse files
authored
[PLT-1677] Added get mal import functions to replace old BulkImportRequest class (#1909)
1 parent 11b4964 commit b5b6e38

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

libs/labelbox/src/labelbox/schema/annotation_import.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,20 @@ def parent_id(self) -> str:
587587
"""
588588
return self.project().uid
589589

590+
def delete(self) -> None:
591+
"""
592+
Deletes a MALPredictionImport job
593+
"""
594+
595+
query_string = """
596+
mutation deleteModelAssistedLabelingPredictionImportPyApi($id: ID!) {
597+
deleteModelAssistedLabelingPredictionImport(where: { id: $id }) {
598+
id
599+
}
600+
}
601+
"""
602+
self.client.execute(query_string, {"id": self.uid})
603+
590604
@classmethod
591605
def create_from_file(
592606
cls, client: "labelbox.Client", project_id: str, name: str, path: str

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_args,
1717
)
1818

19+
from labelbox.schema.annotation_import import LabelImport, MALPredictionImport
1920
from lbox.exceptions import (
2021
InvalidQueryError,
2122
LabelboxError,
@@ -710,6 +711,56 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict):
710711
},
711712
)
712713

714+
def get_mal_prediction_imports(self) -> PaginatedCollection:
715+
"""Returns mal prediction import objects which are used in model-assisted labeling associated with the project.
716+
717+
Returns:
718+
PaginatedCollection
719+
"""
720+
721+
id_param = "projectId"
722+
query_str = """
723+
query getModelAssistedLabelingPredictionImportsPyApi($%s: ID!) {
724+
modelAssistedLabelingPredictionImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }}
725+
""" % (
726+
id_param,
727+
id_param,
728+
query.results_query_part(MALPredictionImport),
729+
)
730+
731+
return PaginatedCollection(
732+
self.client,
733+
query_str,
734+
{id_param: self.uid},
735+
["modelAssistedLabelingPredictionImports"],
736+
MALPredictionImport,
737+
)
738+
739+
def get_label_imports(self) -> PaginatedCollection:
740+
"""Returns label import objects associated with the project.
741+
742+
Returns:
743+
PaginatedCollection
744+
"""
745+
746+
id_param = "projectId"
747+
query_str = """
748+
query getLabelImportsPyApi($%s: ID!) {
749+
labelImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }}
750+
""" % (
751+
id_param,
752+
id_param,
753+
query.results_query_part(LabelImport),
754+
)
755+
756+
return PaginatedCollection(
757+
self.client,
758+
query_str,
759+
{id_param: self.uid},
760+
["labelImports"],
761+
LabelImport,
762+
)
763+
713764
def create_batch(
714765
self,
715766
name: str,

libs/labelbox/tests/data/annotation_import/test_label_import.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ def test_get(client, module_project, annotation_import_test_helpers):
136136
annotation_import_test_helpers.check_running_state(label_import, name, url)
137137

138138

139+
def test_get_import_jobs_from_project(client, configured_project):
140+
name = str(uuid.uuid4())
141+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
142+
label_import = LabelImport.create_from_url(
143+
client=client, project_id=configured_project.uid, name=name, url=url
144+
)
145+
label_import.wait_until_done()
146+
147+
label_imports = list(configured_project.get_label_imports())
148+
assert len(label_imports) == 1
149+
assert label_imports[0].input_file_url == url
150+
151+
139152
@pytest.mark.slow
140153
def test_wait_till_done(client, module_project, predictions):
141154
name = str(uuid.uuid4())

libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,20 @@ def test_create_with_path_arg(
6565
annotation_import_test_helpers.assert_file_content(
6666
label_import.input_file_url, object_predictions
6767
)
68+
69+
70+
def test_get_mal_import_jobs_from_project(client, configured_project):
71+
name = str(uuid.uuid4())
72+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
73+
label_import = MALPredictionImport.create(
74+
client=client, id=configured_project.uid, name=name, url=url
75+
)
76+
label_import.wait_until_done()
77+
78+
label_imports = list(configured_project.get_mal_prediction_imports())
79+
assert len(label_imports) == 1
80+
assert label_imports[0].input_file_url == url
81+
82+
label_imports[0].delete()
83+
label_imports = list(configured_project.get_mal_prediction_imports())
84+
assert len(label_imports) == 0

0 commit comments

Comments
 (0)