Skip to content

Commit

Permalink
Merge pull request HeliXonProtein#18 from RuiWang1998/main
Browse files Browse the repository at this point in the history
Memory usage reduction
  • Loading branch information
xiwen1995 authored Aug 13, 2022
2 parents 044d013 + b95cbfa commit 61ac2ae
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 151 deletions.
47 changes: 32 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

# OmegaFold: High-resolution de novo Structure Prediction from Primary Sequence

#### This is the first release for paper [High-resolution de novo structure prediction from primary sequence](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1).
#### This is the release code for paper [High-resolution de novo structure prediction from primary sequence](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1).

We will continue to optimize this repository for more ease of use, for
instance, reducing the GRAM required to inference long proteins and
releasing possibly stronger models.

## Update Notes

We have optimized (to some extent) the GRAM usage of OmegaFold model in our
latest release. Now the model can inference protein sequence as long as
_4096_ on NVIDIA A100 Graphics card with 80 GB of memory with
`--subbatch_size` set to 128 without hitting 70 GB of memory.
This version's model is more sensitive to `--subbatch_size`.

## Setup

To prepare the environment to run OmegaFold,
Expand All @@ -23,34 +31,43 @@ pip install git+https://github.com/HeliXonProtein/OmegaFold.git
```commandline
git clone https://github.com/HeliXonProtein/OmegaFold
cd OmegaFold
pip install -r requirements.txt
python setup.py install
```

should get you where you want.
Even if this failed, since we use minimal 3rd party libraries, you can
always just install the latest
[PyTorch](https://pytorch.org) and [biopython](https://biopython.org)
(and that's it!)
yourself.

The `INPUT_FILE.fasta` should be a normal fasta file with possibly many
sequences with a comment line starting with `>` or `:` above the amino
acid sequence itself.

This command will download the weight
from https://helixon.s3.amazonaws.com/release1.pt
to `~/.cache/omegafold_ckpt/model.pt`
and load the model

## Running

There should be only one way to use the model:
You could simply

```commandline
omegafold INPUT_FILE.fasta OUTPUT_DIRECTORY
```

And voila!

The `INPUT_FILE.fasta` should be a normal fasta file with possibly many
sequences with a comment line starting with `>` or `:` above the amino
acid sequence itself.
### Alternatively

This command will download the weight
from https://helixon.s3.amazonaws.com/release1.pt
to `~/.cache/omegafold_ckpt/model.pt`
and load the model
Even if this failed, since we use minimal 3rd party libraries, you can
always just install the latest
[PyTorch](https://pytorch.org) and [biopython](https://biopython.org)
(and that's it!) yourself.
In this case, you could run

```commandline
python main.py INPUT_FILE.fasta OUTPUT_DIRECTORY
```

### Notes on resources

However, since we have implemented sharded execution, it is possible to

Expand Down
5 changes: 2 additions & 3 deletions omegafold/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def forward(

# Combine them and take the softmax
logits = scalar_logits + edge_logits - point_logits
m = utils.bit_wise_not(frames.mask[None, ..., None])
logits = torch.masked_fill(logits, m, -1e8)
attn_w = torch.softmax(logits, dim=-2) # (num_res, num_res, n_head)
logits += utils.mask2bias(frames.mask[None, ..., None])
attn_w = modules.softmax(logits, dim=-2, in_place=True)

# get the output
ret_edge = torch.einsum("...qkh,...qkc->...qhc", attn_w, edge_repr)
Expand Down
30 changes: 18 additions & 12 deletions omegafold/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,16 @@ def __init__(self, cfg: argparse.Namespace) -> None:
self.proj_j = nn.Embedding(cfg.alphabet_size, cfg.edge_dim)
self.relpos = RelPosEmbedder(cfg.relpos_len * 2 + 1, cfg.edge_dim)

def forward(self, fasta_sequence: torch.Tensor) -> torch.Tensor:
i = self.proj_i(fasta_sequence).unsqueeze(-2)
j = self.proj_j(fasta_sequence).unsqueeze(-3)
edge_repr = i + j
rel_pos = self.relpos(fasta_sequence.size(-1))
edge_repr += rel_pos
def forward(
self,
fasta_sequence: torch.Tensor,
out: torch.Tensor
) -> torch.Tensor:
out += self.proj_i(fasta_sequence).unsqueeze(-2)
out += self.proj_j(fasta_sequence).unsqueeze(-3)
out += self.relpos(fasta_sequence.size(-1))

return edge_repr
return out


class RoPE(nn.Module):
Expand Down Expand Up @@ -241,9 +243,11 @@ def forward(
fasta: torch.Tensor,
prev_node: torch.Tensor,
prev_edge: torch.Tensor,
prev_x: torch.Tensor
prev_x: torch.Tensor,
node_repr: torch.Tensor,
edge_repr: torch.Tensor,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""
"""Recycle the last run
Args:
fasta:
Expand All @@ -253,16 +257,18 @@ def forward(
of shape [num_res, num_res, edge_repr_dim]
prev_x: pseudo beta coordinates from the previous cycle.
of shape [num_res, 3]
node_repr: the node representation to put stuff in
edge_repr: the edge representation to put stuff in
Returns:
"""
atom_mask = rc.restype2atom_mask[fasta.cpu()].to(self.device)
atom_mask = rc.restype2atom_mask[fasta].to(self.device)
prev_beta = utils.create_pseudo_beta(prev_x, atom_mask)
d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3))
d = self.dgram(d)
edge_repr = self.prev_pos_embed(d)
node_repr = self.layernorm_node(prev_node)
node_repr[..., 0, :, :] += self.layernorm_node(prev_node)
edge_repr += self.prev_pos_embed(d)
edge_repr += self.layernorm_edge(prev_edge)

return node_repr, edge_repr
Expand Down
29 changes: 20 additions & 9 deletions omegafold/geoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,37 @@ def forward(
Returns:
"""
node_repr += self.attention_w_edge_bias(node_repr, edge_repr, mask)
node_col = utils.normalize(node_repr.transpose(-2, -3).contiguous())
node_col, _ = self.column_attention(
node_col,
node_col,
bias=utils.mask2bias(mask.T[..., None, None, :])
node_repr += self.attention_w_edge_bias(
node_repr,
edge_repr,
mask,
fwd_cfg=fwd_cfg
)
node_repr += node_col.transpose(-2, -3)
node_repr = self._column_attention(node_repr, mask, fwd_cfg=fwd_cfg)
node_repr += self.node_transition(
node_repr, subbatch_size=fwd_cfg.subbatch_size
node_repr,
subbatch_size=fwd_cfg.subbatch_size
)

edge_repr += self.out_product(node_repr, mask)
for layer in self.geometric_attention:
edge_repr += layer(edge_repr, mask[..., 0, :])
edge_repr += layer(edge_repr, mask[..., 0, :], fwd_cfg=fwd_cfg)

edge_repr += self.edge_transition(edge_repr, fwd_cfg.subbatch_size)

return node_repr, edge_repr

def _column_attention(self, node_repr, mask, fwd_cfg):
node_repr_col = utils.normalize(node_repr.transpose(-2, -3))
node_repr_col = self.column_attention(
node_repr_col,
node_repr_col,
bias=utils.mask2bias(mask.T[..., None, None, :]),
fwd_cfg=fwd_cfg
)
node_repr += node_repr_col.transpose(-2, -3)
return node_repr


class GeoFormer(modules.OFModule):
def __init__(self, cfg: argparse.Namespace):
Expand Down
37 changes: 21 additions & 16 deletions omegafold/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ def __init__(self, cfg: argparse.Namespace) -> None:

def forward(
self,
# fasta: torch.Tensor,
# mask: torch.Tensor,
inputs: _INPUTS,
predict_with_confidence: typing.Optional[bool] = True,
*,
Expand All @@ -145,12 +143,9 @@ def forward(
Args:
inputs:
predict_with_confidence: if to choose with confidence
fwd_cfg: the configuration for this forward run
containing just subbatch_size currently
fwd_cfg: forward configuration
Returns:
A dictionary containing the position, the mask of the atoms in
atom14 format, per-residue confidence, and overall confidence
"""
# Preparation before entering the cycles
Expand All @@ -168,10 +163,14 @@ def forward(
p_msa_mask,
fwd_cfg
)
prev_dict['fasta'] = fasta
node_recycle, edge_recycle = self.recycle_embedder(**prev_dict)
node_repr[..., 0, :, :] = node_repr[..., 0, :, :] + node_recycle
edge_repr = edge_repr + edge_recycle
node_recycle, edge_repr = self.recycle_embedder(
fasta=fasta,
prev_node=prev_dict.pop('prev_node'),
prev_edge=prev_dict.pop('prev_edge'),
prev_x=prev_dict.pop('prev_x'),
node_repr=node_repr,
edge_repr=edge_repr,
)

result, prev_dict = self.omega_fold_cycle(
fasta=fasta,
Expand Down Expand Up @@ -208,16 +207,22 @@ def deep_sequence_embed(
Args:
fasta: the fasta sequence
mask: the mask indicating the validity of the token
fwd_cfg:
Returns:
"""
edge_repr = self.input_embedder(fasta[..., 0, :])
node_plm, edge_plm = self.omega_plm(fasta, mask, fwd_cfg=fwd_cfg)
node_repr = self.plm_node_embedder(utils.normalize(node_plm))
edge_plm = edge_plm.permute(1, 2, 0)
edge_repr += self.plm_edge_embedder(utils.normalize(edge_plm))
node_repr, edge_repr = self.omega_plm(
fasta, mask, fwd_cfg=fwd_cfg
)
# return node_plm, edge_plm
node_repr = self.plm_node_embedder(
utils.normalize(node_repr, in_place=True)
)
edge_repr = edge_repr.permute(1, 2, 0)
edge_repr = self.plm_edge_embedder(
utils.normalize(edge_repr, in_place=True)
)
edge_repr = self.input_embedder(fasta[..., 0, :], out=edge_repr)

return node_repr, edge_repr

Expand Down
Loading

0 comments on commit 61ac2ae

Please sign in to comment.