From e7f18bd8d7074fab998bae3fa6e522246649da1c Mon Sep 17 00:00:00 2001 From: a00663044 Date: Sun, 10 Mar 2024 22:17:58 +0000 Subject: [PATCH] [ANTBO]: add option to optimize from table of candidates. --- AntBO/bo/botask.py | 4 +- AntBO/bo/localbo_cat.py | 185 +++++++----------- AntBO/bo/localbo_utils.py | 396 ++++++++++++++------------------------ AntBO/bo/main.py | 33 +++- AntBO/bo/optimizer.py | 245 ++++++++++++++--------- AntBO/bo/utils.py | 91 ++++++--- AntBO/demo.py | 11 +- AntBO/environment.yaml | 4 +- AntBO/task/base.py | 17 ++ AntBO/task/tools.py | 79 +++++++- 10 files changed, 569 insertions(+), 496 deletions(-) diff --git a/AntBO/bo/botask.py b/AntBO/bo/botask.py index 8c86d5ea..72704bae 100644 --- a/AntBO/bo/botask.py +++ b/AntBO/bo/botask.py @@ -1,6 +1,6 @@ import numpy as np from bo.base import TestFunction -from task.tools import Absolut, Manual +from task.tools import Absolut, Manual, TableFilling import torch class BOTask(TestFunction): @@ -28,6 +28,8 @@ def __init__(self, self.fbox = Absolut(self.bbox) elif self.bbox['tool'] == 'manual': self.fbox = Manual(self.bbox) + elif self.bbox['tool'] == 'table_filling': + self.fbox = TableFilling(self.bbox) else: assert 0,f"{self.bbox['tool']} Not Implemented" diff --git a/AntBO/bo/localbo_cat.py b/AntBO/bo/localbo_cat.py index 331ef7f0..2aa8ecb9 100644 --- a/AntBO/bo/localbo_cat.py +++ b/AntBO/bo/localbo_cat.py @@ -1,16 +1,17 @@ -import math +import os import sys from copy import deepcopy +from typing import Optional import gpytorch +import math import numpy as np import torch -import sklearn from sklearn.preprocessing import power_transform -from torch.distributions import Normal from bo.gp import train_gp from bo.localbo_utils import random_sample_within_discrete_tr_ordinal +from bo.utils import update_table_of_candidates def hebo_transform(X): @@ -87,7 +88,6 @@ def __init__( # assert len(lb) == len(ub) # assert np.all(ub > lb) assert max_evals > 0 and isinstance(max_evals, int) - assert n_init > 0 and isinstance(n_init, int) assert batch_size > 0 and isinstance(batch_size, int) assert isinstance(verbose, bool) and isinstance(use_ard, bool) assert max_cholesky_size >= 0 and isinstance(batch_size, int) @@ -116,17 +116,18 @@ def __init__( self.n_training_steps = n_training_steps self.cdr_constraints = cdr_constraints self.normalise = normalise + if self.search_strategy != 'local': + self.kwargs['noise_variance'] = None self.acq = acq self.kernel_type = kernel_type - if self.kernel_type in ['rbfBERT']: + if self.kernel_type in ['rbfBERT', 'rbf-pca-BERT', 'cosine-BERT', 'cosine-pca-BERT']: self.BERT_model = self.kwargs['BERT_model'] self.BERT_tokeniser = self.kwargs['BERT_tokeniser'] self.BERT_batchsize = self.kwargs['BERT_batchsize'] - self.use_pca = self.kwargs['use_pca'] self.antigen = self.kwargs['antigen'] else: self.BERT_model, self.BERT_tokeniser, self.BERT_batchsize = None, None, None - self.use_pca, self.antigen = None, None + self.antigen = None # Hyperparameters self.mean = np.zeros((0, 1)) @@ -137,8 +138,14 @@ def __init__( # Tolerances and counters self.n_cand = kwargs['n_cand'] if 'n_cand' in kwargs.keys() else min(100 * self.dim, 5000) self.tr_multiplier = kwargs['multiplier'] if 'multiplier' in kwargs.keys() else 1.5 - self.failtol = kwargs['failtol'] if 'failtol' in kwargs.keys() else 40 - self.succtol = kwargs['succtol'] if 'succtol' in kwargs.keys() else 3 + if os.getenv("ANTBO_DEBUG", False): + failtol = 5 + succtol = 3 + else: + succtol = kwargs['succtol'] if 'succtol' in kwargs.keys() else 3 + failtol = kwargs['failtol'] if 'failtol' in kwargs.keys() else 40 + self.succtol = succtol + self.failtol = failtol self.n_evals = 0 # Trust region sizes @@ -148,8 +155,16 @@ def __init__( # Trust region sizes (in terms of Hamming distance) of the discrete variables. self.length_min_discrete = kwargs['length_min_discrete'] if 'length_min_discrete' in kwargs.keys() else 1 - self.length_max_discrete = kwargs['length_max_discrete'] if 'length_max_discrete' in kwargs.keys() else 30 - self.length_init_discrete = kwargs['length_init_discrete'] if 'length_init_discrete' in kwargs.keys() else 20 + if os.getenv("ANTBO_DEBUG", False): + lmd = 10 + else: + lmd = kwargs['length_max_discrete'] if 'length_max_discrete' in kwargs.keys() else min(30, self.dim) + self.length_max_discrete = lmd + if os.getenv("ANTBO_DEBUG", False): + lid = lmd + else: + lid = kwargs['length_init_discrete'] if 'length_init_discrete' in kwargs.keys() else min(20, self.dim) + self.length_init_discrete = lid # Save the full history self.X = np.zeros((0, self.dim)) @@ -194,13 +209,12 @@ def _adjust_length(self, fX_next): # Ditto for shrinking. self.length_discrete = int(self.length_discrete / self.tr_multiplier) # self.length = max(self.length / 1.5, self.length_min) - print("Shrink", self.length, self.length_discrete) - def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, return_acq=False, num_samples=51, - warmup_steps=100, thinning=1): + def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, + table_of_candidates: Optional[np.ndarray], return_acq=False): # assert X.min() >= 0.0 and X.max() <= 1.0 # Figure out what device we are running on - if self.search_strategy in ['global', 'glocal', 'batch_local']: + if self.search_strategy == 'glocal': fX = hebo_transform(fX) else: fX = (fX - fX.mean()) / (fX.std() + 1e-8) @@ -220,7 +234,8 @@ def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, if self.kernel_type == 'ssk': assert self.kwargs['alphabet_size'] is not None gp = train_gp( - train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, num_steps=n_training_steps, hypers=hypers, + train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, + num_steps=n_training_steps, hypers=hypers, kern=self.kernel_type, noise_variance=self.kwargs['noise_variance'] if 'noise_variance' in self.kwargs else None, @@ -229,12 +244,7 @@ def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, BERT_tokeniser=self.BERT_tokeniser, BERT_batchsize=self.BERT_batchsize, antigen=self.antigen, - use_pca=self.use_pca, - search_strategy=self.search_strategy, - acq=self.acq, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, + search_strategy=self.search_strategy ) # Save state dict hypers = gp.state_dict() @@ -242,23 +252,14 @@ def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, # mu, sigma = np.median(fX), fX.std() # sigma = 1.0 if sigma < 1e-6 else sigma # fX = (deepcopy(fX) - mu) / sigma - from .localbo_utils import local_search, glocal_search, blocal_search + from bo.localbo_utils import local_search, glocal_search, local_table_search - if self.search_strategy in ['glocal', 'batch_local']: - search = glocal_search if self.search_strategy == 'glocal' else blocal_search - kwargs = {'kernel_type': self.kernel_type, 'alphabet_size': self.kwargs['alphabet_size'], 'biased': True} - elif self.search_strategy == 'global': + if table_of_candidates is not None: + search = local_table_search + elif self.search_strategy == 'glocal': search = glocal_search - kwargs = {'kernel_type': self.kernel_type, 'alphabet_size': self.kwargs['alphabet_size'], 'biased': False} - elif self.search_strategy == 'local': - search = local_search - kwargs = {} - elif self.search_strategy == 'local-no-hamming': - search = local_search - kwargs = {} - length = self.dim else: - raise ValueError(f"Unknown search strategy: {self.search_strategy}") + search = local_search x_center = X[fX.argmin().item(), :][None, :] @@ -286,58 +287,16 @@ def thompson(n_cand=5000): y_cand[indbest, :] = np.inf return X_next, y_next - def _mace(X, augmented=False, eps=1e-4, maximise=True, kappa=2.0): - """MACE with option to augment""" - - if not isinstance(X, torch.Tensor): - X = torch.tensor(X, dtype=torch.float32).to(device) - if X.dim() == 1: - X = X.reshape(1, -1) - gauss = Normal(torch.zeros(1).to(device), torch.ones(1).to(device)) - - preds = gp(X.to(device)) - - # use in-fill criterion - tau = gp.likelihood(gp(torch.tensor(x_center[0].reshape(1, -1), dtype=torch.float32).to(device))).mean - mean, std = preds.mean, preds.stddev - std = std.clamp(min=torch.finfo(std.dtype).eps) - lcb_min_opt = (mean - kappa * std) - lcb_max_opt = - 1.0 * lcb_min_opt - normed = (tau - eps - mean) / std - log_phi = gauss.log_prob(normed) - Phi = gauss.cdf(normed) - PI = Phi - EI = std * (Phi * normed + log_phi.exp()) - - logEIapp = mean.log() - 0.5 * normed ** 2 - (normed ** 2 - 1).log() - logPIapp = -0.5 * normed ** 2 - torch.log(-1 * normed) - torch.log(torch.sqrt(torch.tensor(2 * np.pi))) - use_app = ~((normed > -6) & torch.isfinite(EI.log()) & torch.isfinite(PI.log())).reshape(-1) - - out = torch.zeros(X.shape[0], 3).to(device) - out[:, 0] = lcb_max_opt.reshape(-1) - out[:, 1][use_app] = logEIapp[use_app].reshape(-1) - out[:, 2][use_app] = logPIapp[use_app].reshape(-1) - out[:, 1][~use_app] = EI[~use_app].log().reshape(-1) - out[:, 2][~use_app] = PI[~use_app].log().reshape(-1) - - if augmented: - sigma_n = gp.likelihood.noise - out *= (1. - torch.sqrt(sigma_n.clone().detach()) / torch.sqrt(sigma_n + std ** 2)) - - if not maximise: - out *= -1.0 - - return out - def _ei(X, augmented=True): """Expected improvement (with option to enable augmented EI""" + from torch.distributions import Normal if not isinstance(X, torch.Tensor): X = torch.tensor(X, dtype=torch.float32) if X.dim() == 1: X = X.reshape(1, -1) gauss = Normal(torch.zeros(1).to(device), torch.ones(1).to(device)) # flip for minimization problems - if self.kernel_type in ['rbfBERT']: + if self.kernel_type in ['rbfBERT', 'rbf-pca-BERT', 'cosine-BERT', 'cosine-pca-BERT']: from bo.utils import BERTFeatures from einops import rearrange bert = BERTFeatures(self.BERT_model, self.BERT_tokeniser) @@ -345,9 +304,10 @@ def _ei(X, augmented=True): x_reprsn = rearrange(x_reprsn, 'b l d -> b (l d)') x_center_reprsn = bert.compute_features(torch.tensor(x_center[0].reshape(1, -1))) x_center_reprsn = rearrange(x_center_reprsn, 'b l d -> b (l d)') - if self.use_pca: - pca = load(f"{self.antigen}_pca.joblib") - scaler = load(f"{self.antigen}_scaler.joblib") + if self.kernel_type in ['rbf-pca-BERT', 'cosine-pca-BERT']: + from joblib import load + pca = load(f"/nfs/aiml/asif/CDRdata/pca/{self.antigen}_pca.joblib") + scaler = load(f"/nfs/aiml/asif/CDRdata/pca/{self.antigen}_scaler.joblib") x_reprsn = torch.from_numpy(pca.transform(scaler.transform(x_reprsn.cpu().numpy()))) x_center_reprsn = torch.from_numpy(pca.transform(scaler.transform(x_center_reprsn.cpu().numpy()))) del bert @@ -375,7 +335,7 @@ def _ei(X, augmented=True): return ei - def _ucb(X, beta=2.): + def _ucb(X, beta=5.): """Upper confidence bound""" if not isinstance(X, torch.Tensor): X = torch.tensor(X, dtype=torch.float32) @@ -385,38 +345,21 @@ def _ucb(X, beta=2.): preds = gp.likelihood(gp(X)) mean, std = preds.mean, preds.stddev - return -(mean + beta * std) - - if self.acq in ['ei', 'ucb', 'eiucb', 'mace', 'imace']: + return mean + beta * std + if self.acq in ['ei', 'ucb']: if self.batch_size == 1: # Sequential setting if self.acq == 'ei': X_next, acq_next = search(x_center=x_center[0], f=_ei, config=self.config, max_hamming_dist=length, n_restart=3, batch_size=self.batch_size, cdr_constraints=self.cdr_constraints, seed=self.seed, dtype=self.dtype, - device=self.device, **kwargs) - elif self.acq == 'eiucb': - X_next, acq_next = search(x_center=x_center[0], f=_ei, f2=_ucb, config=self.config, - max_hamming_dist=length, n_restart=3, batch_size=self.batch_size, - cdr_constraints=self.cdr_constraints, seed=self.seed, dtype=self.dtype, - device=self.device, **kwargs) - elif self.acq == 'mace': - X_next, acq_next = search(x_center=x_center[0], f=_mace, config=self.config, - max_hamming_dist=length, n_obj=3, batch_size=self.batch_size, - cdr_constraints=self.cdr_constraints, seed=self.seed, dtype=self.dtype, - device=self.device, **kwargs) - elif self.acq == 'imace': - print("USING IMACE YES") - X_next, acq_next = search(x_center=x_center[0], f=_imace, config=self.config, - max_hamming_dist=length, n_obj=3, batch_size=self.batch_size, - cdr_constraints=self.cdr_constraints, seed=self.seed, dtype=self.dtype, - device=self.device, **kwargs) + device=self.device, table_of_candidates=table_of_candidates) else: X_next, acq_next = search(x_center=x_center[0], f=_ucb, config=self.config, max_hamming_dist=length, n_restart=3, batch_size=self.batch_size, cdr_constraints=self.cdr_constraints, seed=self.seed, dtype=self.dtype, - device=self.device, **kwargs) + device=self.device, table_of_candidates=table_of_candidates) else: # batch setting: for these, we use the fantasised points {x, y} X_next = torch.tensor([], dtype=torch.float32) @@ -425,21 +368,30 @@ def _ucb(X, beta=2.): if self.acq == 'ei': x_next, acq = search(x_center=x_center[0], f=_ei, config=self.config, max_hamming_dist=length, n_restart=3, batch_size=1, cdr_constraints=self.cdr_constraints, - seed=self.seed, dtype=self.dtype, device=self.device, **kwargs) + seed=self.seed, dtype=self.dtype, device=self.device, + table_of_candidates=table_of_candidates) else: x_next, acq = search(x_center=x_center[0], f=_ucb, config=self.config, max_hamming_dist=length, n_restart=3, batch_size=1, cdr_constraints=self.cdr_constraints, - seed=self.seed, dtype=self.dtype, device=self.device, **kwargs) + seed=self.seed, dtype=self.dtype, device=self.device, + table_of_candidates=table_of_candidates) + + table_of_candidates = update_table_of_candidates( + original_table=table_of_candidates, + observed_candidates=x_next, + check_candidates_in_table=True + ) - x_next = torch.tensor(x_next, dtype=torch.float32) + x_next = torch.tensor(x_next, dtype=torch.float32, device=self.device) # The fantasy point is filled by the posterior mean of the Gaussian process. - if self.kernel_type in ['rbfBERT']: + if self.kernel_type in ['rbfBERT', 'rbf-pca-BERT', 'cosine-BERT', 'cosine-pca-BERT']: from bo.utils import BERTFeatures from einops import rearrange bert = BERTFeatures(self.BERT_model, self.BERT_tokeniser) x_next_reprsn = bert.compute_features(x_next) x_next_reprsn = rearrange(x_next_reprsn, 'b l d -> b (l d)') - if self.use_pca: + if self.kernel_type in ['rbf-pca-BERT', 'cosine-pca-BERT']: + from joblib import load pca = load(f"{self.antigen}_pca.joblib") scaler = load(f"{self.antigen}_scaler.joblib") x_next_reprsn = torch.from_numpy( @@ -449,8 +401,8 @@ def _ucb(X, beta=2.): else: y_next = gp(x_next).mean.detach() with gpytorch.settings.max_cholesky_size(self.max_cholesky_size): - X_torch = torch.cat((X_torch, x_next), dim=0) - y_torch = torch.cat((y_torch, y_next), dim=0) + X_torch = torch.cat((X_torch, x_next), dim=0).to(device=self.device) + y_torch = torch.cat((y_torch, y_next), dim=0).to(device=self.device) gp = train_gp( train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, num_steps=n_training_steps, kern=self.kernel_type, @@ -462,12 +414,7 @@ def _ucb(X, beta=2.): BERT_tokeniser=self.BERT_tokeniser, BERT_batchsize=self.BERT_batchsize, antigen=self.antigen, - use_pca=self.use_pca, - search_strategy=self.search_strategy, - acq=self.acq, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, + search_strategy=self.search_strategy ) X_next = torch.cat((X_next, x_next), dim=0) acq_next = np.hstack((acq_next, acq)) @@ -476,7 +423,7 @@ def _ucb(X, beta=2.): X_next, acq_next = thompson() else: raise ValueError('Unknown acquisition function choice %s' % self.acq) - # print(f'{self.acq} Next X, ', X_next) + del X_torch, y_torch X_next = np.array(X_next) if return_acq: diff --git a/AntBO/bo/localbo_utils.py b/AntBO/bo/localbo_utils.py index b59e1771..94f6167a 100644 --- a/AntBO/bo/localbo_utils.py +++ b/AntBO/bo/localbo_utils.py @@ -1,11 +1,18 @@ import logging -from itertools import groupby -from collections import Callable import random +import re +from collections import Callable from copy import deepcopy +from itertools import groupby + +import scipy +from pymoo.operators.crossover.sbx import SBX +from pymoo.operators.mutation.pm import PolynomialMutation +from pymoo.operators.repair.rounding import RoundingRepair from bo.kernels import * -import re + +# from Bio.SeqUtils.ProtParam import ProteinAnalysis COUNT_AA = 5 AA = 'ACDEFGHIKLMNPQRSTVWY' @@ -14,10 +21,7 @@ N_glycosylation_pattern = 'N[^P][ST][^P]' -# from Bio.SeqUtils.ProtParam import ProteinAnalysis - - -def check_cdr_constraints_all(x, x_center_local=None, hamming=None, config=None): +def check_cdr_constraints_all(x): # Constraints on CDR3 sequence x_to_seq = ''.join(idx_to_AA[int(aa)] for aa in x) # prot = ProteinAnalysis(x_to_seq) @@ -40,16 +44,14 @@ def check_cdr_constraints_all(x, x_center_local=None, hamming=None, config=None) else: c3 = True - if x_center_local is not None: - # 1 if met (True) - c4 = compute_hamming_dist_ordinal(x_center_local, x, config) <= hamming - # Return 0 if True - return int(not (c1)), int(not (c2)), int(not (c3)), int(not (c4)) - + # stability = prot.instability_index() + # if stability>40: + # return False + # If constraint is satisfied return 0 for pymoo return int(not (c1)), int(not (c2)), int(not (c3)) -def check_cdr_constraints(x) -> bool: +def check_cdr_constraints(x): constr = check_cdr_constraints_all(x) return not np.any(constr) @@ -108,6 +110,38 @@ def latin_hypercube(n_pts, dim): return X +def space_fill_table_sample(n_pts: int, table_of_candidates: np.ndarray) -> np.ndarray: + """ + Sample points from a table of candidates + + Args: + n_pts: number of points to sample + table_of_candidates: 2d array from which to sample points + + Returns: + samples: 2d array with shape (n_pts, n_dim) taken from table_of_candidates + """ + selected_inds = set() + candidates = np.zeros((n_pts, table_of_candidates.shape[-1])) + ind = np.random.randint(0, len(table_of_candidates)) + selected_inds.add(ind) + candidates[0] = table_of_candidates[ind] # sample first point at random + i = 1 + for i in range(1, min(n_pts, 100)): # sample the first 100 points with a space-filling strategy + # compute distance among table_of_candidates and already_selected candidates + distances = scipy.spatial.distance.cdist(table_of_candidates, candidates[:i], metric="hamming") + distances[distances == 0] = -np.inf # penalize already selected points + mean_dist = distances.mean(-1) + max_mean_dist = mean_dist.max() + # sample among best + ind = np.random.choice(np.arange(len(table_of_candidates))[mean_dist == max_mean_dist]) + selected_inds.add(ind) + candidates[i] = table_of_candidates[ind] + remaining_inds = [ind for ind in range(len(table_of_candidates)) if ind not in selected_inds] + candidates[i:] = table_of_candidates[np.random.choice(remaining_inds, max(0, n_pts - i), replace=False)] + return deepcopy(candidates) + + def compute_hamming_dist(x1, x2, categorical_dims, normalize=False): """ Compute the hamming distance of two one-hot encoded strings. @@ -167,55 +201,7 @@ def sample_neighbour_ordinal(x, n_categories): return x_pert -def sample_neighbour_ordinal_constrained(x, n_categories): - """Same as above, but the variables are represented ordinally.""" - - x_pert = deepcopy(x) - n_categories = deepcopy(n_categories) - # Chooose a variable to modify - choice = random.randint(0, len(n_categories) - 1) - # Obtain the current value. - curr_val = x[choice] - options = [i for i in range(n_categories[choice]) if i != curr_val] - random.shuffle(options) - i = 0 - x_pert[choice] = options[i] - - while np.logical_not(check_cdr_constraints(x_pert)) and i < (len(n_categories) - 1): - i += 1 - x_pert[choice] = options[i] - return x_pert - - -def neighbourhood_init(x_center_local, config, pop_size): - pop = np.array([sample_neighbour_ordinal_constrained(x_center_local, config) for _ in range((pop_size))]) - pop[0] = x_center_local - return pop - - -def random_sample_within_discrete_tr(x_center, max_hamming_dist, categorical_dims, - mode='ordinal'): - """Randomly sample a point within the discrete trust region""" - if max_hamming_dist < 1: # Normalised hamming distance is used - bit_change = int(max_hamming_dist * len(categorical_dims)) - else: # Hamming distance is not normalized - max_hamming_dist = min(max_hamming_dist, len(categorical_dims)) - bit_change = int(max_hamming_dist) - - x_pert = deepcopy(x_center) - # Randomly sample n bits to change. - modified_bits = random.sample(range(len(categorical_dims)), bit_change) - for bit in modified_bits: - n_values = len(categorical_dims[bit]) - # Change this value - selected_value = random.choice(range(n_values)) - # Change to one-hot encoding - substitute_values = np.array([1 if i == selected_value else 0 for i in range(n_values)]) - x_pert[categorical_dims[bit]] = substitute_values - return x_pert - - -def random_sample_within_discrete_tr_ordinal(x_center, max_hamming_dist, n_categories): +def random_sample_within_discrete_tr_ordinal(x_center: np.ndarray, max_hamming_dist, n_categories) -> np.ndarray: """Same as above, but here we assume a ordinal representation of the categorical variables.""" # random.seed(random.randint(0, 1e6)) if max_hamming_dist < 1: @@ -231,7 +217,6 @@ def random_sample_within_discrete_tr_ordinal(x_center, max_hamming_dist, n_categ from pymoo.algorithms.moo.nsga2 import NSGA2 -from pymoo.factory import get_mutation, get_crossover, get_termination from pymoo.optimize import minimize from pymoo.core.problem import Problem import numpy as np @@ -244,19 +229,12 @@ class CDRH3Prob(Problem): A solution is considered as feasible of all constraint violations are less than zero.""" def __init__(self, f_acq: Callable, n_var=11, n_obj=1, n_constr=3, xl=0, xu=19, cdr_constraints=True, - device=torch.device('cpu'), dtype=torch.float32, f2_acq=None, f3_acq=None): - if f2_acq is not None: - n_obj += 1 - if f3_acq is not None: - n_obj += 1 + device=torch.device('cpu'), dtype=torch.float32): super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr if cdr_constraints else 0, xl=xl, xu=xu) - - self.f2_acq = f2_acq - self.f3_acq = f3_acq self.f_acq = f_acq self.cdr_constraints = cdr_constraints self.device = device @@ -264,8 +242,9 @@ def __init__(self, f_acq: Callable, n_var=11, n_obj=1, n_constr=3, xl=0, xu=19, def _evaluate(self, x, out, *args, **kwargs): with torch.no_grad(): + X = torch.from_numpy(x).to(device=self.device, dtype=self.dtype) # Switch from max to min problem - acq_x = -1.0 * self.f_acq(x).detach().cpu().numpy() + acq_x = -1.0 * self.f_acq(X).detach().cpu().numpy() out["F"] = acq_x if self.cdr_constraints: @@ -273,38 +252,7 @@ def _evaluate(self, x, out, *args, **kwargs): out["G"] = np.column_stack([c1, c2, c3]) -class CDRH3ProbHamming(CDRH3Prob): - """CDRH3 Problem For pymoo. - Maximise f_acq but taking the negative in _evaluate to perform minimisation overall. - A solution is considered as feasible of all constraint violations are less than zero.""" - - def __init__(self, max_hamming_distance=0, x_center_local=None, config=None, **kwargs): - super().__init__(**kwargs) - self.hamming = max_hamming_distance - self.x_center_local = x_center_local - self.config = config - - def _evaluate(self, x, out, *args, **kwargs): - # Always 1 Objective - with torch.no_grad(): - # Switch from max to min problem - acq_x = -1.0 * self.f_acq(x).detach().cpu().numpy() - if self.f2_acq is not None: - acq2_x = -1.0 * self.f2_acq(x).detach().cpu().numpy() - acq_x = np.column_stack([acq_x, acq2_x]) - if self.f3_acq is not None: - acq3_x = -1.0 * self.f3_acq(x).detach().cpu().numpy() - acq_x = np.column_stack([acq_x, acq3_x]) - out["F"] = acq_x - - if self.cdr_constraints: - c1, c2, c3, c4 = zip(*list( - map(lambda seq: check_cdr_constraints_all(seq, x_center_local=self.x_center_local, hamming=self.hamming, - config=self.config), x))) - out["G"] = np.column_stack([c1, c2, c3, c4]) - - -def get_pop(seq_len, pop_size, x_center_local, seed=0): +def get_biased_pop(seq_len, pop_size, x_center_local, seed=0): eng = SobolEngine(seq_len, scramble=True, seed=seed) sobol_samp = eng.draw(pop_size) # sobol_samp = sobol_samp * (space.opt_ub - space.opt_lb) + space.opt_lb @@ -319,107 +267,84 @@ def test_f(x): return torch.from_numpy(np.random.randint(0, 20, x.shape[0])) -def glocal_search(x_center, - f: Callable, - config, - max_hamming_dist, - cdr_constraints: bool = False, - n_restart: int = 1, - n_obj=1, - batch_size: int = 1, - seed=0, - seq_len=11, - dtype=torch.float32, - device=torch.device('cpu'), - pop_size: int = 200, - eliminate_duplicates=True, - biased=True, - f2: Callable = None, - f3: Callable = None, - **kwargs): +def local_table_search(x_center: np.ndarray, + f: Callable, + config: np.ndarray, + table_of_candidates: np.ndarray, + batch_size: int, + max_hamming_dist, + dtype=torch.float32, + device=torch.device('cpu'), + max_batch_size=5000, **kwargs): """ - Global & Glocal search algorithm - :param n_restart: number of restarts - :param config: - :param x0: the initial point to start the search - :param x_center: the center of the trust region. In this case, this should be the optimum encountered so far. - :param f: the function handle to evaluate x on (the acquisition function, in this case) - :param max_hamming_dist: maximum Hamming distance from x_center - :param step: number of maximum local search steps the algorithm is allowed to take. - :return: + Search strategy: + 1. Compute filtr of valid points around center + 2. Iteratively jump to a new point that is still in the valid area + + Args: + x_center: 1d array corresponding to center of search space in transformed space + config: the config for the categorical variables (number of categoories per dim) + + Returns: + x_candidates: 2d ndarray acquisition points + f_x_candidates: 1d array function value of these points """ - if kwargs['kernel_type'] == 'ssk': - pop_size = 20 + assert x_center.ndim == 1, x_center.shape + assert table_of_candidates.ndim == 2, table_of_candidates.shape + assert batch_size == 1, "Methods is designed to output only one candidate for now" + n_candidates = 0 + hamming_dists = scipy.spatial.distance.cdist(table_of_candidates, x_center.reshape(1, -1), metric="hamming") + hamming_dists *= x_center.shape[-1] # denormalize + hamming_dists = hamming_dists.flatten() - x_center_local = deepcopy(x_center) - if biased: - # True, Do neighbourhood sampling - init_pop = neighbourhood_init(x_center_local, config, pop_size) - else: - # False, Do Global sampling - init_pop = get_pop(seq_len, pop_size, x_center_local, seed=seed) + # entry `i` contains the number of points in the table of candidates that are at distance `i` of the center + n_cand_per_dist = np.array([(hamming_dists == i).sum() for i in range(table_of_candidates.shape[-1])]) + max_hamming_dist = max(max_hamming_dist, np.argmax(np.cumsum(n_cand_per_dist) > 0)) - if f2 is not None or f3 is not None or n_obj > 1: - eliminate_duplicates = False + table_of_candidates = table_of_candidates[hamming_dists <= max_hamming_dist] - algorithm = NSGA2(pop_size=pop_size, - n_offsprings=pop_size, - sampling=init_pop, - # crossover=get_crossover("int_sbx", eta=15, prob=0.0, prob_per_variable=0.0), - crossover=get_crossover("int_sbx", eta=15, prob=0.9, prob_per_variable=1.0 / seq_len), - mutation=get_mutation("int_pm", eta=20), - eliminate_duplicates=eliminate_duplicates) - problem = CDRH3Prob(n_obj=n_obj, n_var=seq_len, f_acq=f, f2_acq=f2, f3_acq=f3, cdr_constraints=cdr_constraints, - device=device, dtype=dtype) - termination = get_termination("n_gen", seq_len * kwargs['alphabet_size']) - - res = minimize(problem, - algorithm, - termination, - seed=seed, - verbose=False) + fmax = -np.inf - # Make sure to filter any that are not satisfied - if res.G.ndim == 1: - G = res.G[None, :] - X = res.X[None, :] - F = res.F[None, :] - else: - G = res.G - X = res.X - F = res.F - # Remove constraint violated - X = X[~G.any(1)] - # Turn back to maximise problem - fX = -1.0 * F - # Remove constraint violated - fX = fX[~G.any(1)] - - if not eliminate_duplicates: - # Remove duplocates - X, idX = np.unique(X, axis=0, return_index=True) - # Remove duplocates - fX = fX[idX] - - if X.ndim == 1: - X = X[None, :] - fX = fX[None, :] - - if f2 is not None or f3 is not None or n_obj > 1: - # Sample from pareto front - idx = np.random.randint(0, X.shape[0], batch_size) - X_next = X[idx] - acq_next = np.array(fX).flatten()[idx] - else: - # Selects top batchsize from list - top_idices = np.argpartition(np.array(fX).flatten(), -batch_size)[-batch_size:] - X_next = np.array([x for i, x in enumerate(X) if i in top_idices]) - acq_next = np.array(fX).flatten()[top_idices] + current_center = deepcopy(x_center) + + for _ in range(10): + if len(table_of_candidates) == 0: + break - return X_next, acq_next + # gather points from closest to farthest + hamming_dists = scipy.spatial.distance.cdist(table_of_candidates, current_center.reshape(1, -1), + metric="hamming") + hamming_dists *= current_center.shape[-1] + hamming_dists = hamming_dists.flatten() + + cand_filtr = np.zeros(len(table_of_candidates)).astype(bool) + dist_to_current_center = 1 + while cand_filtr.sum() < max_batch_size and dist_to_current_center <= table_of_candidates.shape[-1]: + new_filtr = hamming_dists == dist_to_current_center + n_ones = max_batch_size - cand_filtr.sum() + n_zeros = new_filtr.sum() - n_ones + if cand_filtr.sum() + new_filtr.sum() > max_batch_size: + new_filtr[new_filtr] = np.random.permutation( + np.concatenate([np.zeros(n_zeros), np.ones(n_ones)]).astype(bool)) + cand_filtr = cand_filtr + new_filtr + dist_to_current_center += 1 + + if cand_filtr.sum() == 0: # no more candidates to evaluate + break + + # evaluate acquisition function + acq_x = f(table_of_candidates[cand_filtr]).detach().cpu().numpy().flatten() + if acq_x.max() > fmax: # new best + fmax = acq_x.max() + current_center = table_of_candidates[cand_filtr][np.argmax(acq_x)] + + # remove already evaluated points from table of candidates + table_of_candidates = table_of_candidates[~cand_filtr] + + return current_center.reshape(1, -1), np.array([fmax]) -def blocal_search(x_center, +def glocal_search(x_center, f: Callable, config, max_hamming_dist, @@ -428,16 +353,12 @@ def blocal_search(x_center, batch_size: int = 1, seed=0, seq_len=11, - n_obj=1, dtype=torch.float32, device=torch.device('cpu'), pop_size: int = 200, - eliminate_duplicates=True, - f2: Callable = None, - f3: Callable = None, - **kwargs): + table_of_candidates=None): """ - Batch Local search algorithm + Local search algorithm :param n_restart: number of restarts :param config: :param x0: the initial point to start the search @@ -445,31 +366,26 @@ def blocal_search(x_center, :param f: the function handle to evaluate x on (the acquisition function, in this case) :param max_hamming_dist: maximum Hamming distance from x_center :param step: number of maximum local search steps the algorithm is allowed to take. + :param table_of_candidates: search within a table of candidates (not supported for now) :return: """ - if kwargs['kernel_type'] == 'ssk': - pop_size = 20 - - if f2 is not None or f3 is not None or n_obj > 1: - eliminate_duplicates = False - + assert table_of_candidates is None x_center_local = deepcopy(x_center) - init_pop = neighbourhood_init(x_center_local, config, pop_size) + init_pop = get_biased_pop(seq_len, pop_size, x_center_local, seed=seed) + crossover = SBX(prob=0.9, prob_var=1.0 / seq_len, eta=15, repair=RoundingRepair(), vtype=float) + mutation = PolynomialMutation(prob=1.0, eta=20, repair=RoundingRepair()) algorithm = NSGA2(pop_size=pop_size, - n_offsprings=pop_size, + n_offsprings=200, sampling=init_pop, - crossover=get_crossover("int_sbx", eta=15, prob=0.9, prob_per_variable=1.0 / seq_len), - # crossover=get_crossover("int_sbx", eta = 15, prob = 0.0, prob_per_variable=0.0), - mutation=get_mutation("int_pm", eta=20), - eliminate_duplicates=eliminate_duplicates) - problem = CDRH3ProbHamming(n_obj=n_obj, x_center_local=x_center_local, n_constr=4, config=config, - max_hamming_distance=max_hamming_dist, n_var=seq_len, f_acq=f, f2_acq=f2, f3_acq=f3, - cdr_constraints=cdr_constraints, device=device, dtype=dtype) - termination = get_termination("n_gen", seq_len * kwargs['alphabet_size']) - - res = minimize(problem, - algorithm, - termination, + crossover=crossover, + mutation=mutation, + eliminate_duplicates=False) + problem = CDRH3Prob(n_var=seq_len, f_acq=f, cdr_constraints=cdr_constraints, device=device, dtype=dtype) + termination = ("n_gen", 11 * 20) + + res = minimize(problem=problem, + algorithm=algorithm, + termination=termination, seed=seed, verbose=False) @@ -485,31 +401,15 @@ def blocal_search(x_center, # Remove constraint violated X = X[~G.any(1)] + # Remove duplocates + X = np.unique(X, axis=0) + # Turn back to maximise problem fX = -1.0 * F - # Remove constraint violated - fX = fX[~G.any(1)] - - if X.ndim == 1: - X = X[None, :] - fX = fX[None, :] - - if not eliminate_duplicates: - # Remove duplocates - X, idX = np.unique(X, axis=0, return_index=True) - # Remove duplocates - fX = fX[idX] - - if f2 is not None or f3 is not None or n_obj > 1: - idx = np.random.randint(0, X.shape[0], batch_size) - X_next = X[idx] - acq_next = np.array(fX).flatten()[idx] - else: - # Selects top batchsize from list - top_idices = np.argpartition(np.array(fX).flatten(), -batch_size)[-batch_size:] - X_next = np.array([x for i, x in enumerate(X) if i in top_idices]) - acq_next = np.array(fX).flatten()[top_idices] - return X_next, acq_next + + # Selects top batchsize from list + top_idices = np.argpartition(np.array(fX).flatten(), -batch_size)[-batch_size:] + return np.array([x for i, x in enumerate(X) if i in top_idices]), np.array(fX).flatten()[top_idices] def local_search(x_center, @@ -523,7 +423,7 @@ def local_search(x_center, dtype=torch.float32, device=torch.device('cpu'), step: int = 200, - **kwargs): + table_of_candidates=None): """ Local search algorithm :param n_restart: number of restarts @@ -533,8 +433,10 @@ def local_search(x_center, :param f: the function handle to evaluate x on (the acquisition function, in this case) :param max_hamming_dist: maximum Hamming distance from x_center :param step: number of maximum local search steps the algorithm is allowed to take. + :param table_of_candidates: search within a table of candidates (not supported for now) :return: """ + assert table_of_candidates is None def _ls(hamming): """One restart of local search""" diff --git a/AntBO/bo/main.py b/AntBO/bo/main.py index c3417b48..c8413b33 100644 --- a/AntBO/bo/main.py +++ b/AntBO/bo/main.py @@ -4,9 +4,11 @@ from pathlib import Path from typing import Optional, Set, Any, Dict + ROOT_PROJECT = str(Path(os.path.realpath(__file__)).parent.parent) sys.path.insert(0, ROOT_PROJECT) +from task import BaseTool from utilities.misc_utils import log from bo.custom_init import get_initial_dataset_path, InitialBODataset, get_top_cut_ratio_per_cat, get_n_per_cat from bo.botask import BOTask as CDRBO @@ -53,6 +55,11 @@ def __init__(self, config: Dict[str, Any], cdr_constraints: bool, seed: int): """ self.config = config + if self.config["tabular_search_csv"] is not None: + print( + f"Tabular BO setting: will select antibodies among available ones from: {config['tabular_search_csv']}" + ) + self.table_of_aas_inds = self.get_table_of_aas_inds(tabular_search_csv=self.config["tabular_search_csv"]) self.seed = seed self.cdr_constraints = cdr_constraints # Sanity checks @@ -83,7 +90,7 @@ def __init__(self, config: Dict[str, Any], cdr_constraints: bool, seed: int): if not os.path.exists(self.path): os.makedirs(self.path) - print(self.path) + print(f"Results of this run will be saved in {self.path}") self.res = pd.DataFrame(np.nan, index=np.arange(int(self.config['max_iters'] * self.config['batch_size'])), columns=['Index', 'LastValue', 'BestValue', 'Time', 'LastProtein', 'BestProtein']) @@ -91,15 +98,20 @@ def __init__(self, config: Dict[str, Any], cdr_constraints: bool, seed: int): self.nm_AAs = 20 self.n_categories = np.array([self.nm_AAs] * self.config['seq_len']) self.start_itern = 0 - self.f_obj = CDRBO(self.config['device'], self.n_categories, self.config['seq_len'], self.config['bbox'], False) + self.f_obj = CDRBO( + device=self.config['device'], n_categories=self.n_categories, + seq_len=self.config['seq_len'], bbox=self.config['bbox'], normalise=False + ) @staticmethod def get_path(save_path: str, antigen: str, kernel_type: str, seed: int, cdr_constraints: int, seq_len: int, search_strategy: str, - custom_init_dataset_path: Optional[str] = None): + custom_init_dataset_path: Optional[str] = None, tabular_search_csv: Optional[str] = None): path: str = f"{save_path}/BO_{kernel_type}/antigen_{antigen}" \ f"_kernel_{kernel_type}_search-strat_{search_strategy}_seed_{seed}" \ f"_cdr_constraint_{bool(cdr_constraints)}_seqlen_{seq_len}" + if tabular_search_csv is not None: + path += f"_tabsearch-{os.path.basename(tabular_search_csv)[:-4]}" if custom_init_dataset_path: custom_init_id = os.path.basename(os.path.dirname(custom_init_dataset_path)) custom_init_id_seed = os.path.basename(os.path.dirname(os.path.dirname(custom_init_dataset_path))) @@ -116,7 +128,8 @@ def path(self) -> str: seed=self.seed, cdr_constraints=self.cdr_constraints, seq_len=self.config['seq_len'], - custom_init_dataset_path=self.custom_initial_dataset_path + custom_init_dataset_path=self.custom_initial_dataset_path, + tabular_search_csv=self.config["tabular_search_csv"] ) @property @@ -166,7 +179,7 @@ def load(self): if os.path.exists(res_path): self.res = pd.read_csv(res_path, usecols=['Index', 'LastValue', 'BestValue', 'Time', 'LastProtein', 'BestProtein']) - self.start_itern = len(self.res) - self.res['Index'].isna().sum() // self.config['batch_size'] + self.start_itern = (len(self.res) - self.res['Index'].isna().sum()) // self.config['batch_size'] print(f"-- Resume -- Already observed {optim.casmopolitan.n_evals}") return optim @@ -229,6 +242,7 @@ def run(self): kernel_type=self.config['kernel_type'], noise_variance=float(self.config['noise_variance']), alphabet_size=self.nm_AAs, + table_of_candidates=self.table_of_aas_inds, **kwargs ) @@ -243,7 +257,7 @@ def run(self): for itern in range(self.start_itern, self.config['max_iters']): start = time.time() - x_next = optim.suggest(self.config['batch_size']) + x_next = optim.suggest(n_suggestions=self.config['batch_size']) if self.custom_initial_dataset and len(optim.casmopolitan.fX) < self.config['n_init']: # observe the custom initial points instead of the suggested ones n_random = min(x_next.shape[0], self.config['n_init'] - len(optim.casmopolitan.fX)) @@ -264,6 +278,13 @@ def log(self, message: str, end: Optional[str] = None): header=f"BOExp - {self.config['bbox']['antigen']} - {self.config['kernel_type']} - seed {self.seed}", end=end) + def get_table_of_aas_inds(self, tabular_search_csv: str) -> np.ndarray: + """ Return array of antigens where each row corresponds to an antigen given by the index of its AA """ + data = pd.read_csv(tabular_search_csv, index_col=None).values + assert data.shape[-1] == 1 + arr = np.array([list(c for c in x) for x in data.flatten()]) + return BaseTool().convert_array_aas_to_idx(arr) + from bo.utils import get_config diff --git a/AntBO/bo/optimizer.py b/AntBO/bo/optimizer.py index 3dcb0277..e5f91609 100644 --- a/AntBO/bo/optimizer.py +++ b/AntBO/bo/optimizer.py @@ -1,24 +1,29 @@ +import os from copy import deepcopy -from itertools import groupby +from typing import Optional + import numpy as np +import scipy.spatial.distance import scipy.stats as ss -from bo.localbo_cat import CASMOPOLITANCat -from bo.localbo_utils import from_unit_cube, latin_hypercube, to_unit_cube, ordinal2onehot, onehot2ordinal,\ - random_sample_within_discrete_tr_ordinal -from bo.localbo_utils import check_cdr_constraints import torch -import logging from gpytorch.utils.errors import NotPSDError, NanError + +from bo.localbo_cat import CASMOPOLITANCat +from bo.localbo_utils import from_unit_cube, latin_hypercube, onehot2ordinal, \ + random_sample_within_discrete_tr_ordinal, check_cdr_constraints, space_fill_table_sample +from bo.utils import update_table_of_candidates from utilities.constraint_utils import check_constraint_satisfaction_batch COUNT_AA = 5 + def order_stats(X): _, idx, cnt = np.unique(X, return_inverse=True, return_counts=True) obs = np.cumsum(cnt) # Need to do it this way due to ties o_stats = obs[idx] return o_stats + def copula_standardize(X): X = np.nan_to_num(np.asarray(X)) # Replace inf by something large assert X.ndim == 1 and np.all(np.isfinite(X)) @@ -31,30 +36,33 @@ def copula_standardize(X): class Optimizer: def __init__(self, - config: np.ndarray, + config, min_cuda, - batch_size: int =1, normalise: bool = False, cdr_constraints: bool = False, n_init: int = None, wrap_discrete: bool = True, guided_restart: bool = True, + table_of_candidates: Optional[np.ndarray] = None, **kwargs): """Build wrapper class to use an optimizer in benchmark. - Parameters - ---------- - config: list. e.g. [2, 3, 4, 5] -- denotes there are 4 categorical variables, with numbers of categories - being 2, 3, 4, and 5 respectively. - guided_restart: whether to fit an auxiliary GP over the best points encountered in all previous restarts, and - sample the points with maximum variance for the next restart. - global_bo: whether to use the global version of the discrete GP without local modelling + Args: + config: list. e.g. [2, 3, 4, 5] -- denotes there are 4 categorical variables, with numbers of categories + being 2, 3, 4, and 5 respectively. + guided_restart: whether to fit an auxiliary GP over the best points encountered in all previous restarts, and + sample the points with maximum variance for the next restart. + global_bo: whether to use the global version of the discrete GP without local modelling + table_of_candidates: if not None, the suggestions should be taken from this list of candidates given as a + 2d array of aas indices. + """ # Maps the input order. self.config = config.astype(int) self.true_dim = len(config) self.kwargs = kwargs + self.table_of_candidates = table_of_candidates if self.kwargs['kernel_type'] == 'ssk': assert 'alphabet_size' != None and self.kwargs['alphabet_size'] != None # Number of one hot dimensions @@ -75,11 +83,12 @@ def __init__(self, dim=self.true_dim, n_init=n_init if n_init is not None else 2 * self.true_dim + 1, max_evals=self.max_evals, - cdr_constraints = self.cdr_constraints, - normalise = normalise, + cdr_constraints=self.cdr_constraints, + normalise=normalise, batch_size=1, # We need to update this later verbose=False, config=self.config, + table_of_candidates=self.table_of_candidates, **kwargs ) @@ -92,7 +101,7 @@ def __init__(self, def restart(self): from bo.gp import train_gp - if self.guided_restart and len(self.casmopolitan._fX) and self.kwargs['search_strategy'] in ['local', 'batch_local']: + if self.guided_restart and len(self.casmopolitan._fX) and self.kwargs['search_strategy'] == 'local': best_idx = self.casmopolitan._fX.argmin() # Obtain the best X and fX within each restart (bo._fX and bo._X get erased at each restart, # but bo.X and bo.fX always store the full history @@ -100,8 +109,10 @@ def restart(self): self.best_fX_each_restart = deepcopy(self.casmopolitan._fX[best_idx]) self.best_X_each_restart = deepcopy(self.casmopolitan._X[best_idx]) else: - self.best_fX_each_restart = np.vstack((self.best_fX_each_restart, deepcopy(self.casmopolitan._fX[best_idx]))) - self.best_X_each_restart = np.vstack((self.best_X_each_restart, deepcopy(self.casmopolitan._X[best_idx]))) + self.best_fX_each_restart = np.vstack( + (self.best_fX_each_restart, deepcopy(self.casmopolitan._fX[best_idx]))) + self.best_X_each_restart = np.vstack( + (self.best_X_each_restart, deepcopy(self.casmopolitan._X[best_idx]))) X_tr_torch = torch.tensor(self.best_X_each_restart, dtype=torch.float32).reshape(-1, self.true_dim) fX_tr_torch = torch.tensor(self.best_fX_each_restart, dtype=torch.float32).view(-1) @@ -109,19 +120,26 @@ def restart(self): # Train the auxiliary self.auxiliary_gp = train_gp(X_tr_torch, fX_tr_torch, False, 300, ) # Generate random points in a Thompson-style sampling - X_init, itern = [], 0 - while(itern 0) + table_of_candidates = table_of_candidates[filtr] + # sample + n_sample = self.casmopolitan.n_init - 1 + self.X_init[1:] = space_fill_table_sample(n_pts=n_sample, table_of_candidates=table_of_candidates) + else: + for i in range(1, self.casmopolitan.n_init): + # Randomly sample within the initial trust region length around the centre + candidate = random_sample_within_discrete_tr_ordinal( + x_center=centre, + max_hamming_dist=self.casmopolitan.length_init_discrete, + n_categories=self.config + ) + self.X_init[i] = deepcopy(candidate) + self.X_init = torch.tensor(self.X_init).to(torch.float32) self.casmopolitan._restart() self.casmopolitan._X = np.zeros((0, self.casmopolitan.dim)) self.casmopolitan._fX = np.zeros((0, 1)) @@ -151,75 +190,99 @@ def restart(self): self.casmopolitan._restart() self.casmopolitan._X = np.zeros((0, self.casmopolitan.dim)) self.casmopolitan._fX = np.zeros((0, 1)) - # Sample Initial Points with frequency criterion - self.X_init, itern = [], 0 - while(itern 0: X_next[:n_init] = deepcopy(self.X_init[:n_init, :]) - if X_next.ndim==1: + if X_next.ndim == 1: X_next = X_next[None, :] self.X_init = self.X_init[n_init:, :] # Remove these pending points + table_of_candidates = update_table_of_candidates( + original_table=table_of_candidates, + observed_candidates=X_next[:n_init], + check_candidates_in_table=True + ) # Get remaining points from TuRBO n_adapt = n_suggestions - n_init + if os.getenv("ANTBO_DEBUG", False): + n_training_steps = 10 + else: + n_training_steps = 500 if n_adapt > 0: if len(self.casmopolitan._X) > 0: # Use random points if we can't fit a GP X = deepcopy(self.casmopolitan._X) - #fX = deepcopy(self.casmopolitan._fX).ravel() - if self.kwargs['search_strategy'] in ['local']: + # fX = deepcopy(self.casmopolitan._fX).ravel() + if self.kwargs['search_strategy'] == 'local': fX = copula_standardize(deepcopy(self.casmopolitan._fX).ravel()) # Use Copula else: - fX = deepcopy(self.casmopolitan._fX).ravel() # No need to use Copula as no GP predictions in here. - - # try: - if True: - X_next[-n_adapt:, :] = self.casmopolitan._create_and_select_candidates(X, fX, - length=self.casmopolitan.length_discrete, - n_training_steps=500, - hypers={})[-n_adapt:, :] - # except (ValueError, NanError, NotPSDError): - # except: - else: + fX = deepcopy(self.casmopolitan._fX).ravel() # No need to use Copula as no GP predictions in here. + + try: + X_next[-n_adapt:, :] = self.casmopolitan._create_and_select_candidates( + X=X, fX=fX, + length=self.casmopolitan.length_discrete, + n_training_steps=n_training_steps, + hypers={}, + table_of_candidates=table_of_candidates + )[-n_adapt:, :] + except (ValueError, NanError, NotPSDError): print(f"Acquisition Failure with Kernel {self.casmopolitan.kernel_type}") - # if self.casmopolitan.kernel_type == 'ssk': - # print(f"Trying with kernel {self.casmopolitan.kernel_type}") - # self.casmopolitan.kernel_type = 'transformed_overlap' - # X_next[-n_adapt:, :] = self.casmopolitan._create_and_select_candidates(X, fX, - # length=self.casmopolitan.length_discrete, - # n_training_steps=500, - # hypers={})[-n_adapt:, :] - # self.casmopolitan.kernel_type = 'ssk' - if self.casmopolitan.kernel_type in ['rbfBERT', 'cosineBERT']: + if self.casmopolitan.kernel_type == 'ssk': + print(f"Trying with kernel {self.casmopolitan.kernel_type}") + self.casmopolitan.kernel_type = 'transformed_overlap' + X_next[-n_adapt:, :] = self.casmopolitan._create_and_select_candidates( + X=X, fX=fX, + length=self.casmopolitan.length_discrete, + n_training_steps=n_training_steps, + hypers={}, + table_of_candidates=table_of_candidates + )[-n_adapt:, :] + self.casmopolitan.kernel_type = 'ssk' + elif self.casmopolitan.kernel_type in ['rbfBERT', 'rbf-pca-BERT', 'cosine-BERT', 'cosine-pca-BERT']: + assert self.table_of_candidates is None, "not supported when given a table of candidates" print("Random Acquisition") X_random_next, j = [], 0 while (j < n_adapt): X_next_j = latin_hypercube(1, self.dim) X_next_j = from_unit_cube(X_next_j, self.lb, self.ub) if self.wrap_discrete: - X_next_j = self.warp_discrete(X_next_j, ) + X_next_j = self.warp_discrete(X_next_j) X_next_j = onehot2ordinal(X_next_j, self.cat_dims) if self.cdr_constraints: if not check_cdr_constraints(X_next_j[0]): @@ -229,6 +292,7 @@ def suggest(self, n_suggestions=1): X_random_next = np.stack(X_random_next, 0) X_next[-n_adapt:, :] = X_random_next else: + assert self.table_of_candidates is None, "not supported when given a table of candidates" print('Resorting to Random Search') # Create the initial population. Last column stores the fitness X_next = np.random.randint(low=0, high=20, size=(n_suggestions, 11)) @@ -255,14 +319,10 @@ def observe(self, X, y): Parameters ---------- - X : list of dict-like - Places where the objective function has already been evaluated. - Each suggestion is a dictionary where each key corresponds to a - parameter being optimized. - y : array-like, shape (n,) - Corresponding values where objective has been evaluated + X : array-like, shape (n, d) + y : tensor of shape (n,) """ - assert len(X) == len(y) + assert len(X) == len(y), (len(X), len(y)) # XX = torch.cat([ordinal2onehot(x, self.n_categories) for x in X]).reshape(len(X), -1) XX = X yy = np.array(y.detach().cpu())[:, None] @@ -273,18 +333,29 @@ def observe(self, X, y): if len(self.casmopolitan._fX) >= self.casmopolitan.n_init > 0: self.casmopolitan._adjust_length(yy) - self.casmopolitan.n_evals += len(y) + self.casmopolitan.n_evals += len(X) # self.batch_size self.casmopolitan._X = np.vstack((self.casmopolitan._X, deepcopy(XX))) self.casmopolitan._fX = np.vstack((self.casmopolitan._fX, deepcopy(yy.reshape(-1, 1)))) self.casmopolitan.X = np.vstack((self.casmopolitan.X, deepcopy(XX))) self.casmopolitan.fX = np.vstack((self.casmopolitan.fX, deepcopy(yy.reshape(-1, 1)))) + if self.table_of_candidates is not None: + print(f"Get {len(self.table_of_candidates)} candidate points" + f" in search table before observing new points") + self.table_of_candidates = update_table_of_candidates( + original_table=self.table_of_candidates, + observed_candidates=XX, + check_candidates_in_table=True + ) + print(f"Get {len(self.table_of_candidates)} candidate points" + f" in search table after observing new points") if self.kwargs['search_strategy'] in ['local', 'batch_local']: # Check for a restart - if self.casmopolitan.length <= self.casmopolitan.length_min or self.casmopolitan.length_discrete <= self.casmopolitan.length_min_discrete: + if (self.casmopolitan.length <= self.casmopolitan.length_min or + self.casmopolitan.length_discrete <= self.casmopolitan.length_min_discrete): self.restart() - def warp_discrete(self, X, ): + def warp_discrete(self, X): X_ = np.copy(X) # Process the integer dimensions diff --git a/AntBO/bo/utils.py b/AntBO/bo/utils.py index 2c52612e..a2bc690d 100644 --- a/AntBO/bo/utils.py +++ b/AntBO/bo/utils.py @@ -1,3 +1,10 @@ +import pickle +from typing import Any, Optional + +import numpy as np +import os + + def spearman(pred, target) -> float: """Compute the spearman correlation coefficient between prediction and target""" from scipy import stats @@ -32,10 +39,6 @@ def get_dim_info(n_categories): offset += cat return dim_info -from typing import Any, List, Optional -import os -import pickle - def save_w_pickle(obj: Any, path: str, filename: Optional[str] = None) -> None: """ Save object obj in file exp_path/filename.pkl """ @@ -62,6 +65,7 @@ def load_w_pickle(path: str, filename: Optional[str] = None) -> Any: print(path, filename) raise + import yaml @@ -70,17 +74,21 @@ def get_config(config): return yaml.safe_load(f) -import os from einops import rearrange + def batch_iterator(data1, step=8): size = len(data1) for i in range(0, size, step): - yield data1[i:min(i+step, size)] + yield data1[i:min(i + step, size)] + import torch + + class BERTFeatures: """Compute BERT Features""" + def __init__(self, model, tokeniser): AAs = 'ACDEFGHIKLMNPQRSTVWY' self.AA_to_idx = {aa: i for i, aa in enumerate(AAs)} @@ -98,21 +106,22 @@ def compute_features(self, x1): reprsn1 = self.model(input_ids=input_ids1, attention_mask=attention_mask1)[0] return reprsn1.to(inp_device) -if __name__=='__main__': - bert_config = { 'datapath': '/nfs/aiml/asif/CDRdata', - 'path': '/nfs/aiml/asif/ProtBERT', - 'modelname': 'prot_bert_bfd', - 'use_cuda': True, - 'batch_size': 256 + +if __name__ == '__main__': + bert_config = {'datapath': '/nfs/aiml/asif/CDRdata', + 'path': '/nfs/aiml/asif/ProtBERT', + 'modelname': 'OutputFinetuneBERTprot_bert_bfd', + 'use_cuda': True, + 'batch_size': 256 } - device_ids = [2,3] + device_ids = [2, 3] import os import glob + import numpy as np + os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(id) for id in device_ids) os.environ["TOKENIZERS_PARALLELISM"] = "false" - from transformers import pipeline, \ - AutoTokenizer, \ - Trainer, \ + from transformers import AutoTokenizer, \ AutoModel device = torch.device("cuda" if torch.cuda.is_available() and bert_config['use_cuda'] else "cpu") @@ -120,13 +129,15 @@ def compute_features(self, x1): model = AutoModel.from_pretrained(f"{bert_config['path']}/{bert_config['modelname']}").to(device) bert_features = BERTFeatures(model, tokeniser) - #antigens = ['1ADQ_A', '1FBI_X', '1HOD_C', '1NSN_S', '1OB1_C', '1WEJ_F', '2YPV_A', '3RAJ_A', '3VRL_C'] - antigens = [antigen.strip().split()[1] for antigen in open(f"/nfs/aiml/asif/CDRdata/antigens.txt", 'r') if antigen != '\n'] + antigens = ['1ADQ_A', '1FBI_X', '1H0D_C', '1NSN_S', '1OB1_C', '1WEJ_F', '2YPV_A', '3RAJ_A', '3VRL_C', '2DD8_S', + '1S78_B', '2JEL_P'] + # antigens = [antigen.strip().split()[1] for antigen in open(f"/nfs/aiml/asif/CDRdata/antigens.txt", 'r') if antigen != '\n'] from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA - from joblib import dump, load + from joblib import dump import pandas as pd + for antigen in antigens: print(f"PCA for antigen {antigen}") try: @@ -143,22 +154,48 @@ def compute_features(self, x1): except pd.errors.ParserError as err: print(f"{filenames[i]} causes an error {err}") continue + except: + continue + + if len(filenames) != 0: reprsns = [] for seq_batch in batch_iterator(sequences, bert_config['batch_size']): seq_batch = torch.tensor([[bert_features.AA_to_idx[aa] for aa in seq] for seq in seq_batch]).to(device) seq_reprsn = bert_features.compute_features(seq_batch) seq_reprsn = rearrange(seq_reprsn, 'b l d -> b (l d)') - reprsns.append(seq_reprsn) + reprsns.append(seq_reprsn.cpu().numpy()) if len(reprsns) == 1000: break - - reprsns = torch.cat(reprsns, 0).cpu().numpy() + reprsns = np.concatenate(reprsns, 0) scaler = StandardScaler() - scaler.fit(scaled_reprsns) + scaler.fit(reprsns) scaled_reprsns = scaler.transform(reprsns) pca = PCA(n_components=100) pca.fit(scaled_reprsns) - dump(pca, f"{bert_config['datapath']}/{antigen}_pca.joblib") - dump(scaler, f"{bert_config['datapath']}/{antigen}_scaler.joblib") - except: - continue + results_path = f"{bert_config['datapath']}/finetune_pca" + if not os.path.exists(results_path): + os.makedirs(results_path) + dump(pca, f"{results_path}/{antigen}_pca.joblib") + dump(scaler, f"{results_path}/{antigen}_scaler.joblib") + + +def update_table_of_candidates(original_table: np.ndarray, observed_candidates: np.ndarray, + check_candidates_in_table: bool) -> np.ndarray: + """ Update the table of candidates, removing the newly observed candidates from the table + + Args: + original_table: table of candidates before observation + observed_candidates: new observed points + check_candidates_in_table: whether the observed candidates should be in the original_table or not + + Returns: + Updated original_table + """ + if observed_candidates.ndim == 1: + observed_candidates = observed_candidates.reshape(1, -1) + for candidate in observed_candidates: + filtr = np.all(original_table == candidate.reshape(1, -1), axis=1) + if not np.any(filtr) and check_candidates_in_table: + raise RuntimeError(f"New point {candidate} is not in the table of candidates.") + original_table = original_table[~filtr] + return original_table diff --git a/AntBO/demo.py b/AntBO/demo.py index 212ec2dd..8589222f 100644 --- a/AntBO/demo.py +++ b/AntBO/demo.py @@ -25,6 +25,11 @@ help='Number of antibodies suggested at each step (default: 1)') parser.add_argument('--pre_evals_csv', type=str, help='Path to csv file containing the binding energy of already evaluated antibody sequences.') + parser.add_argument('--tabular_search_csv', type=str, + help='Path to csv file containing the set of eligible antibodies with their pre-computed ' + 'binding energy (to test optimisation in a controlled scenario).') + parser.add_argument('--path_to_eval_csv', type=str, default="./table_of_evals.csv", + help='If the black-box evaluations are provided by filling a table, path to this table.') parser.add_argument('--cuda_id', type=int, default=0, help='ID of the cuda device to use.') parser.add_argument('--seed', type=int, nargs="+", default=[42], help='Seed for reproducibility.') parser.add_argument('--absolut_path', type=str, help='Path to Absolut! (if Absolut is needed.)') @@ -41,6 +46,7 @@ # -------- Create config config = { 'pre_evals': args.pre_evals_csv, + 'tabular_search_csv': args.tabular_search_csv, 'acq': 'ei', 'ard': True, 'n_init': n_init, @@ -63,8 +69,9 @@ # 'antigen': args.antigen # }, 'bbox': { - 'tool': 'manual', - 'antigen': args.antigen + 'tool': 'table_filling', + 'antigen': args.antigen, + 'path_to_eval_csv': args.path_to_eval_csv }, } diff --git a/AntBO/environment.yaml b/AntBO/environment.yaml index 419aa21d..93d4b3da 100644 --- a/AntBO/environment.yaml +++ b/AntBO/environment.yaml @@ -119,7 +119,7 @@ dependencies: - readline=8.1=h46c0cb4_0 - requests=2.26.0=pyhd3eb1b0_0 - scikit-learn=0.24.2=py39ha9443f7_0 - - scipy=1.7.1=py39h292c36d_2 + - scipy=1.12.0 - setuptools=58.0.4=py39hf3d152e_0 - simanneal - sip=4.19.13=py39h2531618_0 @@ -153,7 +153,7 @@ dependencies: - geneticalgorithm==1.0.2 - huggingface-hub==0.2.1 - packaging==21.3 - - pymoo==0.5.0 + - pymoo==0.6.1 - regex==2021.11.10 - sacremoses==0.0.46 - terminalplot==0.3.0 diff --git a/AntBO/task/base.py b/AntBO/task/base.py index bef6801b..ec45a1fd 100644 --- a/AntBO/task/base.py +++ b/AntBO/task/base.py @@ -1,4 +1,6 @@ from abc import ABC +from typing import Any, Dict, Optional + import numpy as np @@ -35,6 +37,21 @@ def __init__(self): self.AA_to_idx = {aa: i for i, aa in enumerate(AA)} self.idx_to_AA = {value: key for key, value in self.AA_to_idx.items()} + @staticmethod + def convert_array(arr: np.ndarray, conversion_dic: Dict[Any, Any], end_type: Optional) -> np.ndarray: + new_arr = np.copy(arr).astype(object) + for k, v in conversion_dic.items(): + new_arr[new_arr == k] = v + if end_type is not None: + new_arr = new_arr.astype(end_type) + return new_arr + + def convert_array_idx_to_aas(self, idx: np.ndarray) -> np.ndarray: + return self.convert_array(arr=idx, conversion_dic=self.idx_to_AA, end_type=None) + + def convert_array_aas_to_idx(self, aas: np.ndarray) -> np.ndarray: + return self.convert_array(arr=aas, conversion_dic=self.AA_to_idx, end_type=int) + def Energy(self, x): ''' x: categorical vector diff --git a/AntBO/task/tools.py b/AntBO/task/tools.py index 8ed3d23b..6578c38d 100644 --- a/AntBO/task/tools.py +++ b/AntBO/task/tools.py @@ -2,6 +2,7 @@ import __main__ import os import subprocess +import time import numpy as np import pandas as pd @@ -10,6 +11,20 @@ from task.base import BaseTool +def custom_input(message: str, default: str) -> str: + """Wrapper around `input` function. + + Replace empty string by `default` + """ + if os.getenv("ANTBO_DEBUG", False): + val = default + else: + val = input(message) + if val == "": + val = default + return val + + ############################ # Black Box Tools ############################ @@ -156,7 +171,7 @@ def visualise(self, antigen, video_length=2): spectrum count, green_yellow_red color blue, ligands set cartoon_fancy_helices, 1 - + mset 1 x{num_frames} util.mroll 1, {num_frames}, 1 set ray_trace_frames, 1 @@ -216,17 +231,71 @@ def Energy(self, x): energies = [] for i in range(len(sequences)): - energy1 = float(input(f"[{self.antigen}] Write energy for {sequences[i]}:")) - energy2 = float(input(f"[{self.antigen}] Confirm energy for {sequences[i]}:")) + default = np.random.randn() + energy1 = float(custom_input(message=f"[{self.antigen}] Write energy for {sequences[i]}:", default=default)) + energy2 = float( + custom_input(message=f"[{self.antigen}] Confirm energy for {sequences[i]}:", default=default)) while energy1 != energy2: print("Mismatch, pleaser enter energies again") - energy1 = float(input(f"[{self.antigen}] Write energy for {sequences[i]}:")) - energy2 = float(input(f"[{self.antigen}] Confirm energy for {sequences[i]}:")) + energy1 = float(custom_input(message=f"[{self.antigen}] Write energy for {sequences[i]}:", + default=default)) + energy2 = float(custom_input(message=f"[{self.antigen}] Confirm energy for {sequences[i]}:", + default=default)) energies.append(energy1) return np.array(energies), sequences +class TableFilling(BaseTool): + """ Get results by reading a csv """ + + def __init__(self, config): + BaseTool.__init__(self) + ''' + config: dictionary of parameters for BO that includes: + - antigen: PDB ID of antigen + - path_to_eval_csv: str + ''' + for key in ['antigen']: + assert key in config, f"\"{key}\" is not defined in config" + self.config = config + self.antigen = self.config["antigen"] + self.path_to_eval_csv = self.config["path_to_eval_csv"] + + def Energy(self, x): + ''' + x: categorical vector (num_Seq x Length) + ''' + x = x.astype('int32') + if len(x.shape) == 1: + x = x.reshape(1, -1) + + sequences = [] + for i, seq in enumerate(x): + seq2char = ''.join(self.idx_to_AA[aa] for aa in seq) + sequences.append(seq2char) + + # save sequences + dirname = os.path.dirname(self.path_to_eval_csv) + os.makedirs(dirname, exist_ok=True) + to_eval = pd.DataFrame([[self.antigen, seq, None, 0] for seq in sequences], + columns=["Antigen", "Antibody", "Eval", "Validate (0/1)"]) + to_eval.to_csv(self.path_to_eval_csv, index=False) + print(f"Saved candidates to evaluate in {os.path.abspath(self.path_to_eval_csv)}") + + # Try to read the evaluations + + while True: + table_of_results = pd.read_csv(self.path_to_eval_csv, index_col=None) + if np.all(table_of_results["Validate (0/1)"].values): + energies = table_of_results["Validate (0/1)"].values + break + else: + time.sleep(5) + + return np.array(energies), sequences + + ########################################### # Docking and Visualisation Tool ###########################################