8
8
from gempy_engine .modules .kernel_constructor .kernel_constructor_interface import yield_evaluation_grad_kernel , yield_evaluation_kernel
9
9
10
10
11
- def generic_evaluator (solver_input : SolverInput , weights : np .ndarray , options : InterpolationOptions ) -> ExportedFields :
11
+ def generic_evaluator (
12
+ solver_input : SolverInput ,
13
+ weights : np .ndarray ,
14
+ options : InterpolationOptions
15
+ ) -> ExportedFields :
12
16
grid_size = solver_input .xyz_to_interpolate .shape [0 ]
13
- matrix_size = grid_size * weights .shape [0 ]
14
- scalar_field : np .ndarray = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
17
+ max_op_size = options .evaluation_chunk_size
18
+ num_weights = weights .shape [0 ]
19
+
20
+
21
+ chunk_size_grid = max (1 , int (max_op_size / num_weights )) # Ensure at least 1 point per chunk
22
+ n_chunks = int (np .ceil (grid_size / chunk_size_grid ))
23
+
24
+ # Pre‑allocate outputs
25
+ scalar_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
15
26
gx_field : Optional [np .ndarray ] = None
16
27
gy_field : Optional [np .ndarray ] = None
17
28
gz_field : Optional [np .ndarray ] = None
18
- gradient = options .compute_scalar_gradient
19
-
20
- # * Chunking the evaluation
21
- max_size = options .evaluation_chunk_size
22
- n_chunks = int (np .ceil (matrix_size / max_size ))
23
- chunk_size = int (np .ceil (grid_size / n_chunks ))
24
- for i in range (n_chunks ): # TODO: It seems the chunking is not properly implemented
25
- slice_array = slice (i * chunk_size , (i + 1 ) * chunk_size )
26
- scalar_field_chunk , gx_field_chunk , gy_field_chunk , gz_field_chunk = _eval_on (
29
+ if options .compute_scalar_gradient :
30
+ gx_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
31
+ gy_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
32
+ if options .number_dimensions == 3 :
33
+ gz_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
34
+
35
+ # Chunked evaluation over grid indices
36
+ for i in range (n_chunks ):
37
+
38
+ start = i * chunk_size_grid
39
+ end = min (grid_size , start + chunk_size_grid ) # Ensure 'end' doesn't exceed grid_size
40
+ slice_array = slice (start , end )
41
+
42
+ # Avoid processing empty slices if start == end
43
+ if start >= end :
44
+ continue
45
+
46
+ sf_chunk , gx_chunk , gy_chunk , gz_chunk = _eval_on (
27
47
solver_input = solver_input ,
28
48
weights = weights ,
29
49
options = options ,
30
50
slice_array = slice_array
31
51
)
32
52
33
- scalar_field [slice_array ] = scalar_field_chunk
34
- if gradient is True :
35
- if i == 0 :
36
- gx_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
37
- gy_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
38
- gz_field = BackendTensor .t .zeros (grid_size , dtype = weights .dtype )
39
-
40
- gx_field [slice_array ] = gx_field_chunk
41
- gy_field [slice_array ] = gy_field_chunk
42
- gz_field [slice_array ] = gz_field_chunk
53
+ scalar_field [slice_array ] = sf_chunk
54
+ if options .compute_scalar_gradient :
55
+ gx_field [slice_array ] = gx_chunk # type: ignore
56
+ gy_field [slice_array ] = gy_chunk # type: ignore
57
+ if gz_field is not None :
58
+ gz_field [slice_array ] = gz_chunk # type: ignore
43
59
44
60
if n_chunks > 5 :
45
61
print (f"Chunking done: { n_chunks } chunks" )
46
62
47
63
return ExportedFields (scalar_field , gx_field , gy_field , gz_field )
48
64
49
65
50
- def _eval_on (solver_input , weights , options , slice_array : slice = None ):
51
- eval_kernel = yield_evaluation_kernel (solver_input , options .kernel_options , slice_array = slice_array )
52
- scalar_field : np .ndarray = (eval_kernel .T @ weights ).reshape (- 1 )
53
- scalar_field [- 50 :]
66
+ def _eval_on (
67
+ solver_input : SolverInput ,
68
+ weights : np .ndarray ,
69
+ options : InterpolationOptions ,
70
+ slice_array : slice
71
+ ):
72
+ eval_kernel = yield_evaluation_kernel (
73
+ solver_input , options .kernel_options , slice_array = slice_array
74
+ )
75
+ scalar_field = (eval_kernel .T @ weights ).reshape (- 1 )
76
+
54
77
gx_field : Optional [np .ndarray ] = None
55
78
gy_field : Optional [np .ndarray ] = None
56
79
gz_field : Optional [np .ndarray ] = None
57
- if options .compute_scalar_gradient is True :
58
- eval_gx_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 0 , slice_array = slice_array )
59
- eval_gy_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 1 , slice_array = slice_array )
60
- gx_field = (eval_gx_kernel .T @ weights ).reshape (- 1 )
61
- gy_field = (eval_gy_kernel .T @ weights ).reshape (- 1 )
80
+
81
+ if options .compute_scalar_gradient :
82
+ eval_gx = yield_evaluation_grad_kernel (
83
+ solver_input , options .kernel_options , axis = 0 , slice_array = slice_array
84
+ )
85
+ eval_gy = yield_evaluation_grad_kernel (
86
+ solver_input , options .kernel_options , axis = 1 , slice_array = slice_array
87
+ )
88
+ gx_field = (eval_gx .T @ weights ).reshape (- 1 )
89
+ gy_field = (eval_gy .T @ weights ).reshape (- 1 )
62
90
63
91
if options .number_dimensions == 3 :
64
- eval_gz_kernel = yield_evaluation_grad_kernel (solver_input , options .kernel_options , axis = 2 , slice_array = slice_array )
65
- gz_field = (eval_gz_kernel .T @ weights ).reshape (- 1 )
66
- elif options .number_dimensions == 2 :
67
- gz_field = None
68
- else :
69
- raise ValueError ("Number of dimensions have to be 2 or 3" )
70
- return scalar_field , gx_field , gy_field , gz_field
92
+ eval_gz = yield_evaluation_grad_kernel (
93
+ solver_input , options .kernel_options , axis = 2 , slice_array = slice_array
94
+ )
95
+ gz_field = (eval_gz .T @ weights ).reshape (- 1 )
96
+ elif options .number_dimensions != 2 :
97
+ raise ValueError ("`number_dimensions` must be 2 or 3" )
98
+
99
+ return scalar_field , gx_field , gy_field , gz_field
0 commit comments