Skip to content

Commit

Permalink
Merge main again
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaflo committed Apr 11, 2023
2 parents d40aa15 + c21e5e7 commit 736f27f
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 102 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ dependencies:
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git
50 changes: 27 additions & 23 deletions notebooks/OpenFold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q conda==4.13.0\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python=3.7 \\\n",
" python=3.8 \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
Expand Down Expand Up @@ -180,15 +181,12 @@
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(relax_prediction):\n",
" %shell conda install -y -q -c conda-forge \\\n",
" openmm=7.5.1 \\\n",
" pdbfixer=1.7\n",
" \n",
" # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n",
" %shell conda install -y -q -c conda-forge openmm=7.5.1\n",
" # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.8/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n",
" %shell conda install -y -q -c conda-forge pdbfixer=1.7\n",
"\n",
" if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
Expand Down Expand Up @@ -222,8 +220,8 @@
"import unittest.mock\n",
"import sys\n",
"\n",
"sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"sys.path.insert(0, '/usr/local/lib/python3.8/site-packages/')\n",
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
"\n",
"# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n",
Expand All @@ -247,6 +245,14 @@
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"import shutil\n",
"\n",
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
"import locale\n",
"def getpreferredencoding(do_setlocale = True):\n",
" return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
Expand All @@ -267,9 +273,8 @@
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"if(relax_prediction):\n",
" from openfold.np.relax import relax\n",
" from openfold.np.relax import utils\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax.utils import overwrite_b_factors\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
Expand Down Expand Up @@ -571,14 +576,13 @@
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n",
" )\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
" best_pdb = relaxed_pdb\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(best_pdb)\n",
"\n",
" pbar.update(n=1) # Finished AMBER relax.\n",
"\n",
"# Construct multiclass b-factors to indicate confidence bands\n",
Expand All @@ -590,7 +594,7 @@
" banded_b_factors.append(idx)\n",
" break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
"to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors)\n",
"\n",
"# --- Visualise the prediction & confidence ---\n",
"show_sidechains = True\n",
Expand Down Expand Up @@ -688,7 +692,7 @@
"\n",
"\n",
"# --- Download the predictions ---\n",
"!zip -q -r {output_dir}.zip {output_dir}\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')"
],
"execution_count": null,
Expand Down
9 changes: 7 additions & 2 deletions openfold/model/triangular_multiplicative_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,13 @@ def forward(self,
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)

if(is_fp16_enabled()):

# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()

if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
Expand Down
145 changes: 140 additions & 5 deletions openfold/np/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
from openfold.np import residue_constants
from Bio.PDB import PDBParser
import numpy as np
import modelcif
import modelcif.model
import modelcif.dumper
import modelcif.reference
import modelcif.protocol
import modelcif.alignment
import modelcif.qa_metric


FeatureDict = Mapping[str, np.ndarray]
Expand Down Expand Up @@ -87,8 +94,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
Expand Down Expand Up @@ -184,7 +191,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])

atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
Expand Down Expand Up @@ -267,7 +274,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""
out_pdb_lines = []
lines = pdb_str.split('\n')

remark = prot.remark
if(remark is not None):
out_pdb_lines.append(f"REMARK {remark}")
Expand Down Expand Up @@ -387,7 +394,7 @@ def to_pdb(prot: Protein) -> str:
0
] # Protein supports only C, N, O, S, this works.
charge = ""

chain_tag = "A"
if(chain_index is not None):
chain_tag = chain_tags[chain_index[i]]
Expand Down Expand Up @@ -436,6 +443,134 @@ def to_pdb(prot: Protein) -> str:
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.


def to_modelcif(prot: Protein) -> str:
"""
Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates
will be treated as the same polymer entity. But note that if chains differ in modelled regions,
no attempt is made at identifying them as a single polymer entity.
Args:
prot: The protein to convert to PDB.
Returns:
ModelCIF string.
"""

restypes = residue_constants.restypes + ["X"]
atom_types = residue_constants.atom_types

atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index

n = aatype.shape[0]
if chain_index is None:
chain_index = [0 for i in range(n)]

system = modelcif.System(title='OpenFold prediction')

# Finding chains and creating entities
seqs = {}
seq = []
last_chain_idx = None
for i in range(n):
if last_chain_idx is not None and last_chain_idx != chain_index[i]:
seqs[last_chain_idx] = seq
seq = []
seq.append(restypes[aatype[i]])
last_chain_idx = chain_index[i]
# finally add the last chain
seqs[last_chain_idx] = seq

# now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions)
unique_seqs = {}
for chain_idx, seq_list in seqs.items():
seq = "".join(seq_list)
if seq in unique_seqs:
unique_seqs[seq].append(chain_idx)
else:
unique_seqs[seq] = [chain_idx]

# adding 1 entity per unique sequence
entities_map = {}
for key, value in unique_seqs.items():
model_e = modelcif.Entity(key, description='Model subunit')
for chain_idx in value:
entities_map[chain_idx] = model_e

chain_tags = string.ascii_uppercase
asym_unit_map = {}
for chain_idx in set(chain_index):
# Define the model assembly
chain_id = chain_tags[chain_idx]
asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id)
asym_unit_map[chain_idx] = asym
modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly')

class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
name = "pLDDT"
software = None
description = "Predicted lddt"

class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
name = "pLDDT"
software = None
description = "Global pLDDT, mean of per-residue pLDDTs"

class _MyModel(modelcif.model.AbInitioModel):
def get_atoms(self):
# Add all atom sites.
for i in range(n):
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
element = atom_name[0] # Protein supports only C, N, O, S, this works.
yield modelcif.model.Atom(
asym_unit=asym_unit_map[chain_index[i]], type_symbol=element,
seq_id=residue_index[i], atom_id=atom_name,
x=pos[0], y=pos[1], z=pos[2],
het=False, biso=b_factor, occupancy=1.00)

def add_scores(self):
# local scores
plddt_per_residue = {}
for i in range(n):
for mask, b_factor in zip(atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
# add 1 per residue, not 1 per atom
if chain_index[i] not in plddt_per_residue:
# first time a chain index is seen: add the key and start the residue dict
plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
if residue_index[i] not in plddt_per_residue[chain_index[i]]:
plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
plddts = []
for chain_idx in plddt_per_residue:
for residue_idx in plddt_per_residue[chain_idx]:
plddt = plddt_per_residue[chain_idx][residue_idx]
plddts.append(plddt)
self.qa_metrics.append(
_LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt))
# global score
self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))

# Add the model and modeling protocol to the file and write them out:
model = _MyModel(assembly=modeled_assembly, name='Best scoring model')
model.add_scores()

model_group = modelcif.model.ModelGroup([model], name='All models')
system.model_groups.append(model_group)

fh = io.StringIO()
modelcif.dumper.write(fh, [system])
return fh.getvalue()


def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
Expand Down
3 changes: 0 additions & 3 deletions openfold/np/relax/amber_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,6 @@ def run_pipeline(
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)

# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot

exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
Expand Down
10 changes: 8 additions & 2 deletions openfold/np/relax/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self._use_gpu = use_gpu

def process(
self, *, prot: protein.Protein
self, *, prot: protein.Protein, cif_output: bool
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
Expand Down Expand Up @@ -89,5 +89,11 @@ def process(
]

min_pdb = protein.add_pdb_headers(prot, min_pdb)
output_str = min_pdb
if cif_output:
# TODO the model cif will be missing some metadata like headers (PARENTs and
# REMARK with some details of the run, like num of recycles)
final_prot = protein.from_pdb_string(min_pdb)
output_str = protein.to_modelcif(final_prot)

return min_pdb, debug_data, violations
return output_str, debug_data, violations
12 changes: 8 additions & 4 deletions openfold/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
return unrelaxed_protein


def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name, cif_output):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(model_device != "cpu"),
**config.relax,
Expand All @@ -239,18 +239,22 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
if "cuda" in model_device:
device_no = model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
# the struct_str will contain either a PDB-format or a ModelCIF format string
struct_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein, cif_output=cif_output)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t

logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))

# Save the relaxed PDB.
suffix = "_relaxed.pdb"
if cif_output:
suffix = "_relaxed.cif"
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
output_directory, f'{output_name}{suffix}'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
fp.write(struct_str)

logger.info(f"Relaxed output written to {relaxed_output_path}...")
Loading

0 comments on commit 736f27f

Please sign in to comment.