Skip to content

Commit 4e0103f

Browse files
authored
[ENH] Convert InterpolationOptions to Pydantic model (#12)
# Getting the first example of InterpolationOptions serializable This PR converts the `InterpolationOptions` class from a dataclass to a Pydantic model to enable serialization. The changes include: - Replaced dataclass with Pydantic's `BaseModel` for `InterpolationOptions` - Converted the constructor to a class method `from_args()` to maintain backward compatibility - Added proper type annotations and Pydantic Field definitions - Fixed a minor type issue in `octree_curvature_threshold` by adding a decimal point - Added proper model configuration for enum handling - Made `TempInterpolationValues` a proper dataclass These changes are the first step toward making the model serializable while maintaining the existing functionality.
2 parents e33d627 + 2a78b46 commit 4e0103f

File tree

17 files changed

+120
-71
lines changed

17 files changed

+120
-71
lines changed

gempy_engine/API/server/main_server_pro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# Default interpolation options
2929
range_ = 1
30-
default_interpolation_options: InterpolationOptions = InterpolationOptions(
30+
default_interpolation_options: InterpolationOptions = InterpolationOptions.from_args(
3131
range=range_,
3232
c_o=(range_ ** 2) / 14 / 3,
3333
number_octree_levels=4,

gempy_engine/core/data/options/evaluation_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class MeshExtractionMaskingOptions(enum.Enum):
1414
class EvaluationOptions:
1515
_number_octree_levels: int = 1
1616
_number_octree_levels_surface: int = 4
17-
octree_curvature_threshold: float = -1 #: Threshold to do octree refinement due to curvature to deal with angular geometries. This curvature assumes that 1 is the maximum curvature of any voxel
17+
octree_curvature_threshold: float = -1. #: Threshold to do octree refinement due to curvature to deal with angular geometries. This curvature assumes that 1 is the maximum curvature of any voxel
1818
octree_error_threshold: float = 1. #: Number of standard deviations to consider a voxel as candidate to refine
1919
octree_min_level: int = 2
2020

gempy_engine/core/data/options/interpolation_options.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import warnings
3-
from dataclasses import dataclass, asdict, field
3+
4+
from pydantic import BaseModel, ConfigDict, Field, model_validator, PrivateAttr
45

56
import gempy_engine.config
67
from .evaluation_options import MeshExtractionMaskingOptions, EvaluationOptions
@@ -10,40 +11,50 @@
1011
from ..raw_arrays_solution import RawArraysSolution
1112

1213

13-
@dataclass
14-
class InterpolationOptions:
15-
__slots__ = ['kernel_options', 'evaluation_options', 'temp_interpolation_values', 'debug',
16-
'cache_mode', 'cache_model_name', 'block_solutions_type', 'sigmoid_slope']
17-
14+
class InterpolationOptions(BaseModel):
1815
class CacheMode(enum.Enum):
1916
""" Cache mode for the interpolation"""
2017
NO_CACHE: int = enum.auto() #: No cache at all even during the interpolation computation. This is quite expensive for no good reason.
2118
CACHE = enum.auto()
2219
IN_MEMORY_CACHE = enum.auto()
2320
CLEAR_CACHE = enum.auto()
2421

22+
model_config = ConfigDict(
23+
arbitrary_types_allowed=False,
24+
use_enum_values=False,
25+
json_encoders={
26+
CacheMode: lambda e: e.value,
27+
AvailableKernelFunctions: lambda e: e.name
28+
}
29+
)
30+
2531
# @off
26-
kernel_options: KernelOptions # * This is the compression of the fields above and the way to go in the future
27-
evaluation_options: EvaluationOptions
28-
temp_interpolation_values: TempInterpolationValues
32+
kernel_options: KernelOptions = Field(init=True, exclude=False) # * This is the compression of the fields above and the way to go in the future
33+
evaluation_options: EvaluationOptions = Field(init=True, exclude= False)
2934

3035
debug: bool
3136
cache_mode: CacheMode
3237
cache_model_name: str # : Model name for the cache
33-
3438
block_solutions_type: RawArraysSolution.BlockSolutionType
35-
3639
sigmoid_slope: int
37-
3840
debug_water_tight: bool = False
3941

40-
def __init__(
41-
self,
42+
# region Volatile
43+
temp_interpolation_values: TempInterpolationValues = Field(
44+
default_factory=TempInterpolationValues,
45+
exclude=True
46+
)
47+
48+
# endregion
49+
50+
@classmethod
51+
def from_args(
52+
cls,
4253
range: int | float,
4354
c_o: float,
4455
uni_degree: int = 1,
45-
i_res: float = 4,
46-
gi_res: float = 2, # ! This should be DEP
56+
i_res: float = 4.,
57+
gi_res: float = 2., # ! This should be DEP
4758
number_dimensions: int = 3, # ? This probably too
4859
number_octree_levels: int = 1,
4960
kernel_function: AvailableKernelFunctions = AvailableKernelFunctions.cubic,
@@ -52,7 +63,7 @@ def __init__(
5263
compute_condition_number: bool = False,
5364
):
5465

55-
self.kernel_options = KernelOptions(
66+
kernel_options = KernelOptions(
5667
range=range,
5768
c_o=c_o,
5869
uni_degree=uni_degree,
@@ -63,7 +74,7 @@ def __init__(
6374
compute_condition_number=compute_condition_number
6475
)
6576

66-
self.evaluation_options = EvaluationOptions(
77+
evaluation_options = EvaluationOptions(
6778
_number_octree_levels=number_octree_levels,
6879
_number_octree_levels_surface=4,
6980
mesh_extraction=mesh_extraction,
@@ -73,18 +84,30 @@ def __init__(
7384

7485
)
7586

76-
self.temp_interpolation_values = TempInterpolationValues()
77-
self.debug = gempy_engine.config.DEBUG_MODE
78-
self.cache_mode = InterpolationOptions.CacheMode.IN_MEMORY_CACHE
79-
self.cache_model_name = ""
80-
self.block_solutions_type = RawArraysSolution.BlockSolutionType.OCTREE
81-
self.sigmoid_slope = 5_000_000
87+
temp_interpolation_values = TempInterpolationValues()
88+
debug = gempy_engine.config.DEBUG_MODE
89+
cache_mode = InterpolationOptions.CacheMode.IN_MEMORY_CACHE
90+
cache_model_name = ""
91+
block_solutions_type = RawArraysSolution.BlockSolutionType.OCTREE
92+
sigmoid_slope = 5_000_000
93+
94+
return InterpolationOptions(
95+
kernel_options=kernel_options,
96+
evaluation_options=evaluation_options,
97+
# temp_interpolation_values=temp_interpolation_values,
98+
debug=debug,
99+
cache_mode=cache_mode,
100+
cache_model_name=cache_model_name,
101+
block_solutions_type=block_solutions_type,
102+
sigmoid_slope=sigmoid_slope,
103+
debug_water_tight=False,
104+
)
82105

83106
# @on
84107

85108
@classmethod
86109
def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
87-
return InterpolationOptions(
110+
return InterpolationOptions.from_args(
88111
range=range,
89112
c_o=c_o,
90113
mesh_extraction=True,
@@ -93,7 +116,7 @@ def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
93116

94117
@classmethod
95118
def init_dense_grid_options(cls):
96-
options = InterpolationOptions(
119+
options = InterpolationOptions.from_args(
97120
range=1.7,
98121
c_o=10,
99122
mesh_extraction=False,
@@ -107,17 +130,17 @@ def probabilistic_options(cls):
107130
# TODO: This should have the sigmoid slope different
108131
raise NotImplementedError("Probabilistic interpolation is not yet implemented.")
109132

110-
def __repr__(self):
111-
return f"InterpolationOptions({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"
112-
113-
def _repr_html_(self):
114-
html = f"""
115-
<table>
116-
<tr><td colspan='2' style='text-align:center'><b>InterpolationOptions</b></td></tr>
117-
{''.join(f'<tr><td>{k}</td><td>{v._repr_html_() if isinstance(v, KernelOptions) else v}</td></tr>' for k, v in asdict(self).items())}
118-
</table>
119-
"""
120-
return html
133+
# def __repr__(self):
134+
# return f"InterpolationOptions.from_args({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"
135+
136+
# def _repr_html_(self):
137+
# html = f"""
138+
# <table>
139+
# <tr><td colspan='2' style='text-align:center'><b>InterpolationOptions</b></td></tr>
140+
# {''.join(f'<tr><td>{k}</td><td>{v._repr_html_() if isinstance(v, KernelOptions) else v}</td></tr>' for k, v in asdict(self).items())}
141+
# </table>
142+
# """
143+
# return html
121144

122145
def update_options(self, **kwargs):
123146
"""

gempy_engine/core/data/options/kernel_options.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import warnings
22

33
from dataclasses import dataclass, asdict
4+
from typing import Optional
5+
6+
from pydantic import field_validator
47

58
from gempy_engine.core.data.kernel_classes.kernel_functions import AvailableKernelFunctions
69
from gempy_engine.core.data.kernel_classes.solvers import Solvers
710

811

912
@dataclass(frozen=False)
1013
class KernelOptions:
11-
range: int # TODO: have constructor from RegularGrid
14+
range: int | float # TODO: have constructor from RegularGrid
1215
c_o: float # TODO: This should be a property
1316
uni_degree: int = 1
1417
i_res: float = 4.
@@ -20,7 +23,25 @@ class KernelOptions:
2023

2124
compute_condition_number: bool = False
2225
optimizing_condition_number: bool = False
23-
condition_number: float = None
26+
condition_number: Optional[float] = None
27+
28+
@field_validator('kernel_function', mode='before')
29+
@classmethod
30+
def _deserialize_kernel_function_from_name(cls, value):
31+
"""
32+
Ensures that a string input (e.g., "cubic" from JSON)
33+
is correctly converted to an AvailableKernelFunctions enum member.
34+
"""
35+
if isinstance(value, str):
36+
try:
37+
return AvailableKernelFunctions[value] # Lookup enum member by name
38+
except KeyError:
39+
# This provides a more specific error if the name doesn't exist
40+
valid_names = [member.name for member in AvailableKernelFunctions]
41+
raise ValueError(f"Invalid kernel function name '{value}'. Must be one of: {valid_names}")
42+
# If it's already an AvailableKernelFunctions member (e.g., during direct model instantiation),
43+
# or if it's another type that Pydantic's later validation will catch as an error.
44+
return value
2445

2546
@property
2647
def n_uni_eq(self):
@@ -65,16 +86,16 @@ def update_options(self, **kwargs):
6586
def __hash__(self):
6687
# Using a tuple to hash all the values together
6788
return hash((
68-
self.range,
69-
self.c_o,
70-
self.uni_degree,
71-
self.i_res,
72-
self.gi_res,
73-
self.number_dimensions,
74-
self.kernel_function,
75-
self.compute_condition_number,
89+
self.range,
90+
self.c_o,
91+
self.uni_degree,
92+
self.i_res,
93+
self.gi_res,
94+
self.number_dimensions,
95+
self.kernel_function,
96+
self.compute_condition_number,
7697
))
77-
98+
7899
def __repr__(self):
79100
return f"KernelOptions({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"
80101

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
15
class TempInterpolationValues:
26
current_octree_level: int = 0 # * Make this a read only property

gempy_engine/modules/kernel_constructor/_pykeops_cov_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def simple_model_2():
4040
ori_i = Orientations(dip_positions, nugget_effect_grad)
4141

4242
range = 5 ** 2
43-
kri = InterpolationOptions(range, 1, 0, i_res=1, gi_res=1,
43+
kri = InterpolationOptions.from_args(range, 1, 0, i_res=1, gi_res=1,
4444
number_dimensions=2)
4545

4646
_ = np.ones(3)

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
numpy
2+
pydantic
23
python-dotenv

tests/benchmark/one_fault_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def one_fault_model():
9393
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
9494
c_o = 1
9595

96-
options = InterpolationOptions(
96+
options = InterpolationOptions.from_args(
9797
range_, c_o,
9898
uni_degree=1,
9999
number_dimensions=3,

tests/fixtures/complex_geometries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def one_fault_model():
7777
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
7878
c_o = 1
7979

80-
options = InterpolationOptions(
80+
options = InterpolationOptions.from_args(
8181
range_, c_o,
8282
uni_degree=1,
8383
number_dimensions=3,
@@ -144,7 +144,7 @@ def one_finite_fault_model():
144144
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
145145
c_o = 1
146146

147-
options = InterpolationOptions(
147+
options = InterpolationOptions.from_args(
148148
range_, c_o,
149149
uni_degree=1,
150150
number_dimensions=3,
@@ -211,7 +211,7 @@ def graben_fault_model():
211211
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
212212
c_o = 1
213213

214-
options = InterpolationOptions(
214+
options = InterpolationOptions.from_args(
215215
range_, c_o,
216216
uni_degree=1,
217217
number_dimensions=3,

tests/fixtures/heavy_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def moureze_model_factory(path_to_root: str, pick_every=8, octree_lvls=3, solver
102102
# endregion
103103

104104
# region InterpolationOptions
105-
interpolation_options: InterpolationOptions = InterpolationOptions(
105+
interpolation_options: InterpolationOptions = InterpolationOptions.from_args(
106106
range=100.,
107107
c_o=10.,
108108
number_octree_levels=octree_lvls,

0 commit comments

Comments
 (0)