1
1
# type: ignore
2
2
import datetime
3
+ import warnings
3
4
from enum import Enum
4
5
from itertools import chain
5
6
from typing import List , Optional , Dict , Union , Callable , Type , Any , Generator
9
10
from labelbox .schema .ontology import SchemaId
10
11
from labelbox .utils import camel_case
11
12
13
+ _MAX_METADATA_FIELDS = 5
14
+
12
15
13
16
class DataRowMetadataKind (Enum ):
17
+ number = "CustomMetadataNumber"
14
18
datetime = "CustomMetadataDateTime"
15
19
enum = "CustomMetadataEnum"
16
20
string = "CustomMetadataString"
@@ -20,13 +24,18 @@ class DataRowMetadataKind(Enum):
20
24
21
25
# Metadata schema
22
26
class DataRowMetadataSchema (BaseModel ):
23
- id : SchemaId
27
+ uid : SchemaId
24
28
name : constr (strip_whitespace = True , min_length = 1 , max_length = 100 )
25
29
reserved : bool
26
30
kind : DataRowMetadataKind
27
31
options : Optional [List ["DataRowMetadataSchema" ]]
28
32
parent : Optional [SchemaId ]
29
33
34
+ @property
35
+ def id (self ):
36
+ warnings .warn ("`id` is being deprecated in favor of `uid`" )
37
+ return self .uid
38
+
30
39
31
40
DataRowMetadataSchema .update_forward_refs ()
32
41
@@ -35,8 +44,9 @@ class DataRowMetadataSchema(BaseModel):
35
44
DateTime : Type [datetime .datetime ] = datetime .datetime # must be in UTC
36
45
String : Type [str ] = constr (max_length = 500 )
37
46
OptionId : Type [SchemaId ] = SchemaId # enum option
47
+ Number : Type [float ] = float
38
48
39
- DataRowMetadataValue = Union [Embedding , DateTime , String , OptionId ]
49
+ DataRowMetadataValue = Union [Embedding , DateTime , String , OptionId , Number ]
40
50
41
51
42
52
class _CamelCaseMixin (BaseModel ):
@@ -106,44 +116,59 @@ class DataRowMetadataOntology:
106
116
"""
107
117
108
118
def __init__ (self , client ):
109
- self .client = client
110
- self ._batch_size = 50
111
119
112
- # TODO: consider making these properties to stay in sync with server
120
+ self ._client = client
121
+ self ._batch_size = 50 # used for uploads and deletes
122
+
113
123
self ._raw_ontology = self ._get_ontology ()
124
+
114
125
# all fields
115
- self .all_fields = self ._parse_ontology ()
116
- self .all_fields_id_index = self ._make_id_index (self .all_fields )
126
+ self .fields = self ._parse_ontology ()
127
+ self .fields_by_id = self ._make_id_index (self .fields )
128
+
117
129
# reserved fields
118
130
self .reserved_fields : List [DataRowMetadataSchema ] = [
119
- f for f in self .all_fields if f .reserved
131
+ f for f in self .fields if f .reserved
120
132
]
121
- self .reserved_id_index = self ._make_id_index (self .reserved_fields )
122
- self .reserved_name_index : Dict [str , DataRowMetadataSchema ] = {
133
+ self .reserved_by_id = self ._make_id_index (self .reserved_fields )
134
+ self .reserved_by_name : Dict [str , DataRowMetadataSchema ] = {
123
135
f .name : f for f in self .reserved_fields
124
136
}
137
+
125
138
# custom fields
126
139
self .custom_fields : List [DataRowMetadataSchema ] = [
127
- f for f in self .all_fields if not f .reserved
140
+ f for f in self .fields if not f .reserved
128
141
]
129
- self .custom_id_index = self ._make_id_index (self .custom_fields )
130
- self .custom_name_index : Dict [str , DataRowMetadataSchema ] = {
142
+ self .custom_by_id = self ._make_id_index (self .custom_fields )
143
+ self .custom_by_name : Dict [str , DataRowMetadataSchema ] = {
131
144
f .name : f for f in self .custom_fields
132
145
}
133
146
147
+ @staticmethod
148
+ def _make_name_index (fields : List [DataRowMetadataSchema ]):
149
+ index = {}
150
+ for f in fields :
151
+ if f .options :
152
+ index [f .name ] = {}
153
+ for o in f .options :
154
+ index [o .name ] = o
155
+ else :
156
+ index [f .name ] = f
157
+ return index
158
+
134
159
@staticmethod
135
160
def _make_id_index (
136
161
fields : List [DataRowMetadataSchema ]
137
162
) -> Dict [SchemaId , DataRowMetadataSchema ]:
138
163
index = {}
139
164
for f in fields :
140
- index [f .id ] = f
165
+ index [f .uid ] = f
141
166
if f .options :
142
167
for o in f .options :
143
- index [o .id ] = o
168
+ index [o .uid ] = o
144
169
return index
145
170
146
- def _get_ontology (self ) -> Dict [str , Any ]:
171
+ def _get_ontology (self ) -> List [ Dict [str , Any ] ]:
147
172
query = """query GetMetadataOntologyBetaPyApi {
148
173
customMetadataOntology {
149
174
id
@@ -158,21 +183,24 @@ def _get_ontology(self) -> Dict[str, Any]:
158
183
}
159
184
}}
160
185
"""
161
- return self .client .execute (query )["customMetadataOntology" ]
186
+ return self ._client .execute (query )["customMetadataOntology" ]
162
187
163
188
def _parse_ontology (self ) -> List [DataRowMetadataSchema ]:
164
189
fields = []
165
190
for schema in self ._raw_ontology :
191
+ schema ["uid" ] = schema .pop ("id" )
166
192
options = None
167
193
if schema .get ("options" ):
168
- options = [
169
- DataRowMetadataSchema (** {
170
- ** option ,
171
- ** {
172
- "parent" : schema ["id" ]
173
- }
174
- }) for option in schema ["options" ]
175
- ]
194
+ options = []
195
+ for option in schema ["options" ]:
196
+ option ["uid" ] = option .pop ("id" )
197
+ options .append (
198
+ DataRowMetadataSchema (** {
199
+ ** option ,
200
+ ** {
201
+ "parent" : schema ["uid" ]
202
+ }
203
+ }))
176
204
schema ["options" ] = options
177
205
fields .append (DataRowMetadataSchema (** schema ))
178
206
@@ -184,7 +212,7 @@ def parse_metadata(
184
212
Dict ]]]]) -> List [DataRowMetadata ]:
185
213
""" Parse metadata responses
186
214
187
- >>> mdo.parse_metadata([datarow.metadata ])
215
+ >>> mdo.parse_metadata([metdata ])
188
216
189
217
Args:
190
218
unparsed: An unparsed metadata export
@@ -200,14 +228,14 @@ def parse_metadata(
200
228
for dr in unparsed :
201
229
fields = []
202
230
for f in dr ["fields" ]:
203
- schema = self .all_fields_id_index [f ["schemaId" ]]
231
+ schema = self .fields_by_id [f ["schemaId" ]]
204
232
if schema .kind == DataRowMetadataKind .enum :
205
233
continue
206
234
elif schema .kind == DataRowMetadataKind .option :
207
235
field = DataRowMetadataField (schema_id = schema .parent ,
208
- value = schema .id )
236
+ value = schema .uid )
209
237
else :
210
- field = DataRowMetadataField (schema_id = schema .id ,
238
+ field = DataRowMetadataField (schema_id = schema .uid ,
211
239
value = f ["value" ])
212
240
213
241
fields .append (field )
@@ -254,7 +282,7 @@ def _batch_upsert(
254
282
}
255
283
}
256
284
}"""
257
- res = self .client .execute (
285
+ res = self ._client .execute (
258
286
query , {"metadata" : upserts })['upsertDataRowCustomMetadata' ]
259
287
return [
260
288
DataRowMetadataBatchResponse (data_row_id = r ['dataRowId' ],
@@ -265,6 +293,10 @@ def _batch_upsert(
265
293
266
294
items = []
267
295
for m in metadata :
296
+ if len (m .fields ) > _MAX_METADATA_FIELDS :
297
+ raise ValueError (
298
+ f"Cannot upload { len (m .fields )} , the max number is { _MAX_METADATA_FIELDS } "
299
+ )
268
300
items .append (
269
301
_UpsertBatchDataRowMetadata (
270
302
data_row_id = m .data_row_id ,
@@ -317,7 +349,7 @@ def _batch_delete(
317
349
}
318
350
}
319
351
"""
320
- res = self .client .execute (
352
+ res = self ._client .execute (
321
353
query , {"deletes" : deletes })['deleteDataRowCustomMetadata' ]
322
354
failures = []
323
355
for dr in res :
@@ -360,7 +392,7 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
360
392
}
361
393
"""
362
394
return self .parse_metadata (
363
- self .client .execute (
395
+ self ._client .execute (
364
396
query ,
365
397
{"dataRowIds" : _data_row_ids })['dataRowCustomMetadata' ])
366
398
@@ -373,11 +405,11 @@ def _parse_upsert(
373
405
) -> List [_UpsertDataRowMetadataInput ]:
374
406
"""Format for metadata upserts to GQL"""
375
407
376
- if metadatum .schema_id not in self .all_fields_id_index :
408
+ if metadatum .schema_id not in self .fields_by_id :
377
409
raise ValueError (
378
410
f"Schema Id `{ metadatum .schema_id } ` not found in ontology" )
379
411
380
- schema = self .all_fields_id_index [metadatum .schema_id ]
412
+ schema = self .fields_by_id [metadatum .schema_id ]
381
413
382
414
if schema .kind == DataRowMetadataKind .datetime :
383
415
parsed = _validate_parse_datetime (metadatum )
@@ -388,7 +420,7 @@ def _parse_upsert(
388
420
elif schema .kind == DataRowMetadataKind .enum :
389
421
parsed = _validate_enum_parse (schema , metadatum )
390
422
elif schema .kind == DataRowMetadataKind .option :
391
- raise ValueError ("An option id should not be as a schema id" )
423
+ raise ValueError ("An Option id should not be set as the Schema id" )
392
424
else :
393
425
raise ValueError (f"Unknown type: { schema } " )
394
426
@@ -400,16 +432,16 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
400
432
401
433
deletes = set ()
402
434
for schema_id in delete .fields :
403
- if schema_id not in self .all_fields_id_index :
435
+ if schema_id not in self .fields_by_id :
404
436
raise ValueError (
405
437
f"Schema Id `{ schema_id } ` not found in ontology" )
406
438
407
- schema = self .all_fields_id_index [schema_id ]
439
+ schema = self .fields_by_id [schema_id ]
408
440
# handle users specifying enums by adding all option enums
409
441
if schema .kind == DataRowMetadataKind .enum :
410
- [deletes .add (o .id ) for o in schema .options ]
442
+ [deletes .add (o .uid ) for o in schema .options ]
411
443
412
- deletes .add (schema .id )
444
+ deletes .add (schema .uid )
413
445
414
446
return _DeleteBatchDataRowMetadata (
415
447
data_row_id = delete .data_row_id ,
@@ -458,7 +490,7 @@ def _validate_enum_parse(
458
490
schema : DataRowMetadataSchema ,
459
491
field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , dict ]]]:
460
492
if schema .options :
461
- if field .value not in {o .id for o in schema .options }:
493
+ if field .value not in {o .uid for o in schema .options }:
462
494
raise ValueError (
463
495
f"Option `{ field .value } ` not found for { field .schema_id } " )
464
496
else :
0 commit comments