Skip to content

Commit 4bb28ae

Browse files
committed
[WIP/ENH/CLN] Improved the pattern to deserialize properties
1 parent d00a887 commit 4bb28ae

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

gempy/core/data/encoders/converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,30 @@ def validate_numpy_array(v):
1717
return np.array(v) if v is not None else None
1818

1919

20+
def instantiate_if_necessary(data: dict, key: str, type: type) -> None:
21+
"""
22+
Creates instances of the specified type for a dictionary key if the key exists and its
23+
current type does not match the specified type. This function modifies the dictionary
24+
in place by converting the value associated with the key into an instance of the given
25+
type.
26+
27+
This is typically used when a dictionary contains data that needs to be represented as
28+
objects of a specific class type.
29+
30+
Args:
31+
data (dict): The dictionary containing the key-value pair to inspect and possibly
32+
convert.
33+
key (str): The key in the dictionary whose value should be inspected and converted
34+
if necessary.
35+
type (type): The type to which the value of `key` should be converted, if it is not
36+
already an instance of the type.
37+
38+
Returns:
39+
None
40+
"""
41+
if key in data and not isinstance(data[key], type):
42+
data[key] = type(**data[key])
43+
2044
numpy_array_short_validator = BeforeValidator(validate_numpy_array)
2145

2246
# First, create a context variable

gempy/core/data/geo_model.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
import datetime
2-
2+
import numpy as np
33
import pprint
44
import warnings
5-
from dataclasses import dataclass, field
6-
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field, model_validator, field_validator
7-
from typing import Sequence, Optional
8-
9-
import numpy as np
5+
from dataclasses import dataclass
6+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field, model_validator, ValidationError
7+
from pydantic.functional_validators import ModelWrapValidatorHandler
8+
from typing import Sequence, Optional, Union
109

10+
from gempy_engine.core.data import InterpolationOptions
1111
from gempy_engine.core.data import Solutions
1212
from gempy_engine.core.data.engine_grid import EngineGrid
1313
from gempy_engine.core.data.geophysics_input import GeophysicsInput
14-
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
15-
from gempy_engine.core.data import InterpolationOptions
1614
from gempy_engine.core.data.input_data_descriptor import InputDataDescriptor
1715
from gempy_engine.core.data.interpolation_input import InterpolationInput
16+
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
1817
from gempy_engine.core.data.transforms import Transform, GlobalAnisotropy
18+
from .encoders.converters import instantiate_if_necessary
1919
from .encoders.json_geomodel_encoder import encode_numpy_array
20-
20+
from .grid import Grid
2121
from .orientations import OrientationsTable
22-
from .surface_points import SurfacePointsTable
2322
from .structural_frame import StructuralFrame
24-
from .grid import Grid
23+
from .surface_points import SurfacePointsTable
2524
from ...modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
2625

2726
"""
@@ -66,14 +65,15 @@ class GeoModel(BaseModel):
6665
)
6766

6867
meta: GeoModelMeta | None = Field(exclude=False) #: Meta-information about the geological model, like its name, creation and modification dates, and owner.
69-
68+
7069
# BUG: Remove None option for structural frame and meta
7170
structural_frame: Optional[StructuralFrame] | None = Field(exclude=False, default=None) #: The structural information of the geological model.
72-
grid: Grid = Field(exclude=False, default=None) #: The general grid used in the geological model.
71+
grid: Grid = Field(exclude=False, default=None) #: The general grid used in the geological model.
7372

7473
# region GemPy engine data types
7574
_interpolation_options: InterpolationOptions #: The interpolation options provided by the user.
76-
# @computed_field(alias="_interpolation_options")
75+
76+
@computed_field(alias="_interpolation_options")
7777
@property
7878
def interpolation_options(self) -> InterpolationOptions:
7979
self._infer_dense_grid_solution()
@@ -94,13 +94,26 @@ def interpolation_options(self, value):
9494
# endregion
9595
_solutions: Solutions = PrivateAttr(init=False, default=None) #: The computed solutions of the geological model.
9696

97-
def __init__(self, **data):
98-
super().__init__(**data)
99-
100-
key = "_interpolation_options"
101-
if key in data and not isinstance(data[key], InterpolationOptions):
102-
data[key] = InterpolationOptions(**data[key])
103-
self._interpolation_options = data.get("_interpolation_options")
97+
@model_validator(mode='wrap')
98+
@classmethod
99+
def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: ModelWrapValidatorHandler["GeoModel"]) -> "GeoModel":
100+
try:
101+
match data:
102+
case GeoModel():
103+
return data
104+
case dict():
105+
instance: GeoModel = constructor(data)
106+
instantiate_if_necessary(
107+
data=data,
108+
key="_interpolation_options",
109+
type=InterpolationOptions
110+
)
111+
instance._interpolation_options = data.get("_interpolation_options")
112+
return instance
113+
case _:
114+
raise ValidationError
115+
except ValidationError:
116+
raise
104117

105118
@classmethod
106119
def from_args(cls, name: str, structural_frame: StructuralFrame, grid: Grid, interpolation_options: InterpolationOptions):

gempy/core/data/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def deserialize_properties(cls, data: Union["Grid", dict], constructor: ModelWra
5858
return data
5959
case dict():
6060
grid: Grid = constructor(data)
61-
grid._active_grids = Grid.GridTypes(data["_active_grids"])
61+
grid._active_grids = Grid.GridTypes(data["active_grids"])
6262
grid._update_values()
6363
return grid
6464
case _:
6565
raise ValidationError
6666
except ValidationError:
6767
raise
6868

69-
@computed_field(alias="_active_grids")
69+
@computed_field(alias="active_grids")
7070
@property
7171
def active_grids(self) -> GridTypes:
7272
return self._active_grids

0 commit comments

Comments
 (0)