Skip to content

Commit ca4f4fb

Browse files
committed
[BUG] Move backend setting to forward function for consistency
Relocated `change_backend_gempy` call to `run_gempy_forward` to ensure proper backend configuration during execution. Removed redundant imports to clean up the module.
1 parent 03450f4 commit ca4f4fb

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

gempy_probability/modules/forwards/_gempy_fw.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import torch
2-
from pyro.distributions import Distribution
3-
4-
import gempy_engine
5-
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
61
import gempy as gp
2+
import gempy_engine
3+
from gempy_engine.core.backend_tensor import BackendTensor
74
from gempy_engine.core.data.interpolation_input import InterpolationInput
85

96

107
def run_gempy_forward(interp_input: InterpolationInput, geo_model: gp.data.GeoModel) -> gp.data.Solutions:
11-
8+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
129
# compute model
1310
sol = gempy_engine.compute_model(
1411
interpolation_input=interp_input,

tests/test_prob_model/test_prob_factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def test_prob_model_factory() -> None:
5050
gempy_model=geo_model,
5151
validate_serialization=False
5252
)
53-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
5453

5554
# 1) Define your priors
5655
model_priors = {

0 commit comments

Comments
 (0)