Skip to content

Commit c6b728b

Browse files
committed
[CLN] Cleaning the tests
1 parent 3a8fdd6 commit c6b728b

File tree

2 files changed

+33
-48
lines changed

2 files changed

+33
-48
lines changed

gempy_engine/modules/activator/activator_interface.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def soft_segment_unbounded(Z, edges, ids, sigmoid_slope):
124124

125125
# weighted sum by the ids
126126
ids__sum = (membership * ids).sum(dim=-1)
127-
return np.atleast_2d(ids__sum.numpy())
127+
128+
# make it at least 2d
129+
ids__sum = ids__sum[None, :]
130+
131+
return ids__sum
128132

129133

130134
import numpy as np

tests/test_common/test_modules/test_activator_fns.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,83 +17,66 @@
1717

1818

1919
def test_activator_3_layers_segmentation_function(simple_model_3_layers, simple_grid_3d_more_points_grid):
20-
interpolation_input = simple_model_3_layers[0]
21-
options = simple_model_3_layers[1]
22-
data_shape = simple_model_3_layers[2].tensors_structure
23-
grid = dataclasses.replace(simple_grid_3d_more_points_grid)
24-
interpolation_input.set_temp_grid(grid)
25-
26-
ids = np.array([1, 2, 3, 4])
20+
Z_x, grid, ids_block, interpolation_input = _run_test(
21+
backend=AvailableBackends.numpy,
22+
ids=np.array([1, 20, 3, 4]),
23+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
24+
simple_model_3_layers=simple_model_3_layers
25+
)
2726

28-
interp_input: SolverInput = input_preprocess(data_shape, interpolation_input)
29-
weights = _solve_interpolation(interp_input, options.kernel_options)
27+
if plot:
28+
_plot_continious(grid, ids_block, interpolation_input)
3029

31-
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
32-
exported_fields.set_structure_values(
33-
reference_sp_position=data_shape.reference_sp_position,
34-
slice_feature=interpolation_input.slice_feature,
35-
grid_size=interpolation_input.grid.len_all_grids)
3630

37-
Z_x: np.ndarray = exported_fields.scalar_field
38-
sasp = exported_fields.scalar_field_at_surface_points
39-
ids = np.array([1, 20, 3, 4])
31+
def test_activator_3_layers_segmentation_function_II(simple_model_3_layers, simple_grid_3d_more_points_grid):
32+
Z_x, grid, ids_block, interpolation_input = _run_test(
33+
backend=AvailableBackends.numpy,
34+
ids=np.array([1, 2, 3, 4]),
35+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
36+
simple_model_3_layers=simple_model_3_layers
37+
)
4038

41-
print(Z_x, Z_x.shape[0])
42-
print(sasp)
39+
BackendTensor.change_backend_gempy(AvailableBackends.numpy)
4340

44-
ids_block = activate_formation_block(
45-
exported_fields=exported_fields,
46-
ids=ids,
47-
sigmoid_slope = 500*4
48-
)[0, :-7]
41+
if plot:
42+
_plot_continious(grid, ids_block, interpolation_input)
4943

50-
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
51-
ids_block = ids_block.detach().numpy()
52-
Z_x = Z_x.detach().numpy()
53-
interpolation_input.surface_points.sp_coords = interpolation_input.surface_points.sp_coords.detach().numpy()
5444

45+
def test_activator_3_layers_segmentation_function_torch(simple_model_3_layers, simple_grid_3d_more_points_grid):
46+
Z_x, grid, ids_block, interpolation_input = _run_test(
47+
backend=AvailableBackends.PYTORCH,
48+
ids=np.array([1, 2, 3, 4]),
49+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
50+
simple_model_3_layers=simple_model_3_layers
51+
)
5552
if plot:
5653
_plot_continious(grid, ids_block, interpolation_input)
5754

5855

59-
def test_activator_3_layers_segmentation_function_II(simple_model_3_layers, simple_grid_3d_more_points_grid):
56+
def _run_test(backend, ids, simple_grid_3d_more_points_grid, simple_model_3_layers):
6057
interpolation_input = simple_model_3_layers[0]
6158
options = simple_model_3_layers[1]
6259
data_shape = simple_model_3_layers[2].tensors_structure
6360
grid = dataclasses.replace(simple_grid_3d_more_points_grid)
6461
interpolation_input.set_temp_grid(grid)
65-
6662
interp_input: SolverInput = input_preprocess(data_shape, interpolation_input)
6763
weights = _solve_interpolation(interp_input, options.kernel_options)
68-
6964
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
7065
exported_fields.set_structure_values(
7166
reference_sp_position=data_shape.reference_sp_position,
7267
slice_feature=interpolation_input.slice_feature,
7368
grid_size=interpolation_input.grid.len_all_grids)
74-
7569
Z_x: np.ndarray = exported_fields.scalar_field
7670
sasp = exported_fields.scalar_field_at_surface_points
77-
ids = np.array([1, 2, 3, 4])
78-
7971
print(Z_x, Z_x.shape[0])
8072
print(sasp)
81-
82-
BackendTensor.change_backend_gempy(AvailableBackends.numpy)
73+
BackendTensor.change_backend_gempy(backend)
8374
ids_block = activate_formation_block(
8475
exported_fields=exported_fields,
8576
ids=ids,
8677
sigmoid_slope=500 * 4
8778
)[0, :-7]
88-
89-
BackendTensor.change_backend_gempy(AvailableBackends.numpy)
90-
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
91-
ids_block = ids_block.detach().numpy()
92-
Z_x = Z_x.detach().numpy()
93-
interpolation_input.surface_points.sp_coords = interpolation_input.surface_points.sp_coords.detach().numpy()
94-
95-
if plot:
96-
_plot_continious(grid, ids_block, interpolation_input)
79+
return Z_x, grid, ids_block, interpolation_input
9780

9881

9982
def _plot_continious(grid, ids_block, interpolation_input):
@@ -113,5 +96,3 @@ def _plot_continious(grid, ids_block, interpolation_input):
11396
plt.plot(xyz[:, 0], xyz[:, 2], "o")
11497
plt.colorbar()
11598
plt.show()
116-
117-

0 commit comments

Comments
 (0)