Skip to content

Commit

Permalink
Update genetic_algorithm.py to remove ordinal variables
Browse files Browse the repository at this point in the history
  • Loading branch information
AntGro authored Aug 30, 2023
1 parent 3431140 commit 2bccb4e
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions MCBO/mcbo/optimizers/non_bo/genetic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ def __init__(self,
dtype: torch.dtype = torch.float64,
):

assert search_space.num_nominal + search_space.num_ordinal == search_space.num_dims, \
'Genetic Algorithm currently supports only nominal and ordinal variables'
assert search_space.num_nominal == search_space.num_dims, \
'Genetic Algorithm currently supports only nominal variables'

super(CategoricalGeneticAlgorithm, self).__init__(
search_space=search_space, dtype=dtype, input_constraints=input_constraints
Expand Down Expand Up @@ -336,11 +336,11 @@ def __init__(self,
if self.tr_manager is not None:
self.x_queue.iloc[0:1] = self.search_space.inverse_transform(self.tr_center.unsqueeze(0))

self.map_to_canonical = self.search_space.nominal_dims + self.search_space.ordinal_dims
self.map_to_canonical = self.search_space.nominal_dims
self.map_to_original = [self.map_to_canonical.index(i) for i in range(len(self.map_to_canonical))]

self.lb = self.search_space.nominal_lb + self.search_space.ordinal_lb
self.ub = self.search_space.nominal_ub + self.search_space.ordinal_ub
self.lb = self.search_space.nominal_lb
self.ub = self.search_space.nominal_ub

def get_tr_point_sampler(self) -> Callable[[int], pd.DataFrame]:
"""
Expand Down Expand Up @@ -422,7 +422,8 @@ def method_suggest(self, n_suggestions: int = 1) -> pd.DataFrame:
if n_remaining and len(self.x_queue):
n = min(n_remaining, len(self.x_queue))
x_next.iloc[idx: idx + n] = self.x_queue.iloc[idx: idx + n]
self.x_queue = self.x_queue.drop(self.x_queue.index[[i for i in range(idx, idx + n)]]).reset_index(drop=True)
self.x_queue = self.x_queue.drop(self.x_queue.index[[i for i in range(idx, idx + n)]]).reset_index(
drop=True)

idx += n
n_remaining -= n
Expand All @@ -432,7 +433,8 @@ def method_suggest(self, n_suggestions: int = 1) -> pd.DataFrame:

n = min(n_remaining, len(self.x_queue))
x_next.iloc[idx: idx + n] = self.x_queue.iloc[idx: idx + n]
self.x_queue = self.x_queue.drop(self.x_queue.index[[i for i in range(idx, idx + n)]]).reset_index(drop=True)
self.x_queue = self.x_queue.drop(self.x_queue.index[[i for i in range(idx, idx + n)]]).reset_index(
drop=True)

idx += n
n_remaining -= n
Expand Down Expand Up @@ -615,7 +617,7 @@ def point_sampler(n_points: int) -> pd.DataFrame:
return

def _crossover(self, x1: torch.Tensor, x2: torch.Tensor) -> (torch.Tensor, torch.Tensor):
assert self.search_space.num_ordinal + self.search_space.num_nominal == self.search_space.num_dims, \
assert self.search_space.num_nominal == self.search_space.num_dims, \
'Current crossover can\'t handle permutations'

x1_ = x1.clone()
Expand Down Expand Up @@ -655,7 +657,7 @@ def _crossover(self, x1: torch.Tensor, x2: torch.Tensor) -> (torch.Tensor, torch
return x1_, x2_

def _mutate(self, x: torch.Tensor) -> torch.Tensor:
assert self.search_space.num_ordinal + self.search_space.num_nominal == self.search_space.num_dims, \
assert self.search_space.num_nominal == self.search_space.num_dims, \
'Current mutate can\'t handle permutations'
assert x.ndim == 2, (x.shape, self.map_to_canonical)
x_ = x.clone()[:, self.map_to_canonical]
Expand Down

0 comments on commit 2bccb4e

Please sign in to comment.