Skip to content

Commit d232dd8

Browse files
author
Gareth
authored
Merge pull request #304 from Labelbox/gj/fix-number-type
Metadata GA changes
2 parents ed04ce0 + 872ac4d commit d232dd8

File tree

3 files changed

+89
-59
lines changed

3 files changed

+89
-59
lines changed

examples/basics/data_row_metadata.ipynb

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,14 @@
8787
" DataRowMetadata,\n",
8888
" DataRowMetadataField,\n",
8989
" DeleteDataRowMetadata,\n",
90-
" DataRowMetadataKind\n",
9190
")\n",
9291
"from sklearn.random_projection import GaussianRandomProjection\n",
92+
"import tensorflow as tf\n",
9393
"import seaborn as sns\n",
94-
"from datetime import datetime\n",
95-
"from pprint import pprint\n",
9694
"import tensorflow_hub as hub\n",
95+
"from datetime import datetime\n",
9796
"from tqdm.notebook import tqdm\n",
9897
"import requests\n",
99-
"import tensorflow as tf\n",
10098
"from pprint import pprint"
10199
]
102100
},
@@ -154,7 +152,7 @@
154152
"outputs": [],
155153
"source": [
156154
"# dictionary access with id\n",
157-
"pprint(mdo.all_fields_id_index, indent=2)"
155+
"pprint(mdo.fields_by_id, indent=2)"
158156
]
159157
},
160158
{
@@ -167,7 +165,8 @@
167165
"outputs": [],
168166
"source": [
169167
"# access by name\n",
170-
"split_field = mdo.reserved_name_index[\"split\"]"
168+
"split_field = mdo.reserved_by_name[\"split\"]\n",
169+
"train_field = mdo.reserved_by_name[\"split\"][\"train\"]"
171170
]
172171
},
173172
{
@@ -191,7 +190,7 @@
191190
},
192191
"outputs": [],
193192
"source": [
194-
"tag_field = mdo.reserved_name_index[\"tag\"]"
193+
"tag_field = mdo.reserved_by_name[\"tag\"]"
195194
]
196195
},
197196
{
@@ -286,7 +285,7 @@
286285
"outputs": [],
287286
"source": [
288287
"field = DataRowMetadataField(\n",
289-
" schema_id=mdo.reserved_name_index[\"captureDateTime\"].id, # specify the schema id\n",
288+
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].id, # specify the schema id\n",
290289
" value=datetime.now(), # typed inputs\n",
291290
")\n",
292291
"# Completed object ready for upload\n",
@@ -356,11 +355,11 @@
356355
" # assign datarows a split\n",
357356
" rnd = random.random()\n",
358357
" if rnd < test:\n",
359-
" split = \"cko8scbz70005h2dkastwhgqt\"\n",
358+
" split = mdo.reserved_by_name[\"split\"][\"test\"]\n",
360359
" elif rnd < valid:\n",
361-
" split = \"cko8sc2yr0004h2dk69aj5x63\"\n",
360+
" split = mdo.reserved_by_name[\"split\"][\"valid\"]\n",
362361
" else:\n",
363-
" split = \"cko8sbscr0003h2dk04w86hof\"\n",
362+
" split = mdo.reserved_by_name[\"split\"][\"train\"]\n",
364363
" \n",
365364
" embeddings.append(list(model(processor(response.content), training=False)[0].numpy()))\n",
366365
" dt = datetime.utcnow() \n",
@@ -371,15 +370,15 @@
371370
" data_row_id=datarow.uid,\n",
372371
" fields=[\n",
373372
" DataRowMetadataField(\n",
374-
" schema_id=mdo.reserved_name_index[\"captureDateTime\"].id,\n",
373+
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].uid,\n",
375374
" value=dt,\n",
376375
" ),\n",
377376
" DataRowMetadataField(\n",
378-
" schema_id=mdo.reserved_name_index[\"split\"].id,\n",
377+
" schema_id=mdo.reserved_by_name[\"split\"].uid,\n",
379378
" value=split\n",
380379
" ),\n",
381380
" DataRowMetadataField(\n",
382-
" schema_id=mdo.reserved_name_index[\"tag\"].id,\n",
381+
" schema_id=mdo.reserved_by_name[\"tag\"].uid,\n",
383382
" value=message\n",
384383
" ),\n",
385384
" ]\n",
@@ -438,7 +437,7 @@
438437
"for md, embd in zip(uploads, projected):\n",
439438
" md.fields.append(\n",
440439
" DataRowMetadataField(\n",
441-
" schema_id=mdo.reserved_name_index[\"embedding\"].id,\n",
440+
" schema_id=mdo.reserved_by_name[\"embedding\"].uid,\n",
442441
" value=embd.tolist(), # convert from numpy to list\n",
443442
" ),\n",
444443
" )"
@@ -568,7 +567,7 @@
568567
"fields = []\n",
569568
"# iterate through the fields you want to delete\n",
570569
"for field in md.fields:\n",
571-
" schema = mdo.all_fields_id_index[field.schema_id]\n",
570+
" schema = mdo.field_by_index[field.schema_id]\n",
572571
" fields.append(field.schema_id)\n",
573572
"\n",
574573
"deletes = DeleteDataRowMetadata(\n",
@@ -650,4 +649,4 @@
650649
},
651650
"nbformat": 4,
652651
"nbformat_minor": 5
653-
}
652+
}

labelbox/schema/data_row_metadata.py

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# type: ignore
22
import datetime
3+
import warnings
34
from enum import Enum
45
from itertools import chain
56
from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator
@@ -9,8 +10,11 @@
910
from labelbox.schema.ontology import SchemaId
1011
from labelbox.utils import camel_case
1112

13+
_MAX_METADATA_FIELDS = 5
14+
1215

1316
class DataRowMetadataKind(Enum):
17+
number = "CustomMetadataNumber"
1418
datetime = "CustomMetadataDateTime"
1519
enum = "CustomMetadataEnum"
1620
string = "CustomMetadataString"
@@ -20,13 +24,18 @@ class DataRowMetadataKind(Enum):
2024

2125
# Metadata schema
2226
class DataRowMetadataSchema(BaseModel):
23-
id: SchemaId
27+
uid: SchemaId
2428
name: constr(strip_whitespace=True, min_length=1, max_length=100)
2529
reserved: bool
2630
kind: DataRowMetadataKind
2731
options: Optional[List["DataRowMetadataSchema"]]
2832
parent: Optional[SchemaId]
2933

34+
@property
35+
def id(self):
36+
warnings.warn("`id` is being deprecated in favor of `uid`")
37+
return self.uid
38+
3039

3140
DataRowMetadataSchema.update_forward_refs()
3241

@@ -35,8 +44,9 @@ class DataRowMetadataSchema(BaseModel):
3544
DateTime: Type[datetime.datetime] = datetime.datetime # must be in UTC
3645
String: Type[str] = constr(max_length=500)
3746
OptionId: Type[SchemaId] = SchemaId # enum option
47+
Number: Type[float] = float
3848

39-
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId]
49+
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId, Number]
4050

4151

4252
class _CamelCaseMixin(BaseModel):
@@ -106,44 +116,59 @@ class DataRowMetadataOntology:
106116
"""
107117

108118
def __init__(self, client):
109-
self.client = client
110-
self._batch_size = 50
111119

112-
# TODO: consider making these properties to stay in sync with server
120+
self._client = client
121+
self._batch_size = 50 # used for uploads and deletes
122+
113123
self._raw_ontology = self._get_ontology()
124+
114125
# all fields
115-
self.all_fields = self._parse_ontology()
116-
self.all_fields_id_index = self._make_id_index(self.all_fields)
126+
self.fields = self._parse_ontology()
127+
self.fields_by_id = self._make_id_index(self.fields)
128+
117129
# reserved fields
118130
self.reserved_fields: List[DataRowMetadataSchema] = [
119-
f for f in self.all_fields if f.reserved
131+
f for f in self.fields if f.reserved
120132
]
121-
self.reserved_id_index = self._make_id_index(self.reserved_fields)
122-
self.reserved_name_index: Dict[str, DataRowMetadataSchema] = {
133+
self.reserved_by_id = self._make_id_index(self.reserved_fields)
134+
self.reserved_by_name: Dict[str, DataRowMetadataSchema] = {
123135
f.name: f for f in self.reserved_fields
124136
}
137+
125138
# custom fields
126139
self.custom_fields: List[DataRowMetadataSchema] = [
127-
f for f in self.all_fields if not f.reserved
140+
f for f in self.fields if not f.reserved
128141
]
129-
self.custom_id_index = self._make_id_index(self.custom_fields)
130-
self.custom_name_index: Dict[str, DataRowMetadataSchema] = {
142+
self.custom_by_id = self._make_id_index(self.custom_fields)
143+
self.custom_by_name: Dict[str, DataRowMetadataSchema] = {
131144
f.name: f for f in self.custom_fields
132145
}
133146

147+
@staticmethod
148+
def _make_name_index(fields: List[DataRowMetadataSchema]):
149+
index = {}
150+
for f in fields:
151+
if f.options:
152+
index[f.name] = {}
153+
for o in f.options:
154+
index[o.name] = o
155+
else:
156+
index[f.name] = f
157+
return index
158+
134159
@staticmethod
135160
def _make_id_index(
136161
fields: List[DataRowMetadataSchema]
137162
) -> Dict[SchemaId, DataRowMetadataSchema]:
138163
index = {}
139164
for f in fields:
140-
index[f.id] = f
165+
index[f.uid] = f
141166
if f.options:
142167
for o in f.options:
143-
index[o.id] = o
168+
index[o.uid] = o
144169
return index
145170

146-
def _get_ontology(self) -> Dict[str, Any]:
171+
def _get_ontology(self) -> List[Dict[str, Any]]:
147172
query = """query GetMetadataOntologyBetaPyApi {
148173
customMetadataOntology {
149174
id
@@ -158,21 +183,24 @@ def _get_ontology(self) -> Dict[str, Any]:
158183
}
159184
}}
160185
"""
161-
return self.client.execute(query)["customMetadataOntology"]
186+
return self._client.execute(query)["customMetadataOntology"]
162187

163188
def _parse_ontology(self) -> List[DataRowMetadataSchema]:
164189
fields = []
165190
for schema in self._raw_ontology:
191+
schema["uid"] = schema.pop("id")
166192
options = None
167193
if schema.get("options"):
168-
options = [
169-
DataRowMetadataSchema(**{
170-
**option,
171-
**{
172-
"parent": schema["id"]
173-
}
174-
}) for option in schema["options"]
175-
]
194+
options = []
195+
for option in schema["options"]:
196+
option["uid"] = option.pop("id")
197+
options.append(
198+
DataRowMetadataSchema(**{
199+
**option,
200+
**{
201+
"parent": schema["uid"]
202+
}
203+
}))
176204
schema["options"] = options
177205
fields.append(DataRowMetadataSchema(**schema))
178206

@@ -184,7 +212,7 @@ def parse_metadata(
184212
Dict]]]]) -> List[DataRowMetadata]:
185213
""" Parse metadata responses
186214
187-
>>> mdo.parse_metadata([datarow.metadata])
215+
>>> mdo.parse_metadata([metdata])
188216
189217
Args:
190218
unparsed: An unparsed metadata export
@@ -200,14 +228,14 @@ def parse_metadata(
200228
for dr in unparsed:
201229
fields = []
202230
for f in dr["fields"]:
203-
schema = self.all_fields_id_index[f["schemaId"]]
231+
schema = self.fields_by_id[f["schemaId"]]
204232
if schema.kind == DataRowMetadataKind.enum:
205233
continue
206234
elif schema.kind == DataRowMetadataKind.option:
207235
field = DataRowMetadataField(schema_id=schema.parent,
208-
value=schema.id)
236+
value=schema.uid)
209237
else:
210-
field = DataRowMetadataField(schema_id=schema.id,
238+
field = DataRowMetadataField(schema_id=schema.uid,
211239
value=f["value"])
212240

213241
fields.append(field)
@@ -254,7 +282,7 @@ def _batch_upsert(
254282
}
255283
}
256284
}"""
257-
res = self.client.execute(
285+
res = self._client.execute(
258286
query, {"metadata": upserts})['upsertDataRowCustomMetadata']
259287
return [
260288
DataRowMetadataBatchResponse(data_row_id=r['dataRowId'],
@@ -265,6 +293,10 @@ def _batch_upsert(
265293

266294
items = []
267295
for m in metadata:
296+
if len(m.fields) > _MAX_METADATA_FIELDS:
297+
raise ValueError(
298+
f"Cannot upload {len(m.fields)}, the max number is {_MAX_METADATA_FIELDS}"
299+
)
268300
items.append(
269301
_UpsertBatchDataRowMetadata(
270302
data_row_id=m.data_row_id,
@@ -317,7 +349,7 @@ def _batch_delete(
317349
}
318350
}
319351
"""
320-
res = self.client.execute(
352+
res = self._client.execute(
321353
query, {"deletes": deletes})['deleteDataRowCustomMetadata']
322354
failures = []
323355
for dr in res:
@@ -360,7 +392,7 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
360392
}
361393
"""
362394
return self.parse_metadata(
363-
self.client.execute(
395+
self._client.execute(
364396
query,
365397
{"dataRowIds": _data_row_ids})['dataRowCustomMetadata'])
366398

@@ -373,11 +405,11 @@ def _parse_upsert(
373405
) -> List[_UpsertDataRowMetadataInput]:
374406
"""Format for metadata upserts to GQL"""
375407

376-
if metadatum.schema_id not in self.all_fields_id_index:
408+
if metadatum.schema_id not in self.fields_by_id:
377409
raise ValueError(
378410
f"Schema Id `{metadatum.schema_id}` not found in ontology")
379411

380-
schema = self.all_fields_id_index[metadatum.schema_id]
412+
schema = self.fields_by_id[metadatum.schema_id]
381413

382414
if schema.kind == DataRowMetadataKind.datetime:
383415
parsed = _validate_parse_datetime(metadatum)
@@ -388,7 +420,7 @@ def _parse_upsert(
388420
elif schema.kind == DataRowMetadataKind.enum:
389421
parsed = _validate_enum_parse(schema, metadatum)
390422
elif schema.kind == DataRowMetadataKind.option:
391-
raise ValueError("An option id should not be as a schema id")
423+
raise ValueError("An Option id should not be set as the Schema id")
392424
else:
393425
raise ValueError(f"Unknown type: {schema}")
394426

@@ -400,16 +432,16 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
400432

401433
deletes = set()
402434
for schema_id in delete.fields:
403-
if schema_id not in self.all_fields_id_index:
435+
if schema_id not in self.fields_by_id:
404436
raise ValueError(
405437
f"Schema Id `{schema_id}` not found in ontology")
406438

407-
schema = self.all_fields_id_index[schema_id]
439+
schema = self.fields_by_id[schema_id]
408440
# handle users specifying enums by adding all option enums
409441
if schema.kind == DataRowMetadataKind.enum:
410-
[deletes.add(o.id) for o in schema.options]
442+
[deletes.add(o.uid) for o in schema.options]
411443

412-
deletes.add(schema.id)
444+
deletes.add(schema.uid)
413445

414446
return _DeleteBatchDataRowMetadata(
415447
data_row_id=delete.data_row_id,
@@ -458,7 +490,7 @@ def _validate_enum_parse(
458490
schema: DataRowMetadataSchema,
459491
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]:
460492
if schema.options:
461-
if field.value not in {o.id for o in schema.options}:
493+
if field.value not in {o.uid for o in schema.options}:
462494
raise ValueError(
463495
f"Option `{field.value}` not found for {field.schema_id}")
464496
else:

0 commit comments

Comments
 (0)