Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Updates for final paper revision
Browse files Browse the repository at this point in the history
  • Loading branch information
bletham committed Oct 22, 2020
1 parent bfcc5d9 commit 403bbc1
Show file tree
Hide file tree
Showing 22 changed files with 623 additions and 59 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ benchmarks/smac3*/
benchmarks/tmp/
benchmarks/results/*.json
evaluations.hdf5
benchmarks/nasbench_only108.tfrecord
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Optimization](https://arxiv.org/abs/2001.11659)"

If you find this code useful please cite it as

@article{Letham2019Re,
@inproceedings{Letham2020Re,
author = {Letham, Benjamin and Calandra, Roberto and Rai, Akshara and Bakshy, Eytan},
title = {Re-Examining Linear Embeddings for High-dimensional Bayesian Optimization},
journal = {arXiv preprint arXiv: 2001.11659},
title = {Re-Examining Linear Embeddings for High-Dimensional {B}ayesian Optimization},
booktitle = {Advances in Neural Information Processing Systems 33},
year = {2020},
series = {NeurIPS},
}

## Installation
Expand All @@ -22,7 +23,7 @@ Some of the baselines require additional packages that can not be pip-installed.
Detailed instructions can be found inside each file of the `benchmarks/` folder.

## Using ALEBO for optimizing a function
See `quickstart.ipynb` for a simple example of how to use ALEBO to optimize a function. ALEBO is built using the [Ax platform](https://ax.dev/); see instructions there on how to install via pip. You will need version 0.1.9.
See `quickstart.ipynb` for a simple example of how to use ALEBO to optimize a function. ALEBO is built using the [Ax platform](https://ax.dev/); see instructions there on how to install via pip. You will need version 0.1.17 or later.

## Reproducing the experiments
This repository contains the code required to run the benchmark experiments and generate the figures in the paper. The only exception are the DAISY figures, since the simulator is not yet open source.
Expand All @@ -49,6 +50,8 @@ All benchmark results are stored in `benchmark/results/` (the json files produce

Executing `figs/fig_5.py` loads these aggregated results and generates the benchmark results figure in the paper.

A separate script `benchmarks/run_nasbench.py` contains all of the code for running the NASBench experiment.

### The ALEBO model and generation code
The actual implementation of the ALEBO method is at: https://github.com/facebook/Ax/blob/master/ax/models/torch/alebo.py

Expand Down
137 changes: 137 additions & 0 deletions benchmarks/ablation_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

# This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union

from ax.models.torch.botorch_defaults import get_and_fit_model
from ax.modelbridge.strategies.alebo import ALEBOStrategy, get_ALEBOInitializer

import torch
from torch import Tensor
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.search_space import SearchSpace
from ax.modelbridge.factory import DEFAULT_TORCH_DEVICE
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transforms.centered_unit_x import CenteredUnitX
from ax.modelbridge.transforms.standardize_y import StandardizeY
from botorch.models.gpytorch import GPyTorchModel
from ax.models.torch.alebo import ALEBO


class ALEBO_kernel_ablation(ALEBO):

def get_and_fit_model(
self,
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
state_dicts: Optional[List[MutableMapping[str, Tensor]]] = None,
) -> GPyTorchModel:
return get_and_fit_model(
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
task_features=[],
fidelity_features=[],
metric_names=[],
state_dict=None,
)


def get_ALEBO_kernel_ablation(
experiment: Experiment,
search_space: SearchSpace,
data: Data,
B: torch.Tensor,
**model_kwargs: Any,
) -> TorchModelBridge:
if search_space is None:
search_space = experiment.search_space
return TorchModelBridge(
experiment=experiment,
search_space=search_space,
data=data,
model=ALEBO_kernel_ablation(B=B, **model_kwargs),
transforms=[CenteredUnitX, StandardizeY],
torch_dtype=B.dtype,
torch_device=B.device,
)


class ALEBOStrategy_kernel_ablation(GenerationStrategy):

def __init__(
self,
D: int,
d: int,
init_size: int,
name: str = "ALEBO",
dtype: torch.dtype = torch.double,
device: torch.device = DEFAULT_TORCH_DEVICE,
random_kwargs: Optional[Dict[str, Any]] = None,
gp_kwargs: Optional[Dict[str, Any]] = None,
gp_gen_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.D = D
self.d = d
self.init_size = init_size
self.dtype = dtype
self.device = device
self.random_kwargs = random_kwargs if random_kwargs is not None else {}
self.gp_kwargs = gp_kwargs if gp_kwargs is not None else {}
self.gp_gen_kwargs = gp_gen_kwargs

B = self.gen_projection(d=d, D=D, device=device, dtype=dtype)

self.gp_kwargs.update({"B": B})
self.random_kwargs.update({"B": B.cpu().numpy()})

steps = [
GenerationStep(
model=get_ALEBOInitializer,
num_arms=init_size,
model_kwargs=self.random_kwargs,
),
GenerationStep(
model=get_ALEBO_kernel_ablation,
num_arms=-1,
model_kwargs=self.gp_kwargs,
model_gen_kwargs=gp_gen_kwargs,
),
]
super().__init__(steps=steps, name=name)

def clone_reset(self) -> "ALEBOStrategy":
"""Copy without state."""
return self.__class__(
D=self.D,
d=self.d,
init_size=self.init_size,
name=self.name,
dtype=self.dtype,
device=self.device,
random_kwargs=self.random_kwargs,
gp_kwargs=self.gp_kwargs,
gp_gen_kwargs=self.gp_gen_kwargs,
)

def gen_projection(
self, d: int, D: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Generate the projection matrix B as a (d x D) tensor
"""
B0 = torch.randn(d, D, dtype=dtype, device=device)
B = B0 / torch.sqrt((B0 ** 2).sum(dim=0))
return B


class ALEBOStrategy_projection_ablation(ALEBOStrategy):
def gen_projection(
self, d: int, D: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
B0 = torch.randn(d, D, dtype=dtype, device=device)
return B0
57 changes: 57 additions & 0 deletions benchmarks/compile_benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,62 @@ def compile_sensitivity_benchmarks():
json.dump(object_to_json(res), fout)


def compile_ablation_benchmarks():
all_results = {}

for rep in range(100):
with open(f'results/ablation_rep_{rep}.json', 'r') as fin:
res_i = object_from_json(json.load(fin))

all_results = merge_benchmark_results(all_results, res_i)

problems = [branin_100]
res = {
p.name+'_ablation': aggregate_problem_results(runs=all_results[p.name], problem=p)
for p in problems
}
# Save
with open(f'results/ablation_aggregated_results.json', "w") as fout:
json.dump(object_to_json(res), fout)


def compile_nasbench():
all_res = {}
# TuRBO and CMAES
for method in ['turbo', 'cmaes']:
all_res[method] = []
for rep in range(100):
with open(f'results/nasbench_{method}_rep_{rep}.json', 'r') as fin:
fs, feas = json.load(fin)
# Set infeasible points to nan
fs = np.array(fs)
fs[~np.array(feas)] = np.nan
all_res[method].append(fs)

# Ax methods
for method in ['Sobol', 'ALEBO', 'HeSBO', 'REMBO']:
all_res[method] = []
for rep in range(100):
with open(f'results/nasbench_{method}_rep_{rep}.json', 'r') as fin:
exp = object_from_json(json.load(fin))
# Pull out results and set infeasible points to nan
df = exp.fetch_data().df.sort_values(by='arm_name')
df_obj = df[df['metric_name'] == 'final_test_accuracy'].copy().reset_index(drop=True)
df_con = df[df['metric_name'] == 'final_training_time'].copy().reset_index(drop=True)
infeas = df_con['mean'] > 1800
df_obj.loc[infeas, 'mean'] = np.nan
all_res[method].append(df_obj['mean'].values)

for method, arr in all_res.items():
all_res[method] = np.fmax.accumulate(np.vstack(all_res[method]), axis=1)

with open(f'results/nasbench_aggregated_results.json', "w") as fout:
json.dump(object_to_json(all_res), fout)


if __name__ == '__main__':
compile_nasbench()
gc.collect()
compile_hartmann6(D=100)
gc.collect()
compile_hartmann6(D=1000)
Expand All @@ -183,3 +238,5 @@ def compile_sensitivity_benchmarks():
compile_sensitivity_benchmarks()
gc.collect()
compile_hartmann6(D=1000, random_subspace=True)
gc.collect()
compile_ablation_benchmarks()
125 changes: 125 additions & 0 deletions benchmarks/nasbench_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

# This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

"""
Requires nasbench==1.0 from https://github.com/google-research/nasbench
Also requires dataset nasbench_only108.tfrecord to be downloaded here.
Creates an evaluation functionn for neural architecture search
"""
import numpy as np

from ax.service.ax_client import AxClient

from nasbench.lib.model_spec import ModelSpec
from nasbench import api

nasbench = api.NASBench('nasbench_only108.tfrecord')


def get_spec(adj_indxs, op_indxs):
"""
Construct a NASBench spec from adjacency matrix and op indicators
"""
op_names = ['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3']
ops = ['input']
ops.extend([op_names[i] for i in op_indxs])
ops.append('output')
iu = np.triu_indices(7, k=1)
adj_matrix = np.zeros((7, 7), dtype=np.int32)
adj_matrix[(iu[0][adj_indxs], iu[1][adj_indxs])] = 1
spec = ModelSpec(adj_matrix, ops)
return spec


def evaluate_x(x):
"""
Evaluate NASBench on the model defined by x.
x is a 36-d array.
The first 21 are for the adjacency matrix. Largest entries will have the
corresponding element in the adjacency matrix set to 1, with as many 1s as
possible within the NASBench model space.
The last 15 are for the ops in each of the five NASBench model components.
One-hot encoded for each of the 5 components, 3 options.
"""
assert len(x) == 36
x_adj = x[:21]
x_op = x[-15:]
x_ord = x_adj.argsort()[::-1]
op_indxs = x_op.reshape(3, 5).argmax(axis=0).tolist()
last_good = None
for i in range(1, 22):
model_spec = get_spec(x_ord[:i], op_indxs)
if model_spec.matrix is not None:
# We have a connected graph
# See if it has too many edges
if model_spec.matrix.sum() > 9:
break
last_good = model_spec
if last_good is None:
# Could not get a valid spec from this x. Return bad metric values.
return [0.80], [50 * 60]
fixed_metrics, computed_metrics = nasbench.get_metrics_from_spec(last_good)
test_acc = [r['final_test_accuracy'] for r in computed_metrics[108]]
train_time = [r['final_training_time'] for r in computed_metrics[108]]
return np.mean(test_acc), np.mean(train_time)


def evaluate_parameters(parameters):
x = np.array([parameters[f'x{i}'] for i in range(36)])
test_acc, train_time = evaluate_x(x)
return {
'final_test_accuracy': (test_acc, 0.0),
'final_training_time': (train_time, 0.0),
}


def get_nasbench_ax_client(generation_strategy):
# Get parameters
parameters = [
{
"name": f"x{i}",
"type": "range",
"bounds": [0, 1],
"value_type": "float",
"log_scale": False,
} for i in range(36)
]
axc = AxClient(generation_strategy=generation_strategy, verbose_logging=False)
axc.create_experiment(
name="nasbench",
parameters=parameters,
objective_name="final_test_accuracy",
minimize=False,
outcome_constraints=["final_training_time <= 1800"],
)
return axc


class NASBenchRunner:
"""
A runner for non-Ax methods.
Assumes method MINIMIZES.
"""
def __init__(self, max_eval):
# For tracking iterations
self.fs = []
self.feas = []
self.n_eval = 0
self.max_eval = max_eval

def f(self, x):
if self.n_eval >= self.max_eval:
raise ValueError("Evaluation budget exhuasted")
test_acc, train_time = evaluate_x(x)
feas = bool(train_time <= 1800)
if not feas:
val = 0.80 # bad value for infeasible
else:
val = test_acc
self.n_eval += 1
self.fs.append(test_acc) # Store the true, not-negated value
self.feas.append(feas)
return -val # ASSUMES METHOD MINIMIZES
1 change: 1 addition & 0 deletions benchmarks/results/nasbench_aggregated_results.json

Large diffs are not rendered by default.

Loading

0 comments on commit 403bbc1

Please sign in to comment.