3
3
from dataclasses import dataclass , asdict
4
4
from typing import Optional
5
5
6
+ from pydantic import field_validator
7
+
6
8
from gempy_engine .core .data .kernel_classes .kernel_functions import AvailableKernelFunctions
7
9
from gempy_engine .core .data .kernel_classes .solvers import Solvers
8
10
9
11
10
12
@dataclass (frozen = False )
11
13
class KernelOptions :
12
- range : int | float # TODO: have constructor from RegularGrid
14
+ range : int | float # TODO: have constructor from RegularGrid
13
15
c_o : float # TODO: This should be a property
14
16
uni_degree : int = 1
15
17
i_res : float = 4.
@@ -23,6 +25,24 @@ class KernelOptions:
23
25
optimizing_condition_number : bool = False
24
26
condition_number : Optional [float ] = None
25
27
28
+ @field_validator ('kernel_function' , mode = 'before' )
29
+ @classmethod
30
+ def _deserialize_kernel_function_from_name (cls , value ):
31
+ """
32
+ Ensures that a string input (e.g., "cubic" from JSON)
33
+ is correctly converted to an AvailableKernelFunctions enum member.
34
+ """
35
+ if isinstance (value , str ):
36
+ try :
37
+ return AvailableKernelFunctions [value ] # Lookup enum member by name
38
+ except KeyError :
39
+ # This provides a more specific error if the name doesn't exist
40
+ valid_names = [member .name for member in AvailableKernelFunctions ]
41
+ raise ValueError (f"Invalid kernel function name '{ value } '. Must be one of: { valid_names } " )
42
+ # If it's already an AvailableKernelFunctions member (e.g., during direct model instantiation),
43
+ # or if it's another type that Pydantic's later validation will catch as an error.
44
+ return value
45
+
26
46
@property
27
47
def n_uni_eq (self ):
28
48
if self .uni_degree == 1 :
@@ -66,16 +86,16 @@ def update_options(self, **kwargs):
66
86
def __hash__ (self ):
67
87
# Using a tuple to hash all the values together
68
88
return hash ((
69
- self .range ,
70
- self .c_o ,
71
- self .uni_degree ,
72
- self .i_res ,
73
- self .gi_res ,
74
- self .number_dimensions ,
75
- self .kernel_function ,
76
- self .compute_condition_number ,
89
+ self .range ,
90
+ self .c_o ,
91
+ self .uni_degree ,
92
+ self .i_res ,
93
+ self .gi_res ,
94
+ self .number_dimensions ,
95
+ self .kernel_function ,
96
+ self .compute_condition_number ,
77
97
))
78
-
98
+
79
99
def __repr__ (self ):
80
100
return f"KernelOptions({ ', ' .join (f'{ k } ={ v } ' for k , v in asdict (self ).items ())} )"
81
101
0 commit comments