Skip to content

Commit 0736823

Browse files
authored
Merge pull request #291 from Labelbox/grant/sdk-label-imports
SDK for label imports [DIAG-604]
2 parents 59014a3 + 6bd9350 commit 0736823

File tree

9 files changed

+306
-43
lines changed

9 files changed

+306
-43
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from labelbox.client import Client
66
from labelbox.schema.model import Model
77
from labelbox.schema.bulk_import_request import BulkImportRequest
8-
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport
8+
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
99
from labelbox.schema.dataset import Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label

labelbox/schema/annotation_import.py

Lines changed: 158 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def from_name(cls,
266266

267267
@classmethod
268268
def _get_url_mutation(cls) -> str:
269-
return """mutation createMEAPredictionImportPyApi($modelRunId : ID!, $name: String!, $fileUrl: String!) {
269+
return """mutation createMEAPredictionImportByUrlPyApi($modelRunId : ID!, $name: String!, $fileUrl: String!) {
270270
createModelErrorAnalysisPredictionImport(data: {
271271
modelRunId: $modelRunId
272272
name: $name
@@ -276,7 +276,7 @@ def _get_url_mutation(cls) -> str:
276276

277277
@classmethod
278278
def _get_file_mutation(cls) -> str:
279-
return """mutation createMEAPredictionImportPyApi($modelRunId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
279+
return """mutation createMEAPredictionImportByFilePyApi($modelRunId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
280280
createModelErrorAnalysisPredictionImport(data: {
281281
modelRunId: $modelRunId name: $name filePayload: { file: $file, contentLength: $contentLength}
282282
}) {%s}
@@ -330,7 +330,7 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str,
330330
"""
331331
if os.path.exists(path):
332332
with open(path, 'rb') as f:
333-
return cls._create_mea_import_from_bytes(
333+
return cls._create_mal_import_from_bytes(
334334
client, project_id, name, f,
335335
os.stat(path).st_size)
336336
else:
@@ -355,7 +355,7 @@ def create_from_objects(
355355
if not data_str:
356356
raise ValueError('annotations cannot be empty')
357357
data = data_str.encode('utf-8')
358-
return cls._create_mea_import_from_bytes(client, project_id, name, data,
358+
return cls._create_mal_import_from_bytes(client, project_id, name, data,
359359
len(data))
360360

361361
@classmethod
@@ -423,7 +423,7 @@ def from_name(cls,
423423

424424
@classmethod
425425
def _get_url_mutation(cls) -> str:
426-
return """mutation createMALPredictionImportPyApi($projectId : ID!, $name: String!, $fileUrl: String!) {
426+
return """mutation createMALPredictionImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) {
427427
createModelAssistedLabelingPredictionImport(data: {
428428
projectId: $projectId
429429
name: $name
@@ -433,14 +433,14 @@ def _get_url_mutation(cls) -> str:
433433

434434
@classmethod
435435
def _get_file_mutation(cls) -> str:
436-
return """mutation createMALPredictionImportPyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
436+
return """mutation createMALPredictionImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
437437
createModelAssistedLabelingPredictionImport(data: {
438438
projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength}
439439
}) {%s}
440440
}""" % query.results_query_part(cls)
441441

442442
@classmethod
443-
def _create_mea_import_from_bytes(
443+
def _create_mal_import_from_bytes(
444444
cls, client: "labelbox.Client", project_id: str, name: str,
445445
bytes_data: BinaryIO, content_len: int) -> "MALPredictionImport":
446446
file_name = f"{project_id}__{name}.ndjson"
@@ -454,3 +454,154 @@ def _create_mea_import_from_bytes(
454454
res = cls._create_from_bytes(client, variables, query_str, file_name,
455455
bytes_data)
456456
return cls(client, res["createModelAssistedLabelingPredictionImport"])
457+
458+
459+
class LabelImport(AnnotationImport):
460+
project = Relationship.ToOne("Project", cache=True)
461+
462+
@property
463+
def parent_id(self) -> str:
464+
"""
465+
Identifier for this import. Used to refresh the status
466+
"""
467+
return self.project().uid
468+
469+
@classmethod
470+
def create_from_file(cls, client: "labelbox.Client", project_id: str,
471+
name: str, path: str) -> "LabelImport":
472+
"""
473+
Create a label import job from a file of annotations
474+
475+
Args:
476+
client: Labelbox Client for executing queries
477+
project_id: Project to import labels into
478+
name: Name of the import job. Can be used to reference the task later
479+
path: Path to ndjson file containing annotations
480+
Returns:
481+
LabelImport
482+
"""
483+
if os.path.exists(path):
484+
with open(path, 'rb') as f:
485+
return cls._create_label_import_from_bytes(
486+
client, project_id, name, f,
487+
os.stat(path).st_size)
488+
else:
489+
raise ValueError(f"File {path} is not accessible")
490+
491+
@classmethod
492+
def create_from_objects(cls, client: "labelbox.Client", project_id: str,
493+
name: str,
494+
labels: List[Dict[str, Any]]) -> "LabelImport":
495+
"""
496+
Create a label import job from an in memory dictionary
497+
498+
Args:
499+
client: Labelbox Client for executing queries
500+
project_id: Project to import labels into
501+
name: Name of the import job. Can be used to reference the task later
502+
labels: List of labels
503+
Returns:
504+
LabelImport
505+
"""
506+
data_str = ndjson.dumps(labels)
507+
if not data_str:
508+
raise ValueError('labels cannot be empty')
509+
data = data_str.encode('utf-8')
510+
return cls._create_label_import_from_bytes(client, project_id, name,
511+
data, len(data))
512+
513+
@classmethod
514+
def create_from_url(cls, client: "labelbox.Client", project_id: str,
515+
name: str, url: str) -> "LabelImport":
516+
"""
517+
Create a label annotation import job from a url
518+
The url must point to a file containing label annotations.
519+
520+
Args:
521+
client: Labelbox Client for executing queries
522+
project_id: Project to import labels into
523+
name: Name of the import job. Can be used to reference the task later
524+
url: Url pointing to file to upload
525+
Returns:
526+
LabelImport
527+
"""
528+
if requests.head(url):
529+
query_str = cls._get_url_mutation()
530+
return cls(
531+
client,
532+
client.execute(query_str,
533+
params={
534+
"fileUrl": url,
535+
"projectId": project_id,
536+
'name': name
537+
})["createLabelImport"])
538+
else:
539+
raise ValueError(f"Url {url} is not reachable")
540+
541+
@classmethod
542+
def from_name(cls,
543+
client: "labelbox.Client",
544+
project_id: str,
545+
name: str,
546+
as_json: bool = False) -> "LabelImport":
547+
"""
548+
Retrieves an label import job.
549+
550+
Args:
551+
client: Labelbox Client for executing queries
552+
project_id: ID used for querying import jobs
553+
name: Name of the import job.
554+
Returns:
555+
LabelImport
556+
"""
557+
query_str = """query getLabelImportPyApi($projectId : ID!, $name: String!) {
558+
labelImport(
559+
where: {projectId: $projectId, name: $name}){
560+
%s
561+
}}""" % query.results_query_part(cls)
562+
params = {
563+
"projectId": project_id,
564+
"name": name,
565+
}
566+
response = client.execute(query_str, params)
567+
if response is None:
568+
raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params)
569+
response = response["labelImport"]
570+
if as_json:
571+
return response
572+
return cls(client, response)
573+
574+
@classmethod
575+
def _get_url_mutation(cls) -> str:
576+
return """mutation createLabelImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) {
577+
createLabelImport(data: {
578+
projectId: $projectId
579+
name: $name
580+
fileUrl: $fileUrl
581+
}) {%s}
582+
}""" % query.results_query_part(cls)
583+
584+
@classmethod
585+
def _get_file_mutation(cls) -> str:
586+
return """mutation createLabelImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
587+
createLabelImport(data: {
588+
projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength}
589+
}) {%s}
590+
}""" % query.results_query_part(cls)
591+
592+
@classmethod
593+
def _create_label_import_from_bytes(cls, client: "labelbox.Client",
594+
project_id: str, name: str,
595+
bytes_data: BinaryIO,
596+
content_len: int) -> "LabelImport":
597+
file_name = f"{project_id}__{name}.ndjson"
598+
variables = {
599+
"file": None,
600+
"contentLength": content_len,
601+
"projectId": project_id,
602+
"name": name
603+
}
604+
query_str = cls._get_file_mutation()
605+
res = cls._create_from_bytes(client, variables, query_str, file_name,
606+
bytes_data)
607+
return cls(client, res["createLabelImport"])

tests/integration/mal_and_mea/conftest.py renamed to tests/integration/annotation_import/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import pytest
44
import time
5+
import requests
6+
import ndjson
57

8+
from typing import Type
69
from labelbox.schema.labeling_frontend import LabelingFrontend
7-
from labelbox.schema.annotation_import import MALPredictionImport
10+
from labelbox.schema.annotation_import import MALPredictionImport, AnnotationImportState
811

912

1013
@pytest.fixture
@@ -277,6 +280,7 @@ def model_run_predictions(polygon_inference, rectangle_inference,
277280
return [polygon_inference, rectangle_inference, line_inference]
278281

279282

283+
# also used for label imports
280284
@pytest.fixture
281285
def object_predictions(polygon_inference, rectangle_inference, line_inference,
282286
entity_inference, segmentation_inference):
@@ -339,3 +343,25 @@ def model_run_annotation_groups(client, configured_project,
339343
time.sleep(3)
340344
yield model_run
341345
# TODO: Delete resources when that is possible ..
346+
347+
348+
class AnnotationImportTestHelpers:
349+
350+
@staticmethod
351+
def assert_file_content(url: str, predictions):
352+
response = requests.get(url)
353+
assert response.text == ndjson.dumps(predictions)
354+
355+
@staticmethod
356+
def check_running_state(req, name, url=None):
357+
assert req.name == name
358+
if url is not None:
359+
assert req.input_file_url == url
360+
assert req.error_file_url is None
361+
assert req.status_file_url is None
362+
assert req.state == AnnotationImportState.RUNNING
363+
364+
365+
@pytest.fixture
366+
def annotation_import_test_helpers() -> Type[AnnotationImportTestHelpers]:
367+
return AnnotationImportTestHelpers()

tests/integration/mal_and_mea/test_bulk_import_request.py renamed to tests/integration/annotation_import/test_bulk_import_request.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def test_validate_file(client, configured_project):
3939
#Schema ids shouldn't match
4040

4141

42-
def test_create_from_objects(configured_project, predictions):
42+
def test_create_from_objects(configured_project, predictions,
43+
annotation_import_test_helpers):
4344
name = str(uuid.uuid4())
4445

4546
bulk_import_request = configured_project.upload_annotations(
@@ -50,10 +51,12 @@ def test_create_from_objects(configured_project, predictions):
5051
assert bulk_import_request.error_file_url is None
5152
assert bulk_import_request.status_file_url is None
5253
assert bulk_import_request.state == BulkImportRequestState.RUNNING
53-
assert_file_content(bulk_import_request.input_file_url, predictions)
54+
annotation_import_test_helpers.assert_file_content(
55+
bulk_import_request.input_file_url, predictions)
5456

5557

56-
def test_create_from_local_file(tmp_path, predictions, configured_project):
58+
def test_create_from_local_file(tmp_path, predictions, configured_project,
59+
annotation_import_test_helpers):
5760
name = str(uuid.uuid4())
5861
file_name = f"{name}.ndjson"
5962
file_path = tmp_path / file_name
@@ -68,7 +71,8 @@ def test_create_from_local_file(tmp_path, predictions, configured_project):
6871
assert bulk_import_request.error_file_url is None
6972
assert bulk_import_request.status_file_url is None
7073
assert bulk_import_request.state == BulkImportRequestState.RUNNING
71-
assert_file_content(bulk_import_request.input_file_url, predictions)
74+
annotation_import_test_helpers.assert_file_content(
75+
bulk_import_request.input_file_url, predictions)
7276

7377

7478
def test_get(client, configured_project):
@@ -144,11 +148,6 @@ def test_wait_till_done(rectangle_inference, configured_project):
144148
'uuid']
145149

146150

147-
def assert_file_content(url: str, predictions):
148-
response = requests.get(url)
149-
assert response.text == ndjson.dumps(predictions)
150-
151-
152151
def test_project_bulk_import_requests(client, configured_project, predictions):
153152
result = configured_project.bulk_import_requests()
154153
assert len(list(result)) == 0

0 commit comments

Comments
 (0)