From 82fa0f8e25cb2cc05cc047c2914f0b993257848e Mon Sep 17 00:00:00 2001 From: Ahmed Gad Date: Tue, 7 Jan 2025 12:37:53 -0500 Subject: [PATCH] Create the plot_pareto_front_curve() method to plot the pareto front curve --- pygad/visualize/plot.py | 101 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/pygad/visualize/plot.py b/pygad/visualize/plot.py index 7dffc4b..3ddeaff 100644 --- a/pygad/visualize/plot.py +++ b/pygad/visualize/plot.py @@ -384,3 +384,104 @@ def plot_genes(self, matplotlib.pyplot.show() return fig + + def plot_pareto_front_curve(self, + title="Pareto Front Curve", + xlabel="Objective 1", + ylabel="Objective 2", + linewidth=3, + font_size=14, + label="Pareto Front", + color="#FF6347", + color_fitness="#4169E1", + grid=True, + alpha=0.7, + marker="o", + save_dir=None): + """ + Creates, shows, and returns the pareto front curve. Can only be used with multi-objective problems. + It only works with 2 objectives. + It also works only after completing at least 1 generation. If no generation is completed, an exception is raised. + + Accepts the following: + title: Figure title. + xlabel: Label on the X-axis. + ylabel: Label on the Y-axis. + linewidth: Line width of the plot. Defaults to 3. + font_size: Font size for the labels and title. Defaults to 14. + label: The label used for the legend. + color: Color of the plot. + color_fitness: Color of the fitness points. + grid: Either True or False to control the visibility of the grid. + alpha: The transparency of the pareto front curve. + marker: The marker of the fitness points. + save_dir: Directory to save the figure. + + Returns the figure. + """ + + if self.generations_completed < 1: + self.logger.error("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.") + raise RuntimeError("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.") + + if type(self.best_solutions_fitness[0]) in [list, tuple, numpy.ndarray] and len(self.best_solutions_fitness[0]) > 1: + # Multi-objective optimization problem. + if len(self.best_solutions_fitness[0]) == 2: + # Only 2 objectives. Proceed. + pass + else: + # More than 2 objectives. + self.logger.error(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.") + raise RuntimeError(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.") + else: + # Single-objective optimization problem. + self.logger.error("The plot_pareto_front_curve() method only works with multi-objective optimization problems.") + raise RuntimeError("The plot_pareto_front_curve() method only works with multi-objective optimization problems.") + + # Plot the pareto front curve. + remaining_set = list(zip(range(0, self.last_generation_fitness.shape[0]), self.last_generation_fitness)) + dominated_set, non_dominated_set = self.get_non_dominated_set(remaining_set) + + # Extract the fitness values (objective values) of the non-dominated solutions for plotting. + pareto_front_x = [self.last_generation_fitness[item[0]][0] for item in dominated_set] + pareto_front_y = [self.last_generation_fitness[item[0]][1] for item in dominated_set] + + # Sort the Pareto front solutions (optional but can make the plot cleaner) + sorted_pareto_front = sorted(zip(pareto_front_x, pareto_front_y)) + + # Plotting + fig = matplotlib.pyplot.figure() + # First, plot the scatter of all points (population) + all_points_x = [self.last_generation_fitness[i][0] for i in range(self.sol_per_pop)] + all_points_y = [self.last_generation_fitness[i][1] for i in range(self.sol_per_pop)] + matplotlib.pyplot.scatter(all_points_x, + all_points_y, + marker=marker, + color=color_fitness, + label='Fitness', + alpha=1.0) + + # Then, plot the Pareto front as a curve + pareto_front_x_sorted, pareto_front_y_sorted = zip(*sorted_pareto_front) + matplotlib.pyplot.plot(pareto_front_x_sorted, + pareto_front_y_sorted, + marker=marker, + label=label, + alpha=alpha, + color=color, + linewidth=linewidth) + + matplotlib.pyplot.title(title, fontsize=font_size) + matplotlib.pyplot.xlabel(xlabel, fontsize=font_size) + matplotlib.pyplot.ylabel(ylabel, fontsize=font_size) + matplotlib.pyplot.legend() + + matplotlib.pyplot.grid(grid) + + if not save_dir is None: + matplotlib.pyplot.savefig(fname=save_dir, + bbox_inches='tight') + + matplotlib.pyplot.show() + + return fig