diff --git a/rocketpy/plots/monte_carlo_plots.py b/rocketpy/plots/monte_carlo_plots.py index cfb865c5f..c70e53b88 100644 --- a/rocketpy/plots/monte_carlo_plots.py +++ b/rocketpy/plots/monte_carlo_plots.py @@ -1,7 +1,10 @@ +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np from ..tools import generate_monte_carlo_ellipses, import_optional_dependency +from .plot_helpers import show_or_save_plot class _MonteCarloPlots: @@ -147,7 +150,7 @@ def ellipses( else: plt.show() - def all(self, keys=None): + def all(self, keys=None, *, filename=None): """ Plot the histograms of the Monte Carlo simulation results. @@ -156,6 +159,13 @@ def all(self, keys=None): keys : str, list or tuple, optional The keys of the results to be plotted. If None, all results will be plotted. Default is None. + filename : str | None, optional + The path the plot should be saved to, by default None. If provided, + the plot will be saved instead of displayed. When multiple plots are + generated (one per key), the key name will be appended to the filename. + Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, + rgba, svg, svgz, tif, tiff and webp (these are the formats supported + by matplotlib). Returns ------- @@ -173,6 +183,7 @@ def all(self, keys=None): ) else: raise ValueError("The 'keys' argument must be a string, list, or tuple.") + for key in keys: # Create figure with GridSpec fig = plt.figure(figsize=(8, 8)) @@ -194,7 +205,16 @@ def all(self, keys=None): ax1.set_xticks([]) plt.tight_layout() - plt.show() + + # Handle filename for multiple plots + if filename is not None: + # For multiple keys, append the key name to the filename + filepath = Path(filename) + # Use the full key name to avoid collisions between x_impact and y_impact + key_filename = filepath.parent / f"{filepath.stem}_{key}{filepath.suffix}" + show_or_save_plot(str(key_filename)) + else: + show_or_save_plot(filename) def plot_comparison(self, other_monte_carlo): """ diff --git a/tests/unit/test_monte_carlo.py b/tests/unit/test_monte_carlo.py index f168b8bfe..4e908e6e1 100644 --- a/tests/unit/test_monte_carlo.py +++ b/tests/unit/test_monte_carlo.py @@ -24,6 +24,33 @@ def test_stochastic_environment_create_object_with_wind_x(stochastic_environment # TODO: add a new test for the special case of ensemble member +def test_monte_carlo_plots_all_with_filename(monte_carlo_calisto_pre_loaded, tmp_path): + """Tests the all method of the MonteCarlo plots with filename parameter. + + Parameters + ---------- + monte_carlo_calisto_pre_loaded : MonteCarlo + A MonteCarlo object with pre-loaded results, this is a pytest fixture. + tmp_path : Path + Temporary directory path for saving test files. + """ + # Test without filename (should work as before) + result = monte_carlo_calisto_pre_loaded.plots.all() + assert result is None + + # Test with filename - save to temporary directory + filename = tmp_path / "test_monte_carlo_plot.png" + result = monte_carlo_calisto_pre_loaded.plots.all(filename=str(filename)) + assert result is None + + # Test with specific keys and filename + filename_apogee = tmp_path / "test_apogee_plot.png" + result = monte_carlo_calisto_pre_loaded.plots.all( + keys="apogee", filename=str(filename_apogee) + ) + assert result is None + + def test_stochastic_solid_motor_create_object_with_impulse(stochastic_solid_motor): """Tests the stochastic solid motor object by checking if the total impulse can be generated properly. The goal is to check if the create_object()