17
17
18
18
19
19
def test_activator_3_layers_segmentation_function (simple_model_3_layers , simple_grid_3d_more_points_grid ):
20
- interpolation_input = simple_model_3_layers [0 ]
21
- options = simple_model_3_layers [1 ]
22
- data_shape = simple_model_3_layers [2 ].tensors_structure
23
- grid = dataclasses .replace (simple_grid_3d_more_points_grid )
24
- interpolation_input .set_temp_grid (grid )
25
-
26
- ids = np .array ([1 , 2 , 3 , 4 ])
20
+ Z_x , grid , ids_block , interpolation_input = _run_test (
21
+ backend = AvailableBackends .numpy ,
22
+ ids = np .array ([1 , 20 , 3 , 4 ]),
23
+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
24
+ simple_model_3_layers = simple_model_3_layers
25
+ )
27
26
28
- interp_input : SolverInput = input_preprocess ( data_shape , interpolation_input )
29
- weights = _solve_interpolation ( interp_input , options . kernel_options )
27
+ if plot :
28
+ _plot_continious ( grid , ids_block , interpolation_input )
30
29
31
- exported_fields = _evaluate_sys_eq (interp_input , weights , options )
32
- exported_fields .set_structure_values (
33
- reference_sp_position = data_shape .reference_sp_position ,
34
- slice_feature = interpolation_input .slice_feature ,
35
- grid_size = interpolation_input .grid .len_all_grids )
36
30
37
- Z_x : np .ndarray = exported_fields .scalar_field
38
- sasp = exported_fields .scalar_field_at_surface_points
39
- ids = np .array ([1 , 20 , 3 , 4 ])
31
+ def test_activator_3_layers_segmentation_function_II (simple_model_3_layers , simple_grid_3d_more_points_grid ):
32
+ Z_x , grid , ids_block , interpolation_input = _run_test (
33
+ backend = AvailableBackends .numpy ,
34
+ ids = np .array ([1 , 2 , 3 , 4 ]),
35
+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
36
+ simple_model_3_layers = simple_model_3_layers
37
+ )
40
38
41
- print (Z_x , Z_x .shape [0 ])
42
- print (sasp )
39
+ BackendTensor .change_backend_gempy (AvailableBackends .numpy )
43
40
44
- ids_block = activate_formation_block (
45
- exported_fields = exported_fields ,
46
- ids = ids ,
47
- sigmoid_slope = 500 * 4
48
- )[0 , :- 7 ]
41
+ if plot :
42
+ _plot_continious (grid , ids_block , interpolation_input )
49
43
50
- if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
51
- ids_block = ids_block .detach ().numpy ()
52
- Z_x = Z_x .detach ().numpy ()
53
- interpolation_input .surface_points .sp_coords = interpolation_input .surface_points .sp_coords .detach ().numpy ()
54
44
45
+ def test_activator_3_layers_segmentation_function_torch (simple_model_3_layers , simple_grid_3d_more_points_grid ):
46
+ Z_x , grid , ids_block , interpolation_input = _run_test (
47
+ backend = AvailableBackends .PYTORCH ,
48
+ ids = np .array ([1 , 2 , 3 , 4 ]),
49
+ simple_grid_3d_more_points_grid = simple_grid_3d_more_points_grid ,
50
+ simple_model_3_layers = simple_model_3_layers
51
+ )
55
52
if plot :
56
53
_plot_continious (grid , ids_block , interpolation_input )
57
54
58
55
59
- def test_activator_3_layers_segmentation_function_II ( simple_model_3_layers , simple_grid_3d_more_points_grid ):
56
+ def _run_test ( backend , ids , simple_grid_3d_more_points_grid , simple_model_3_layers ):
60
57
interpolation_input = simple_model_3_layers [0 ]
61
58
options = simple_model_3_layers [1 ]
62
59
data_shape = simple_model_3_layers [2 ].tensors_structure
63
60
grid = dataclasses .replace (simple_grid_3d_more_points_grid )
64
61
interpolation_input .set_temp_grid (grid )
65
-
66
62
interp_input : SolverInput = input_preprocess (data_shape , interpolation_input )
67
63
weights = _solve_interpolation (interp_input , options .kernel_options )
68
-
69
64
exported_fields = _evaluate_sys_eq (interp_input , weights , options )
70
65
exported_fields .set_structure_values (
71
66
reference_sp_position = data_shape .reference_sp_position ,
72
67
slice_feature = interpolation_input .slice_feature ,
73
68
grid_size = interpolation_input .grid .len_all_grids )
74
-
75
69
Z_x : np .ndarray = exported_fields .scalar_field
76
70
sasp = exported_fields .scalar_field_at_surface_points
77
- ids = np .array ([1 , 2 , 3 , 4 ])
78
-
79
71
print (Z_x , Z_x .shape [0 ])
80
72
print (sasp )
81
-
82
- BackendTensor .change_backend_gempy (AvailableBackends .numpy )
73
+ BackendTensor .change_backend_gempy (backend )
83
74
ids_block = activate_formation_block (
84
75
exported_fields = exported_fields ,
85
76
ids = ids ,
86
77
sigmoid_slope = 500 * 4
87
78
)[0 , :- 7 ]
88
-
89
- BackendTensor .change_backend_gempy (AvailableBackends .numpy )
90
- if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
91
- ids_block = ids_block .detach ().numpy ()
92
- Z_x = Z_x .detach ().numpy ()
93
- interpolation_input .surface_points .sp_coords = interpolation_input .surface_points .sp_coords .detach ().numpy ()
94
-
95
- if plot :
96
- _plot_continious (grid , ids_block , interpolation_input )
79
+ return Z_x , grid , ids_block , interpolation_input
97
80
98
81
99
82
def _plot_continious (grid , ids_block , interpolation_input ):
@@ -113,5 +96,3 @@ def _plot_continious(grid, ids_block, interpolation_input):
113
96
plt .plot (xyz [:, 0 ], xyz [:, 2 ], "o" )
114
97
plt .colorbar ()
115
98
plt .show ()
116
-
117
-
0 commit comments