|
1 | 1 | import dotenv
|
| 2 | +import numpy as np |
2 | 3 | import os
|
3 |
| - |
4 | 4 | from typing import Optional
|
5 | 5 |
|
6 |
| -import numpy as np |
7 |
| - |
8 | 6 | import gempy_engine
|
9 |
| -from gempy_engine.core.backend_tensor import BackendTensor |
10 | 7 | from gempy.API.gp2_gp3_compatibility.gp3_to_gp2_input import gempy3_to_gempy2
|
11 | 8 | from gempy_engine.config import AvailableBackends
|
| 9 | +from gempy_engine.core.backend_tensor import BackendTensor |
12 | 10 | from gempy_engine.core.data import Solutions
|
13 |
| -from gempy_engine.core.data.interpolation_input import InterpolationInput |
14 | 11 | from .grid_API import set_custom_grid
|
| 12 | +from ..core.data import StructuralGroup |
15 | 13 | from ..core.data.gempy_engine_config import GemPyEngineConfig
|
16 | 14 | from ..core.data.geo_model import GeoModel
|
17 |
| -from ..modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame |
| 15 | +from ..modules.data_manipulation import interpolation_input_from_structural_frame |
| 16 | +from ..modules.optimize_nuggets import nugget_optimizer |
18 | 17 | from ..optional_dependencies import require_gempy_legacy
|
19 | 18 |
|
20 | 19 | dotenv.load_dotenv()
|
@@ -92,91 +91,29 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
|
92 | 91 | return sol.raw_arrays.custom
|
93 | 92 |
|
94 | 93 |
|
95 |
| -def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
96 |
| - convergence_criteria: float = 1e5): |
| 94 | +def optimize_nuggets(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
| 95 | + convergence_criteria: float = 1e5, only_groups:list[StructuralGroup] | None = None) -> GeoModel: |
| 96 | + """ |
| 97 | + Optimize the nuggets of the interpolation input of the provided model. |
| 98 | + """ |
| 99 | + |
97 | 100 | if engine_config.backend != AvailableBackends.PYTORCH:
|
98 | 101 | raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}')
|
99 |
| - |
100 |
| - BackendTensor.change_backend_gempy( |
101 |
| - engine_backend=engine_config.backend, |
102 |
| - use_gpu=engine_config.use_gpu, |
103 |
| - dtype=engine_config.dtype |
104 |
| - ) |
105 |
| - |
106 |
| - import torch |
107 |
| - from gempy_engine.core.data.continue_epoch import ContinueEpoch |
108 |
| - interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model) |
109 |
| - |
110 |
| - geo_model.taped_interpolation_input = interpolation_input |
111 |
| - |
112 |
| - nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar |
113 |
| - |
114 |
| - optimizer = torch.optim.Adam( |
115 |
| - params=[nugget_effect_scalar], |
116 |
| - lr=0.01, |
| 102 | + |
| 103 | + geo_model = nugget_optimizer( |
| 104 | + target_cond_num=convergence_criteria, |
| 105 | + engine_cfg=engine_config, |
| 106 | + model=geo_model, |
| 107 | + max_epochs=max_epochs, |
| 108 | + only_groups=only_groups |
117 | 109 | )
|
118 | 110 |
|
119 |
| - # Optimization loop |
120 |
| - geo_model.interpolation_options.kernel_options.optimizing_condition_number = True |
121 |
| - |
122 |
| - def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5): |
123 |
| - reached_conditional_target = conditional_number < conditional_number_target |
124 |
| - if reached_conditional_target == False and epoch > 10: |
125 |
| - condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old |
126 |
| - if condition_number_change < 0.01: |
127 |
| - reached_conditional_target = True |
128 |
| - return reached_conditional_target |
129 |
| - |
130 |
| - previous_condition_number = 0 |
131 |
| - for epoch in range(max_epochs): |
132 |
| - optimizer.zero_grad() |
133 |
| - try: |
134 |
| - # geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid |
135 |
| - |
136 |
| - gempy_engine.compute_model( |
137 |
| - interpolation_input=geo_model.taped_interpolation_input, |
138 |
| - options=geo_model.interpolation_options, |
139 |
| - data_descriptor=geo_model.input_data_descriptor, |
140 |
| - geophysics_input=geo_model.geophysics_input, |
141 |
| - ) |
142 |
| - except ContinueEpoch: |
143 |
| - # Get absolute values of gradients |
144 |
| - grad_magnitudes = torch.abs(nugget_effect_scalar.grad) |
145 |
| - |
146 |
| - # Get indices of the 10 largest gradients |
147 |
| - grad_magnitudes.size |
148 |
| - |
149 |
| - # * This ignores 90 percent of the gradients |
150 |
| - # To int |
151 |
| - n_values = int(grad_magnitudes.size()[0] * 0.9) |
152 |
| - _, indices = torch.topk(grad_magnitudes, n_values, largest=False) |
153 |
| - |
154 |
| - # Zero out gradients that are not in the top 10 |
155 |
| - mask = torch.ones_like(nugget_effect_scalar.grad) |
156 |
| - mask[indices] = 0 |
157 |
| - nugget_effect_scalar.grad *= mask |
158 |
| - |
159 |
| - # Update the vector |
160 |
| - optimizer.step() |
161 |
| - nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0 |
162 |
| - |
163 |
| - # optimizer.zero_grad() |
164 |
| - # Monitor progress |
165 |
| - if epoch % 1 == 0: |
166 |
| - # print(f"Epoch {epoch}: Condition Number = {condition_number.item()}") |
167 |
| - print(f"Epoch {epoch}") |
168 |
| - |
169 |
| - if _check_convergence_criterion( |
170 |
| - conditional_number=geo_model.interpolation_options.kernel_options.condition_number, |
171 |
| - condition_number_old=previous_condition_number, |
172 |
| - conditional_number_target=convergence_criteria, |
173 |
| - ): |
174 |
| - break |
175 |
| - previous_condition_number = geo_model.interpolation_options.kernel_options.condition_number |
176 |
| - continue |
177 |
| - |
178 |
| - geo_model.interpolation_options.kernel_options.optimizing_condition_number = False |
| 111 | + return geo_model |
179 | 112 |
|
| 113 | +def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
| 114 | + convergence_criteria: float = 1e5): |
| 115 | + |
| 116 | + optimize_nuggets(geo_model, engine_config, max_epochs, convergence_criteria) |
180 | 117 | geo_model.solutions = gempy_engine.compute_model(
|
181 | 118 | interpolation_input=geo_model.taped_interpolation_input,
|
182 | 119 | options=geo_model.interpolation_options,
|
|
0 commit comments