From 85d123297578c303801aa41de480b459e69058b5 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Sat, 6 Aug 2022 15:37:32 +0800 Subject: [PATCH 1/2] fixed a bug with Windows path, and slightly modified how we treat fasta comments that removes colons at the beginning --- pipeline.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pipeline.py b/pipeline.py index 4bb2387..5b08a17 100644 --- a/pipeline.py +++ b/pipeline.py @@ -27,6 +27,7 @@ import os import os.path import pathlib +import sys import typing from Bio import PDB as PDB @@ -112,9 +113,9 @@ def fasta2inputs( for line in lines: if len(line) == 0: continue - if line.startswith(">") or line .startswith(":"): + if line.startswith(">") or line.startswith(":"): name = True - chain_ids.append(line.strip(">").strip("\n")) + chain_ids.append(line[1:].strip("\n")) else: if name: aastr.append(line.strip("\n").upper()) @@ -129,7 +130,10 @@ def fasta2inputs( folder_name = path_leaf(fasta_path).split(".")[0] output_dir = os.path.join(parent, folder_name) os.makedirs(output_dir, exist_ok=True) - name_max = os.pathconf(output_dir, 'PC_NAME_MAX') - 4 + if sys.path == 'win32': + name_max = 100 + else: + name_max = os.pathconf(output_dir, 'PC_NAME_MAX') - 4 for i, (ch, fas) in enumerate(combined): fas = fas.replace("Z", "E").replace("B", "D").replace("U", "C") From 3546a3602829dd3544ee7c01d0646b659b98acc6 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Sat, 6 Aug 2022 16:01:47 +0800 Subject: [PATCH 2/2] Added supports for Apple Silicon chips with MPS devices optional --- omegafold/decode.py | 3 ++- omegafold/embedders.py | 2 +- omegafold/geoformer.py | 10 ++++---- omegafold/model.py | 5 ---- omegafold/modules.py | 2 +- omegafold/utils/__init__.py | 2 +- omegafold/utils/protein_utils/aaframe.py | 22 ++++++++-------- omegafold/utils/protein_utils/functions.py | 30 +++++++++++++++------- 8 files changed, 42 insertions(+), 34 deletions(-) diff --git a/omegafold/decode.py b/omegafold/decode.py index 9d5f3bb..a06dfcc 100644 --- a/omegafold/decode.py +++ b/omegafold/decode.py @@ -131,7 +131,8 @@ def forward( # Combine them and take the softmax logits = scalar_logits + edge_logits - point_logits - logits = torch.masked_fill(logits, ~frames.mask[None, ..., None], -1e8) + m = utils.bit_wise_not(frames.mask[None, ..., None]) + logits = torch.masked_fill(logits, m, -1e8) attn_w = torch.softmax(logits, dim=-2) # (num_res, num_res, n_head) # get the output diff --git a/omegafold/embedders.py b/omegafold/embedders.py index d7c8d18..75179c2 100644 --- a/omegafold/embedders.py +++ b/omegafold/embedders.py @@ -257,7 +257,7 @@ def forward( Returns: """ - atom_mask = rc.restype2atom_mask[fasta].to(self.device) + atom_mask = rc.restype2atom_mask[fasta.cpu()].to(self.device) prev_beta = utils.create_pseudo_beta(prev_x, atom_mask) d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3)) d = self.dgram(d) diff --git a/omegafold/geoformer.py b/omegafold/geoformer.py index 69ba47d..4a63517 100644 --- a/omegafold/geoformer.py +++ b/omegafold/geoformer.py @@ -106,13 +106,13 @@ def forward( """ node_repr += self.attention_w_edge_bias(node_repr, edge_repr, mask) - node_repr_col = utils.normalize(node_repr.transpose(-2, -3)) - node_repr_col, _ = self.column_attention( - node_repr_col, - node_repr_col, + node_col = utils.normalize(node_repr.transpose(-2, -3).contiguous()) + node_col, _ = self.column_attention( + node_col, + node_col, bias=utils.mask2bias(mask.T[..., None, None, :]) ) - node_repr += node_repr_col.transpose(-2, -3) + node_repr += node_col.transpose(-2, -3) node_repr += self.node_transition( node_repr, subbatch_size=fwd_cfg.subbatch_size ) diff --git a/omegafold/model.py b/omegafold/model.py index eb451cd..15a03cb 100644 --- a/omegafold/model.py +++ b/omegafold/model.py @@ -237,11 +237,6 @@ def create_initial_prev_dict( [num_res, 14, 3], device=self.device, dtype=torch.float ) - prev_x_rot = torch.zeros( - [num_res, 8, 9], - device=self.device, dtype=torch.float - ) - prev_x_rot[..., [0, 4, ]] = 1 return { "prev_node": torch.zeros( diff --git a/omegafold/modules.py b/omegafold/modules.py index a2001ca..f437838 100644 --- a/omegafold/modules.py +++ b/omegafold/modules.py @@ -549,7 +549,7 @@ def forward( ab = torch.einsum( '...ikrd,...jkrd->...ijrd', *act.split([self.d_edge, self.d_edge], dim=-1) - ) + ).contiguous() ab = utils.normalize(ab) ab = torch.einsum( '...rd,rdc->...rc', ab, self.out_proj_w diff --git a/omegafold/utils/__init__.py b/omegafold/utils/__init__.py index 5ff6b3f..cf0b9a7 100644 --- a/omegafold/utils/__init__.py +++ b/omegafold/utils/__init__.py @@ -27,9 +27,9 @@ from omegafold.utils.protein_utils import residue_constants from omegafold.utils.protein_utils.aaframe import AAFrame from omegafold.utils.protein_utils.functions import ( + bit_wise_not, create_pseudo_beta, get_norm, - norm_l2, robust_normalize, ) from omegafold.utils.torch_utils import ( diff --git a/omegafold/utils/protein_utils/aaframe.py b/omegafold/utils/protein_utils/aaframe.py index 0414aa3..f2a21b7 100644 --- a/omegafold/utils/protein_utils/aaframe.py +++ b/omegafold/utils/protein_utils/aaframe.py @@ -210,9 +210,8 @@ def translation(self, value: torch.Tensor) -> None: value: the translation value """ - self._translation = value.masked_fill( - ~self.mask.unsqueeze(-1).expand_as(value), 0 - ) + m = f.bit_wise_not(self.mask.unsqueeze(-1).expand_as(value)) + self._translation = value.masked_fill(m, 0) @property def rotation(self) -> torch.Tensor: @@ -234,10 +233,10 @@ def rotation(self, value: torch.Tensor) -> None: value: the rotational matrices """ - mask = ~self.mask.unsqueeze(-1).unsqueeze(-1).expand_as(value) - value = value.masked_fill(mask, 0) + mask = f.bit_wise_not(self.mask[..., None, None].expand_as(value)) + value = value.masked_fill(mask, 0.) value = value.masked_fill( - mask * torch.eye(3, device=mask.device, dtype=torch.bool), 1 + mask * torch.eye(3, dtype=torch.bool).to(mask.device), 1 ) self._rotation = value @@ -783,7 +782,7 @@ def expand_w_torsion( # make extra backbone frames # This follows the order of ~restypes - m = rc.restype_aa_default_frame[fasta].to(self.device) + m = rc.restype_aa_default_frame[fasta.cpu()].to(self.device) default_frames = AAFrame.from_4x4( m, torsion_angles_mask, unit="Angstrom" ) @@ -861,17 +860,18 @@ def expanded_to_pos( assert self._unit == "Angstrom" - residx2group = rc.restype_atom14_to_aa.to(self.device) + fasta = fasta.cpu() + residx2group = rc.restype_atom14_to_aa residx2group = residx2group[..., :pos_counts] - residx2group = residx2group[fasta] + residx2group = residx2group[fasta].to(self.device) group_mask = F.one_hot(residx2group, num_classes=8) group_mask = group_mask[..., :num_classes] group_mask = group_mask * frame.mask[..., None, :] to_mask = frame.unsqueeze(-2) * group_mask map_atoms_to_global = to_mask.sum(-1) - lit_pos = rc.restype_atom14_aa_positions.to(self.device) + lit_pos = rc.restype_atom14_aa_positions lit_pos = lit_pos[..., :pos_counts, :] - lit_pos = lit_pos[fasta] + lit_pos = lit_pos[fasta].to(self.device) pred_pos = map_atoms_to_global.transform(lit_pos) # mask = c.restype_atom14_mask[sequence] # (N, 14) # mask |= self.mask[..., None] diff --git a/omegafold/utils/protein_utils/functions.py b/omegafold/utils/protein_utils/functions.py index 563159a..ba6a74f 100644 --- a/omegafold/utils/protein_utils/functions.py +++ b/omegafold/utils/protein_utils/functions.py @@ -23,20 +23,29 @@ # ============================================================================= # Imports # ============================================================================= -import functools -from typing import Union +import typing import torch -from torch import linalg as LA + # ============================================================================= # Functions # ============================================================================= -get_norm = functools.partial(LA.norm, ord=2, dim=-1) +def get_norm(x: torch.Tensor) -> torch.Tensor: + """ + Replacement for LA.norm since MPS does not support it yet. + + Args: + x: + + Returns: + + """ + return x.norm(p=2, dim=-1) def robust_normalize( - x: torch.Tensor, dim: int = -1, p: Union[int, str] = 2 + x: torch.Tensor, dim: int = -1, p: typing.Union[int, str] = 2 ) -> torch.Tensor: """ Normalization with a constant small term on the denominator @@ -50,10 +59,7 @@ def robust_normalize( the normalized result """ - return x / (LA.norm(x, ord=p, dim=dim, keepdim=True).clamp(4e-5)) - - -norm_l2 = functools.partial(LA.norm, ord=2, dim=-1) + return x / (x.norm(p=p, dim=dim, keepdim=True).clamp(4e-5)) def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: @@ -139,6 +145,12 @@ def create_pseudo_beta( return pseudo_beta +def bit_wise_not(boolean_tensor: torch.Tensor) -> torch.Tensor: + """For MPS devices that have no support for yet bit-wise not""" + boolean_tensor = 1 - boolean_tensor.float() + return boolean_tensor.bool() + + # ============================================================================= # Tests # =============================================================================