1
1
# type: ignore
2
2
import datetime
3
3
import warnings
4
+ from copy import deepcopy
4
5
from enum import Enum
5
6
from itertools import chain
6
7
from typing import List , Optional , Dict , Union , Callable , Type , Any , Generator
@@ -46,7 +47,9 @@ def id(self):
46
47
OptionId : Type [SchemaId ] = SchemaId # enum option
47
48
Number : Type [float ] = float
48
49
49
- DataRowMetadataValue = Union [Embedding , DateTime , String , OptionId , Number ]
50
+ DataRowMetadataValue = Union [Embedding , Number , DateTime , String , OptionId ]
51
+ # primitives used in uploads
52
+ _DataRowMetadataValuePrimitives = Union [str , List , dict , float ]
50
53
51
54
52
55
class _CamelCaseMixin (BaseModel ):
@@ -59,7 +62,7 @@ class Config:
59
62
# Metadata base class
60
63
class DataRowMetadataField (_CamelCaseMixin ):
61
64
schema_id : SchemaId
62
- value : DataRowMetadataValue
65
+ value : Any
63
66
64
67
65
68
class DataRowMetadata (_CamelCaseMixin ):
@@ -85,7 +88,7 @@ class DataRowMetadataBatchResponse(_CamelCaseMixin):
85
88
# Bulk upsert values
86
89
class _UpsertDataRowMetadataInput (_CamelCaseMixin ):
87
90
schema_id : str
88
- value : Union [ str , List , dict ]
91
+ value : Any
89
92
90
93
91
94
# Batch of upsert values for a datarow
@@ -121,28 +124,30 @@ def __init__(self, client):
121
124
self ._batch_size = 50 # used for uploads and deletes
122
125
123
126
self ._raw_ontology = self ._get_ontology ()
127
+ self ._build_ontology ()
124
128
129
+ def _build_ontology (self ):
125
130
# all fields
126
- self .fields = self ._parse_ontology ()
131
+ self .fields = self ._parse_ontology (self . _raw_ontology )
127
132
self .fields_by_id = self ._make_id_index (self .fields )
128
133
129
134
# reserved fields
130
135
self .reserved_fields : List [DataRowMetadataSchema ] = [
131
136
f for f in self .fields if f .reserved
132
137
]
133
138
self .reserved_by_id = self ._make_id_index (self .reserved_fields )
134
- self .reserved_by_name : Dict [str , DataRowMetadataSchema ] = {
135
- f . name : f for f in self . reserved_fields
136
- }
139
+ self .reserved_by_name : Dict [
140
+ str ,
141
+ DataRowMetadataSchema ] = self . _make_name_index ( self . reserved_fields )
137
142
138
143
# custom fields
139
144
self .custom_fields : List [DataRowMetadataSchema ] = [
140
145
f for f in self .fields if not f .reserved
141
146
]
142
147
self .custom_by_id = self ._make_id_index (self .custom_fields )
143
- self .custom_by_name : Dict [str , DataRowMetadataSchema ] = {
144
- f . name : f for f in self . custom_fields
145
- }
148
+ self .custom_by_name : Dict [
149
+ str ,
150
+ DataRowMetadataSchema ] = self . _make_name_index ( self . custom_fields )
146
151
147
152
@staticmethod
148
153
def _make_name_index (fields : List [DataRowMetadataSchema ]):
@@ -151,7 +156,7 @@ def _make_name_index(fields: List[DataRowMetadataSchema]):
151
156
if f .options :
152
157
index [f .name ] = {}
153
158
for o in f .options :
154
- index [o .name ] = o
159
+ index [f . name ][ o .name ] = o
155
160
else :
156
161
index [f .name ] = f
157
162
return index
@@ -185,15 +190,17 @@ def _get_ontology(self) -> List[Dict[str, Any]]:
185
190
"""
186
191
return self ._client .execute (query )["customMetadataOntology" ]
187
192
188
- def _parse_ontology (self ) -> List [DataRowMetadataSchema ]:
193
+ @staticmethod
194
+ def _parse_ontology (raw_ontology ) -> List [DataRowMetadataSchema ]:
189
195
fields = []
190
- for schema in self ._raw_ontology :
191
- schema ["uid" ] = schema .pop ("id" )
196
+ copy = deepcopy (raw_ontology )
197
+ for schema in copy :
198
+ schema ["uid" ] = schema ["id" ]
192
199
options = None
193
200
if schema .get ("options" ):
194
201
options = []
195
202
for option in schema ["options" ]:
196
- option ["uid" ] = option . pop ( "id" )
203
+ option ["uid" ] = option [ "id" ]
197
204
options .append (
198
205
DataRowMetadataSchema (** {
199
206
** option ,
@@ -415,6 +422,8 @@ def _parse_upsert(
415
422
parsed = _validate_parse_datetime (metadatum )
416
423
elif schema .kind == DataRowMetadataKind .string :
417
424
parsed = _validate_parse_text (metadatum )
425
+ elif schema .kind == DataRowMetadataKind .number :
426
+ parsed = _validate_parse_number (metadatum )
418
427
elif schema .kind == DataRowMetadataKind .embedding :
419
428
parsed = _validate_parse_embedding (metadatum )
420
429
elif schema .kind == DataRowMetadataKind .enum :
@@ -472,6 +481,12 @@ def _validate_parse_embedding(
472
481
return [field .dict (by_alias = True )]
473
482
474
483
484
+ def _validate_parse_number (
485
+ field : DataRowMetadataField
486
+ ) -> List [Dict [str , Union [SchemaId , Number ]]]:
487
+ return [field .dict (by_alias = True )]
488
+
489
+
475
490
def _validate_parse_datetime (
476
491
field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
477
492
# TODO: better validate tzinfo
0 commit comments