Skip to content

Commit 3c74ca0

Browse files
fix(parsing): parse extra field types
1 parent 0a33c47 commit 3c74ca0

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/codex/_models.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,18 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride]
208208
else:
209209
fields_values[name] = field_get_default(field)
210210

211+
extra_field_type = _get_extra_fields_type(__cls)
212+
211213
_extra = {}
212214
for key, value in values.items():
213215
if key not in model_fields:
216+
parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
217+
214218
if PYDANTIC_V2:
215-
_extra[key] = value
219+
_extra[key] = parsed
216220
else:
217221
_fields_set.add(key)
218-
fields_values[key] = value
222+
fields_values[key] = parsed
219223

220224
object.__setattr__(m, "__dict__", fields_values)
221225

@@ -370,6 +374,23 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
370374
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
371375

372376

377+
def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
378+
if not PYDANTIC_V2:
379+
# TODO
380+
return None
381+
382+
schema = cls.__pydantic_core_schema__
383+
if schema["type"] == "model":
384+
fields = schema["schema"]
385+
if fields["type"] == "model-fields":
386+
extras = fields.get("extras_schema")
387+
if extras and "cls" in extras:
388+
# mypy can't narrow the type
389+
return extras["cls"] # type: ignore[no-any-return]
390+
391+
return None
392+
393+
373394
def is_basemodel(type_: type) -> bool:
374395
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
375396
if is_union(type_):

tests/test_models.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, List, Union, Optional, cast
2+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
33
from datetime import datetime, timezone
44
from typing_extensions import Literal, Annotated, TypeAliasType
55

@@ -934,3 +934,30 @@ class Type2(BaseModel):
934934
)
935935
assert isinstance(model, Type1)
936936
assert isinstance(model.value, InnerType2)
937+
938+
939+
@pytest.mark.skipif(not PYDANTIC_V2, reason="this is only supported in pydantic v2 for now")
940+
def test_extra_properties() -> None:
941+
class Item(BaseModel):
942+
prop: int
943+
944+
class Model(BaseModel):
945+
__pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride]
946+
947+
other: str
948+
949+
if TYPE_CHECKING:
950+
951+
def __getattr__(self, attr: str) -> Item: ...
952+
953+
model = construct_type(
954+
type_=Model,
955+
value={
956+
"a": {"prop": 1},
957+
"other": "foo",
958+
},
959+
)
960+
assert isinstance(model, Model)
961+
assert model.a.prop == 1
962+
assert isinstance(model.a, Item)
963+
assert model.other == "foo"

0 commit comments

Comments
 (0)