Skip to content

Commit

Permalink
Added supports for Apple Silicon chips with MPS devices optional
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiWang1998 committed Aug 6, 2022
1 parent 85d1232 commit 3546a36
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 34 deletions.
3 changes: 2 additions & 1 deletion omegafold/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def forward(

# Combine them and take the softmax
logits = scalar_logits + edge_logits - point_logits
logits = torch.masked_fill(logits, ~frames.mask[None, ..., None], -1e8)
m = utils.bit_wise_not(frames.mask[None, ..., None])
logits = torch.masked_fill(logits, m, -1e8)
attn_w = torch.softmax(logits, dim=-2) # (num_res, num_res, n_head)

# get the output
Expand Down
2 changes: 1 addition & 1 deletion omegafold/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def forward(
Returns:
"""
atom_mask = rc.restype2atom_mask[fasta].to(self.device)
atom_mask = rc.restype2atom_mask[fasta.cpu()].to(self.device)
prev_beta = utils.create_pseudo_beta(prev_x, atom_mask)
d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3))
d = self.dgram(d)
Expand Down
10 changes: 5 additions & 5 deletions omegafold/geoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ def forward(
"""
node_repr += self.attention_w_edge_bias(node_repr, edge_repr, mask)
node_repr_col = utils.normalize(node_repr.transpose(-2, -3))
node_repr_col, _ = self.column_attention(
node_repr_col,
node_repr_col,
node_col = utils.normalize(node_repr.transpose(-2, -3).contiguous())
node_col, _ = self.column_attention(
node_col,
node_col,
bias=utils.mask2bias(mask.T[..., None, None, :])
)
node_repr += node_repr_col.transpose(-2, -3)
node_repr += node_col.transpose(-2, -3)
node_repr += self.node_transition(
node_repr, subbatch_size=fwd_cfg.subbatch_size
)
Expand Down
5 changes: 0 additions & 5 deletions omegafold/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,6 @@ def create_initial_prev_dict(
[num_res, 14, 3],
device=self.device, dtype=torch.float
)
prev_x_rot = torch.zeros(
[num_res, 8, 9],
device=self.device, dtype=torch.float
)
prev_x_rot[..., [0, 4, ]] = 1

return {
"prev_node": torch.zeros(
Expand Down
2 changes: 1 addition & 1 deletion omegafold/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def forward(
ab = torch.einsum(
'...ikrd,...jkrd->...ijrd',
*act.split([self.d_edge, self.d_edge], dim=-1)
)
).contiguous()
ab = utils.normalize(ab)
ab = torch.einsum(
'...rd,rdc->...rc', ab, self.out_proj_w
Expand Down
2 changes: 1 addition & 1 deletion omegafold/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from omegafold.utils.protein_utils import residue_constants
from omegafold.utils.protein_utils.aaframe import AAFrame
from omegafold.utils.protein_utils.functions import (
bit_wise_not,
create_pseudo_beta,
get_norm,
norm_l2,
robust_normalize,
)
from omegafold.utils.torch_utils import (
Expand Down
22 changes: 11 additions & 11 deletions omegafold/utils/protein_utils/aaframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ def translation(self, value: torch.Tensor) -> None:
value: the translation value
"""
self._translation = value.masked_fill(
~self.mask.unsqueeze(-1).expand_as(value), 0
)
m = f.bit_wise_not(self.mask.unsqueeze(-1).expand_as(value))
self._translation = value.masked_fill(m, 0)

@property
def rotation(self) -> torch.Tensor:
Expand All @@ -234,10 +233,10 @@ def rotation(self, value: torch.Tensor) -> None:
value: the rotational matrices
"""
mask = ~self.mask.unsqueeze(-1).unsqueeze(-1).expand_as(value)
value = value.masked_fill(mask, 0)
mask = f.bit_wise_not(self.mask[..., None, None].expand_as(value))
value = value.masked_fill(mask, 0.)
value = value.masked_fill(
mask * torch.eye(3, device=mask.device, dtype=torch.bool), 1
mask * torch.eye(3, dtype=torch.bool).to(mask.device), 1
)
self._rotation = value

Expand Down Expand Up @@ -783,7 +782,7 @@ def expand_w_torsion(

# make extra backbone frames
# This follows the order of ~restypes
m = rc.restype_aa_default_frame[fasta].to(self.device)
m = rc.restype_aa_default_frame[fasta.cpu()].to(self.device)
default_frames = AAFrame.from_4x4(
m, torsion_angles_mask, unit="Angstrom"
)
Expand Down Expand Up @@ -861,17 +860,18 @@ def expanded_to_pos(

assert self._unit == "Angstrom"

residx2group = rc.restype_atom14_to_aa.to(self.device)
fasta = fasta.cpu()
residx2group = rc.restype_atom14_to_aa
residx2group = residx2group[..., :pos_counts]
residx2group = residx2group[fasta]
residx2group = residx2group[fasta].to(self.device)
group_mask = F.one_hot(residx2group, num_classes=8)
group_mask = group_mask[..., :num_classes]
group_mask = group_mask * frame.mask[..., None, :]
to_mask = frame.unsqueeze(-2) * group_mask
map_atoms_to_global = to_mask.sum(-1)
lit_pos = rc.restype_atom14_aa_positions.to(self.device)
lit_pos = rc.restype_atom14_aa_positions
lit_pos = lit_pos[..., :pos_counts, :]
lit_pos = lit_pos[fasta]
lit_pos = lit_pos[fasta].to(self.device)
pred_pos = map_atoms_to_global.transform(lit_pos)
# mask = c.restype_atom14_mask[sequence] # (N, 14)
# mask |= self.mask[..., None]
Expand Down
30 changes: 21 additions & 9 deletions omegafold/utils/protein_utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,29 @@
# =============================================================================
# Imports
# =============================================================================
import functools
from typing import Union
import typing

import torch
from torch import linalg as LA


# =============================================================================
# Functions
# =============================================================================
get_norm = functools.partial(LA.norm, ord=2, dim=-1)
def get_norm(x: torch.Tensor) -> torch.Tensor:
"""
Replacement for LA.norm since MPS does not support it yet.
Args:
x:
Returns:
"""
return x.norm(p=2, dim=-1)


def robust_normalize(
x: torch.Tensor, dim: int = -1, p: Union[int, str] = 2
x: torch.Tensor, dim: int = -1, p: typing.Union[int, str] = 2
) -> torch.Tensor:
"""
Normalization with a constant small term on the denominator
Expand All @@ -50,10 +59,7 @@ def robust_normalize(
the normalized result
"""
return x / (LA.norm(x, ord=p, dim=dim, keepdim=True).clamp(4e-5))


norm_l2 = functools.partial(LA.norm, ord=2, dim=-1)
return x / (x.norm(p=p, dim=dim, keepdim=True).clamp(4e-5))


def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -139,6 +145,12 @@ def create_pseudo_beta(
return pseudo_beta


def bit_wise_not(boolean_tensor: torch.Tensor) -> torch.Tensor:
"""For MPS devices that have no support for yet bit-wise not"""
boolean_tensor = 1 - boolean_tensor.float()
return boolean_tensor.bool()


# =============================================================================
# Tests
# =============================================================================
Expand Down

0 comments on commit 3546a36

Please sign in to comment.