Skip to content

Commit 0fdd61d

Browse files
author
Matt Sokoloff
committed
update annotation import to use latest endpoints
1 parent 9d8ad39 commit 0fdd61d

File tree

5 files changed

+119
-117
lines changed

5 files changed

+119
-117
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ test-local: build
99
-e LABELBOX_TEST_API_KEY_STAGING=${LABELBOX_TEST_API_KEY_LOCAL} \
1010
local/labelbox-python:test pytest $(PATH_TO_TEST) -svvx
1111

12-
1312
test-staging: build
1413
docker run -it -v ${PWD}:/usr/src -w /usr/src \
1514
-e LABELBOX_TEST_ENVIRON="staging" \

labelbox/schema/annotation_import.py

Lines changed: 110 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from enum import Enum
2-
from labelbox.schema.enums import AnnotationImportState, ImportType
3-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Union
42
import functools
53
import os
64
import json
@@ -12,6 +10,7 @@
1210
import requests
1311

1412
import labelbox
13+
from labelbox.schema.enums import AnnotationImportState
1514
from labelbox.orm.db_object import DbObject
1615
from labelbox.orm.model import Field, Relationship
1716
from labelbox.orm import query
@@ -21,13 +20,6 @@
2120

2221

2322
class AnnotationImport(DbObject):
24-
# This class will replace BulkImportRequest.
25-
# Currently this exists for the MEA beta.
26-
# Use BulkImportRequest for now if you are not using MEA.
27-
28-
id_name: str
29-
import_type: ImportType
30-
3123
name = Field.String("name")
3224
state = Field.Enum(AnnotationImportState, "state")
3325
input_file_url = Field.String("input_file_url")
@@ -36,6 +28,10 @@ class AnnotationImport(DbObject):
3628

3729
created_by = Relationship.ToOne("User", False, "created_by")
3830

31+
parent_id: str
32+
_mutation: str
33+
_parent_id_field: str
34+
3935
@property
4036
def inputs(self) -> List[Dict[str, Any]]:
4137
"""
@@ -123,20 +119,12 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
123119
return ndjson.loads(response.text)
124120

125121
@classmethod
126-
def _build_import_predictions_query(cls, file_args: str, vars: str):
127-
raise NotImplementedError("")
128-
129-
@classmethod
130-
def validate_cls(cls):
131-
supported_base_classes = {MALPredictionImport, MEAPredictionImport}
132-
if cls not in {MALPredictionImport, MEAPredictionImport}:
133-
raise TypeError(
134-
f"Can't directly use the base AnnotationImport class. Must use one of {supported_base_classes}"
135-
)
136-
137-
@classmethod
138-
def from_name(cls, client, parent_id, name: str, raw=False):
139-
cls.validate_cls()
122+
def _from_name(cls,
123+
client: "labelbox.Client",
124+
parent_id: str,
125+
name: str,
126+
raw=False
127+
) -> Union["MEAPredictionImport", "MALPredictionImport"]:
140128
query_str = """query getImportPyApi($parent_id : ID!, $name: String!) {
141129
annotationImport(
142130
where: {%s: $parent_id, name: $name}){
@@ -145,7 +133,7 @@ def from_name(cls, client, parent_id, name: str, raw=False):
145133
... on ModelErrorAnalysisPredictionImport {%s}
146134
}}""" % \
147135
(
148-
cls.id_name,
136+
cls._parent_id_field,
149137
query.results_query_part(MALPredictionImport),
150138
query.results_query_part(MEAPredictionImport)
151139
)
@@ -159,19 +147,6 @@ def from_name(cls, client, parent_id, name: str, raw=False):
159147

160148
return cls(client, response['annotationImport'])
161149

162-
@classmethod
163-
def _create_from_url(cls, client, parent_id, name, url):
164-
file_args = "fileUrl : $fileUrl"
165-
query_str = cls._build_import_predictions_query(file_args,
166-
"$fileUrl: String!")
167-
response = client.execute(query_str,
168-
params={
169-
"fileUrl": url,
170-
"parent_id": parent_id,
171-
'name': name
172-
})
173-
return cls(client, response['createAnnotationImport'])
174-
175150
@staticmethod
176151
def _make_file_name(parent_id: str, name: str) -> str:
177152
return f"{parent_id}__{name}.ndjson"
@@ -180,131 +155,160 @@ def refresh(self) -> None:
180155
"""Synchronizes values of all fields with the database.
181156
"""
182157
cls = type(self)
183-
res = cls.from_name(self.client,
184-
self.get_parent_id(),
185-
self.name,
186-
raw=True)
158+
res = cls._from_name(self.client, self.parent_id, self.name, raw=True)
187159
self._set_field_values(res)
188160

189161
@classmethod
190-
def _create_from_bytes(cls, client, parent_id, name, bytes_data,
191-
content_len):
162+
def _create_from_bytes(
163+
cls, client: "labelbox.Client", parent_id: str, name: str,
164+
bytes_data: bytes, content_len: int
165+
) -> Union["MEAPredictionImport", "MALPredictionImport"]:
192166
file_name = cls._make_file_name(parent_id, name)
193-
file_args = """filePayload: {
194-
file: $file,
195-
contentLength: $contentLength
196-
}"""
197-
query_str = cls._build_import_predictions_query(
198-
file_args, "$file: Upload!, $contentLength: Int!")
199167
variables = {
200168
"file": None,
201169
"contentLength": content_len,
202-
"parent_id": parent_id,
170+
"parentId": parent_id,
203171
"name": name
204172
}
173+
query_str = cls._get_file_mutation()
205174
operations = json.dumps({"variables": variables, "query": query_str})
206175
data = {
207176
"operations": operations,
208177
"map": (None, json.dumps({file_name: ["variables.file"]}))
209178
}
210179
file_data = (file_name, bytes_data, NDJSON_MIME_TYPE)
211180
files = {file_name: file_data}
212-
213-
print(data)
214-
breakpoint()
215-
return client.execute(data=data, files=files)
181+
return cls(client,
182+
client.execute(data=data, files=files)[cls._mutation])
216183

217184
@classmethod
218-
def _create_from_objects(cls, client, parent_id, name, predictions):
185+
def _create_from_objects(
186+
cls, client: "labelbox.Client", parent_id: str, name: str,
187+
predictions: List[Dict[str, Any]]
188+
) -> Union["MEAPredictionImport", "MALPredictionImport"]:
219189
data_str = ndjson.dumps(predictions)
220190
if not data_str:
221191
raise ValueError('annotations cannot be empty')
222192
data = data_str.encode('utf-8')
223193
return cls._create_from_bytes(client, parent_id, name, data, len(data))
224194

225195
@classmethod
226-
def _create_from_file(cls, client, parent_id, name, path):
196+
def _create_from_url(
197+
cls, client: "labelbox.Client", parent_id: str, name: str,
198+
url: str) -> Union["MEAPredictionImport", "MALPredictionImport"]:
199+
if requests.head(url):
200+
query_str = cls._get_url_mutation()
201+
return cls(
202+
client,
203+
client.execute(query_str,
204+
params={
205+
"fileUrl": url,
206+
"parentId": parent_id,
207+
'name': name
208+
})[cls._mutation])
209+
else:
210+
raise ValueError(f"Url {url} is not reachable")
211+
212+
@classmethod
213+
def _create_from_file(
214+
cls, client: "labelbox.Client", parent_id: str, name: str,
215+
path: str) -> Union["MEAPredictionImport", "MALPredictionImport"]:
227216
if os.path.exists(path):
228217
with open(path, 'rb') as f:
229218
return cls._create_from_bytes(client, parent_id, name, f,
230219
os.stat(path).st_size)
231-
elif requests.head(path):
232-
return cls._create_from_url(client, parent_id, name, path)
233-
raise ValueError(
234-
f"Path {path} is not accessible locally or on a remote server")
235-
236-
def create_from_objects(*args, **kwargs):
237-
raise NotImplementedError("")
220+
else:
221+
raise ValueError(f"File {path} is not accessible")
238222

239-
def create_from_file(*args, **kwargs):
240-
raise NotImplementedError("")
223+
@classmethod
224+
def _get_url_mutation(cls) -> str:
225+
return """mutation create%sPyApi($parentId : ID!, $name: String!, $fileUrl: String!) {
226+
%s(data: {
227+
%s: $parentId
228+
name: $name
229+
fileUrl: $fileUrl
230+
}) {%s}
231+
}""" % (cls.__class__.__name__, cls._mutation, cls._parent_id_field,
232+
query.results_query_part(cls))
241233

242-
def get_parent_id(*args, **kwargs):
243-
raise NotImplementedError("")
234+
@classmethod
235+
def _get_file_mutation(cls) -> str:
236+
return """mutation create%sPyApi($parentId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
237+
%s(data: { %s : $parentId name: $name filePayload: { file: $file, contentLength: $contentLength}
238+
}) {%s}
239+
}""" % (cls.__class__.__name__, cls._mutation, cls._parent_id_field,
240+
query.results_query_part(cls))
244241

245242

246243
class MEAPredictionImport(AnnotationImport):
247-
id_name = "modelRunId"
248-
import_type = ImportType.MODEL_ERROR_ANALYSIS
249244
model_run_id = Field.String("model_run_id")
245+
_mutation = "createModelErrorAnalysisPredictionImport"
246+
_parent_id_field = "modelRunId"
250247

251-
def get_parent_id(self):
248+
@property
249+
def parent_id(self) -> str:
252250
return self.model_run_id
253251

254252
@classmethod
255-
def create_from_file(cls, client, model_run_id, name, path):
256-
breakpoint()
257-
return cls(client, cls._create_from_file(client=client,
253+
def create_from_file(cls, client: "labelbox.Client", model_run_id: str,
254+
name: str, path: str) -> "MEAPredictionImport":
255+
return cls._create_from_file(client=client,
258256
parent_id=model_run_id,
259257
name=name,
260-
path=path)['createModelErrorAnalysisPredictionImport'])
258+
path=path)
261259

262260
@classmethod
263-
def create_from_objects(cls, client, model_run_id, name, predictions):
264-
return cls(client, cls._create_from_objects(client, model_run_id, name, predictions)['createModelErrorAnalysisPredictionImport'])
261+
def create_from_objects(cls, client: "labelbox.Client", model_run_id: str,
262+
name, predictions) -> "MEAPredictionImport":
263+
return cls._create_from_objects(client, model_run_id, name, predictions)
265264

266265
@classmethod
267-
def _build_import_predictions_query(cls, file_args: str, vars: str):
268-
query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
269-
createModelErrorAnalysisPredictionImport(data: {
270-
%s : $parent_id
271-
name: $name
272-
%s
273-
}) {%s}
274-
}""" % (vars, cls.id_name, file_args,query.results_query_part(cls))
275-
return query_str
266+
def create_from_url(cls, client: "labelbox.Client", model_run_id: str,
267+
name: str, url: str) -> "MEAPredictionImport":
268+
return cls._create_from_url(client=client,
269+
parent_id=model_run_id,
270+
name=name,
271+
url=url)
272+
273+
@classmethod
274+
def from_name(
275+
cls, client: "labelbox.Client", model_run_id: str,
276+
name: str) -> Union["MEAPredictionImport", "MALPredictionImport"]:
277+
return cls._from_name(client, model_run_id, name)
276278

277279

278280
class MALPredictionImport(AnnotationImport):
279-
id_name = "projectId"
280-
import_type = ImportType.MODEL_ASSISTED_LABELING
281281
project = Relationship.ToOne("Project", cache=True)
282+
_mutation = "createModelAssistedLabelingPredictionImport"
283+
_parent_id_field = "projectId"
282284

283-
def get_parent_id(self):
285+
@property
286+
def parent_id(self) -> str:
284287
return self.project().uid
285288

286289
@classmethod
287-
def create_from_file(cls, client, project_id, name, path):
288-
return cls(client, cls._create_from_file(client=client,
290+
def create_from_file(cls, client: "labelbox.Client", project_id: str,
291+
name: str, path: str) -> "MALPredictionImport":
292+
return cls._create_from_file(client=client,
289293
parent_id=project_id,
290294
name=name,
291-
path=path)['createModelAssistedLabelingPredictionImport'])
295+
path=path)
292296

293297
@classmethod
294-
def create_from_objects(cls, client, project_id, name, predictions):
295-
return cls(client, cls._create_from_objects(client, project_id, name, predictions)['createModelAssistedLabelingPredictionImport'])
298+
def create_from_objects(cls, client: "labelbox.Client", project_id: str,
299+
name, predictions) -> "MALPredictionImport":
300+
return cls._create_from_objects(client, project_id, name, predictions)
296301

297302
@classmethod
298-
def _build_import_predictions_query(cls, file_args: str, vars: str):
299-
query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
300-
createModelAssistedLabelingPredictionImport(data: {
301-
%s : $parent_id
302-
name: $name
303-
%s
304-
}) {%s}
305-
}""" % (vars, cls.id_name, file_args,
306-
query.results_query_part(cls))
307-
return query_str
308-
309-
303+
def create_from_url(cls, client: "labelbox.Client", project_id: str,
304+
name: str, url: str) -> "MALPredictionImport":
305+
return cls._create_from_url(client=client,
306+
parent_id=project_id,
307+
name=name,
308+
url=url)
310309

310+
@classmethod
311+
def from_name(
312+
cls, client: "labelbox.Client", project_id: str,
313+
name: str) -> Union["MEAPredictionImport", "MALPredictionImport"]:
314+
return cls._from_name(client, project_id, name)

labelbox/schema/enums.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,3 @@ class AnnotationImportState(Enum):
4444
RUNNING = "RUNNING"
4545
FAILED = "FAILED"
4646
FINISHED = "FINISHED"
47-
48-
49-
class ImportType(Enum):
50-
MODEL_ERROR_ANALYSIS = "MODEL_ERROR_ANALYSIS"
51-
MODEL_ASSISTED_LABELING = "MODEL_ASSISTED_LABELING"

labelbox/schema/model_run.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Iterable, Union
22
from pathlib import Path
3+
import os
34

45
from labelbox.pagination import PaginatedCollection
56
from labelbox.schema.annotation_import import MEAPredictionImport
@@ -48,8 +49,12 @@ def add_predictions(
4849
"""
4950
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
5051
if isinstance(predictions, str) or isinstance(predictions, Path):
51-
return MEAPredictionImport.create_from_file(path=predictions,
52-
**kwargs)
52+
if os.path.exists(predictions):
53+
return MEAPredictionImport.create_from_file(path=predictions,
54+
**kwargs)
55+
else:
56+
return MEAPredictionImport.create_from_url(url=predictions,
57+
**kwargs)
5358
elif isinstance(predictions, Iterable):
5459
return MEAPredictionImport.create_from_objects(
5560
predictions=predictions, **kwargs)

tests/integration/bulk_import/test_mea_annotation_import.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ def test_get(client, model_run):
5959
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
6060
model_run.add_predictions(name=name, predictions=url)
6161

62-
annotation_import = MEAPredictionImport.from_name(client,
63-
parent_id=model_run.uid,
64-
name=name)
62+
annotation_import = MEAPredictionImport.from_name(
63+
client, model_run_id=model_run.uid, name=name)
6564

6665
assert annotation_import.model_run_id == model_run.uid
6766
check_running_state(annotation_import, name, url)

0 commit comments

Comments
 (0)