Skip to content

Commit 2c90b63

Browse files
committed
Increase evaluation chunk size and consider input size in chunking computations. Previously, the chunking was solely based on the grid size. Now, both the grid size and the size of the input data (weights) are taken into consideration when determining the number of chunks and the size of each chunk. This change improves the efficiency of the program and avoids potential memory issues with large input data.
1 parent ba8bc14 commit 2c90b63

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

gempy_engine/core/data/options/evaluation_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class EvaluationOptions:
2222
mesh_extraction_masking_options: MeshExtractionMaskingOptions = MeshExtractionMaskingOptions.INTERSECT
2323
mesh_extraction_fancy: bool = True
2424

25-
evaluation_chunk_size: int = 50_000
25+
evaluation_chunk_size: int = 5_000_000
2626

2727
compute_scalar_gradient: bool = False
2828

gempy_engine/modules/evaluator/generic_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
def generic_evaluator(solver_input: SolverInput, weights: np.ndarray, options: InterpolationOptions) -> ExportedFields:
1212
grid_size = solver_input.xyz_to_interpolate.shape[0]
13+
matrix_size = grid_size * weights.shape[0]
1314
scalar_field: np.ndarray = BackendTensor.t.zeros(grid_size, dtype=weights.dtype)
1415
gx_field: Optional[np.ndarray] = None
1516
gy_field: Optional[np.ndarray] = None
@@ -18,8 +19,8 @@ def generic_evaluator(solver_input: SolverInput, weights: np.ndarray, options: I
1819

1920
# * Chunking the evaluation
2021
max_size = options.evaluation_chunk_size
21-
n_chunks = int(np.ceil(grid_size / max_size))
22-
chunk_size = int(np.ceil(grid_size / n_chunks))
22+
n_chunks = int(np.ceil(matrix_size / max_size))
23+
chunk_size = int(np.ceil(matrix_size / n_chunks))
2324
for i in range(n_chunks):
2425
slice_array = slice(i * chunk_size, (i + 1) * chunk_size)
2526
scalar_field_chunk, gx_field_chunk, gy_field_chunk, gz_field_chunk = _eval_on(

0 commit comments

Comments
 (0)