Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeming Lin committed Jun 25, 2024
0 parents commit b42b58b
Show file tree
Hide file tree
Showing 53 changed files with 8,812 additions and 0 deletions.
80 changes: 80 additions & 0 deletions LICENSE.md

Large diffs are not rendered by default.

106 changes: 106 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# ESM3
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

ESM3 is a *generative* masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.

<!--![ESM3 Diagram](_assets/esm3_diagram.png)-->
<img src="_assets/esm3_diagram.png" alt="ESM3 Diagram" width="400" />

The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.

Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family.
ESM3-open is available under a [non-commercial license](LICENSE.md).
Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3!


## Quickstart for ESM3-open

```
pip install esm
```

In order to download the weights, we require users to accept our non-commercial license.
The weights are stored on HuggingFace Hub under [HuggingFace/EvolutionaryScale/esm3](https://huggingface.co/EvolutionaryScale/esm3).
Please create an account and accept the license.

```py
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# This will prompt you to get an API key from huggingface hub, make one with
# "Read" or "Write" permission and copy it back here.
login()

# This will download the model weights and instantiate the model on your machine.
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") # or "cpu"

# Generate a completion for a partial Carbonic Anhydrase (2vvb)
prompt = "___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
protein = ESMProtein(sequence=prompt)
# Generate the sequence, then the structure. This will iteratively unmask the sequence track.
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# We can show the predicted structure for the generated sequence.
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./generation.pdb")
# Then we can do a round trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.structure = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped.pdb")
```

Congratulations, you just ran a chain of thought with ESM3!
Let's explore some more advanced prompting examples:

[Open examples/generate.ipynb in Colab](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/generate.ipynb)

## Forge: Access to larger ESM3 models
You can apply for beta access to the full family of ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai).

We encourage users to interact with the Forge API through the python `esm` library instead of the command line.
The python interface enables you to interactively load proteins, build prompts, and inspect generated proteins.
Additionally, users can seamlessly swap between `esm.models.esm3.ESM3` running locally, and
`esm.sdk.forge.ESM3ForgeInferenceClient` connecting to the Forge API.

Once the forge client is released, we'll be able to do something like:
```py
model: ESM3InferenceClient = ESMForgeInferenceClient("esm3_sm_open_v1").to("cuda")
...
```
and the exact same code will work.
This will enable seamless access to our large 98B protein language models for protein design work.

## Responsible Development

EvolutionaryScale is a public benefit company. Our mission is to develop artificial intelligence to understand biology for the benefit of human health and society, through partnership with the scientific community, and open, safe, and responsible research. Inspired by the history of our field as well as [new principles and recommendations](https://responsiblebiodesign.ai/), we have created a Responsible Development Framework to guide our work towards our mission with transparency and clarity.

The core tenants of our framework are

- We will communicate the benefits and risks of our research
- We will proactively and rigorously evaluate the risk of our models before public deployment
- We will adopt risk mitigation strategies and precautionary guardrails
- We will work with stakeholders in government, policy, and civil society to keep them informed

With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model)


## License

**The Big Picture:**

1. The EvolutionaryScale AI Model is **only** available under this Community License Agreement for **non-commercial use** by **individuals** or **non-commercial organizations** (including universities, non-profit organizations and research institutes, educational and government bodies).

2. You **may not** use the EvolutionaryScale AI Model or any derivative works of the EvolutionaryScale AI Model or its outputs:

1. in connection with **any commercial activities**, for example, any activities **by, on behalf of or for a commercial entity** or to develop **any product or service** such as hosting the AI Model behind an API; or

2. without attribution to EvolutionaryScale and this Community License Agreement; or

3. to **train** any other **large language model**, any technology for protein representation learning or protein generation or any other AI-powered third party model **similar to EvolutionaryScale’s AI Model**, even for non-commercial usage.

3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including the requirement to **restrict** the usage of any reproductions and copies **by, on behalf of or for a commercial entity** or **for any commercial purpose**.


Please refer to our [non-commercial license](LICENSE.md) for details.
Binary file added _assets/esm3_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added esm/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions esm/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools

import einops
import torch
import torch.nn.functional as F
from torch import nn

from esm.layers.rotary import RotaryEmbedding


class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = False,
qk_layernorm: bool = True,
):
super().__init__()

self.d_model = d_model
self.n_heads = n_heads

self.d_head = self.d_model // self.n_heads
self.layernorm_qkv = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)

if qk_layernorm:
self.q_ln = nn.LayerNorm(d_model, bias=bias)
self.k_ln = nn.LayerNorm(d_model, bias=bias)
else:
self.q_ln = nn.Identity()
self.k_ln = nn.Identity()

self.rotary = RotaryEmbedding(d_model // n_heads)

def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
q = q.unflatten(-1, (self.n_heads, self.d_head))
k = k.unflatten(-1, (self.n_heads, self.d_head))
q, k = self.rotary(q, k)
q = q.flatten(-2, -1)
k = k.flatten(-2, -1)
return q, k

def forward(self, x, seq_id):
qkv_BLD3 = self.layernorm_qkv(x)
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)

n_heads = self.n_heads
reshaper = functools.partial(
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
)

query_BHLD, key_BHLD, value_BHLD = map(
reshaper, (query_BLD, key_BLD, value_BLD)
)

# Where True, enable participation in attention.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)

context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
)
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
return self.out_proj(context_BLD)
153 changes: 153 additions & 0 deletions esm/layers/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from esm.layers.attention import MultiHeadAttention
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.utils.structure.affine3d import Affine3D


def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
# set hidden dimesion to nearest multiple of 256 after expansion ratio
return int(((expansion_ratio * d_model) + 255) // 256 * 256)


class SwiGLU(nn.Module):
"""
SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential.
This module splits the input tensor along the last dimension and applies the SiLU (Swish)
activation function to the first half, then multiplies it by the second half.
"""

def __init__(self):
super(SwiGLU, self).__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2


def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
return nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias
),
SwiGLU(),
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
)


def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
hidden_dim = int(expansion_ratio * d_model)
return nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
nn.GELU(),
nn.Linear(hidden_dim, d_model, bias=bias),
)


class UnifiedTransformerBlock(nn.Module):
"""
A unified transformer block that can optionally incorporate geometric attention.
This class defines a transformer block that can be configured to use geometric attention
alongside the standard multi-head attention mechanism. It is designed to be a flexible
component of transformer-based models, allowing for the integration of geometric reasoning.
Parameters
----------
d_model : int
The dimensionality of the input and output features of the transformer block.
n_heads : int
The number of attention heads in the multi-head attention mechanism.
n_layers : int
The number of layers in the transformer block.
use_geom_attn : bool, optional
Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False.
v_heads : int, optional
The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True.
"""

def __init__(
self,
d_model: int,
n_heads: int,
use_geom_attn: bool = False,
use_plain_attn: bool = True,
v_heads: int | None = None,
bias: bool = False,
expansion_ratio: float = 4.0,
residue_scaling_factor: float = 1,
mask_and_zero_frameless: bool = False,
qk_layernorm: bool = True,
ffn_type: str = "swiglu", # swiglu | gelu
):
super().__init__()
self.use_plain_attn = use_plain_attn
if self.use_plain_attn:
self.attn = MultiHeadAttention(
d_model, n_heads, bias, qk_layernorm=qk_layernorm
)
self.use_geom_attn = use_geom_attn
if self.use_geom_attn:
if v_heads is None:
raise ValueError("v_heads must be specified when use_geom_attn is True")
self.geom_attn = GeometricReasoningOriginalImpl(
c_s=d_model,
v_heads=v_heads,
bias=bias,
mask_and_zero_frameless=mask_and_zero_frameless,
)
if ffn_type == "swiglu":
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
elif ffn_type == "gelu":
self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias)
else:
raise ValueError(f"Unknown ffn_type: {ffn_type}")
self.scaling_factor = residue_scaling_factor

def forward(
self,
x: torch.Tensor,
sequence_id: torch.Tensor,
frames: Affine3D,
frames_mask: torch.Tensor,
chain_id: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass for the UnifiedTransformerBlock.
Parameters
----------
x : torch.Tensor[float]
Input tensor to the transformer block, typically the output from the previous layer.
sequence_id : torch.Tensor[int]
Tensor containing sequence IDs for each element in the batch, used for attention masking.
frames : Affine3D
Affine3D containing geometric frame information for geometric attention.
frames_mask : torch.Tensor[bool]
Boolean mask tensor indicating valid frames for geometric attention.
chain_id : torch.Tensor[int]
Tensor containing chain IDs for each element, used for attention masking in geometric attention.
Returns
-------
torch.Tensor[float]
The output tensor after applying the transformer block operations.
"""
if self.use_plain_attn:
r1 = self.attn(x, sequence_id)
x = x + r1 / self.scaling_factor

if self.use_geom_attn:
r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)
x = x + r2 / self.scaling_factor

r3 = self.ffn(x) / self.scaling_factor
x = x + r3

return x
Loading

0 comments on commit b42b58b

Please sign in to comment.