Skip to content

Commit a3dc00b

Browse files
committed
Support of NSGA-II
1 parent 6aeb685 commit a3dc00b

File tree

1 file changed

+78
-21
lines changed

1 file changed

+78
-21
lines changed

pygad/visualize/plot.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy
66
import warnings
77
import matplotlib.pyplot
8+
import pygad
89

910
class Plot:
1011
def plot_result(self,
@@ -14,7 +15,7 @@ def plot_result(self,
1415
linewidth=3,
1516
font_size=14,
1617
plot_type="plot",
17-
color="#3870FF",
18+
color="#64f20c",
1819
save_dir=None):
1920

2021
if not self.suppress_warnings:
@@ -30,14 +31,15 @@ def plot_result(self,
3031
save_dir=save_dir)
3132

3233
def plot_fitness(self,
33-
title="PyGAD - Generation vs. Fitness",
34-
xlabel="Generation",
35-
ylabel="Fitness",
36-
linewidth=3,
37-
font_size=14,
38-
plot_type="plot",
39-
color="#3870FF",
40-
save_dir=None):
34+
title="PyGAD - Generation vs. Fitness",
35+
xlabel="Generation",
36+
ylabel="Fitness",
37+
linewidth=3,
38+
font_size=14,
39+
plot_type="plot",
40+
color="#64f20c",
41+
label=None,
42+
save_dir=None):
4143

4244
"""
4345
Creates, shows, and returns a figure that summarizes how the fitness value evolved by generation. Can only be called after completing at least 1 generation. If no generation is completed, an exception is raised.
@@ -47,9 +49,10 @@ def plot_fitness(self,
4749
xlabel: Label on the X-axis.
4850
ylabel: Label on the Y-axis.
4951
linewidth: Line width of the plot. Defaults to 3.
50-
font_size: Font size for the labels and title. Defaults to 14.
52+
font_size: Font size for the labels and title. Defaults to 14. Can be a list/tuple/numpy.ndarray if the problem is multi-objective optimization.
5153
plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar".
52-
color: Color of the plot which defaults to "#3870FF".
54+
color: Color of the plot which defaults to "#64f20c". Can be a list/tuple/numpy.ndarray if the problem is multi-objective optimization.
55+
label: The label used for the legend in the figures of multi-objective problems. It is not used for single-objective problems.
5356
save_dir: Directory to save the figure.
5457
5558
Returns the figure.
@@ -60,15 +63,69 @@ def plot_fitness(self,
6063
raise RuntimeError("The plot_fitness() (i.e. plot_result()) method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.")
6164

6265
fig = matplotlib.pyplot.figure()
63-
if plot_type == "plot":
64-
matplotlib.pyplot.plot(self.best_solutions_fitness, linewidth=linewidth, color=color)
65-
elif plot_type == "scatter":
66-
matplotlib.pyplot.scatter(range(len(self.best_solutions_fitness)), self.best_solutions_fitness, linewidth=linewidth, color=color)
67-
elif plot_type == "bar":
68-
matplotlib.pyplot.bar(range(len(self.best_solutions_fitness)), self.best_solutions_fitness, linewidth=linewidth, color=color)
66+
if len(self.best_solutions_fitness[0]) > 1:
67+
# Multi-objective optimization problem.
68+
if type(linewidth) in pygad.GA.supported_int_float_types:
69+
linewidth = [linewidth]
70+
linewidth.extend([linewidth[0]]*len(self.best_solutions_fitness[0]))
71+
elif type(linewidth) in [list, tuple, numpy.ndarray]:
72+
pass
73+
74+
if type(color) is str:
75+
color = [color]
76+
color.extend([None]*len(self.best_solutions_fitness[0]))
77+
elif type(color) in [list, tuple, numpy.ndarray]:
78+
pass
79+
80+
if label is None:
81+
label = [None]*len(self.best_solutions_fitness[0])
82+
83+
# Loop through each objective to plot its fitness.
84+
for objective_idx in range(len(self.best_solutions_fitness[0])):
85+
# Return the color, line width, and label of the current plot.
86+
current_color = color[objective_idx]
87+
current_linewidth = linewidth[objective_idx]
88+
current_label = label[objective_idx]
89+
# Return the fitness values for the current objective function across all best solutions acorss all generations.
90+
fitness = numpy.array(self.best_solutions_fitness)[:, objective_idx]
91+
if plot_type == "plot":
92+
matplotlib.pyplot.plot(fitness,
93+
linewidth=current_linewidth,
94+
color=current_color,
95+
label=current_label)
96+
elif plot_type == "scatter":
97+
matplotlib.pyplot.scatter(range(len(fitness)),
98+
fitness,
99+
linewidth=current_linewidth,
100+
color=current_color,
101+
label=current_label)
102+
elif plot_type == "bar":
103+
matplotlib.pyplot.bar(range(len(fitness)),
104+
fitness,
105+
linewidth=current_linewidth,
106+
color=current_color,
107+
label=current_label)
108+
else:
109+
# Single-objective optimization problem.
110+
if plot_type == "plot":
111+
matplotlib.pyplot.plot(self.best_solutions_fitness,
112+
linewidth=linewidth,
113+
color=color)
114+
elif plot_type == "scatter":
115+
matplotlib.pyplot.scatter(range(len(self.best_solutions_fitness)),
116+
self.best_solutions_fitness,
117+
linewidth=linewidth,
118+
color=color)
119+
elif plot_type == "bar":
120+
matplotlib.pyplot.bar(range(len(self.best_solutions_fitness)),
121+
self.best_solutions_fitness,
122+
linewidth=linewidth,
123+
color=color)
69124
matplotlib.pyplot.title(title, fontsize=font_size)
70125
matplotlib.pyplot.xlabel(xlabel, fontsize=font_size)
71126
matplotlib.pyplot.ylabel(ylabel, fontsize=font_size)
127+
# Create a legend out of the labels.
128+
matplotlib.pyplot.legend()
72129

73130
if not save_dir is None:
74131
matplotlib.pyplot.savefig(fname=save_dir,
@@ -84,7 +141,7 @@ def plot_new_solution_rate(self,
84141
linewidth=3,
85142
font_size=14,
86143
plot_type="plot",
87-
color="#3870FF",
144+
color="#64f20c",
88145
save_dir=None):
89146

90147
"""
@@ -97,7 +154,7 @@ def plot_new_solution_rate(self,
97154
linewidth: Line width of the plot. Defaults to 3.
98155
font_size: Font size for the labels and title. Defaults to 14.
99156
plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar".
100-
color: Color of the plot which defaults to "#3870FF".
157+
color: Color of the plot which defaults to "#64f20c".
101158
save_dir: Directory to save the figure.
102159
103160
Returns the figure.
@@ -154,7 +211,7 @@ def plot_genes(self,
154211
font_size=14,
155212
plot_type="plot",
156213
graph_type="plot",
157-
fill_color="#3870FF",
214+
fill_color="#64f20c",
158215
color="black",
159216
solutions="all",
160217
save_dir=None):
@@ -172,7 +229,7 @@ def plot_genes(self,
172229
font_size: Font size for the labels and title. Defaults to 14.
173230
plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar".
174231
graph_type: Type of the graph which can be either "plot" (default), "boxplot", or "histogram".
175-
fill_color: Fill color of the graph which defaults to "#3870FF". This has no effect if graph_type="plot".
232+
fill_color: Fill color of the graph which defaults to "#64f20c". This has no effect if graph_type="plot".
176233
color: Color of the plot which defaults to "black".
177234
solutions: Defaults to "all" which means use all solutions. If "best" then only the best solutions are used.
178235
save_dir: Directory to save the figure.

0 commit comments

Comments
 (0)