Skip to content

Commit 83555cb

Browse files
committed
sdk for label imports
1 parent f75d1e3 commit 83555cb

File tree

9 files changed

+279
-37
lines changed

9 files changed

+279
-37
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from labelbox.client import Client
66
from labelbox.schema.model import Model
77
from labelbox.schema.bulk_import_request import BulkImportRequest
8-
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport
8+
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
99
from labelbox.schema.dataset import Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label

labelbox/schema/annotation_import.py

Lines changed: 155 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str,
330330
"""
331331
if os.path.exists(path):
332332
with open(path, 'rb') as f:
333-
return cls._create_mea_import_from_bytes(
333+
return cls._create_mal_import_from_bytes(
334334
client, project_id, name, f,
335335
os.stat(path).st_size)
336336
else:
@@ -355,7 +355,7 @@ def create_from_objects(
355355
if not data_str:
356356
raise ValueError('annotations cannot be empty')
357357
data = data_str.encode('utf-8')
358-
return cls._create_mea_import_from_bytes(client, project_id, name, data,
358+
return cls._create_mal_import_from_bytes(client, project_id, name, data,
359359
len(data))
360360

361361
@classmethod
@@ -440,7 +440,7 @@ def _get_file_mutation(cls) -> str:
440440
}""" % query.results_query_part(cls)
441441

442442
@classmethod
443-
def _create_mea_import_from_bytes(
443+
def _create_mal_import_from_bytes(
444444
cls, client: "labelbox.Client", project_id: str, name: str,
445445
bytes_data: BinaryIO, content_len: int) -> "MALPredictionImport":
446446
file_name = f"{project_id}__{name}.ndjson"
@@ -454,3 +454,155 @@ def _create_mea_import_from_bytes(
454454
res = cls._create_from_bytes(client, variables, query_str, file_name,
455455
bytes_data)
456456
return cls(client, res["createModelAssistedLabelingPredictionImport"])
457+
458+
459+
class LabelImport(AnnotationImport):
460+
project = Relationship.ToOne("Project", cache=True)
461+
462+
@property
463+
def parent_id(self) -> str:
464+
"""
465+
Identifier for this import. Used to refresh the status
466+
"""
467+
return self.project().uid
468+
469+
@classmethod
470+
def create_from_file(cls, client: "labelbox.Client", project_id: str,
471+
name: str, path: str) -> "LabelImport":
472+
"""
473+
Create a label import job from a file of annotations
474+
475+
Args:
476+
client: Labelbox Client for executing queries
477+
project_id: Project to import labels into
478+
name: Name of the import job. Can be used to reference the task later
479+
path: Path to ndjson file containing annotations
480+
Returns:
481+
LabelImport
482+
"""
483+
if os.path.exists(path):
484+
with open(path, 'rb') as f:
485+
return cls._create_label_import_from_bytes(
486+
client, project_id, name, f,
487+
os.stat(path).st_size)
488+
else:
489+
raise ValueError(f"File {path} is not accessible")
490+
491+
@classmethod
492+
def create_from_objects(
493+
cls, client: "labelbox.Client", project_id: str, name: str,
494+
labels: List[Dict[str, Any]]) -> "LabelImport":
495+
"""
496+
Create an label import job from an in memory dictionary
497+
498+
Args:
499+
client: Labelbox Client for executing queries
500+
project_id: Project to import labels into
501+
name: Name of the import job. Can be used to reference the task later
502+
labels: List of labels
503+
Returns:
504+
LabelImport
505+
"""
506+
data_str = ndjson.dumps(labels)
507+
if not data_str:
508+
raise ValueError('labels cannot be empty')
509+
data = data_str.encode('utf-8')
510+
return cls._create_label_import_from_bytes(client, project_id, name, data,
511+
len(data))
512+
513+
@classmethod
514+
def create_from_url(cls, client: "labelbox.Client", project_id: str,
515+
name: str, url: str) -> "LabelImport":
516+
"""
517+
Create an label annotation import job from a url
518+
The url must point to a file containing label annotations.
519+
520+
Args:
521+
client: Labelbox Client for executing queries
522+
project_id: Project to import labels into
523+
name: Name of the import job. Can be used to reference the task later
524+
url: Url pointing to file to upload
525+
Returns:
526+
LabelImport
527+
"""
528+
if requests.head(url):
529+
query_str = cls._get_url_mutation()
530+
return cls(
531+
client,
532+
client.execute(
533+
query_str,
534+
params={
535+
"fileUrl": url,
536+
"projectId": project_id,
537+
'name': name
538+
})["createLabelImport"])
539+
else:
540+
raise ValueError(f"Url {url} is not reachable")
541+
542+
@classmethod
543+
def from_name(cls,
544+
client: "labelbox.Client",
545+
project_id: str,
546+
name: str,
547+
as_json: bool = False) -> "LabelImport":
548+
"""
549+
Retrieves an label import job.
550+
551+
Args:
552+
client: Labelbox Client for executing queries
553+
project_id: ID used for querying import jobs
554+
name: Name of the import job.
555+
Returns:
556+
LabelImport
557+
"""
558+
query_str = """query getLabelImportPyApi($projectId : ID!, $name: String!) {
559+
labelImport(
560+
where: {projectId: $projectId, name: $name}){
561+
%s
562+
}}""" % query.results_query_part(cls)
563+
params = {
564+
"projectId": project_id,
565+
"name": name,
566+
}
567+
response = client.execute(query_str, params)
568+
if response is None:
569+
raise labelbox.exceptions.ResourceNotFoundError(
570+
LabelImport, params)
571+
response = response["labelImport"]
572+
if as_json:
573+
return response
574+
return cls(client, response)
575+
576+
@classmethod
577+
def _get_url_mutation(cls) -> str:
578+
return """mutation createLabelImportPyApi($projectId : ID!, $name: String!, $fileUrl: String!) {
579+
createLabelImport(data: {
580+
projectId: $projectId
581+
name: $name
582+
fileUrl: $fileUrl
583+
}) {%s}
584+
}""" % query.results_query_part(cls)
585+
586+
@classmethod
587+
def _get_file_mutation(cls) -> str:
588+
return """mutation createLabelImportPyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
589+
createLabelImport(data: {
590+
projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength}
591+
}) {%s}
592+
}""" % query.results_query_part(cls)
593+
594+
@classmethod
595+
def _create_label_import_from_bytes(
596+
cls, client: "labelbox.Client", project_id: str, name: str,
597+
bytes_data: BinaryIO, content_len: int) -> "LabelImport":
598+
file_name = f"{project_id}__{name}.ndjson"
599+
variables = {
600+
"file": None,
601+
"contentLength": content_len,
602+
"projectId": project_id,
603+
"name": name
604+
}
605+
query_str = cls._get_file_mutation()
606+
res = cls._create_from_bytes(client, variables, query_str, file_name,
607+
bytes_data)
608+
return cls(client, res["createLabelImport"])

tests/integration/mal_and_mea/conftest.py renamed to tests/integration/annotation_import/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import pytest
44
import time
5+
import requests
6+
import ndjson
57

8+
from typing import Type
69
from labelbox.schema.labeling_frontend import LabelingFrontend
7-
from labelbox.schema.annotation_import import MALPredictionImport
10+
from labelbox.schema.annotation_import import MALPredictionImport, AnnotationImportState
811

912

1013
@pytest.fixture
@@ -277,6 +280,7 @@ def model_run_predictions(polygon_inference, rectangle_inference,
277280
return [polygon_inference, rectangle_inference, line_inference]
278281

279282

283+
# also used for label imports
280284
@pytest.fixture
281285
def object_predictions(polygon_inference, rectangle_inference, line_inference,
282286
entity_inference, segmentation_inference):
@@ -339,3 +343,25 @@ def model_run_annotation_groups(client, configured_project,
339343
time.sleep(3)
340344
yield model_run
341345
# TODO: Delete resources when that is possible ..
346+
347+
348+
class AnnotationImportTestHelpers:
349+
@staticmethod
350+
def assert_file_content(url: str, predictions):
351+
response = requests.get(url)
352+
assert response.text == ndjson.dumps(predictions)
353+
354+
@staticmethod
355+
def check_running_state(req, name, url=None):
356+
assert req.name == name
357+
if url is not None:
358+
assert req.input_file_url == url
359+
assert req.error_file_url is None
360+
assert req.status_file_url is None
361+
assert req.state == AnnotationImportState.RUNNING
362+
363+
364+
@pytest.fixture
365+
def annotation_import_test_helpers() -> Type[AnnotationImportTestHelpers]:
366+
return AnnotationImportTestHelpers()
367+

tests/integration/mal_and_mea/test_bulk_import_request.py renamed to tests/integration/annotation_import/test_bulk_import_request.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_validate_file(client, configured_project):
3939
#Schema ids shouldn't match
4040

4141

42-
def test_create_from_objects(configured_project, predictions):
42+
def test_create_from_objects(configured_project, predictions, annotation_import_test_helpers):
4343
name = str(uuid.uuid4())
4444

4545
bulk_import_request = configured_project.upload_annotations(
@@ -50,7 +50,7 @@ def test_create_from_objects(configured_project, predictions):
5050
assert bulk_import_request.error_file_url is None
5151
assert bulk_import_request.status_file_url is None
5252
assert bulk_import_request.state == BulkImportRequestState.RUNNING
53-
assert_file_content(bulk_import_request.input_file_url, predictions)
53+
annotation_import_test_helpers.assert_file_content(bulk_import_request.input_file_url, predictions)
5454

5555

5656
def test_create_from_local_file(tmp_path, predictions, configured_project):
@@ -68,7 +68,7 @@ def test_create_from_local_file(tmp_path, predictions, configured_project):
6868
assert bulk_import_request.error_file_url is None
6969
assert bulk_import_request.status_file_url is None
7070
assert bulk_import_request.state == BulkImportRequestState.RUNNING
71-
assert_file_content(bulk_import_request.input_file_url, predictions)
71+
annotation_import_test_helpers.assert_file_content(bulk_import_request.input_file_url, predictions)
7272

7373

7474
def test_get(client, configured_project):
@@ -144,11 +144,6 @@ def test_wait_till_done(rectangle_inference, configured_project):
144144
'uuid']
145145

146146

147-
def assert_file_content(url: str, predictions):
148-
response = requests.get(url)
149-
assert response.text == ndjson.dumps(predictions)
150-
151-
152147
def test_project_bulk_import_requests(client, configured_project, predictions):
153148
result = configured_project.bulk_import_requests()
154149
assert len(list(result)) == 0
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import uuid
2+
import ndjson
3+
import pytest
4+
import requests
5+
6+
from labelbox.schema.annotation_import import AnnotationImportState, LabelImport
7+
"""
8+
- Here we only want to check that the uploads are calling the validation
9+
- Then with unit tests we can check the types of errors raised
10+
11+
"""
12+
13+
14+
def test_create_from_url(client, project, annotation_import_test_helpers):
15+
name = str(uuid.uuid4())
16+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
17+
label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=url)
18+
assert label_import.parent_id == project.uid
19+
annotation_import_test_helpers.check_running_state(label_import, name, url)
20+
21+
22+
def test_create_from_objects(client, project, object_predictions, annotation_import_test_helpers):
23+
name = str(uuid.uuid4())
24+
25+
label_import = LabelImport.create_from_objects(client=client, project_id=project.uid, name=name, labels=object_predictions)
26+
assert label_import.parent_id == project.uid
27+
annotation_import_test_helpers.check_running_state(label_import, name)
28+
annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions)
29+
30+
31+
# TODO: add me when we add this ability
32+
# def test_create_from_local_file(client, tmp_path, project,
33+
# object_predictions, annotation_import_test_helpers):
34+
# name = str(uuid.uuid4())
35+
# file_name = f"{name}.ndjson"
36+
# file_path = tmp_path / file_name
37+
# with file_path.open("w") as f:
38+
# ndjson.dump(object_predictions, f)
39+
40+
# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path))
41+
42+
# assert label_import.parent_id == project.uid
43+
# annotation_import_test_helpers.check_running_state(label_import, name)
44+
# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions)
45+
46+
47+
def test_get(client, project, annotation_import_test_helpers):
48+
name = str(uuid.uuid4())
49+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
50+
51+
label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=url)
52+
53+
assert label_import.parent_id == project.uid
54+
annotation_import_test_helpers.check_running_state(label_import, name, url)
55+
56+
57+
@pytest.mark.slow
58+
def test_wait_till_done(client, project, model_run_predictions):
59+
name = str(uuid.uuid4())
60+
label_import = LabelImport.create_from_objects(client=client, project_id=project.uid, name=name, labels=model_run_predictions)
61+
62+
assert len(label_import.inputs) == len(model_run_predictions)
63+
label_import.wait_until_done()
64+
# TODO(grant): some of this is commented out
65+
# TODO(grant): since the pipeline is not complete, you will get a failed status
66+
67+
# assert label_import.state == AnnotationImportState.FINISHED
68+
# # Check that the status files are being returned as expected
69+
# assert len(label_import.errors) == 0
70+
assert len(label_import.inputs) == len(model_run_predictions)
71+
input_uuids = [
72+
input_annot['uuid'] for input_annot in label_import.inputs
73+
]
74+
inference_uuids = [pred['uuid'] for pred in model_run_predictions]
75+
assert set(input_uuids) == set(inference_uuids)
76+
assert len(label_import.statuses) == len(model_run_predictions)
77+
# for status in label_import.statuses:
78+
# assert status['status'] == 'SUCCESS'
79+
status_uuids = [
80+
input_annot['uuid'] for input_annot in label_import.statuses
81+
]
82+
assert set(input_uuids) == set(status_uuids)
83+

0 commit comments

Comments
 (0)