Skip to content

Commit 82fa0f8

Browse files
committed
Create the plot_pareto_front_curve() method to plot the pareto front curve
1 parent 67ff07f commit 82fa0f8

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

pygad/visualize/plot.py

+101
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,104 @@ def plot_genes(self,
384384
matplotlib.pyplot.show()
385385

386386
return fig
387+
388+
def plot_pareto_front_curve(self,
389+
title="Pareto Front Curve",
390+
xlabel="Objective 1",
391+
ylabel="Objective 2",
392+
linewidth=3,
393+
font_size=14,
394+
label="Pareto Front",
395+
color="#FF6347",
396+
color_fitness="#4169E1",
397+
grid=True,
398+
alpha=0.7,
399+
marker="o",
400+
save_dir=None):
401+
"""
402+
Creates, shows, and returns the pareto front curve. Can only be used with multi-objective problems.
403+
It only works with 2 objectives.
404+
It also works only after completing at least 1 generation. If no generation is completed, an exception is raised.
405+
406+
Accepts the following:
407+
title: Figure title.
408+
xlabel: Label on the X-axis.
409+
ylabel: Label on the Y-axis.
410+
linewidth: Line width of the plot. Defaults to 3.
411+
font_size: Font size for the labels and title. Defaults to 14.
412+
label: The label used for the legend.
413+
color: Color of the plot.
414+
color_fitness: Color of the fitness points.
415+
grid: Either True or False to control the visibility of the grid.
416+
alpha: The transparency of the pareto front curve.
417+
marker: The marker of the fitness points.
418+
save_dir: Directory to save the figure.
419+
420+
Returns the figure.
421+
"""
422+
423+
if self.generations_completed < 1:
424+
self.logger.error("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.")
425+
raise RuntimeError("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.")
426+
427+
if type(self.best_solutions_fitness[0]) in [list, tuple, numpy.ndarray] and len(self.best_solutions_fitness[0]) > 1:
428+
# Multi-objective optimization problem.
429+
if len(self.best_solutions_fitness[0]) == 2:
430+
# Only 2 objectives. Proceed.
431+
pass
432+
else:
433+
# More than 2 objectives.
434+
self.logger.error(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.")
435+
raise RuntimeError(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.")
436+
else:
437+
# Single-objective optimization problem.
438+
self.logger.error("The plot_pareto_front_curve() method only works with multi-objective optimization problems.")
439+
raise RuntimeError("The plot_pareto_front_curve() method only works with multi-objective optimization problems.")
440+
441+
# Plot the pareto front curve.
442+
remaining_set = list(zip(range(0, self.last_generation_fitness.shape[0]), self.last_generation_fitness))
443+
dominated_set, non_dominated_set = self.get_non_dominated_set(remaining_set)
444+
445+
# Extract the fitness values (objective values) of the non-dominated solutions for plotting.
446+
pareto_front_x = [self.last_generation_fitness[item[0]][0] for item in dominated_set]
447+
pareto_front_y = [self.last_generation_fitness[item[0]][1] for item in dominated_set]
448+
449+
# Sort the Pareto front solutions (optional but can make the plot cleaner)
450+
sorted_pareto_front = sorted(zip(pareto_front_x, pareto_front_y))
451+
452+
# Plotting
453+
fig = matplotlib.pyplot.figure()
454+
# First, plot the scatter of all points (population)
455+
all_points_x = [self.last_generation_fitness[i][0] for i in range(self.sol_per_pop)]
456+
all_points_y = [self.last_generation_fitness[i][1] for i in range(self.sol_per_pop)]
457+
matplotlib.pyplot.scatter(all_points_x,
458+
all_points_y,
459+
marker=marker,
460+
color=color_fitness,
461+
label='Fitness',
462+
alpha=1.0)
463+
464+
# Then, plot the Pareto front as a curve
465+
pareto_front_x_sorted, pareto_front_y_sorted = zip(*sorted_pareto_front)
466+
matplotlib.pyplot.plot(pareto_front_x_sorted,
467+
pareto_front_y_sorted,
468+
marker=marker,
469+
label=label,
470+
alpha=alpha,
471+
color=color,
472+
linewidth=linewidth)
473+
474+
matplotlib.pyplot.title(title, fontsize=font_size)
475+
matplotlib.pyplot.xlabel(xlabel, fontsize=font_size)
476+
matplotlib.pyplot.ylabel(ylabel, fontsize=font_size)
477+
matplotlib.pyplot.legend()
478+
479+
matplotlib.pyplot.grid(grid)
480+
481+
if not save_dir is None:
482+
matplotlib.pyplot.savefig(fname=save_dir,
483+
bbox_inches='tight')
484+
485+
matplotlib.pyplot.show()
486+
487+
return fig

0 commit comments

Comments
 (0)