Skip to content

Commit 0fb74e5

Browse files
authored
[ENH] Add nugget optimization module with per-group optimization capability (#1036)
# Description Refactored the nugget optimization functionality into a dedicated module for better organization and reusability. The PR extracts the optimization logic from `compute_API.py` into a new `optimize_nuggets` module with improved structure and cleaner implementation. It also adds a new API function `optimize_nuggets` that allows optimizing specific structural groups rather than the entire model. The PR also includes code reorganization, moving the `_LoadDEMArtificial` class from `create_topography.py` to `topography.py` where it's more logically placed, and fixing import paths throughout the codebase. Relates to #optimization-refactoring # Checklist - [x] My code uses type hinting for function and method arguments and return values. - [x] I have created tests which cover my code. - [x] The test code either 1. demonstrates at least one valuable use case (e.g. integration tests) or 2. verifies that outputs are as expected for given inputs (e.g. unit tests). - [x] New tests pass locally with my changes.
2 parents c115fc1 + f401319 commit 0fb74e5

File tree

17 files changed

+562
-196
lines changed

17 files changed

+562
-196
lines changed

gempy/API/compute_API.py

Lines changed: 23 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
import dotenv
2+
import numpy as np
23
import os
3-
44
from typing import Optional
55

6-
import numpy as np
7-
86
import gempy_engine
9-
from gempy_engine.core.backend_tensor import BackendTensor
107
from gempy.API.gp2_gp3_compatibility.gp3_to_gp2_input import gempy3_to_gempy2
118
from gempy_engine.config import AvailableBackends
9+
from gempy_engine.core.backend_tensor import BackendTensor
1210
from gempy_engine.core.data import Solutions
13-
from gempy_engine.core.data.interpolation_input import InterpolationInput
1411
from .grid_API import set_custom_grid
12+
from ..core.data import StructuralGroup
1513
from ..core.data.gempy_engine_config import GemPyEngineConfig
1614
from ..core.data.geo_model import GeoModel
17-
from ..modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
15+
from ..modules.data_manipulation import interpolation_input_from_structural_frame
16+
from ..modules.optimize_nuggets import nugget_optimizer
1817
from ..optional_dependencies import require_gempy_legacy
1918

2019
dotenv.load_dotenv()
@@ -92,91 +91,29 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
9291
return sol.raw_arrays.custom
9392

9493

95-
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
96-
convergence_criteria: float = 1e5):
94+
def optimize_nuggets(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
95+
convergence_criteria: float = 1e5, only_groups:list[StructuralGroup] | None = None) -> GeoModel:
96+
"""
97+
Optimize the nuggets of the interpolation input of the provided model.
98+
"""
99+
97100
if engine_config.backend != AvailableBackends.PYTORCH:
98101
raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}')
99-
100-
BackendTensor.change_backend_gempy(
101-
engine_backend=engine_config.backend,
102-
use_gpu=engine_config.use_gpu,
103-
dtype=engine_config.dtype
104-
)
105-
106-
import torch
107-
from gempy_engine.core.data.continue_epoch import ContinueEpoch
108-
interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
109-
110-
geo_model.taped_interpolation_input = interpolation_input
111-
112-
nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar
113-
114-
optimizer = torch.optim.Adam(
115-
params=[nugget_effect_scalar],
116-
lr=0.01,
102+
103+
geo_model = nugget_optimizer(
104+
target_cond_num=convergence_criteria,
105+
engine_cfg=engine_config,
106+
model=geo_model,
107+
max_epochs=max_epochs,
108+
only_groups=only_groups
117109
)
118110

119-
# Optimization loop
120-
geo_model.interpolation_options.kernel_options.optimizing_condition_number = True
121-
122-
def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5):
123-
reached_conditional_target = conditional_number < conditional_number_target
124-
if reached_conditional_target == False and epoch > 10:
125-
condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old
126-
if condition_number_change < 0.01:
127-
reached_conditional_target = True
128-
return reached_conditional_target
129-
130-
previous_condition_number = 0
131-
for epoch in range(max_epochs):
132-
optimizer.zero_grad()
133-
try:
134-
# geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid
135-
136-
gempy_engine.compute_model(
137-
interpolation_input=geo_model.taped_interpolation_input,
138-
options=geo_model.interpolation_options,
139-
data_descriptor=geo_model.input_data_descriptor,
140-
geophysics_input=geo_model.geophysics_input,
141-
)
142-
except ContinueEpoch:
143-
# Get absolute values of gradients
144-
grad_magnitudes = torch.abs(nugget_effect_scalar.grad)
145-
146-
# Get indices of the 10 largest gradients
147-
grad_magnitudes.size
148-
149-
# * This ignores 90 percent of the gradients
150-
# To int
151-
n_values = int(grad_magnitudes.size()[0] * 0.9)
152-
_, indices = torch.topk(grad_magnitudes, n_values, largest=False)
153-
154-
# Zero out gradients that are not in the top 10
155-
mask = torch.ones_like(nugget_effect_scalar.grad)
156-
mask[indices] = 0
157-
nugget_effect_scalar.grad *= mask
158-
159-
# Update the vector
160-
optimizer.step()
161-
nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0
162-
163-
# optimizer.zero_grad()
164-
# Monitor progress
165-
if epoch % 1 == 0:
166-
# print(f"Epoch {epoch}: Condition Number = {condition_number.item()}")
167-
print(f"Epoch {epoch}")
168-
169-
if _check_convergence_criterion(
170-
conditional_number=geo_model.interpolation_options.kernel_options.condition_number,
171-
condition_number_old=previous_condition_number,
172-
conditional_number_target=convergence_criteria,
173-
):
174-
break
175-
previous_condition_number = geo_model.interpolation_options.kernel_options.condition_number
176-
continue
177-
178-
geo_model.interpolation_options.kernel_options.optimizing_condition_number = False
111+
return geo_model
179112

113+
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
114+
convergence_criteria: float = 1e5):
115+
116+
optimize_nuggets(geo_model, engine_config, max_epochs, convergence_criteria)
180117
geo_model.solutions = gempy_engine.compute_model(
181118
interpolation_input=geo_model.taped_interpolation_input,
182119
options=geo_model.interpolation_options,

gempy/core/data/geo_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
from gempy_engine.core.data.interpolation_input import InterpolationInput
1616
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
1717
from gempy_engine.core.data.transforms import Transform, GlobalAnisotropy
18-
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
1918
from .encoders.converters import instantiate_if_necessary
2019
from .encoders.json_geomodel_encoder import encode_numpy_array
2120
from .grid import Grid
2221
from .orientations import OrientationsTable
2322
from .structural_frame import StructuralFrame
2423
from .surface_points import SurfacePointsTable
25-
from ...modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
24+
from ...modules.data_manipulation import interpolation_input_from_structural_frame
2625

2726

2827
"""
@@ -319,6 +318,7 @@ def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: Mode
319318
# * Reset geophysics if necessary
320319
centered_grid = instance.grid.centered_grid
321320
if centered_grid is not None and instance.geophysics_input is not None:
321+
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
322322
instance.geophysics_input.tz = calculate_gravity_gradient(centered_grid)
323323

324324
return instance

gempy/core/data/grid_modules/topography.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import numpy as np
88

99
from .regular_grid import RegularGrid
10-
from ....modules.grids.create_topography import _LoadDEMArtificial
1110

12-
from ....optional_dependencies import require_skimage
11+
from ....optional_dependencies import require_skimage, require_scipy
1312
from dataclasses import field, dataclass
1413
from ..encoders.converters import short_array_type
1514

@@ -256,3 +255,107 @@ def load(self, path):
256255

257256
def load_from_saved(self, *args, **kwargs):
258257
self.load(*args, **kwargs)
258+
259+
260+
class _LoadDEMArtificial: # * Cannot think of a good reason to be a class
261+
262+
def __init__(self, grid=None, fd=2.0, extent=None, resolution=None, d_z=None):
263+
"""Class to create a random topography based on a fractal grid algorithm.
264+
265+
Args:
266+
fd: fractal dimension, defaults to 2.0
267+
d_z: maximum height difference. If none, last 20% of the model in z direction
268+
extent: extent in xy direction. If none, geo_model.grid.extent
269+
resolution: desired resolution of the topography array. If none, geo_model.grid.resolution
270+
"""
271+
self.values_2d = np.array([])
272+
self.resolution = grid.resolution[:2] if resolution is None else resolution
273+
274+
assert all(np.asarray(self.resolution) >= 2), 'The regular grid needs to be at least of size 2 on all directions.'
275+
self.extent = grid.extent if extent is None else extent
276+
277+
if d_z is None:
278+
self.d_z = np.array([self.extent[5] - (self.extent[5] - self.extent[4]) * 1 / 5, self.extent[5]])
279+
print(self.d_z)
280+
else:
281+
self.d_z = d_z
282+
283+
topo = self.fractalGrid(fd, n=self.resolution.max())
284+
topo = np.interp(topo, (topo.min(), topo.max()), self.d_z)
285+
286+
self.dem_zval = topo[:self.resolution[0], :self.resolution[1]] # crop fractal grid with resolution
287+
self.create_topo_array()
288+
289+
@staticmethod
290+
def fractalGrid(fd, n=256):
291+
"""
292+
Modified after https://github.com/samthiele/pycompass/blob/master/examples/3_Synthetic%20Examples.ipynb
293+
294+
Generate isotropic fractal surface image using
295+
spectral synthesis method [1, p.]
296+
References:
297+
1. Yuval Fisher, Michael McGuire,
298+
The Science of Fractal Images, 1988
299+
300+
(cf. http://shortrecipes.blogspot.com.au/2008/11/python-isotropic-fractal-surface.html)
301+
**Arguments**:
302+
-fd = the fractal dimension
303+
-N = the size of the fractal surface/image
304+
305+
"""
306+
h = 1 - (fd - 2)
307+
# X = np.zeros((N, N), complex)
308+
a = np.zeros((n, n), complex)
309+
powerr = -(h + 1.0) / 2.0
310+
311+
for i in range(int(n / 2) + 1):
312+
for j in range(int(n / 2) + 1):
313+
phase = 2 * np.pi * np.random.rand()
314+
315+
if i != 0 or j != 0:
316+
rad = (i * i + j * j) ** powerr * np.random.normal()
317+
else:
318+
rad = 0.0
319+
320+
a[i, j] = complex(rad * np.cos(phase), rad * np.sin(phase))
321+
322+
if i == 0:
323+
i0 = 0
324+
else:
325+
i0 = n - i
326+
327+
if j == 0:
328+
j0 = 0
329+
else:
330+
j0 = n - j
331+
332+
a[i0, j0] = complex(rad * np.cos(phase), -rad * np.sin(phase))
333+
334+
a.imag[int(n / 2)][0] = 0.0
335+
a.imag[0, int(n / 2)] = 0.0
336+
a.imag[int(n / 2)][int(n / 2)] = 0.0
337+
338+
for i in range(1, int(n / 2)):
339+
for j in range(1, int(n / 2)):
340+
phase = 2 * np.pi * np.random.rand()
341+
rad = (i * i + j * j) ** powerr * np.random.normal()
342+
a[i, n - j] = complex(rad * np.cos(phase), rad * np.sin(phase))
343+
a[n - i, j] = complex(rad * np.cos(phase), -rad * np.sin(phase))
344+
345+
scipy = require_scipy()
346+
itemp = scipy.fftpack.ifft2(a)
347+
itemp = itemp - itemp.min()
348+
349+
return itemp.real / itemp.real.max()
350+
351+
def create_topo_array(self):
352+
"""for masking the lith block"""
353+
x = np.linspace(self.extent[0], self.extent[1], self.resolution[0])
354+
y = np.linspace(self.extent[2], self.extent[3], self.resolution[1])
355+
self.x = x
356+
self.y = y
357+
xx, yy = np.meshgrid(x, y, indexing='ij')
358+
self.values_2d = np.dstack([xx, yy, self.dem_zval])
359+
360+
def get_values(self):
361+
return self.values_2d
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._engine_factory import interpolation_input_from_structural_frame

gempy/modules/data_manipulation/manipulate_points.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5-
from gempy.core.data import GeoModel, StructuralFrame, SurfacePointsTable, StructuralElement, OrientationsTable
6-
from gempy.core.data.orientations import DEFAULT_ORI_NUGGET
7-
from gempy.core.data.surface_points import DEFAULT_SP_NUGGET
5+
from ...core.data import GeoModel, StructuralFrame, SurfacePointsTable, StructuralElement, OrientationsTable
6+
from ...core.data.orientations import DEFAULT_ORI_NUGGET
7+
from ...core.data.surface_points import DEFAULT_SP_NUGGET
88

99

1010
def add_surface_points(

0 commit comments

Comments
 (0)