Skip to content

Commit c6949e1

Browse files
committed
Refactor code to remove duplicate sections
1 parent f492bb3 commit c6949e1

File tree

1 file changed

+140
-72
lines changed

1 file changed

+140
-72
lines changed

pygad/helper/unique.py

+140-72
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def solve_duplicate_genes_randomly(self,
2525
max_val (int): The maximum value of the range to sample a number randomly.
2626
mutation_by_replacement (bool): Indicates if mutation is performed by replacement.
2727
gene_type (type): The data type of the gene (e.g., int, float).
28-
num_trials (int): The maximum number of attempts to resolve duplicates by changing the gene values.
28+
num_trials (int): The maximum number of attempts to resolve duplicates by changing the gene values. Only works for floating-point gene types.
2929
3030
Returns:
3131
tuple:
@@ -42,53 +42,48 @@ def solve_duplicate_genes_randomly(self,
4242
num_unsolved_duplicates = 0
4343
if len(not_unique_indices) > 0:
4444
for duplicate_index in not_unique_indices:
45-
for trial_index in range(num_trials):
46-
if self.gene_type_single == True:
47-
dtype = gene_type
48-
else:
49-
dtype = gene_type[duplicate_index]
50-
51-
if dtype[0] in pygad.GA.supported_int_types:
52-
temp_val = self.unique_int_gene_from_range(solution=new_solution,
53-
gene_index=duplicate_index,
54-
min_val=min_val,
55-
max_val=max_val,
56-
mutation_by_replacement=mutation_by_replacement,
57-
gene_type=gene_type)
58-
else:
59-
temp_val = numpy.random.uniform(low=min_val,
60-
high=max_val,
61-
size=1)[0]
62-
if mutation_by_replacement:
45+
if self.gene_type_single == True:
46+
dtype = gene_type
47+
else:
48+
dtype = gene_type[duplicate_index]
49+
50+
if dtype[0] in pygad.GA.supported_int_types:
51+
temp_val = self.unique_int_gene_from_range(solution=new_solution,
52+
gene_index=duplicate_index,
53+
min_val=min_val,
54+
max_val=max_val,
55+
mutation_by_replacement=mutation_by_replacement,
56+
gene_type=gene_type)
57+
else:
58+
temp_val = self.unique_float_gene_from_range(solution=new_solution,
59+
gene_index=duplicate_index,
60+
min_val=min_val,
61+
max_val=max_val,
62+
mutation_by_replacement=mutation_by_replacement,
63+
gene_type=gene_type,
64+
num_trials=num_trials)
65+
"""
66+
temp_val = numpy.random.uniform(low=min_val,
67+
high=max_val,
68+
size=1)[0]
69+
if mutation_by_replacement:
6370
pass
64-
else:
71+
else:
6572
temp_val = new_solution[duplicate_index] + temp_val
73+
"""
74+
75+
if temp_val in new_solution:
76+
num_unsolved_duplicates = num_unsolved_duplicates + 1
77+
if not self.suppress_warnings: warnings.warn(f"Failed to find a unique value for gene with index {duplicate_index} whose value is {solution[duplicate_index]}. Consider adding more values in the gene space or use a wider range for initial population or random mutation.")
78+
else:
79+
# Unique gene value found.
80+
new_solution[duplicate_index] = temp_val
81+
82+
# Update the list of duplicate indices after each iteration.
83+
_, unique_gene_indices = numpy.unique(new_solution, return_index=True)
84+
not_unique_indices = set(range(len(solution))) - set(unique_gene_indices)
85+
# self.logger.info("not_unique_indices INSIDE", not_unique_indices)
6686

67-
# Similar to the round_genes() method in the pygad module,
68-
# Create a round_gene() method to round a single gene.
69-
if not dtype[1] is None:
70-
temp_val = numpy.round(dtype[0](temp_val),
71-
dtype[1])
72-
else:
73-
temp_val = dtype[0](temp_val)
74-
75-
if temp_val in new_solution and trial_index == (num_trials - 1):
76-
num_unsolved_duplicates = num_unsolved_duplicates + 1
77-
if not self.suppress_warnings: warnings.warn(f"Failed to find a unique value for gene with index {duplicate_index} whose value is {solution[duplicate_index]}. Consider adding more values in the gene space or use a wider range for initial population or random mutation.")
78-
elif temp_val in new_solution:
79-
# Keep trying in the other remaining trials.
80-
continue
81-
else:
82-
# Unique gene value found.
83-
new_solution[duplicate_index] = temp_val
84-
break
85-
86-
# TODO Move this code outside the loops.
87-
# Update the list of duplicate indices after each iteration.
88-
_, unique_gene_indices = numpy.unique(new_solution, return_index=True)
89-
not_unique_indices = set(range(len(solution))) - set(unique_gene_indices)
90-
# self.logger.info("not_unique_indices INSIDE", not_unique_indices)
91-
9287
return new_solution, not_unique_indices, num_unsolved_duplicates
9388

9489
def solve_duplicate_genes_by_space(self,
@@ -167,14 +162,14 @@ def unique_int_gene_from_range(self,
167162
Args:
168163
solution (list): A solution containing genes, potentially with duplicate values.
169164
gene_index (int): The index of the gene for which to find a unique value.
170-
min_val (int): The minimum value of the range to sample a number randomly.
171-
max_val (int): The maximum value of the range to sample a number randomly.
165+
min_val (int): The minimum value of the range to sample an integer randomly.
166+
max_val (int): The maximum value of the range to sample an integer randomly.
172167
mutation_by_replacement (bool): Indicates if mutation is performed by replacement.
173-
gene_type (type): The data type of the gene (e.g., int, float).
168+
gene_type (type): The data type of the gene (e.g., int, int8, uint16, etc).
174169
step (int, optional): The step size for generating candidate values. Defaults to 1.
175170
176171
Returns:
177-
int: The new value of the gene. If no unique value can be found, the original gene value is returned.
172+
int: The new integer value of the gene. If no unique value can be found, the original gene value is returned.
178173
"""
179174

180175
# The gene_type is of the form [type, precision]
@@ -194,22 +189,86 @@ def unique_int_gene_from_range(self,
194189
else:
195190
all_gene_values = all_gene_values + solution[gene_index]
196191

197-
# After adding solution[gene_index] to the list, we have to change the data type again.
198-
# TODO: The gene data type is converted twine. One above and one here.
199-
all_gene_values = numpy.asarray(all_gene_values,
200-
dtype)
192+
# After adding solution[gene_index] to the list, we have to change the data type again.
193+
all_gene_values = numpy.asarray(all_gene_values,
194+
dtype)
201195

202196
values_to_select_from = list(set(list(all_gene_values)) - set(solution))
203197

204198
if len(values_to_select_from) == 0:
205199
# If there are no values, then keep the current gene value.
206-
if not self.suppress_warnings: warnings.warn("You set 'allow_duplicate_genes=False' but there is no enough values to prevent duplicates.")
207200
selected_value = solution[gene_index]
208201
else:
209202
selected_value = random.choice(values_to_select_from)
203+
204+
selected_value = dtype[0](selected_value)
210205

211206
return selected_value
212207

208+
def unique_float_gene_from_range(self,
209+
solution,
210+
gene_index,
211+
min_val,
212+
max_val,
213+
mutation_by_replacement,
214+
gene_type,
215+
num_trials=10):
216+
217+
"""
218+
Finds a unique floating-point value for a specific gene in a solution.
219+
220+
Args:
221+
solution (list): A solution containing genes, potentially with duplicate values.
222+
gene_index (int): The index of the gene for which to find a unique value.
223+
min_val (int): The minimum value of the range to sample a floating-point number randomly.
224+
max_val (int): The maximum value of the range to sample a floating-point number randomly.
225+
mutation_by_replacement (bool): Indicates if mutation is performed by replacement.
226+
gene_type (type): The data type of the gene (e.g., float, float16, float32, etc).
227+
num_trials (int): The maximum number of attempts to resolve duplicates by changing the gene values.
228+
229+
Returns:
230+
int: The new floating-point value of the gene. If no unique value can be found, the original gene value is returned.
231+
"""
232+
233+
# The gene_type is of the form [type, precision]
234+
dtype = gene_type
235+
236+
for trial_index in range(num_trials):
237+
temp_val = numpy.random.uniform(low=min_val,
238+
high=max_val,
239+
size=1)[0]
240+
241+
# If mutation is by replacement, do not add the current gene value into the list.
242+
# This is to avoid replacing the value by itself again. We are doing nothing in this case.
243+
if mutation_by_replacement:
244+
pass
245+
else:
246+
temp_val = temp_val + solution[gene_index]
247+
248+
if not dtype[1] is None:
249+
# Precision is available and we have to round the number.
250+
# Convert the data type and round the number.
251+
temp_val = numpy.round(dtype[0](temp_val),
252+
dtype[1])
253+
else:
254+
# There is no precision and rounding the number is not needed. The type is [type, None]
255+
# Just convert the data type.
256+
temp_val = dtype[0](temp_val)
257+
258+
if temp_val in solution and trial_index == (num_trials - 1):
259+
# If there are no values, then keep the current gene value.
260+
if not self.suppress_warnings: warnings.warn("You set 'allow_duplicate_genes=False' but cannot find a value to prevent duplicates.")
261+
selected_value = solution[gene_index]
262+
elif temp_val in solution:
263+
# Keep trying in the other remaining trials.
264+
continue
265+
else:
266+
# Unique gene value found.
267+
selected_value = temp_val
268+
break
269+
270+
return selected_value
271+
213272
def unique_genes_by_space(self,
214273
new_solution,
215274
gene_type,
@@ -225,7 +284,7 @@ def unique_genes_by_space(self,
225284
new_solution (list): A solution containing genes with duplicate values.
226285
gene_type (type): The data type of the gene (e.g., int, float).
227286
not_unique_indices (list): The indices of genes with duplicate values.
228-
num_trials (int): The maximum number of attempts to resolve duplicates for each gene.
287+
num_trials (int): The maximum number of attempts to resolve duplicates for each gene. Only works for floating-point numbers.
229288
230289
Returns:
231290
tuple:
@@ -236,22 +295,18 @@ def unique_genes_by_space(self,
236295

237296
num_unsolved_duplicates = 0
238297
for duplicate_index in not_unique_indices:
239-
for trial_index in range(num_trials):
240-
temp_val = self.unique_gene_by_space(solution=new_solution,
241-
gene_idx=duplicate_index,
242-
gene_type=gene_type,
243-
build_initial_pop=build_initial_pop)
244-
245-
if temp_val in new_solution and trial_index == (num_trials - 1):
246-
# self.logger.info("temp_val, duplicate_index", temp_val, duplicate_index, new_solution)
247-
num_unsolved_duplicates = num_unsolved_duplicates + 1
248-
if not self.suppress_warnings: warnings.warn(f"Failed to find a unique value for gene with index {duplicate_index} whose value is {new_solution[duplicate_index]}. Consider adding more values in the gene space or use a wider range for initial population or random mutation.")
249-
elif temp_val in new_solution:
250-
continue
251-
else:
252-
new_solution[duplicate_index] = temp_val
253-
# self.logger.info("SOLVED", duplicate_index)
254-
break
298+
temp_val = self.unique_gene_by_space(solution=new_solution,
299+
gene_idx=duplicate_index,
300+
gene_type=gene_type,
301+
build_initial_pop=build_initial_pop,
302+
num_trials=num_trials)
303+
304+
if temp_val in new_solution:
305+
# self.logger.info("temp_val, duplicate_index", temp_val, duplicate_index, new_solution)
306+
num_unsolved_duplicates = num_unsolved_duplicates + 1
307+
if not self.suppress_warnings: warnings.warn(f"Failed to find a unique value for gene with index {duplicate_index} whose value is {new_solution[duplicate_index]}. Consider adding more values in the gene space or use a wider range for initial population or random mutation.")
308+
else:
309+
new_solution[duplicate_index] = temp_val
255310

256311
# Update the list of duplicate indices after each iteration.
257312
_, unique_gene_indices = numpy.unique(new_solution, return_index=True)
@@ -264,7 +319,8 @@ def unique_gene_by_space(self,
264319
solution,
265320
gene_idx,
266321
gene_type,
267-
build_initial_pop=False):
322+
build_initial_pop=False,
323+
num_trials=10):
268324

269325
"""
270326
Returns a unique value for a specific gene based on its value space to resolve duplicates.
@@ -273,6 +329,7 @@ def unique_gene_by_space(self,
273329
solution (list): A solution containing genes with duplicate values.
274330
gene_idx (int): The index of the gene that has a duplicate value.
275331
gene_type (type): The data type of the gene (e.g., int, float).
332+
num_trials (int): The maximum number of attempts to resolve duplicates for each gene. Only works for floating-point numbers.
276333
277334
Returns:
278335
Any: A unique value for the gene, if one exists; otherwise, the original gene value. """
@@ -320,9 +377,20 @@ def unique_gene_by_space(self,
320377
low = self.random_mutation_min_val
321378
high = self.random_mutation_max_val
322379

380+
"""
323381
value_from_space = numpy.random.uniform(low=low,
324382
high=high,
325383
size=1)[0]
384+
"""
385+
386+
value_from_space = self.unique_float_gene_from_range(solution=solution,
387+
gene_index=gene_idx,
388+
min_val=low,
389+
max_val=high,
390+
mutation_by_replacement=True,
391+
gene_type=dtype,
392+
num_trials=num_trials)
393+
326394

327395
elif type(curr_gene_space) is dict:
328396
if self.gene_type_single == True:

0 commit comments

Comments
 (0)