1
1
import pprint
2
2
import warnings
3
+ from dataclasses import dataclass
3
4
from enum import Enum , auto
4
- from typing import Optional , Union
5
+ from typing import Optional
5
6
6
7
import numpy as np
7
- from dataclasses import dataclass
8
+ from typing_extensions import Annotated
9
+
10
+ from .encoders .converters import numpy_array_short_validator
8
11
9
12
10
13
class TransformOpsOrder (Enum ):
@@ -13,16 +16,16 @@ class TransformOpsOrder(Enum):
13
16
14
17
15
18
class GlobalAnisotropy (Enum ):
16
- CUBE = auto () # * Transform data to be as close as possible to a cube
17
- NONE = auto () # * Do not transform data
18
- MANUAL = auto () # * Use the user defined transform
19
-
19
+ CUBE = auto () # * Transform data to be as close as possible to a cube
20
+ NONE = auto () # * Do not transform data
21
+ MANUAL = auto () # * Use the user defined transform
22
+
20
23
21
24
@dataclass
22
25
class Transform :
23
- position : np .ndarray
24
- rotation : np .ndarray
25
- scale : np .ndarray
26
+ position : Annotated [ np .ndarray , numpy_array_short_validator ]
27
+ rotation : Annotated [ np .ndarray , numpy_array_short_validator ]
28
+ scale : Annotated [ np .ndarray , numpy_array_short_validator ]
26
29
27
30
_is_default_transform : bool = False
28
31
_cached_pivot : Optional [np .ndarray ] = None
@@ -68,11 +71,10 @@ def from_matrix(cls, matrix: np.ndarray):
68
71
])
69
72
return cls (position , rotation_degrees , scale )
70
73
71
-
72
74
@property
73
75
def cached_pivot (self ):
74
76
return self ._cached_pivot
75
-
77
+
76
78
@cached_pivot .setter
77
79
def cached_pivot (self , pivot : np .ndarray ):
78
80
self ._cached_pivot = pivot
@@ -96,7 +98,7 @@ def from_input_points(cls, surface_points: 'gempy.data.SurfacePointsTable', orie
96
98
97
99
# The scaling factor for each dimension is the inverse of its range
98
100
scaling_factors = 1 / range_coord
99
-
101
+
100
102
# ! Be careful with toy models
101
103
center : np .ndarray = (max_coord + min_coord ) / 2
102
104
return cls (
@@ -127,14 +129,14 @@ def apply_anisotropy(self, anisotropy_type: GlobalAnisotropy, anisotropy_limit:
127
129
)
128
130
else :
129
131
raise NotImplementedError
130
-
132
+
131
133
@staticmethod
132
134
def _adjust_scale_to_limit_ratio (s , anisotropic_limit = np .array ([10 , 10 , 10 ])):
133
135
# Calculate the ratios
134
136
ratios = [
135
- s [0 ] / s [1 ], s [0 ] / s [2 ],
136
- s [1 ] / s [0 ], s [1 ] / s [2 ],
137
- s [2 ] / s [0 ], s [2 ] / s [1 ]
137
+ s [0 ] / s [1 ], s [0 ] / s [2 ],
138
+ s [1 ] / s [0 ], s [1 ] / s [2 ],
139
+ s [2 ] / s [0 ], s [2 ] / s [1 ]
138
140
]
139
141
140
142
# Adjust the scales based on the index of the max ratio
@@ -158,9 +160,9 @@ def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
158
160
@staticmethod
159
161
def _max_scale_ratio (s ):
160
162
ratios = [
161
- s [0 ] / s [1 ], s [0 ] / s [2 ],
162
- s [1 ] / s [0 ], s [1 ] / s [2 ],
163
- s [2 ] / s [0 ], s [2 ] / s [1 ]
163
+ s [0 ] / s [1 ], s [0 ] / s [2 ],
164
+ s [1 ] / s [0 ], s [1 ] / s [2 ],
165
+ s [2 ] / s [0 ], s [2 ] / s [1 ]
164
166
]
165
167
return max (ratios )
166
168
@@ -223,7 +225,7 @@ def apply(self, points: np.ndarray, transform_op_order: TransformOpsOrder = Tran
223
225
224
226
def scale_points (self , points : np .ndarray ):
225
227
return points * self .scale
226
-
228
+
227
229
def apply_inverse (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
228
230
# * NOTE: to compare with legacy we would have to add 0.5 to the coords
229
231
assert points .shape [1 ] == 3
@@ -233,12 +235,11 @@ def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrde
233
235
transformed_points = (inv @ homogeneous_points .T ).T
234
236
return transformed_points [:, :3 ]
235
237
236
-
237
238
def apply_with_cached_pivot (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
238
239
if self ._cached_pivot is None :
239
240
raise ValueError ("A pivot must be set before calling this method" )
240
241
return self .apply_with_pivot (points , self ._cached_pivot , transform_op_order )
241
-
242
+
242
243
def apply_inverse_with_cached_pivot (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
243
244
if self ._cached_pivot is None :
244
245
raise ValueError ("A pivot must be set before calling this method" )
@@ -269,7 +270,7 @@ def apply_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
269
270
def apply_inverse_with_pivot (self , points : np .ndarray , pivot : np .ndarray ,
270
271
transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
271
272
assert points .shape [1 ] == 3
272
-
273
+
273
274
# Translation matrices to and from the pivot
274
275
T_to_origin = self ._translation_matrix (- pivot [0 ], - pivot [1 ], - pivot [2 ])
275
276
T_back = self ._translation_matrix (* pivot )
@@ -284,10 +285,10 @@ def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
284
285
@staticmethod
285
286
def _translation_matrix (tx , ty , tz ):
286
287
return np .array ([
287
- [1 , 0 , 0 , tx ],
288
- [0 , 1 , 0 , ty ],
289
- [0 , 0 , 1 , tz ],
290
- [0 , 0 , 0 , 1 ]
288
+ [1 , 0 , 0 , tx ],
289
+ [0 , 1 , 0 , ty ],
290
+ [0 , 0 , 1 , tz ],
291
+ [0 , 0 , 0 , 1 ]
291
292
])
292
293
293
294
def transform_gradient (self , gradients : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ,
0 commit comments