Skip to content

Commit 6525b45

Browse files
committed
[BUG] Fixing some refactoring issues
1 parent 4b8a5dd commit 6525b45

File tree

14 files changed

+36
-33
lines changed

14 files changed

+36
-33
lines changed

gempy_engine/API/dual_contouring/multi_scalar_dual_contouring.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,16 @@ def _mask_generation(octree_leaves, masking_option: MeshExtractionMaskingOptions
110110
mask_matrix[i] = BackendTensor.t.ones(grid_size // 8, dtype=bool)
111111
case MeshExtractionMaskingOptions.DISJOINT, _:
112112
raise NotImplementedError("Disjoint is not supported yet. Not even sure if there is anything to support")
113-
# case (DualContouringMaskingOptions.DISJOINT | DualContouringMaskingOptions.INTERSECT, StackRelationType.FAULT):
113+
# case (MeshExtractionMaskingOptions.DISJOINT | MeshExtractionMaskingOptions.INTERSECT, StackRelationType.FAULT):
114114
# mask_matrix[i] = np.ones(grid_size//8, dtype=bool)
115-
# case DualContouringMaskingOptions.DISJOINT, StackRelationType.ERODE | StackRelationType.BASEMENT:
115+
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ERODE | StackRelationType.BASEMENT:
116116
# mask_scalar = all_scalar_fields_outputs[i - 1].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
117117
# if MaskBuffer.previous_mask is None:
118118
# mask = mask_scalar
119119
# else:
120120
# mask = (MaskBuffer.previous_mask ^ mask_scalar) * mask_scalar
121121
# MaskBuffer.previous_mask = mask
122-
# case DualContouringMaskingOptions.DISJOINT, StackRelationType.ONLAP:
122+
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ONLAP:
123123
# raise NotImplementedError("Onlap is not supported yet")
124124
# return octree_leaves.outputs_corners[n_scalar_field].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
125125
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ERODE:

gempy_engine/API/interp_single/_octree_generation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import copy
22
from typing import List
33

4-
from ...config import COMPUTE_GRADIENTS
4+
from ...core.backend_tensor import BackendTensor
5+
from ...config import COMPUTE_GRADIENTS, AvailableBackends
56
from ...core.data.regular_grid import RegularGrid
67
from ...core.data.options import InterpolationOptions
78
from ...core.data.octree_level import OctreeLevel
@@ -17,7 +18,7 @@
1718

1819
def interpolate_on_octree(interpolation_input: InterpolationInput, options: InterpolationOptions,
1920
data_shape: InputDataDescriptor) -> OctreeLevel:
20-
if COMPUTE_GRADIENTS is False:
21+
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and COMPUTE_GRADIENTS is False:
2122
interpolation_input = copy.deepcopy(interpolation_input)
2223

2324
# * Interpolate - centers

gempy_engine/API/model/model_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import copy
22
from typing import List, Optional
33

4-
from ...config import COMPUTE_GRADIENTS
4+
from ...core.backend_tensor import BackendTensor
5+
from ...config import COMPUTE_GRADIENTS, AvailableBackends
56
from ...core.data.interp_output import InterpOutput
67
from ...core.data.geophysics_input import GeophysicsInput
78
from ...modules.geophysics.fw_gravity import compute_gravity
@@ -24,7 +25,8 @@ def compute_model(interpolation_input: InterpolationInput, options: Interpolatio
2425

2526
# TODO: Make sure if this works with TF
2627
# ! If we inline this it seems the deepcopy does not work
27-
if COMPUTE_GRADIENTS is False:
28+
29+
if BackendTensor.engine_backend is AvailableBackends.PYTORCH and COMPUTE_GRADIENTS is False:
2830
interpolation_input = copy.deepcopy(interpolation_input)
2931

3032
output: list[OctreeLevel] = interpolate_n_octree_levels(

gempy_engine/API/server/main_server_pro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def compute_gempy_model(gempy_input: GemPyInput):
6464
FANCY_TRIANGULATION = True
6565
if FANCY_TRIANGULATION:
6666
default_interpolation_options.mesh_extraction_fancy = True
67-
# default_interpolation_options.mesh_extraction_masking_options = DualContouringMaskingOptions.RAW # * To Date only raw making is supported
67+
# default_interpolation_options.mesh_extraction_masking_options = MeshExtractionMaskingOptions.RAW # * To Date only raw making is supported
6868
# endregion
6969

7070
default_interpolation_options.mesh_extraction_masking_options = MeshExtractionMaskingOptions.RAW # * To Date only raw making is supported

gempy_engine/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class AvailableBackends(Flag):
2020
DEFAULT_TENSOR_DTYPE = 'float64'
2121
LINE_PROFILER_ENABLED = False
2222
SET_RAW_ARRAYS_IN_SOLUTION = True
23-
COMPUTE_GRADIENTS = True
23+
COMPUTE_GRADIENTS = False
2424

2525
is_numpy_installed = find_spec("numpy") is not None
2626
is_tensorflow_installed = find_spec("tensorflow") is not None

gempy_engine/core/data/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def update_options(self, **kwargs):
169169
current_octree_level (int, optional): Current octree level. Default is 0.
170170
compute_scalar_gradient (bool, optional): Whether to compute the scalar gradient. Default is False.
171171
dual_contouring (bool, optional): Whether to use dual contouring. Default is True.
172-
mesh_extraction_masking_options (DualContouringMaskingOptions, optional): Options for dual contouring masking.
172+
mesh_extraction_masking_options (MeshExtractionMaskingOptions, optional): Options for dual contouring masking.
173173
dual_contouring_fancy (bool, optional): Fancy version of dual contouring. Default is True.
174174
debug (bool, optional): Debug mode status. Default is derived from config.
175175
debug_water_tight (bool, optional): Debug mode for water-tight conditions. Default is False.

gempy_engine/modules/activator/activator_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray, s
1111
Z_x: np.ndarray = exported_fields.scalar_field_everywhere
1212
scalar_value_at_sp: np.ndarray = exported_fields.scalar_field_at_surface_points
1313

14-
if LEGACY := False:
14+
if LEGACY := True:
1515
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
1616
else:
1717
sigm = activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp, sigmoid_slope)

tests/benchmark/one_fault_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from gempy_engine.core.data.stacks_structure import StacksStructure
1818
from gempy_engine.core.data.interpolation_input import InterpolationInput
1919
from gempy_engine.core.data.kernel_classes.kernel_functions import AvailableKernelFunctions
20-
from gempy_engine.core.data.options import DualContouringMaskingOptions
20+
from gempy_engine.core.data.options import MeshExtractionMaskingOptions
2121
from gempy_engine.core.data.solutions import Solutions
2222

2323

@@ -26,7 +26,7 @@ def my_func():
2626

2727
options.compute_scalar_gradient = False
2828
options.dual_contouring = False
29-
options.mesh_extraction_masking_options = DualContouringMaskingOptions.RAW
29+
options.mesh_extraction_masking_options = MeshExtractionMaskingOptions.RAW
3030

3131
options.number_octree_levels = 8
3232
solutions: Solutions = compute_model(interpolation_input, options, structure)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tests.fixtures.heavy_models import *
1717

1818
pykeops_enabled = False
19-
backend = AvailableBackends.PYTORCH
19+
backend = AvailableBackends.numpy
2020
use_gpu = False
2121
plot_pyvista = False # ! Set here if you want to plot the results
2222

tests/fixtures/heavy_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def moureze_model_factory(path_to_root: str, pick_every=8, octree_lvls=3, solver
111111
# TODO: Add solver parameter
112112
interpolation_options.kernel_options.kernel_solver = solver
113113

114-
from gempy_engine.core.data.options import DualContouringMaskingOptions
115-
interpolation_options.mesh_extraction_masking_options = DualContouringMaskingOptions.RAW
114+
from gempy_engine.core.data.options import MeshExtractionMaskingOptions
115+
interpolation_options.mesh_extraction_masking_options = MeshExtractionMaskingOptions.RAW
116116

117117
# endregion
118118
# region InputDataDescriptor

0 commit comments

Comments
 (0)