Skip to content

Commit d93acd7

Browse files
committed
[HOTFIX!] Previous chunking fix was breaking the interpolation
1 parent bd6354d commit d93acd7

File tree

1 file changed

+67
-38
lines changed

1 file changed

+67
-38
lines changed

gempy_engine/modules/evaluator/generic_evaluator.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,92 @@
88
from gempy_engine.modules.kernel_constructor.kernel_constructor_interface import yield_evaluation_grad_kernel, yield_evaluation_kernel
99

1010

11-
def generic_evaluator(solver_input: SolverInput, weights: np.ndarray, options: InterpolationOptions) -> ExportedFields:
11+
def generic_evaluator(
12+
solver_input: SolverInput,
13+
weights: np.ndarray,
14+
options: InterpolationOptions
15+
) -> ExportedFields:
1216
grid_size = solver_input.xyz_to_interpolate.shape[0]
13-
matrix_size = grid_size * weights.shape[0]
14-
scalar_field: np.ndarray = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
17+
max_op_size = options.evaluation_chunk_size
18+
num_weights = weights.shape[0]
19+
20+
21+
chunk_size_grid = max(1, int(max_op_size / num_weights)) # Ensure at least 1 point per chunk
22+
n_chunks = int(np.ceil(grid_size / chunk_size_grid))
23+
24+
# Pre‑allocate outputs
25+
scalar_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
1526
gx_field: Optional[np.ndarray] = None
1627
gy_field: Optional[np.ndarray] = None
1728
gz_field: Optional[np.ndarray] = None
18-
gradient = options.compute_scalar_gradient
19-
20-
# * Chunking the evaluation
21-
max_size = options.evaluation_chunk_size
22-
n_chunks = int(np.ceil(matrix_size / max_size))
23-
chunk_size = int(np.ceil(grid_size / n_chunks))
24-
for i in range(n_chunks): # TODO: It seems the chunking is not properly implemented
25-
slice_array = slice(i * chunk_size, (i + 1) * chunk_size)
26-
scalar_field_chunk, gx_field_chunk, gy_field_chunk, gz_field_chunk = _eval_on(
29+
if options.compute_scalar_gradient:
30+
gx_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
31+
gy_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
32+
if options.number_dimensions == 3:
33+
gz_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
34+
35+
# Chunked evaluation over grid indices
36+
for i in range(n_chunks):
37+
38+
start = i * chunk_size_grid
39+
end = min(grid_size, start + chunk_size_grid) # Ensure 'end' doesn't exceed grid_size
40+
slice_array = slice(start, end)
41+
42+
# Avoid processing empty slices if start == end
43+
if start >= end:
44+
continue
45+
46+
sf_chunk, gx_chunk, gy_chunk, gz_chunk = _eval_on(
2747
solver_input=solver_input,
2848
weights=weights,
2949
options=options,
3050
slice_array=slice_array
3151
)
3252

33-
scalar_field[slice_array] = scalar_field_chunk
34-
if gradient is True:
35-
if i == 0:
36-
gx_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
37-
gy_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
38-
gz_field = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
39-
40-
gx_field[slice_array] = gx_field_chunk
41-
gy_field[slice_array] = gy_field_chunk
42-
gz_field[slice_array] = gz_field_chunk
53+
scalar_field[slice_array] = sf_chunk
54+
if options.compute_scalar_gradient:
55+
gx_field[slice_array] = gx_chunk # type: ignore
56+
gy_field[slice_array] = gy_chunk # type: ignore
57+
if gz_field is not None:
58+
gz_field[slice_array] = gz_chunk # type: ignore
4359

4460
if n_chunks > 5:
4561
print(f"Chunking done: {n_chunks} chunks")
4662

4763
return ExportedFields(scalar_field, gx_field, gy_field, gz_field)
4864

4965

50-
def _eval_on(solver_input, weights, options, slice_array: slice = None):
51-
eval_kernel = yield_evaluation_kernel(solver_input, options.kernel_options, slice_array=slice_array)
52-
scalar_field: np.ndarray = (eval_kernel.T @ weights).reshape(-1)
53-
scalar_field[-50:]
66+
def _eval_on(
67+
solver_input: SolverInput,
68+
weights: np.ndarray,
69+
options: InterpolationOptions,
70+
slice_array: slice
71+
):
72+
eval_kernel = yield_evaluation_kernel(
73+
solver_input, options.kernel_options, slice_array=slice_array
74+
)
75+
scalar_field = (eval_kernel.T @ weights).reshape(-1)
76+
5477
gx_field: Optional[np.ndarray] = None
5578
gy_field: Optional[np.ndarray] = None
5679
gz_field: Optional[np.ndarray] = None
57-
if options.compute_scalar_gradient is True:
58-
eval_gx_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=0, slice_array=slice_array)
59-
eval_gy_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=1, slice_array=slice_array)
60-
gx_field = (eval_gx_kernel.T @ weights).reshape(-1)
61-
gy_field = (eval_gy_kernel.T @ weights).reshape(-1)
80+
81+
if options.compute_scalar_gradient:
82+
eval_gx = yield_evaluation_grad_kernel(
83+
solver_input, options.kernel_options, axis=0, slice_array=slice_array
84+
)
85+
eval_gy = yield_evaluation_grad_kernel(
86+
solver_input, options.kernel_options, axis=1, slice_array=slice_array
87+
)
88+
gx_field = (eval_gx.T @ weights).reshape(-1)
89+
gy_field = (eval_gy.T @ weights).reshape(-1)
6290

6391
if options.number_dimensions == 3:
64-
eval_gz_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=2, slice_array=slice_array)
65-
gz_field = (eval_gz_kernel.T @ weights).reshape(-1)
66-
elif options.number_dimensions == 2:
67-
gz_field = None
68-
else:
69-
raise ValueError("Number of dimensions have to be 2 or 3")
70-
return scalar_field, gx_field, gy_field, gz_field
92+
eval_gz = yield_evaluation_grad_kernel(
93+
solver_input, options.kernel_options, axis=2, slice_array=slice_array
94+
)
95+
gz_field = (eval_gz.T @ weights).reshape(-1)
96+
elif options.number_dimensions != 2:
97+
raise ValueError("`number_dimensions` must be 2 or 3")
98+
99+
return scalar_field, gx_field, gy_field, gz_field

0 commit comments

Comments
 (0)