forked from evolutionaryscale/esm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zeming Lin
committed
Jun 25, 2024
0 parents
commit b42b58b
Showing
53 changed files
with
8,812 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# ESM3 | ||
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks. | ||
|
||
ESM3 is a *generative* masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does. | ||
|
||
<!--![ESM3 Diagram](_assets/esm3_diagram.png)--> | ||
<img src="_assets/esm3_diagram.png" alt="ESM3 Diagram" width="400" /> | ||
|
||
The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters. | ||
|
||
Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family. | ||
ESM3-open is available under a [non-commercial license](LICENSE.md). | ||
Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3! | ||
|
||
|
||
## Quickstart for ESM3-open | ||
|
||
``` | ||
pip install esm | ||
``` | ||
|
||
In order to download the weights, we require users to accept our non-commercial license. | ||
The weights are stored on HuggingFace Hub under [HuggingFace/EvolutionaryScale/esm3](https://huggingface.co/EvolutionaryScale/esm3). | ||
Please create an account and accept the license. | ||
|
||
```py | ||
from huggingface_hub import login | ||
from esm.models.esm3 import ESM3 | ||
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig | ||
|
||
# This will prompt you to get an API key from huggingface hub, make one with | ||
# "Read" or "Write" permission and copy it back here. | ||
login() | ||
|
||
# This will download the model weights and instantiate the model on your machine. | ||
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") # or "cpu" | ||
|
||
# Generate a completion for a partial Carbonic Anhydrase (2vvb) | ||
prompt = "___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________" | ||
protein = ESMProtein(sequence=prompt) | ||
# Generate the sequence, then the structure. This will iteratively unmask the sequence track. | ||
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7)) | ||
# We can show the predicted structure for the generated sequence. | ||
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8)) | ||
protein.to_pdb("./generation.pdb") | ||
# Then we can do a round trip design by inverse folding the sequence and recomputing the structure | ||
protein.sequence = None | ||
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8)) | ||
protein.structure = None | ||
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8)) | ||
protein.to_pdb("./round_tripped.pdb") | ||
``` | ||
|
||
Congratulations, you just ran a chain of thought with ESM3! | ||
Let's explore some more advanced prompting examples: | ||
|
||
[Open examples/generate.ipynb in Colab](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/generate.ipynb) | ||
|
||
## Forge: Access to larger ESM3 models | ||
You can apply for beta access to the full family of ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai). | ||
|
||
We encourage users to interact with the Forge API through the python `esm` library instead of the command line. | ||
The python interface enables you to interactively load proteins, build prompts, and inspect generated proteins. | ||
Additionally, users can seamlessly swap between `esm.models.esm3.ESM3` running locally, and | ||
`esm.sdk.forge.ESM3ForgeInferenceClient` connecting to the Forge API. | ||
|
||
Once the forge client is released, we'll be able to do something like: | ||
```py | ||
model: ESM3InferenceClient = ESMForgeInferenceClient("esm3_sm_open_v1").to("cuda") | ||
... | ||
``` | ||
and the exact same code will work. | ||
This will enable seamless access to our large 98B protein language models for protein design work. | ||
|
||
## Responsible Development | ||
|
||
EvolutionaryScale is a public benefit company. Our mission is to develop artificial intelligence to understand biology for the benefit of human health and society, through partnership with the scientific community, and open, safe, and responsible research. Inspired by the history of our field as well as [new principles and recommendations](https://responsiblebiodesign.ai/), we have created a Responsible Development Framework to guide our work towards our mission with transparency and clarity. | ||
|
||
The core tenants of our framework are | ||
|
||
- We will communicate the benefits and risks of our research | ||
- We will proactively and rigorously evaluate the risk of our models before public deployment | ||
- We will adopt risk mitigation strategies and precautionary guardrails | ||
- We will work with stakeholders in government, policy, and civil society to keep them informed | ||
|
||
With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) | ||
|
||
|
||
## License | ||
|
||
**The Big Picture:** | ||
|
||
1. The EvolutionaryScale AI Model is **only** available under this Community License Agreement for **non-commercial use** by **individuals** or **non-commercial organizations** (including universities, non-profit organizations and research institutes, educational and government bodies). | ||
|
||
2. You **may not** use the EvolutionaryScale AI Model or any derivative works of the EvolutionaryScale AI Model or its outputs: | ||
|
||
1. in connection with **any commercial activities**, for example, any activities **by, on behalf of or for a commercial entity** or to develop **any product or service** such as hosting the AI Model behind an API; or | ||
|
||
2. without attribution to EvolutionaryScale and this Community License Agreement; or | ||
|
||
3. to **train** any other **large language model**, any technology for protein representation learning or protein generation or any other AI-powered third party model **similar to EvolutionaryScale’s AI Model**, even for non-commercial usage. | ||
|
||
3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including the requirement to **restrict** the usage of any reproductions and copies **by, on behalf of or for a commercial entity** or **for any commercial purpose**. | ||
|
||
|
||
Please refer to our [non-commercial license](LICENSE.md) for details. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import functools | ||
|
||
import einops | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from esm.layers.rotary import RotaryEmbedding | ||
|
||
|
||
class MultiHeadAttention(nn.Module): | ||
def __init__( | ||
self, | ||
d_model: int, | ||
n_heads: int, | ||
bias: bool = False, | ||
qk_layernorm: bool = True, | ||
): | ||
super().__init__() | ||
|
||
self.d_model = d_model | ||
self.n_heads = n_heads | ||
|
||
self.d_head = self.d_model // self.n_heads | ||
self.layernorm_qkv = nn.Sequential( | ||
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias) | ||
) | ||
self.out_proj = nn.Linear(d_model, d_model, bias=bias) | ||
|
||
if qk_layernorm: | ||
self.q_ln = nn.LayerNorm(d_model, bias=bias) | ||
self.k_ln = nn.LayerNorm(d_model, bias=bias) | ||
else: | ||
self.q_ln = nn.Identity() | ||
self.k_ln = nn.Identity() | ||
|
||
self.rotary = RotaryEmbedding(d_model // n_heads) | ||
|
||
def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor): | ||
q = q.unflatten(-1, (self.n_heads, self.d_head)) | ||
k = k.unflatten(-1, (self.n_heads, self.d_head)) | ||
q, k = self.rotary(q, k) | ||
q = q.flatten(-2, -1) | ||
k = k.flatten(-2, -1) | ||
return q, k | ||
|
||
def forward(self, x, seq_id): | ||
qkv_BLD3 = self.layernorm_qkv(x) | ||
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) | ||
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD) | ||
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) | ||
|
||
n_heads = self.n_heads | ||
reshaper = functools.partial( | ||
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads | ||
) | ||
|
||
query_BHLD, key_BHLD, value_BHLD = map( | ||
reshaper, (query_BLD, key_BLD, value_BLD) | ||
) | ||
|
||
# Where True, enable participation in attention. | ||
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2) | ||
mask_BHLL = mask_BLL.unsqueeze(1) | ||
|
||
context_BHLD = F.scaled_dot_product_attention( | ||
query_BHLD, key_BHLD, value_BHLD, mask_BHLL | ||
) | ||
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)") | ||
return self.out_proj(context_BLD) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from esm.layers.attention import MultiHeadAttention | ||
from esm.layers.geom_attention import ( | ||
GeometricReasoningOriginalImpl, | ||
) | ||
from esm.utils.structure.affine3d import Affine3D | ||
|
||
|
||
def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: | ||
# set hidden dimesion to nearest multiple of 256 after expansion ratio | ||
return int(((expansion_ratio * d_model) + 255) // 256 * 256) | ||
|
||
|
||
class SwiGLU(nn.Module): | ||
""" | ||
SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential. | ||
This module splits the input tensor along the last dimension and applies the SiLU (Swish) | ||
activation function to the first half, then multiplies it by the second half. | ||
""" | ||
|
||
def __init__(self): | ||
super(SwiGLU, self).__init__() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x1, x2 = x.chunk(2, dim=-1) | ||
return F.silu(x1) * x2 | ||
|
||
|
||
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): | ||
return nn.Sequential( | ||
nn.LayerNorm(d_model), | ||
nn.Linear( | ||
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias | ||
), | ||
SwiGLU(), | ||
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias), | ||
) | ||
|
||
|
||
def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): | ||
hidden_dim = int(expansion_ratio * d_model) | ||
return nn.Sequential( | ||
nn.LayerNorm(d_model), | ||
nn.Linear(d_model, hidden_dim, bias=bias), | ||
nn.GELU(), | ||
nn.Linear(hidden_dim, d_model, bias=bias), | ||
) | ||
|
||
|
||
class UnifiedTransformerBlock(nn.Module): | ||
""" | ||
A unified transformer block that can optionally incorporate geometric attention. | ||
This class defines a transformer block that can be configured to use geometric attention | ||
alongside the standard multi-head attention mechanism. It is designed to be a flexible | ||
component of transformer-based models, allowing for the integration of geometric reasoning. | ||
Parameters | ||
---------- | ||
d_model : int | ||
The dimensionality of the input and output features of the transformer block. | ||
n_heads : int | ||
The number of attention heads in the multi-head attention mechanism. | ||
n_layers : int | ||
The number of layers in the transformer block. | ||
use_geom_attn : bool, optional | ||
Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False. | ||
v_heads : int, optional | ||
The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
d_model: int, | ||
n_heads: int, | ||
use_geom_attn: bool = False, | ||
use_plain_attn: bool = True, | ||
v_heads: int | None = None, | ||
bias: bool = False, | ||
expansion_ratio: float = 4.0, | ||
residue_scaling_factor: float = 1, | ||
mask_and_zero_frameless: bool = False, | ||
qk_layernorm: bool = True, | ||
ffn_type: str = "swiglu", # swiglu | gelu | ||
): | ||
super().__init__() | ||
self.use_plain_attn = use_plain_attn | ||
if self.use_plain_attn: | ||
self.attn = MultiHeadAttention( | ||
d_model, n_heads, bias, qk_layernorm=qk_layernorm | ||
) | ||
self.use_geom_attn = use_geom_attn | ||
if self.use_geom_attn: | ||
if v_heads is None: | ||
raise ValueError("v_heads must be specified when use_geom_attn is True") | ||
self.geom_attn = GeometricReasoningOriginalImpl( | ||
c_s=d_model, | ||
v_heads=v_heads, | ||
bias=bias, | ||
mask_and_zero_frameless=mask_and_zero_frameless, | ||
) | ||
if ffn_type == "swiglu": | ||
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias) | ||
elif ffn_type == "gelu": | ||
self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias) | ||
else: | ||
raise ValueError(f"Unknown ffn_type: {ffn_type}") | ||
self.scaling_factor = residue_scaling_factor | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
sequence_id: torch.Tensor, | ||
frames: Affine3D, | ||
frames_mask: torch.Tensor, | ||
chain_id: torch.Tensor, | ||
) -> torch.Tensor: | ||
""" | ||
Forward pass for the UnifiedTransformerBlock. | ||
Parameters | ||
---------- | ||
x : torch.Tensor[float] | ||
Input tensor to the transformer block, typically the output from the previous layer. | ||
sequence_id : torch.Tensor[int] | ||
Tensor containing sequence IDs for each element in the batch, used for attention masking. | ||
frames : Affine3D | ||
Affine3D containing geometric frame information for geometric attention. | ||
frames_mask : torch.Tensor[bool] | ||
Boolean mask tensor indicating valid frames for geometric attention. | ||
chain_id : torch.Tensor[int] | ||
Tensor containing chain IDs for each element, used for attention masking in geometric attention. | ||
Returns | ||
------- | ||
torch.Tensor[float] | ||
The output tensor after applying the transformer block operations. | ||
""" | ||
if self.use_plain_attn: | ||
r1 = self.attn(x, sequence_id) | ||
x = x + r1 / self.scaling_factor | ||
|
||
if self.use_geom_attn: | ||
r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id) | ||
x = x + r2 / self.scaling_factor | ||
|
||
r3 = self.ffn(x) / self.scaling_factor | ||
x = x + r3 | ||
|
||
return x |
Oops, something went wrong.