Skip to content

Commit f154a53

Browse files
author
Gareth
authored
Merge pull request #306 from Labelbox/gj/fix-number-type
Gj/fix number type
2 parents dde829d + 4cb3c4b commit f154a53

File tree

3 files changed

+90
-37
lines changed

3 files changed

+90
-37
lines changed

examples/basics/data_row_metadata.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,8 @@
374374
" value=dt,\n",
375375
" ),\n",
376376
" DataRowMetadataField(\n",
377-
" schema_id=mdo.reserved_by_name[\"split\"].uid,\n",
378-
" value=split\n",
377+
" schema_id=split.parent,\n",
378+
" value=split.uid\n",
379379
" ),\n",
380380
" DataRowMetadataField(\n",
381381
" schema_id=mdo.reserved_by_name[\"tag\"].uid,\n",

labelbox/schema/data_row_metadata.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# type: ignore
22
import datetime
33
import warnings
4+
from copy import deepcopy
45
from enum import Enum
56
from itertools import chain
67
from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator
@@ -46,7 +47,9 @@ def id(self):
4647
OptionId: Type[SchemaId] = SchemaId # enum option
4748
Number: Type[float] = float
4849

49-
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId, Number]
50+
DataRowMetadataValue = Union[Embedding, Number, DateTime, String, OptionId]
51+
# primitives used in uploads
52+
_DataRowMetadataValuePrimitives = Union[str, List, dict, float]
5053

5154

5255
class _CamelCaseMixin(BaseModel):
@@ -59,7 +62,7 @@ class Config:
5962
# Metadata base class
6063
class DataRowMetadataField(_CamelCaseMixin):
6164
schema_id: SchemaId
62-
value: DataRowMetadataValue
65+
value: Any
6366

6467

6568
class DataRowMetadata(_CamelCaseMixin):
@@ -85,7 +88,7 @@ class DataRowMetadataBatchResponse(_CamelCaseMixin):
8588
# Bulk upsert values
8689
class _UpsertDataRowMetadataInput(_CamelCaseMixin):
8790
schema_id: str
88-
value: Union[str, List, dict]
91+
value: Any
8992

9093

9194
# Batch of upsert values for a datarow
@@ -121,28 +124,30 @@ def __init__(self, client):
121124
self._batch_size = 50 # used for uploads and deletes
122125

123126
self._raw_ontology = self._get_ontology()
127+
self._build_ontology()
124128

129+
def _build_ontology(self):
125130
# all fields
126-
self.fields = self._parse_ontology()
131+
self.fields = self._parse_ontology(self._raw_ontology)
127132
self.fields_by_id = self._make_id_index(self.fields)
128133

129134
# reserved fields
130135
self.reserved_fields: List[DataRowMetadataSchema] = [
131136
f for f in self.fields if f.reserved
132137
]
133138
self.reserved_by_id = self._make_id_index(self.reserved_fields)
134-
self.reserved_by_name: Dict[str, DataRowMetadataSchema] = {
135-
f.name: f for f in self.reserved_fields
136-
}
139+
self.reserved_by_name: Dict[
140+
str,
141+
DataRowMetadataSchema] = self._make_name_index(self.reserved_fields)
137142

138143
# custom fields
139144
self.custom_fields: List[DataRowMetadataSchema] = [
140145
f for f in self.fields if not f.reserved
141146
]
142147
self.custom_by_id = self._make_id_index(self.custom_fields)
143-
self.custom_by_name: Dict[str, DataRowMetadataSchema] = {
144-
f.name: f for f in self.custom_fields
145-
}
148+
self.custom_by_name: Dict[
149+
str,
150+
DataRowMetadataSchema] = self._make_name_index(self.custom_fields)
146151

147152
@staticmethod
148153
def _make_name_index(fields: List[DataRowMetadataSchema]):
@@ -151,7 +156,7 @@ def _make_name_index(fields: List[DataRowMetadataSchema]):
151156
if f.options:
152157
index[f.name] = {}
153158
for o in f.options:
154-
index[o.name] = o
159+
index[f.name][o.name] = o
155160
else:
156161
index[f.name] = f
157162
return index
@@ -185,15 +190,17 @@ def _get_ontology(self) -> List[Dict[str, Any]]:
185190
"""
186191
return self._client.execute(query)["customMetadataOntology"]
187192

188-
def _parse_ontology(self) -> List[DataRowMetadataSchema]:
193+
@staticmethod
194+
def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]:
189195
fields = []
190-
for schema in self._raw_ontology:
191-
schema["uid"] = schema.pop("id")
196+
copy = deepcopy(raw_ontology)
197+
for schema in copy:
198+
schema["uid"] = schema["id"]
192199
options = None
193200
if schema.get("options"):
194201
options = []
195202
for option in schema["options"]:
196-
option["uid"] = option.pop("id")
203+
option["uid"] = option["id"]
197204
options.append(
198205
DataRowMetadataSchema(**{
199206
**option,
@@ -415,6 +422,8 @@ def _parse_upsert(
415422
parsed = _validate_parse_datetime(metadatum)
416423
elif schema.kind == DataRowMetadataKind.string:
417424
parsed = _validate_parse_text(metadatum)
425+
elif schema.kind == DataRowMetadataKind.number:
426+
parsed = _validate_parse_number(metadatum)
418427
elif schema.kind == DataRowMetadataKind.embedding:
419428
parsed = _validate_parse_embedding(metadatum)
420429
elif schema.kind == DataRowMetadataKind.enum:
@@ -472,6 +481,12 @@ def _validate_parse_embedding(
472481
return [field.dict(by_alias=True)]
473482

474483

484+
def _validate_parse_number(
485+
field: DataRowMetadataField
486+
) -> List[Dict[str, Union[SchemaId, Number]]]:
487+
return [field.dict(by_alias=True)]
488+
489+
475490
def _validate_parse_datetime(
476491
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
477492
# TODO: better validate tzinfo

tests/integration/test_data_row_metadata.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from datetime import datetime
21
import time
2+
from datetime import datetime
33

44
import pytest
55

@@ -8,17 +8,29 @@
88
DataRowMetadataOntology
99

1010
FAKE_SCHEMA_ID = "0" * 25
11+
FAKE_DATAROW_ID = "D" * 25
1112
SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal"
1213
TRAIN_SPLIT_ID = "cko8sbscr0003h2dk04w86hof"
1314
TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt"
1415
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
1516
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1617
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
1718

19+
FAKE_NUMBER_FIELD = {
20+
"id": FAKE_SCHEMA_ID,
21+
"name": "number",
22+
"kind": 'CustomMetadataNumber',
23+
"reserved": False
24+
}
25+
1826

1927
@pytest.fixture
2028
def mdo(client):
21-
yield client.get_data_row_metadata_ontology()
29+
mdo = client.get_data_row_metadata_ontology()
30+
mdo._raw_ontology = mdo._get_ontology()
31+
mdo._raw_ontology.append(FAKE_NUMBER_FIELD)
32+
mdo._build_ontology()
33+
yield mdo
2234

2335

2436
@pytest.fixture
@@ -67,7 +79,21 @@ def make_metadata(dr_id) -> DataRowMetadata:
6779
def test_get_datarow_metadata_ontology(mdo):
6880
assert len(mdo.fields)
6981
assert len(mdo.reserved_fields)
70-
assert len(mdo.custom_fields) == 0
82+
assert len(mdo.custom_fields) == 1
83+
84+
split = mdo.reserved_by_name["split"]["train"]
85+
86+
assert DataRowMetadata(
87+
data_row_id=FAKE_DATAROW_ID,
88+
fields=[
89+
DataRowMetadataField(
90+
schema_id=mdo.reserved_by_name["captureDateTime"].uid,
91+
value=datetime.utcnow(),
92+
),
93+
DataRowMetadataField(schema_id=split.parent, value=split.uid),
94+
DataRowMetadataField(schema_id=mdo.reserved_by_name["tag"].uid,
95+
value="hello-world"),
96+
])
7197

7298

7399
def test_bulk_upsert_datarow_metadata(datarow, mdo: DataRowMetadataOntology):
@@ -127,7 +153,6 @@ def test_bulk_partial_delete_datarow_metadata(datarow, mdo):
127153

128154

129155
def test_large_bulk_delete_datarow_metadata(big_dataset, mdo):
130-
131156
metadata = []
132157
data_row_ids = [dr.uid for dr in big_dataset.data_rows()]
133158
wait_for_embeddings_svc(data_row_ids, mdo)
@@ -217,23 +242,36 @@ def test_parse_raw_metadata(mdo):
217242
example = {
218243
'dataRowId':
219244
'ckr6kkfx801ui0yrtg9fje8xh',
220-
'fields': [{
221-
'schemaId': 'cko8s9r5v0001h2dk9elqdidh',
222-
'value': 'my-new-message'
223-
}, {
224-
'schemaId': 'cko8sbczn0002h2dkdaxb5kal',
225-
'value': {}
226-
}, {
227-
'schemaId': 'cko8sbscr0003h2dk04w86hof',
228-
'value': {}
229-
}, {
230-
'schemaId': 'cko8sdzv70006h2dk8jg64zvb',
231-
'value': '2021-07-20T21:41:14.606710Z'
232-
}]
245+
'fields': [
246+
{
247+
'schemaId': 'cko8s9r5v0001h2dk9elqdidh',
248+
'value': 'my-new-message'
249+
},
250+
{
251+
'schemaId': 'cko8sbczn0002h2dkdaxb5kal',
252+
'value': {}
253+
},
254+
{
255+
'schemaId': 'cko8sbscr0003h2dk04w86hof',
256+
'value': {}
257+
},
258+
{
259+
'schemaId': 'cko8sdzv70006h2dk8jg64zvb',
260+
'value': '2021-07-20T21:41:14.606710Z'
261+
},
262+
{
263+
'schemaId': FAKE_SCHEMA_ID,
264+
'value': 0.5
265+
},
266+
]
233267
}
234268

235269
parsed = mdo.parse_metadata([example])
236270
assert len(parsed) == 1
237-
row = parsed[0]
238-
assert row.data_row_id == example["dataRowId"]
239-
assert len(row.fields) == 3
271+
for row in parsed:
272+
assert row.data_row_id == example["dataRowId"]
273+
assert len(row.fields) == 4
274+
275+
for row in parsed:
276+
for field in row.fields:
277+
assert mdo._parse_upsert(field)

0 commit comments

Comments
 (0)