Skip to content

Commit 3821295

Browse files
authored
[ENH] Add numpy array validator and fix type annotations (#13)
# Improved JSON Serialization and Validation for Numpy Arrays This PR adds proper JSON serialization support for numpy arrays in the Transform class by using Pydantic's Annotated type with a custom validator. The changes include: - Added a new `encoders` module with a `converters.py` file containing a numpy array validator - Updated the `Transform` class to use the new validator for position, rotation, and scale properties - Fixed the `kernel_function` field validator to properly handle JSON schema input - Updated floating point values in `interpolation_options.py` to use explicit decimal notation These changes improve the serialization/deserialization process, particularly when working with JSON data.
2 parents 4e0103f + 6b27ff1 commit 3821295

File tree

5 files changed

+36
-30
lines changed

5 files changed

+36
-30
lines changed

gempy_engine/core/data/encoders/__init__.py

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
from pydantic import BeforeValidator
3+
4+
numpy_array_short_validator = BeforeValidator(lambda v: np.array(v) if v is not None else None)

gempy_engine/core/data/options/interpolation_options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def from_args(
106106
# @on
107107

108108
@classmethod
109-
def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
109+
def init_octree_options(cls, range=1.7, c_o=10., refinement: int = 1):
110110
return InterpolationOptions.from_args(
111111
range=range,
112112
c_o=c_o,
@@ -118,7 +118,7 @@ def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
118118
def init_dense_grid_options(cls):
119119
options = InterpolationOptions.from_args(
120120
range=1.7,
121-
c_o=10,
121+
c_o=10.,
122122
mesh_extraction=False,
123123
number_octree_levels=1
124124
)

gempy_engine/core/data/options/kernel_options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class KernelOptions:
2525
optimizing_condition_number: bool = False
2626
condition_number: Optional[float] = None
2727

28-
@field_validator('kernel_function', mode='before')
28+
29+
@field_validator('kernel_function', mode='before', json_schema_input_type=str)
2930
@classmethod
3031
def _deserialize_kernel_function_from_name(cls, value):
3132
"""

gempy_engine/core/data/transforms.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import pprint
22
import warnings
3+
from dataclasses import dataclass
34
from enum import Enum, auto
4-
from typing import Optional, Union
5+
from typing import Optional
56

67
import numpy as np
7-
from dataclasses import dataclass
8+
from typing_extensions import Annotated
9+
10+
from .encoders.converters import numpy_array_short_validator
811

912

1013
class TransformOpsOrder(Enum):
@@ -13,16 +16,16 @@ class TransformOpsOrder(Enum):
1316

1417

1518
class GlobalAnisotropy(Enum):
16-
CUBE = auto() # * Transform data to be as close as possible to a cube
17-
NONE = auto() # * Do not transform data
18-
MANUAL = auto() # * Use the user defined transform
19-
19+
CUBE = auto() # * Transform data to be as close as possible to a cube
20+
NONE = auto() # * Do not transform data
21+
MANUAL = auto() # * Use the user defined transform
22+
2023

2124
@dataclass
2225
class Transform:
23-
position: np.ndarray
24-
rotation: np.ndarray
25-
scale: np.ndarray
26+
position: Annotated[np.ndarray, numpy_array_short_validator]
27+
rotation: Annotated[np.ndarray, numpy_array_short_validator]
28+
scale: Annotated[np.ndarray, numpy_array_short_validator]
2629

2730
_is_default_transform: bool = False
2831
_cached_pivot: Optional[np.ndarray] = None
@@ -68,11 +71,10 @@ def from_matrix(cls, matrix: np.ndarray):
6871
])
6972
return cls(position, rotation_degrees, scale)
7073

71-
7274
@property
7375
def cached_pivot(self):
7476
return self._cached_pivot
75-
77+
7678
@cached_pivot.setter
7779
def cached_pivot(self, pivot: np.ndarray):
7880
self._cached_pivot = pivot
@@ -96,7 +98,7 @@ def from_input_points(cls, surface_points: 'gempy.data.SurfacePointsTable', orie
9698

9799
# The scaling factor for each dimension is the inverse of its range
98100
scaling_factors = 1 / range_coord
99-
101+
100102
# ! Be careful with toy models
101103
center: np.ndarray = (max_coord + min_coord) / 2
102104
return cls(
@@ -127,14 +129,14 @@ def apply_anisotropy(self, anisotropy_type: GlobalAnisotropy, anisotropy_limit:
127129
)
128130
else:
129131
raise NotImplementedError
130-
132+
131133
@staticmethod
132134
def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
133135
# Calculate the ratios
134136
ratios = [
135-
s[0] / s[1], s[0] / s[2],
136-
s[1] / s[0], s[1] / s[2],
137-
s[2] / s[0], s[2] / s[1]
137+
s[0] / s[1], s[0] / s[2],
138+
s[1] / s[0], s[1] / s[2],
139+
s[2] / s[0], s[2] / s[1]
138140
]
139141

140142
# Adjust the scales based on the index of the max ratio
@@ -158,9 +160,9 @@ def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
158160
@staticmethod
159161
def _max_scale_ratio(s):
160162
ratios = [
161-
s[0] / s[1], s[0] / s[2],
162-
s[1] / s[0], s[1] / s[2],
163-
s[2] / s[0], s[2] / s[1]
163+
s[0] / s[1], s[0] / s[2],
164+
s[1] / s[0], s[1] / s[2],
165+
s[2] / s[0], s[2] / s[1]
164166
]
165167
return max(ratios)
166168

@@ -223,7 +225,7 @@ def apply(self, points: np.ndarray, transform_op_order: TransformOpsOrder = Tran
223225

224226
def scale_points(self, points: np.ndarray):
225227
return points * self.scale
226-
228+
227229
def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
228230
# * NOTE: to compare with legacy we would have to add 0.5 to the coords
229231
assert points.shape[1] == 3
@@ -233,12 +235,11 @@ def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrde
233235
transformed_points = (inv @ homogeneous_points.T).T
234236
return transformed_points[:, :3]
235237

236-
237238
def apply_with_cached_pivot(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
238239
if self._cached_pivot is None:
239240
raise ValueError("A pivot must be set before calling this method")
240241
return self.apply_with_pivot(points, self._cached_pivot, transform_op_order)
241-
242+
242243
def apply_inverse_with_cached_pivot(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
243244
if self._cached_pivot is None:
244245
raise ValueError("A pivot must be set before calling this method")
@@ -269,7 +270,7 @@ def apply_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
269270
def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
270271
transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
271272
assert points.shape[1] == 3
272-
273+
273274
# Translation matrices to and from the pivot
274275
T_to_origin = self._translation_matrix(-pivot[0], -pivot[1], -pivot[2])
275276
T_back = self._translation_matrix(*pivot)
@@ -284,10 +285,10 @@ def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
284285
@staticmethod
285286
def _translation_matrix(tx, ty, tz):
286287
return np.array([
287-
[1, 0, 0, tx],
288-
[0, 1, 0, ty],
289-
[0, 0, 1, tz],
290-
[0, 0, 0, 1]
288+
[1, 0, 0, tx],
289+
[0, 1, 0, ty],
290+
[0, 0, 1, tz],
291+
[0, 0, 0, 1]
291292
])
292293

293294
def transform_gradient(self, gradients: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT,

0 commit comments

Comments
 (0)