Skip to content

Commit 2b00acc

Browse files
authored
[ENH] Improve type annotations with short_array_type and field defaults (#14)
# Improved Type Annotations and Serialization Support ## [CLN] Hide temp_interpolation_values from repr output Added `repr=False` to `temp_interpolation_values` to prevent it from being displayed in object representations, improving clarity when debugging or logging. ## [ENH] Introduce short_array_type for consistent type annotations Created a reusable `short_array_type` using `Annotated[np.ndarray, numpy_array_short_validator]` to standardize array validation across the codebase. ## [ENH] Improve CenteredGrid field definitions - Replaced manual field initialization with `field(init=False)` for cleaner dataclass definition - Updated type hints to use the new `short_array_type` and modern Python syntax - Improved type annotations for radius to properly support both float and array types ## [ENH] Enhance FiniteFaultData serialization support - Added proper field definitions with exclusion for non-serializable callable - Implemented error handling for deserialized objects with missing implicit functions - Updated type annotations to use `short_array_type` for consistent validation ## [ENH] Update GeophysicsInput with proper type annotations Replaced BackendTensor references with properly annotated numpy arrays using the validation system, improving type safety and serialization support.
2 parents 8733909 + 4684526 commit 2b00acc

File tree

5 files changed

+35
-25
lines changed

5 files changed

+35
-25
lines changed

gempy_engine/core/data/centered_grid.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from dataclasses import dataclass
2-
from typing import Sequence, Union
1+
from dataclasses import dataclass, field
32

43
import numpy as np
54

6-
from gempy_engine.core.backend_tensor import BackendTensor
7-
from gempy_engine.core.utils import cast_type_inplace
5+
from .encoders.converters import short_array_type
86

97

108
@dataclass
119
class CenteredGrid:
12-
centers: np.ndarray #: This is just used to calculate xyz to interpolate. Tz is independent
13-
resolution: Sequence[float]
14-
radius: Union[float, Sequence[float]]
10+
centers: short_array_type #: This is just used to calculate xyz to interpolate. Tz is independent
11+
resolution: short_array_type
12+
radius: float | short_array_type
1513

16-
kernel_grid_centers: np.ndarray = None
17-
left_voxel_edges: np.ndarray = None
18-
right_voxel_edges: np.ndarray = None
14+
kernel_grid_centers: np.ndarray = field(init=False)
15+
left_voxel_edges: np.ndarray = field(init=False)
16+
right_voxel_edges: np.ndarray = field(init=False)
1917

2018
def __len__(self):
2119
return self.centers.shape[0] * self.kernel_grid_centers.shape[0]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import Annotated
2+
13
import numpy as np
24
from pydantic import BeforeValidator
35

46
numpy_array_short_validator = BeforeValidator(lambda v: np.array(v) if v is not None else None)
7+
short_array_type = Annotated[np.ndarray, numpy_array_short_validator]
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from dataclasses import dataclass
2+
from typing import Annotated
23

3-
from ..backend_tensor import BackendTensor
4+
import numpy as np
5+
6+
from .encoders.converters import numpy_array_short_validator
47

58

69
@dataclass
7-
class GeophysicsInput():
8-
tz: BackendTensor.t
9-
densities: BackendTensor.t
10+
class GeophysicsInput:
11+
tz: Annotated[np.ndarray, numpy_array_short_validator]
12+
densities: Annotated[np.ndarray, numpy_array_short_validator]

gempy_engine/core/data/kernel_classes/faults.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,39 @@
11
import dataclasses
2-
from typing import Optional
2+
from typing import Optional, Callable
33

44
import numpy as np
5+
from pydantic import Field
56

6-
from gempy_engine.core.data.transforms import Transform
7+
from ..encoders.converters import short_array_type
8+
from ..transforms import Transform
79

810

911
@dataclasses.dataclass
1012
class FiniteFaultData:
11-
implicit_function: callable
12-
implicit_function_transform: Transform
13-
pivot: np.ndarray
14-
13+
implicit_function: Callable | None = Field(exclude=True, default=None)#, default=None)
14+
implicit_function_transform: Transform = Field()
15+
pivot: short_array_type = Field()
16+
1517
def apply(self, points: np.ndarray) -> np.ndarray:
1618
transformed_points = self.implicit_function_transform.apply_inverse_with_pivot(
1719
points=points,
1820
pivot=self.pivot
1921
)
22+
if self.implicit_function is None:
23+
raise ValueError("No implicit function defined. This can happen after deserializing (loading).")
24+
2025
scalar_block = self.implicit_function(transformed_points)
2126
return scalar_block
2227

2328

2429

2530
@dataclasses.dataclass
2631
class FaultsData:
27-
fault_values_everywhere: np.ndarray = None
28-
fault_values_on_sp: np.ndarray = None
32+
fault_values_everywhere: short_array_type | None = None
33+
fault_values_on_sp: short_array_type | None = None
2934

30-
fault_values_ref: np.ndarray = None
31-
fault_values_rest: np.ndarray = None
35+
fault_values_ref: short_array_type | None = None
36+
fault_values_rest: short_array_type | None = None
3237

3338
# User given data:
3439
thickness: Optional[float] = None

gempy_engine/core/data/options/interpolation_options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class CacheMode(enum.Enum):
4242
# region Volatile
4343
temp_interpolation_values: TempInterpolationValues = Field(
4444
default_factory=TempInterpolationValues,
45-
exclude=True
45+
exclude=True,
46+
repr=False
4647
)
4748

4849
# endregion

0 commit comments

Comments
 (0)