Skip to content

Commit 393e204

Browse files
committed
[ENH] Introduce GemPyPyroModel type for enhanced type clarity
Refactored the code to define and utilize the `GemPyPyroModel` type alias for improved type clarity across the codebase. Updated method signatures and relevant imports to align with the new type definition.
1 parent ca4f4fb commit 393e204

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

gempy_probability/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model
1+
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model, GemPyPyroModel
22
from .modules import likelihoods
33

44
from ._version import __version__

gempy_probability/modules/model_definition/prob_model_factory.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import pyro
22
import torch
3-
from typing import Callable, Dict
43
from pyro.distributions import Distribution
4+
from typing import Callable, Dict
55

6-
from gempy_probability.modules.forwards import run_gempy_forward
76
import gempy as gp
7+
from gempy_probability.modules.forwards import run_gempy_forward
8+
9+
GemPyPyroModel = Callable[[gp.data.GeoModel, torch.Tensor], None]
810

911

1012
def make_gempy_pyro_model(
13+
*,
1114
priors: Dict[str, Distribution],
1215
set_interp_input_fn: Callable[
1316
[Dict[str, torch.Tensor], gp.data.GeoModel],
1417
gp.data.InterpolationInput
1518
],
1619
likelihood_fn: Callable[[gp.data.Solutions], Distribution],
1720
obs_name: str = "obs"
18-
) -> Callable[[gp.data.GeoModel, torch.Tensor], None]:
21+
) -> GemPyPyroModel:
1922
"""
2023
Factory to produce a Pyro model for GemPy forward simulations.
2124

tests/test_prob_model/test_prob_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_prob_model_factory() -> None:
5959
)
6060
}
6161

62-
pyro_gravity_model = gpp.make_gempy_pyro_model(
62+
pyro_gravity_model: gpp.GemPyPyroModel = gpp.make_gempy_pyro_model(
6363
priors=model_priors,
6464
set_interp_input_fn=modify_z_for_surface_point1,
6565
likelihood_fn=gpp.likelihoods.thickness_likelihood,
@@ -91,7 +91,7 @@ def modify_z_for_surface_point1(
9191
return interp_input
9292

9393

94-
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
94+
def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
9595
y_obs_list: torch.Tensor) -> None:
9696
# Run prior sampling and visualization
9797
from pyro.infer import Predictive

0 commit comments

Comments
 (0)