Skip to content

Commit 84c82ae

Browse files
committed
[TEST] Reviving activation test for multiple layers
1 parent 88cf224 commit 84c82ae

File tree

3 files changed

+76
-8
lines changed

3 files changed

+76
-8
lines changed

gempy_engine/modules/activator/activator_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slo
3333
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
3434

3535
for i in range(len(ids)):
36-
if LEGACY:=False:
36+
if LEGACY:=True:
3737
sigm += _compute_sigmoid(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
3838
else:
3939
sigm += HardSigmoid.apply(Z_x, scalar_0_v[i], scalar_1_v[i])

tests/fixtures/simple_models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,40 @@ def simple_model_values_block_output(simple_model, simple_grid_3d_more_points_gr
327327

328328
return output
329329

330+
@pytest.fixture(scope="session")
331+
def simple_model_3_layers_output(simple_model_3_layers):
332+
interporlation_input = simple_model_3_layers[0]
333+
options = simple_model_3_layers[1]
334+
data_shape = simple_model_3_layers[2].tensors_structure
335+
ids = np.array([1, 2, 3, 4])
336+
337+
interp_input: SolverInput = input_preprocess(data_shape, interporlation_input)
338+
weights = _solve_interpolation(interp_input, options.kernel_options)
339+
340+
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
341+
342+
exported_fields.set_structure_values(
343+
reference_sp_position=data_shape.reference_sp_position,
344+
slice_feature=interporlation_input.slice_feature,
345+
grid_size=interporlation_input.grid.len_all_grids)
346+
347+
# -----------------
348+
# Export and Masking operations can happen even in parallel
349+
# TODO: [~X] Export block
350+
values_block: np.ndarray = activate_formation_block(exported_fields, ids, sigmoid_slope=50000)
351+
352+
output = InterpOutput(
353+
ScalarFieldOutput(
354+
weights=weights,
355+
grid=interporlation_input.grid,
356+
exported_fields=exported_fields,
357+
values_block=values_block,
358+
stack_relation=interporlation_input.stack_relation,
359+
),
360+
)
361+
362+
return output
363+
330364

331365
@pytest.fixture(scope="session")
332366
def unconformity_complex():

tests/test_common/test_modules/test_activator.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import dataclasses
12
import os
23
import pytest
34
import matplotlib.pyplot as plt
45
import numpy as np
56

7+
from gempy_engine.API.interp_single._interp_scalar_field import _solve_interpolation, _evaluate_sys_eq
8+
from gempy_engine.API.interp_single._interp_single_feature import input_preprocess
9+
from gempy_engine.core.data.internal_structs import SolverInput
610
from gempy_engine.core.data.interp_output import InterpOutput
711
from gempy_engine.modules.activator.activator_interface import activate_formation_block
812
from gempy_engine.API.interp_single.interp_features import interpolate_single_field
@@ -35,21 +39,51 @@ def test_activator(simple_model_values_block_output):
3539
plt.show()
3640

3741

38-
@pytest.mark.skip(reason="This is unfinished I have to extract the 3 layers values")
39-
def test_activator_3_layers(simple_model_3_layers):
40-
interpolation_input, options, structure = simple_model_3_layers
42+
# @pytest.mark.skip(reason="This is unfinished I have to extract the 3 layers values")
43+
def test_activator_3_layers(simple_model_3_layers, simple_grid_3d_more_points_grid):
44+
interpolation_input = simple_model_3_layers[0]
45+
options = simple_model_3_layers[1]
46+
data_shape = simple_model_3_layers[2].tensors_structure
47+
grid = dataclasses.replace(simple_grid_3d_more_points_grid)
48+
interpolation_input.grid = grid
49+
50+
ids = np.array([1, 2, 3, 4])
4151

42-
res = interpolation_input.grid.regular_grid.resolution
52+
interp_input: SolverInput = input_preprocess(data_shape, interpolation_input)
53+
weights = _solve_interpolation(interp_input, options.kernel_options)
4354

44-
output: InterpOutput = interpolate_single_field(interpolation_input, options, structure)
45-
Z_x = output.exported_fields.scalar_field
55+
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
56+
57+
exported_fields.set_structure_values(
58+
reference_sp_position=data_shape.reference_sp_position,
59+
slice_feature=interpolation_input.slice_feature,
60+
grid_size=interpolation_input.grid.len_all_grids)
61+
62+
Z_x: np.ndarray = exported_fields.scalar_field
63+
sasp = exported_fields.scalar_field_at_surface_points
64+
ids = np.array([1, 2, 3, 4])
65+
66+
print(Z_x, Z_x.shape[0])
67+
print(sasp)
68+
69+
70+
ids_block = activate_formation_block(
71+
exported_fields=exported_fields,
72+
ids= ids,
73+
sigmoid_slope=50000
74+
)[:, :-7]
4675

4776
if plot:
48-
plt.contourf(Z_x.reshape(res)[:, 0, :].T, N=40, cmap="autumn",
77+
plt.contourf(Z_x.reshape(50, 5, 50)[:, 0, :].T, N=40, cmap="autumn",
4978
extent=(.25, .75, .25, .75))
5079

5180
xyz = interpolation_input.surface_points.sp_coords
5281
plt.plot(xyz[:, 0], xyz[:, 2], "o")
5382
plt.colorbar()
5483

5584
plt.show()
85+
86+
plt.contourf(ids_block[0, :-4].reshape(50, 5, 50)[:, 2, :].T, N=40, cmap="viridis")
87+
plt.colorbar()
88+
89+
plt.show()

0 commit comments

Comments
 (0)