Skip to content

Commit

Permalink
Add kurobako benchmark script
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Oct 20, 2020
1 parent 0a86ba4 commit 517a861
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
11 changes: 10 additions & 1 deletion benchmark/optuna_solver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import optuna

from cmaes import SepCMA
from kurobako import solver
from kurobako.solver.optuna import OptunaSolverFactory

parser = argparse.ArgumentParser()
parser.add_argument("sampler", choices=["cmaes", "ipop-cmaes", "pycma"])
parser.add_argument("sampler", choices=["cmaes", "sep-cmaes", "ipop-cmaes", "pycma"])
parser.add_argument(
"--loglevel", choices=["debug", "info", "warning", "error"], default="warning"
)
Expand All @@ -26,6 +27,12 @@ def create_cmaes_study(seed):
return optuna.create_study(sampler=sampler)


def create_sep_cmaes_study(seed):
optuna.samplers._cmaes.CMA = SepCMA # monkey patch
sampler = optuna.samplers.CmaEsSampler(seed=seed, warn_independent_sampling=True)
return optuna.create_study(sampler=sampler)


def create_ipop_cmaes_study(seed):
sampler = optuna.samplers.CmaEsSampler(
seed=seed,
Expand All @@ -47,6 +54,8 @@ def create_pycma_study(seed):
if __name__ == "__main__":
if args.sampler == "cmaes":
factory = OptunaSolverFactory(create_cmaes_study)
elif args.sampler == "sep-cmaes":
factory = OptunaSolverFactory(create_sep_cmaes_study)
elif args.sampler == "ipop-cmaes":
factory = OptunaSolverFactory(create_ipop_cmaes_study)
elif args.sampler == "pycma":
Expand Down
5 changes: 3 additions & 2 deletions benchmark/runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,19 @@ esac

RANDOM_SOLVER=$($KUROBAKO solver random)
CMAES_SOLVER=$($KUROBAKO solver --name 'cmaes' command python $DIR/optuna_solver.py cmaes)
SEP_CMAES_SOLVER=$($KUROBAKO solver --name 'sep-cmaes' command python $DIR/optuna_solver.py sep-cmaes)
IPOP_CMAES_SOLVER=$($KUROBAKO solver --name 'ipop-cmaes' command python $DIR/optuna_solver.py ipop-cmaes)
PYCMA_SOLVER=$($KUROBAKO solver --name 'pycma' command python $DIR/optuna_solver.py pycma)

if [ $BUDGET -le 500 ]; then
$KUROBAKO studies \
--solvers $RANDOM_SOLVER $IPOP_CMAES_SOLVER $PYCMA_SOLVER $CMAES_SOLVER \
--solvers $RANDOM_SOLVER $IPOP_CMAES_SOLVER $PYCMA_SOLVER $CMAES_SOLVER $SEP_CMAES_SOLVER \
--problems $PROBLEM \
--seed $SEED --repeats $REPEATS --budget $BUDGET \
| $KUROBAKO run --parallelism 4 > $2
else
$KUROBAKO studies \
--solvers $RANDOM_SOLVER $IPOP_CMAES_SOLVER $CMAES_SOLVER \
--solvers $RANDOM_SOLVER $IPOP_CMAES_SOLVER $CMAES_SOLVER $SEP_CMAES_SOLVER \
--problems $PROBLEM \
--seed $SEED --repeats $REPEATS --budget $BUDGET \
| $KUROBAKO run --parallelism 6 > $2
Expand Down
16 changes: 16 additions & 0 deletions cmaes/_sepcma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
import numpy as np

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -175,6 +177,20 @@ def generation(self) -> int:
when multi-variate gaussian distribution is updated."""
return self._g

def __getstate__(self) -> Dict[str, Any]:
attrs = {}
for name in self.__dict__:
# Remove _rng in pickle serialized object.
if name == "_rng":
continue
attrs[name] = getattr(self, name)
return attrs

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Set _rng for unpickled object.
setattr(self, "_rng", np.random.RandomState())

def set_bounds(self, bounds: Optional[np.ndarray]) -> None:
"""Update boundary constraints"""
assert (
Expand Down
3 changes: 3 additions & 0 deletions examples/optuna_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import optuna
from cmaes import SepCMA

optuna.samplers._cmaes.CMA = SepCMA # monkey patch


def objective(trial: optuna.Trial):
Expand Down

0 comments on commit 517a861

Please sign in to comment.