1
- from enum import Enum
2
- from labelbox .schema .enums import AnnotationImportState , ImportType
3
- from typing import Any , Dict , List
1
+ from typing import Any , Dict , List , Union
4
2
import functools
5
3
import os
6
4
import json
12
10
import requests
13
11
14
12
import labelbox
13
+ from labelbox .schema .enums import AnnotationImportState
15
14
from labelbox .orm .db_object import DbObject
16
15
from labelbox .orm .model import Field , Relationship
17
16
from labelbox .orm import query
21
20
22
21
23
22
class AnnotationImport (DbObject ):
24
- # This class will replace BulkImportRequest.
25
- # Currently this exists for the MEA beta.
26
- # Use BulkImportRequest for now if you are not using MEA.
27
-
28
- id_name : str
29
- import_type : ImportType
30
-
31
23
name = Field .String ("name" )
32
24
state = Field .Enum (AnnotationImportState , "state" )
33
25
input_file_url = Field .String ("input_file_url" )
@@ -36,6 +28,10 @@ class AnnotationImport(DbObject):
36
28
37
29
created_by = Relationship .ToOne ("User" , False , "created_by" )
38
30
31
+ parent_id : str
32
+ _mutation : str
33
+ _parent_id_field : str
34
+
39
35
@property
40
36
def inputs (self ) -> List [Dict [str , Any ]]:
41
37
"""
@@ -123,20 +119,12 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
123
119
return ndjson .loads (response .text )
124
120
125
121
@classmethod
126
- def _build_import_predictions_query (cls , file_args : str , vars : str ):
127
- raise NotImplementedError ("" )
128
-
129
- @classmethod
130
- def validate_cls (cls ):
131
- supported_base_classes = {MALPredictionImport , MEAPredictionImport }
132
- if cls not in {MALPredictionImport , MEAPredictionImport }:
133
- raise TypeError (
134
- f"Can't directly use the base AnnotationImport class. Must use one of { supported_base_classes } "
135
- )
136
-
137
- @classmethod
138
- def from_name (cls , client , parent_id , name : str , raw = False ):
139
- cls .validate_cls ()
122
+ def _from_name (cls ,
123
+ client : "labelbox.Client" ,
124
+ parent_id : str ,
125
+ name : str ,
126
+ raw = False
127
+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
140
128
query_str = """query getImportPyApi($parent_id : ID!, $name: String!) {
141
129
annotationImport(
142
130
where: {%s: $parent_id, name: $name}){
@@ -145,7 +133,7 @@ def from_name(cls, client, parent_id, name: str, raw=False):
145
133
... on ModelErrorAnalysisPredictionImport {%s}
146
134
}}""" % \
147
135
(
148
- cls .id_name ,
136
+ cls ._parent_id_field ,
149
137
query .results_query_part (MALPredictionImport ),
150
138
query .results_query_part (MEAPredictionImport )
151
139
)
@@ -159,19 +147,6 @@ def from_name(cls, client, parent_id, name: str, raw=False):
159
147
160
148
return cls (client , response ['annotationImport' ])
161
149
162
- @classmethod
163
- def _create_from_url (cls , client , parent_id , name , url ):
164
- file_args = "fileUrl : $fileUrl"
165
- query_str = cls ._build_import_predictions_query (file_args ,
166
- "$fileUrl: String!" )
167
- response = client .execute (query_str ,
168
- params = {
169
- "fileUrl" : url ,
170
- "parent_id" : parent_id ,
171
- 'name' : name
172
- })
173
- return cls (client , response ['createAnnotationImport' ])
174
-
175
150
@staticmethod
176
151
def _make_file_name (parent_id : str , name : str ) -> str :
177
152
return f"{ parent_id } __{ name } .ndjson"
@@ -180,131 +155,160 @@ def refresh(self) -> None:
180
155
"""Synchronizes values of all fields with the database.
181
156
"""
182
157
cls = type (self )
183
- res = cls .from_name (self .client ,
184
- self .get_parent_id (),
185
- self .name ,
186
- raw = True )
158
+ res = cls ._from_name (self .client , self .parent_id , self .name , raw = True )
187
159
self ._set_field_values (res )
188
160
189
161
@classmethod
190
- def _create_from_bytes (cls , client , parent_id , name , bytes_data ,
191
- content_len ):
162
+ def _create_from_bytes (
163
+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
164
+ bytes_data : bytes , content_len : int
165
+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
192
166
file_name = cls ._make_file_name (parent_id , name )
193
- file_args = """filePayload: {
194
- file: $file,
195
- contentLength: $contentLength
196
- }"""
197
- query_str = cls ._build_import_predictions_query (
198
- file_args , "$file: Upload!, $contentLength: Int!" )
199
167
variables = {
200
168
"file" : None ,
201
169
"contentLength" : content_len ,
202
- "parent_id " : parent_id ,
170
+ "parentId " : parent_id ,
203
171
"name" : name
204
172
}
173
+ query_str = cls ._get_file_mutation ()
205
174
operations = json .dumps ({"variables" : variables , "query" : query_str })
206
175
data = {
207
176
"operations" : operations ,
208
177
"map" : (None , json .dumps ({file_name : ["variables.file" ]}))
209
178
}
210
179
file_data = (file_name , bytes_data , NDJSON_MIME_TYPE )
211
180
files = {file_name : file_data }
212
-
213
- print (data )
214
- breakpoint ()
215
- return client .execute (data = data , files = files )
181
+ return cls (client ,
182
+ client .execute (data = data , files = files )[cls ._mutation ])
216
183
217
184
@classmethod
218
- def _create_from_objects (cls , client , parent_id , name , predictions ):
185
+ def _create_from_objects (
186
+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
187
+ predictions : List [Dict [str , Any ]]
188
+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
219
189
data_str = ndjson .dumps (predictions )
220
190
if not data_str :
221
191
raise ValueError ('annotations cannot be empty' )
222
192
data = data_str .encode ('utf-8' )
223
193
return cls ._create_from_bytes (client , parent_id , name , data , len (data ))
224
194
225
195
@classmethod
226
- def _create_from_file (cls , client , parent_id , name , path ):
196
+ def _create_from_url (
197
+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
198
+ url : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
199
+ if requests .head (url ):
200
+ query_str = cls ._get_url_mutation ()
201
+ return cls (
202
+ client ,
203
+ client .execute (query_str ,
204
+ params = {
205
+ "fileUrl" : url ,
206
+ "parentId" : parent_id ,
207
+ 'name' : name
208
+ })[cls ._mutation ])
209
+ else :
210
+ raise ValueError (f"Url { url } is not reachable" )
211
+
212
+ @classmethod
213
+ def _create_from_file (
214
+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
215
+ path : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
227
216
if os .path .exists (path ):
228
217
with open (path , 'rb' ) as f :
229
218
return cls ._create_from_bytes (client , parent_id , name , f ,
230
219
os .stat (path ).st_size )
231
- elif requests .head (path ):
232
- return cls ._create_from_url (client , parent_id , name , path )
233
- raise ValueError (
234
- f"Path { path } is not accessible locally or on a remote server" )
235
-
236
- def create_from_objects (* args , ** kwargs ):
237
- raise NotImplementedError ("" )
220
+ else :
221
+ raise ValueError (f"File { path } is not accessible" )
238
222
239
- def create_from_file (* args , ** kwargs ):
240
- raise NotImplementedError ("" )
223
+ @classmethod
224
+ def _get_url_mutation (cls ) -> str :
225
+ return """mutation create%sPyApi($parentId : ID!, $name: String!, $fileUrl: String!) {
226
+ %s(data: {
227
+ %s: $parentId
228
+ name: $name
229
+ fileUrl: $fileUrl
230
+ }) {%s}
231
+ }""" % (cls .__class__ .__name__ , cls ._mutation , cls ._parent_id_field ,
232
+ query .results_query_part (cls ))
241
233
242
- def get_parent_id (* args , ** kwargs ):
243
- raise NotImplementedError ("" )
234
+ @classmethod
235
+ def _get_file_mutation (cls ) -> str :
236
+ return """mutation create%sPyApi($parentId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
237
+ %s(data: { %s : $parentId name: $name filePayload: { file: $file, contentLength: $contentLength}
238
+ }) {%s}
239
+ }""" % (cls .__class__ .__name__ , cls ._mutation , cls ._parent_id_field ,
240
+ query .results_query_part (cls ))
244
241
245
242
246
243
class MEAPredictionImport (AnnotationImport ):
247
- id_name = "modelRunId"
248
- import_type = ImportType .MODEL_ERROR_ANALYSIS
249
244
model_run_id = Field .String ("model_run_id" )
245
+ _mutation = "createModelErrorAnalysisPredictionImport"
246
+ _parent_id_field = "modelRunId"
250
247
251
- def get_parent_id (self ):
248
+ @property
249
+ def parent_id (self ) -> str :
252
250
return self .model_run_id
253
251
254
252
@classmethod
255
- def create_from_file (cls , client , model_run_id , name , path ):
256
- breakpoint ()
257
- return cls ( client , cls ._create_from_file (client = client ,
253
+ def create_from_file (cls , client : "labelbox.Client" , model_run_id : str ,
254
+ name : str , path : str ) -> "MEAPredictionImport" :
255
+ return cls ._create_from_file (client = client ,
258
256
parent_id = model_run_id ,
259
257
name = name ,
260
- path = path )[ 'createModelErrorAnalysisPredictionImport' ])
258
+ path = path )
261
259
262
260
@classmethod
263
- def create_from_objects (cls , client , model_run_id , name , predictions ):
264
- return cls (client , cls ._create_from_objects (client , model_run_id , name , predictions )['createModelErrorAnalysisPredictionImport' ])
261
+ def create_from_objects (cls , client : "labelbox.Client" , model_run_id : str ,
262
+ name , predictions ) -> "MEAPredictionImport" :
263
+ return cls ._create_from_objects (client , model_run_id , name , predictions )
265
264
266
265
@classmethod
267
- def _build_import_predictions_query (cls , file_args : str , vars : str ):
268
- query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
269
- createModelErrorAnalysisPredictionImport(data: {
270
- %s : $parent_id
271
- name: $name
272
- %s
273
- }) {%s}
274
- }""" % (vars , cls .id_name , file_args ,query .results_query_part (cls ))
275
- return query_str
266
+ def create_from_url (cls , client : "labelbox.Client" , model_run_id : str ,
267
+ name : str , url : str ) -> "MEAPredictionImport" :
268
+ return cls ._create_from_url (client = client ,
269
+ parent_id = model_run_id ,
270
+ name = name ,
271
+ url = url )
272
+
273
+ @classmethod
274
+ def from_name (
275
+ cls , client : "labelbox.Client" , model_run_id : str ,
276
+ name : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
277
+ return cls ._from_name (client , model_run_id , name )
276
278
277
279
278
280
class MALPredictionImport (AnnotationImport ):
279
- id_name = "projectId"
280
- import_type = ImportType .MODEL_ASSISTED_LABELING
281
281
project = Relationship .ToOne ("Project" , cache = True )
282
+ _mutation = "createModelAssistedLabelingPredictionImport"
283
+ _parent_id_field = "projectId"
282
284
283
- def get_parent_id (self ):
285
+ @property
286
+ def parent_id (self ) -> str :
284
287
return self .project ().uid
285
288
286
289
@classmethod
287
- def create_from_file (cls , client , project_id , name , path ):
288
- return cls (client , cls ._create_from_file (client = client ,
290
+ def create_from_file (cls , client : "labelbox.Client" , project_id : str ,
291
+ name : str , path : str ) -> "MALPredictionImport" :
292
+ return cls ._create_from_file (client = client ,
289
293
parent_id = project_id ,
290
294
name = name ,
291
- path = path )[ 'createModelAssistedLabelingPredictionImport' ])
295
+ path = path )
292
296
293
297
@classmethod
294
- def create_from_objects (cls , client , project_id , name , predictions ):
295
- return cls (client , cls ._create_from_objects (client , project_id , name , predictions )['createModelAssistedLabelingPredictionImport' ])
298
+ def create_from_objects (cls , client : "labelbox.Client" , project_id : str ,
299
+ name , predictions ) -> "MALPredictionImport" :
300
+ return cls ._create_from_objects (client , project_id , name , predictions )
296
301
297
302
@classmethod
298
- def _build_import_predictions_query (cls , file_args : str , vars : str ):
299
- query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
300
- createModelAssistedLabelingPredictionImport(data: {
301
- %s : $parent_id
302
- name: $name
303
- %s
304
- }) {%s}
305
- }""" % (vars , cls .id_name , file_args ,
306
- query .results_query_part (cls ))
307
- return query_str
308
-
309
-
303
+ def create_from_url (cls , client : "labelbox.Client" , project_id : str ,
304
+ name : str , url : str ) -> "MALPredictionImport" :
305
+ return cls ._create_from_url (client = client ,
306
+ parent_id = project_id ,
307
+ name = name ,
308
+ url = url )
310
309
310
+ @classmethod
311
+ def from_name (
312
+ cls , client : "labelbox.Client" , project_id : str ,
313
+ name : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
314
+ return cls ._from_name (client , project_id , name )
0 commit comments