Skip to content

Commit 6a60326

Browse files
author
gdj0nes
committed
FIX: type conversion
1 parent c2282f9 commit 6a60326

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

labelbox/schema/data_row_metadata.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def id(self):
4646
OptionId: Type[SchemaId] = SchemaId # enum option
4747
Number: Type[float] = float
4848

49-
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId, Number]
49+
DataRowMetadataValue = Union[Embedding, Number, DateTime, String, OptionId]
50+
# primitives used in uploads
51+
_DataRowMetadataValuePrimitives = Union[str, List, dict, float]
5052

5153

5254
class _CamelCaseMixin(BaseModel):
@@ -84,7 +86,7 @@ class DataRowMetadataBatchResponse(_CamelCaseMixin):
8486
# Bulk upsert values
8587
class _UpsertDataRowMetadataInput(_CamelCaseMixin):
8688
schema_id: str
87-
value: Union[str, List, dict]
89+
value: _DataRowMetadataValuePrimitives
8890

8991

9092
# Batch of upsert values for a datarow
@@ -121,27 +123,24 @@ def __init__(self, client):
121123

122124
self._raw_ontology = self._get_ontology()
123125

126+
def _build_ontology(self):
124127
# all fields
125-
self.fields = self._parse_ontology()
128+
self.fields = self._parse_ontology(self._raw_ontology)
126129
self.fields_by_id = self._make_id_index(self.fields)
127130

128131
# reserved fields
129132
self.reserved_fields: List[DataRowMetadataSchema] = [
130133
f for f in self.fields if f.reserved
131134
]
132135
self.reserved_by_id = self._make_id_index(self.reserved_fields)
133-
self.reserved_by_name: Dict[str, DataRowMetadataSchema] = {
134-
f.name: f for f in self.reserved_fields
135-
}
136+
self.reserved_by_name: Dict[str, DataRowMetadataSchema] = self._make_name_index(self.reserved_fields)
136137

137138
# custom fields
138139
self.custom_fields: List[DataRowMetadataSchema] = [
139140
f for f in self.fields if not f.reserved
140141
]
141142
self.custom_by_id = self._make_id_index(self.custom_fields)
142-
self.custom_by_name: Dict[str, DataRowMetadataSchema] = {
143-
f.name: f for f in self.custom_fields
144-
}
143+
self.custom_by_name: Dict[str, DataRowMetadataSchema] = self._make_name_index(self.custom_fields)
145144

146145
@staticmethod
147146
def _make_name_index(fields: List[DataRowMetadataSchema]):
@@ -184,9 +183,10 @@ def _get_ontology(self) -> List[Dict[str, Any]]:
184183
"""
185184
return self._client.execute(query)["customMetadataOntology"]
186185

187-
def _parse_ontology(self) -> List[DataRowMetadataSchema]:
186+
@staticmethod
187+
def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]:
188188
fields = []
189-
for schema in self._raw_ontology:
189+
for schema in raw_ontology:
190190
schema["uid"] = schema.pop("id")
191191
options = None
192192
if schema.get("options"):

tests/integration/test_data_row_metadata.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,34 @@
1515
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1616
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
1717

18+
FAKE_NUMBER_FIELD = {
19+
"id": FAKE_SCHEMA_ID,
20+
"name": "number",
21+
"kind": 'CustomMetadataNumber',
22+
"reserved": False
23+
}
24+
25+
"""
26+
customMetadataOntology {
27+
id
28+
name
29+
kind
30+
reserved
31+
options {
32+
id
33+
kind
34+
name
35+
reserved
36+
}
37+
}}"""
38+
1839

1940
@pytest.fixture
2041
def mdo(client):
21-
yield client.get_data_row_metadata_ontology()
42+
mdo = client.get_data_row_metadata_ontology()
43+
mdo._raw_ontology.append(FAKE_NUMBER_FIELD)
44+
mdo._build_ontology()
45+
yield mdo
2246

2347

2448
@pytest.fixture
@@ -229,7 +253,7 @@ def test_parse_raw_metadata(mdo):
229253
'schemaId': 'cko8sdzv70006h2dk8jg64zvb',
230254
'value': '2021-07-20T21:41:14.606710Z'
231255
}, {
232-
'schemaId': 'cko8sdzv70006h2dk8jg64zvb',
256+
'schemaId': FAKE_SCHEMA_ID,
233257
'value': 0.5
234258
},
235259
]
@@ -243,4 +267,4 @@ def test_parse_raw_metadata(mdo):
243267

244268
for row in parsed:
245269
for field in row.fields:
246-
assert mdo._parse_upsert(field)
270+
assert mdo._parse_upsert(field)

0 commit comments

Comments
 (0)