diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..fc36cc3 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,80 @@ +# EvolutionaryScale Community License Agreement + +Please read this EvolutionaryScale Community License Agreement (the “**Agreement**”) carefully before using the AI Model (as defined below), which is offered by EvolutionaryScale, PBC (“**ES**”). + +By downloading the AI Model, or otherwise using the AI Model in any manner, You agree that You have read and agree to be bound by the terms of this Agreement. If You are accessing the AI Model on behalf of an organization or entity, You represent and warrant that You are authorized to enter into this Agreement on that organization’s or entity’s behalf and bind them to the terms of this Agreement (in which case, the references to “**You**” and “**Your**” in this Agreement, except for in this sentence, refer to that organization or entity) and that such entity is a non-commercial organization (such as a university, non-profit organization, research institute or educational or governmental body). Use of the AI Model is expressly conditioned upon Your assent to all terms of this Agreement, to the exclusion of all other terms. + +## Definitions. + +In addition to other terms defined elsewhere in this Agreement, the terms below have the following meanings. + +1. “**AI Model**” means the EvolutionaryScale ESM-3 Open Model code and model weights made available at the following link [https://github.com/evolutionaryscale/esm] (the “**GitHub Page**”), as may be updated and amended from time to time, whether in Source or Object form, made available to You pursuant to this Agreement. +2. “**Commercial Entity**” means any entity engaged in any activity intended for or directed toward commercial advantage or monetary compensation, including, without limitation, the development of any product or service intended to be sold or made available for a fee. For the purpose of this Agreement, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutes and educational and government bodies. +3. “**Contribution**” means any work of authorship, including the original version of the AI Model and any modifications or additions to that AI Model or Derivative Works thereof, that is intentionally submitted to ES for inclusion in the AI Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "**submitted**" means any form of electronic, verbal, or written communication sent to ES or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, ES for the purpose of discussing and improving the AI Model, but excluding Outputs and all communications that are conspicuously marked or otherwise designated in writing by the copyright owner as "**Not a Contribution**." +4. “**Contributor**” means ES and any individual or Legal Entity on behalf of whom a Contribution has been received by ES and subsequently incorporated within the AI Model. +5. “**Derivative Work**” means any work, whether in Source or Object form, that is based on (or derived from) the AI Model and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this Agreement, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the AI Model and Derivative Works thereof. +6. “**Legal Entity**” means the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "**control**" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. +7. “**Non-Commercial Purposes**” means uses not intended for or directed toward commercial advantage or monetary compensation, or the facilitation of development of any product or service to be sold or made available for a fee. For the avoidance of doubt, the provision of Outputs as a service is not a Non-Commercial Purpose. +8. “**Object**” means any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. +9. “**Output**” means any output, including any protein sequence, structure prediction, functional annotation, molecule, descriptions of a molecule, model, sequence, text, and/or image that is elicited directly or indirectly by, or otherwise made available to, You in connection with Your use of the AI Model, including, but not limited to, the use of AI-Powered Technology. +10. “**Output Derivatives**” means any enhancements, modifications and derivative works of Outputs (including, but not limited to, any derivative sequences or molecules). +11. “**Source**” means the preferred form for making modifications, including but not limited to AI Model source code, documentation source, and configuration files. +12. “**Third Party Model**” means any non-human tool, platform and/or other technology powered or made available in connection with the use of generative artificial intelligence or machine learning models that is operated by any third party. +13. “**You**” or “**Your**” means the individual entering into this Agreement or the organization or entity on whose behalf such individual is entering into this Agreement. + +## Intellectual Property Rights and Licenses. + +1. **Copyright License. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a non-exclusive, non-transferable, limited copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the AI Model and such Derivative Works in Source or Object form for Your Non-Commercial Purposes.** +2. **Patent License**. Subject to the terms and conditions of this Agreement, each Contributor hereby grants to You a non-exclusive, non-transferable, limited patent license to make, have made, use, import, and otherwise transfer the AI Model for Your Non-Commercial Purposes, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the AI Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the AI Model or a Contribution incorporated within the AI Model constitutes direct or contributory patent infringement, then any patent licenses granted to You under this Agreement for that AI Model shall terminate as of the date such litigation is filed. + +## Licensing Process. + +In connection with Your licensing of the AI Model, ES may collect, from You or automatically through Your use of the AI Model, certain registration information about You, any Legal Entity You may represent, and Your use of the AI Model. The collection of this information and ES’s policies and procedures regarding the collection, use, disclosure and security of information received are described further in ES’s Privacy Policy available at [[](https://redpanda.com/legal/privacy-policy/)https://evolutionaryscale.ai/privacy], as may be updated from time to time. + +## Redistribution. + +Subject to Section 5 (Use Restrictions) below, You may reproduce and distribute copies of the AI Model or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form for Your Non-Commercial Purposes, provided that You meet the following conditions: + +1. You must not distribute copies of the AI Model or Derivative Works thereof to, or allow the use of any reproductions or copies thereof by, on behalf of or for, any Commercial Entity; and +2. You must restrict the usage of any copies of the AI Model or Derivative Works to usage for Non-Commercial Purposes; and +3. You must give any other recipients of the AI Model or Derivative Works a copy of this Agreement; and +4. You must cause any modified files to carry prominent notices stating that You changed the files; and +5. You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the AI Model, excluding those notices that do not pertain to any part of the Derivative Works; and +6. If the AI Model includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify this Agreement. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the AI Model, provided that such additional attribution notices cannot be construed as modifying this Agreement. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the AI Model otherwise complies with the conditions stated in this Agreement. + +## Use Restrictions. + +**No Commercial Use**.  You may only use the AI Model, Contributions, Derivative Works, Outputs and Output Derivatives (as defined below) for Non-Commercial Purposes. For the avoidance of doubt, structure tokens are also considered Outputs and may only be used for Non-Commercial Purposes. Any commercial use of any of the foregoing, including, without limitation, any use by, on behalf of or for any Commercial Entity or to facilitate the development of any product or service to be sold or made available for a fee, is strictly prohibited under this Agreement. + +**No Use in Drug Development or Discovery**. Without limiting the foregoing, You may not use the AI Model or any Contributions, Derivative Works, Outputs or Output Derivatives in or in connection with: (i) the development (at any stage) or discovery of any drug, medication or pharmaceutical of any kind; (ii) any molecular or biological target, hit or lead identification; (iii) drug candidate selection; or (iv) lead optimization. + +**Use of Outputs**.  Notwithstanding anything to the contrary in this Agreement, You may not use or provide access to any Outputs or Output Derivatives to train, optimize, improve or otherwise influence the functionality or performance of any: (i) other large language model; (ii) technology for protein structure prediction; ****or (iii) other Third Party Model ****that is similar to the AI Model. You may, however, use the Outputs and Outputs Derivatives to train, optimize, improve or otherwise influence the functionality or performance of the AI Model itself and downstream Derivative Works thereof. + +**Additional Restrictions**.  Your use of the AI Model may also be subject to additional use restrictions communicated to You through the AI Model or otherwise, including those set forth in the ES Acceptable Use Policy available at [https://evolutionaryscale.ai/acceptable-use-policy], as may be updated and amended from time to time (the “**AUP**”), the terms of which are incorporated herein by reference. In the event of any conflict between the terms of this Agreement and the terms of the AUP, the terms that are more restrictive of Your use of the AI Model, Derivative Works, Outputs and Output Derivatives, as applicable, shall govern and control. + +## Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the AI Model by You to ES shall be under the terms and conditions of this Agreement, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with ES regarding such Contributions. + +## Trademarks. + +This Agreement does not grant permission to use the trade names, trademarks, service marks, or product names of ES, except as required for reasonable and customary use in describing the origin of the AI Model and reproducing the content of the NOTICE file. + +## Disclaimer of Warranty. + +UNLESS REQUIRED BY APPLICABLE LAW OR OTHERWISE EXPRESSLY AGREED MUTUALLY AGREED UPON BY YOU AND ES IN WRITING, ES PROVIDES THE AI MODEL (AND EACH CONTRIBUTOR PROVIDES ITS CONTRIBUTIONS) ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, CORRECTNESS, RELIABILITY OR FITNESS FOR A PARTICULAR PURPOSE, ALL OF WHICH ARE HEREBY DISCLAIMED. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE AI MODEL AND ASSUME ANY RISKS ASSOCIATED WITH YOUR EXERCISE OF PERMISSIONS UNDER THIS AGREEMENT. + +## Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the AI Model (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +## General. + +1. **Entire Agreement**. This Agreement constitutes the entire agreement between You and ES relating to the subject matter hereof and supersedes all proposals, understandings, or discussions, whether written or oral, relating to the subject matter of this Agreement and all past dealing or industry custom. The failure of either party to enforce its rights under this Agreement at any time for any period shall not be construed as a waiver of such rights. ES may amend or modify this Agreement from time to time and will use reasonable efforts to provide You with notice of any material changes that may negatively impact Your use of the AI Model through the GitHub Page or through another means made available to You. No other changes, modifications or waivers to this Agreement will be effective unless in writing and signed by both parties. +2. **Relationship of Parties**. You and ES are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose. +3. **Export Control**. You shall comply with the U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority. +4. **Assignment**. This Agreement and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of ES. Any assignment in violation of this provision is void. ES may freely assign or transfer this Agreement, in whole or in part. This Agreement shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties. +5. **Governing Law**. This Agreement shall be governed by and construed under the laws of the State of New York and the United States without regard to conflicts of laws provisions thereof, and without regard to the Uniform Computer Information Transactions Act. +6. **Severability**.  If any provision of this Agreement is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this Agreement otherwise remains in full force and effect and enforceable. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a6d33c1 --- /dev/null +++ b/README.md @@ -0,0 +1,106 @@ +# ESM3 +[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks. + +ESM3 is a *generative* masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does. + + +ESM3 Diagram + +The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters. + +Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family. +ESM3-open is available under a [non-commercial license](LICENSE.md). +Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3! + + +## Quickstart for ESM3-open + +``` +pip install esm +``` + +In order to download the weights, we require users to accept our non-commercial license. +The weights are stored on HuggingFace Hub under [HuggingFace/EvolutionaryScale/esm3](https://huggingface.co/EvolutionaryScale/esm3). +Please create an account and accept the license. + +```py +from huggingface_hub import login +from esm.models.esm3 import ESM3 +from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig + +# This will prompt you to get an API key from huggingface hub, make one with +# "Read" or "Write" permission and copy it back here. +login() + +# This will download the model weights and instantiate the model on your machine. +model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") # or "cpu" + +# Generate a completion for a partial Carbonic Anhydrase (2vvb) +prompt = "___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________" +protein = ESMProtein(sequence=prompt) +# Generate the sequence, then the structure. This will iteratively unmask the sequence track. +protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7)) +# We can show the predicted structure for the generated sequence. +protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8)) +protein.to_pdb("./generation.pdb") +# Then we can do a round trip design by inverse folding the sequence and recomputing the structure +protein.sequence = None +protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8)) +protein.structure = None +protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8)) +protein.to_pdb("./round_tripped.pdb") +``` + +Congratulations, you just ran a chain of thought with ESM3! +Let's explore some more advanced prompting examples: + +[Open examples/generate.ipynb in Colab](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/generate.ipynb) + +## Forge: Access to larger ESM3 models +You can apply for beta access to the full family of ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai). + +We encourage users to interact with the Forge API through the python `esm` library instead of the command line. +The python interface enables you to interactively load proteins, build prompts, and inspect generated proteins. +Additionally, users can seamlessly swap between `esm.models.esm3.ESM3` running locally, and +`esm.sdk.forge.ESM3ForgeInferenceClient` connecting to the Forge API. + +Once the forge client is released, we'll be able to do something like: +```py +model: ESM3InferenceClient = ESMForgeInferenceClient("esm3_sm_open_v1").to("cuda") +... +``` +and the exact same code will work. +This will enable seamless access to our large 98B protein language models for protein design work. + +## Responsible Development + +EvolutionaryScale is a public benefit company. Our mission is to develop artificial intelligence to understand biology for the benefit of human health and society, through partnership with the scientific community, and open, safe, and responsible research. Inspired by the history of our field as well as [new principles and recommendations](https://responsiblebiodesign.ai/), we have created a Responsible Development Framework to guide our work towards our mission with transparency and clarity. + +The core tenants of our framework are + +- We will communicate the benefits and risks of our research +- We will proactively and rigorously evaluate the risk of our models before public deployment +- We will adopt risk mitigation strategies and precautionary guardrails +- We will work with stakeholders in government, policy, and civil society to keep them informed + +With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) + + +## License + +**The Big Picture:** + +1. The EvolutionaryScale AI Model is **only** available under this Community License Agreement for **non-commercial use** by **individuals** or **non-commercial organizations** (including universities, non-profit organizations and research institutes, educational and government bodies). + +2. You **may not** use the EvolutionaryScale AI Model or any derivative works of the EvolutionaryScale AI Model or its outputs: + + 1. in connection with **any commercial activities**, for example, any activities **by, on behalf of or for a commercial entity** or to develop **any product or service** such as hosting the AI Model behind an API; or + + 2. without attribution to EvolutionaryScale and this Community License Agreement; or + + 3. to **train** any other **large language model**, any technology for protein representation learning or protein generation or any other AI-powered third party model **similar to EvolutionaryScale’s AI Model**, even for non-commercial usage. + +3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including the requirement to **restrict** the usage of any reproductions and copies **by, on behalf of or for a commercial entity** or **for any commercial purpose**. + + +Please refer to our [non-commercial license](LICENSE.md) for details. diff --git a/_assets/esm3_diagram.png b/_assets/esm3_diagram.png new file mode 100644 index 0000000..4e6e2d5 Binary files /dev/null and b/_assets/esm3_diagram.png differ diff --git a/esm/__init__.py b/esm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/esm/layers/attention.py b/esm/layers/attention.py new file mode 100644 index 0000000..823aa2c --- /dev/null +++ b/esm/layers/attention.py @@ -0,0 +1,70 @@ +import functools + +import einops +import torch +import torch.nn.functional as F +from torch import nn + +from esm.layers.rotary import RotaryEmbedding + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + bias: bool = False, + qk_layernorm: bool = True, + ): + super().__init__() + + self.d_model = d_model + self.n_heads = n_heads + + self.d_head = self.d_model // self.n_heads + self.layernorm_qkv = nn.Sequential( + nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias) + ) + self.out_proj = nn.Linear(d_model, d_model, bias=bias) + + if qk_layernorm: + self.q_ln = nn.LayerNorm(d_model, bias=bias) + self.k_ln = nn.LayerNorm(d_model, bias=bias) + else: + self.q_ln = nn.Identity() + self.k_ln = nn.Identity() + + self.rotary = RotaryEmbedding(d_model // n_heads) + + def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor): + q = q.unflatten(-1, (self.n_heads, self.d_head)) + k = k.unflatten(-1, (self.n_heads, self.d_head)) + q, k = self.rotary(q, k) + q = q.flatten(-2, -1) + k = k.flatten(-2, -1) + return q, k + + def forward(self, x, seq_id): + qkv_BLD3 = self.layernorm_qkv(x) + query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) + query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD) + query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) + + n_heads = self.n_heads + reshaper = functools.partial( + einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads + ) + + query_BHLD, key_BHLD, value_BHLD = map( + reshaper, (query_BLD, key_BLD, value_BLD) + ) + + # Where True, enable participation in attention. + mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2) + mask_BHLL = mask_BLL.unsqueeze(1) + + context_BHLD = F.scaled_dot_product_attention( + query_BHLD, key_BHLD, value_BHLD, mask_BHLL + ) + context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)") + return self.out_proj(context_BLD) diff --git a/esm/layers/blocks.py b/esm/layers/blocks.py new file mode 100644 index 0000000..7d0203f --- /dev/null +++ b/esm/layers/blocks.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from esm.layers.attention import MultiHeadAttention +from esm.layers.geom_attention import ( + GeometricReasoningOriginalImpl, +) +from esm.utils.structure.affine3d import Affine3D + + +def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: + # set hidden dimesion to nearest multiple of 256 after expansion ratio + return int(((expansion_ratio * d_model) + 255) // 256 * 256) + + +class SwiGLU(nn.Module): + """ + SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential. + This module splits the input tensor along the last dimension and applies the SiLU (Swish) + activation function to the first half, then multiplies it by the second half. + """ + + def __init__(self): + super(SwiGLU, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return F.silu(x1) * x2 + + +def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): + return nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear( + d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias + ), + SwiGLU(), + nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias), + ) + + +def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): + hidden_dim = int(expansion_ratio * d_model) + return nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, hidden_dim, bias=bias), + nn.GELU(), + nn.Linear(hidden_dim, d_model, bias=bias), + ) + + +class UnifiedTransformerBlock(nn.Module): + """ + A unified transformer block that can optionally incorporate geometric attention. + + This class defines a transformer block that can be configured to use geometric attention + alongside the standard multi-head attention mechanism. It is designed to be a flexible + component of transformer-based models, allowing for the integration of geometric reasoning. + + Parameters + ---------- + d_model : int + The dimensionality of the input and output features of the transformer block. + n_heads : int + The number of attention heads in the multi-head attention mechanism. + n_layers : int + The number of layers in the transformer block. + use_geom_attn : bool, optional + Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False. + v_heads : int, optional + The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + use_geom_attn: bool = False, + use_plain_attn: bool = True, + v_heads: int | None = None, + bias: bool = False, + expansion_ratio: float = 4.0, + residue_scaling_factor: float = 1, + mask_and_zero_frameless: bool = False, + qk_layernorm: bool = True, + ffn_type: str = "swiglu", # swiglu | gelu + ): + super().__init__() + self.use_plain_attn = use_plain_attn + if self.use_plain_attn: + self.attn = MultiHeadAttention( + d_model, n_heads, bias, qk_layernorm=qk_layernorm + ) + self.use_geom_attn = use_geom_attn + if self.use_geom_attn: + if v_heads is None: + raise ValueError("v_heads must be specified when use_geom_attn is True") + self.geom_attn = GeometricReasoningOriginalImpl( + c_s=d_model, + v_heads=v_heads, + bias=bias, + mask_and_zero_frameless=mask_and_zero_frameless, + ) + if ffn_type == "swiglu": + self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias) + elif ffn_type == "gelu": + self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias) + else: + raise ValueError(f"Unknown ffn_type: {ffn_type}") + self.scaling_factor = residue_scaling_factor + + def forward( + self, + x: torch.Tensor, + sequence_id: torch.Tensor, + frames: Affine3D, + frames_mask: torch.Tensor, + chain_id: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for the UnifiedTransformerBlock. + + Parameters + ---------- + x : torch.Tensor[float] + Input tensor to the transformer block, typically the output from the previous layer. + sequence_id : torch.Tensor[int] + Tensor containing sequence IDs for each element in the batch, used for attention masking. + frames : Affine3D + Affine3D containing geometric frame information for geometric attention. + frames_mask : torch.Tensor[bool] + Boolean mask tensor indicating valid frames for geometric attention. + chain_id : torch.Tensor[int] + Tensor containing chain IDs for each element, used for attention masking in geometric attention. + + Returns + ------- + torch.Tensor[float] + The output tensor after applying the transformer block operations. + """ + if self.use_plain_attn: + r1 = self.attn(x, sequence_id) + x = x + r1 / self.scaling_factor + + if self.use_geom_attn: + r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id) + x = x + r2 / self.scaling_factor + + r3 = self.ffn(x) / self.scaling_factor + x = x + r3 + + return x diff --git a/esm/layers/codebook.py b/esm/layers/codebook.py new file mode 100644 index 0000000..886d17c --- /dev/null +++ b/esm/layers/codebook.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + + +class EMACodebook(nn.Module): + def __init__( + self, + n_codes, + embedding_dim, + no_random_restart=True, + restart_thres=1.0, + ema_decay=0.99, + ): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + self.freeze_codebook = False + self.ema_decay = ema_decay + + def reset_parameters(self): + # For meta init + pass + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, t, c] + self._need_init = False + flat_inputs = z.view(-1, self.embedding_dim) + y = self._tile(flat_inputs) + + y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, t, c] + if self._need_init and self.training and not self.freeze_codebook: + self._init_embeddings(z) + # z is of shape [batch_size, sequence length, channels] + flat_inputs = z.view(-1, self.embedding_dim) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) # [bt, c] + + encoding_indices = torch.argmin(distances, dim=1) + encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode] + + embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training and not self.freeze_codebook: + assert False, "Not implemented" + embeddings_st = (embeddings - z).detach() + z + + return embeddings_st, encoding_indices, commitment_loss + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + + def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor: + return weights @ self.embeddings diff --git a/esm/layers/ffn.py b/esm/layers/ffn.py new file mode 100644 index 0000000..3e97389 --- /dev/null +++ b/esm/layers/ffn.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# NOT CURRENTLY USED + + +class SwiGLU(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return hidden + + +class FFN(nn.Module): + def __init__(self, in_proj, activation, out_proj) -> None: + super().__init__() + self.in_proj = in_proj + self.activation = activation + self.out_proj = out_proj + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + x = self.activation(x) + x = self.out_proj(x) + return x diff --git a/esm/layers/geom_attention.py b/esm/layers/geom_attention.py new file mode 100644 index 0000000..ec577cd --- /dev/null +++ b/esm/layers/geom_attention.py @@ -0,0 +1,151 @@ +from math import sqrt + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + + +class GeometricReasoningOriginalImpl(nn.Module): + def __init__( + self, + c_s: int, + v_heads: int, + num_vector_messages: int = 1, + mask_and_zero_frameless: bool = True, + divide_residual_by_depth: bool = False, + bias: bool = False, + ): + """Approximate implementation: + + ATTN(A, v) := (softmax_j A_ij) v_j + make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3) + make_vectors(x) := T(i->g) Linear(x).reshape(..., 3) + + v <- make_rot_vectors(x) + q_dir, k_dir <- make_rot_vectors(x) + q_dist, k_dist <- make_vectors(x) + + A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2 + x <- x + Linear(T(g->i) ATTN(A, v)) + """ + super().__init__() + self.c_s = c_s + self.v_heads = v_heads + self.num_vector_messages = num_vector_messages + self.mask_and_zero_frameless = mask_and_zero_frameless + + self.s_norm = nn.LayerNorm(c_s, bias=bias) + dim_proj = ( + 4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages + ) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z) + self.proj = nn.Linear(c_s, dim_proj, bias=bias) + channels_out = self.v_heads * 3 * self.num_vector_messages + self.out_proj = nn.Linear(channels_out, c_s, bias=bias) + + # The basic idea is for some attention heads to pay more or less attention to rotation versus distance, + # as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues + # very nearby or should there be shallower dropoff in attention weight?) + self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) + self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) + + def forward(self, s, affine, affine_mask, sequence_id, chain_id): + attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2) + attn_bias = attn_bias.unsqueeze(1).float() + attn_bias = attn_bias.masked_fill( + ~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min + ) + chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2) + attn_bias = attn_bias.masked_fill( + chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min + ) + + ns = self.s_norm(s) + vec_rot, vec_dist = self.proj(ns).split( + [ + self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages, + self.v_heads * 2 * 3, + ], + dim=-1, + ) + + # Rotate the queries and keys for the rotation term. We also rotate the values. + # NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change + # this in the future. + query_rot, key_rot, value = ( + affine.rot[..., None] + .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3)) + .split( + [ + self.v_heads, + self.v_heads, + self.v_heads * self.num_vector_messages, + ], + dim=-2, + ) + ) + + # Rotate and translate the queries and keys for the distance term + # NOTE(thayes): a simple speedup would be to apply all rotations together, then + # separately apply the translations. + query_dist, key_dist = ( + affine[..., None] + .apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3)) + .chunk(2, dim=-2) + ) + + query_dist = rearrange(query_dist, "b s h d -> b h s 1 d") + key_dist = rearrange(key_dist, "b s h d -> b h 1 s d") + query_rot = rearrange(query_rot, "b s h d -> b h s d") + key_rot = rearrange(key_rot, "b s h d -> b h d s") + value = rearrange( + value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages + ) + + distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3) + rotation_term = query_rot.matmul(key_rot) / sqrt(3) + distance_term_weight = rearrange( + F.softplus(self.distance_scale_per_head), "h -> h 1 1" + ) + rotation_term_weight = rearrange( + F.softplus(self.rotation_scale_per_head), "h -> h 1 1" + ) + + attn_weight = ( + rotation_term * rotation_term_weight - distance_term * distance_term_weight + ) + + if attn_bias is not None: + # we can re-use the attention bias from the transformer layers + # NOTE(thayes): This attention bias is expected to handle two things: + # 1. Masking attention on padding tokens + # 2. Masking cross sequence attention in the case of bin packing + s_q = attn_weight.size(2) + s_k = attn_weight.size(3) + _s_q = max(0, attn_bias.size(2) - s_q) + _s_k = max(0, attn_bias.size(3) - s_k) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + attn_weight = attn_weight + attn_bias + + attn_weight = torch.softmax(attn_weight, dim=-1) + + attn_out = attn_weight.matmul(value) + + attn_out = ( + affine.rot[..., None] + .invert() + .apply( + rearrange( + attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages + ) + ) + ) + + attn_out = rearrange( + attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages + ) + if self.mask_and_zero_frameless: + attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0) + s = self.out_proj(attn_out) + + return s diff --git a/esm/layers/regression_head.py b/esm/layers/regression_head.py new file mode 100644 index 0000000..a237872 --- /dev/null +++ b/esm/layers/regression_head.py @@ -0,0 +1,24 @@ +import torch.nn as nn + + +def RegressionHead( + d_model: int, + output_dim: int, + hidden_dim: int | None = None, +) -> nn.Module: + """Single-hidden layer MLP for supervised output. + + Args: + d_model: input dimension + output_dim: dimensionality of the output. + hidden_dim: optional dimension of hidden layer, defaults to d_model. + Returns: + output MLP module. + """ + hidden_dim = hidden_dim if hidden_dim is not None else d_model + return nn.Sequential( + nn.Linear(d_model, hidden_dim), + nn.GELU(), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, output_dim), + ) diff --git a/esm/layers/rotary.py b/esm/layers/rotary.py new file mode 100644 index 0000000..b820969 --- /dev/null +++ b/esm/layers/rotary.py @@ -0,0 +1,221 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# NOTE: this implementation is from LLaMA 2: +# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114 +# Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary` + +from typing import Tuple + +import torch +from einops import rearrange, repeat + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + seqlen = x.size(1) + cos = cos[:seqlen] + sin = sin[:seqlen] + cos = repeat(cos, "s d -> s 1 (2 d)") + sin = repeat(sin, "s d -> s 1 (2 d)") + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + scaling_factor=1.0, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + scaling_factor: RotaryEmbedding extended with linear scaling. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + self.interleaved = interleaved + self.scale_base = scale_base + self.scaling_factor = scaling_factor + self.device = device + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.reset_parameters() + + def reset_parameters(self): + inv_freq = self._compute_inv_freq(self.device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) + scale = ( + (arange + 0.4 * self.dim) / (1.4 * self.dim) + if self.scale_base is not None + else None + ) + self.register_buffer("scale", scale) + + def _compute_inv_freq(self, device=None): + return 1 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + / self.dim + ) + ) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + t /= self.scaling_factor + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self.inv_freq.to(torch.float32) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + t /= self.scaling_factor + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + q: (batch, seqlen, nheads, headdim) + k: (batch, seqlen, nheads, headdim) + seqlen_offset: can be used in generation where the qkv being passed in is only the last + token in the batch. + """ + self._update_cos_sin_cache( + q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype + ) + assert self._cos_cached is not None + assert self._sin_cached is not None + if self.scale is None: + return ( + apply_rotary_emb_torch( + q, + self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:], + self.interleaved, + True, # inplace=True + ), + apply_rotary_emb_torch( + k, + self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:], + self.interleaved, + True, # inplace=True + ), + ) # type: ignore + else: + assert False diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py new file mode 100644 index 0000000..a650176 --- /dev/null +++ b/esm/layers/structure_proj.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + +from esm.utils.constants.physics import ( + BB_COORDINATES, +) +from esm.utils.structure.affine3d import ( + Affine3D, + RotationMatrix, +) + + +class Dim6RotStructureHead(nn.Module): + # Normally, AF2 uses quaternions to specify rotations. There's some evidence that + # other representations are more well behaved - the best one according to + # https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf + # is using graham schmidt on 2 vectors, which is implemented here. + def __init__( + self, + input_dim: int, + trans_scale_factor: float = 10, + norm_type: str = "layernorm", + activation_fn: str = "esm_gelu", + predict_torsion_angles: bool = True, + ): + super().__init__() + self.ffn1 = nn.Linear(input_dim, input_dim) + self.activation_fn = nn.GELU() + self.norm = nn.LayerNorm(input_dim) + self.proj = nn.Linear(input_dim, 9 + 7 * 2) + self.trans_scale_factor = trans_scale_factor + self.predict_torsion_angles = predict_torsion_angles + self.bb_local_coords = torch.tensor(BB_COORDINATES).float() + + def forward(self, x, affine, affine_mask, **kwargs): + if affine is None: + rigids = Affine3D.identity( + x.shape[:-1], + dtype=x.dtype, + device=x.device, + requires_grad=self.training, + rotation_type=RotationMatrix, + ) + else: + rigids = affine + + # [*, N] + x = self.ffn1(x) + x = self.activation_fn(x) + x = self.norm(x) + trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1) + trans = trans * self.trans_scale_factor + x = x / (x.norm(dim=-1, keepdim=True) + 1e-5) + y = y / (y.norm(dim=-1, keepdim=True) + 1e-5) + update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans) + rigids = rigids.compose(update.mask(affine_mask)) + affine = rigids.tensor + + # We approximate the positions of the backbone atoms in the global frame by applying the rigid + # transformation to the mean of the backbone atoms in the local frame. + all_bb_coords_local = ( + self.bb_local_coords[None, None, :, :] + .expand(*x.shape[:-1], 3, 3) + .to(x.device) + ) + pred_xyz = rigids[..., None].apply(all_bb_coords_local) + + return affine, pred_xyz diff --git a/esm/layers/transformer_stack.py b/esm/layers/transformer_stack.py new file mode 100644 index 0000000..22d65ae --- /dev/null +++ b/esm/layers/transformer_stack.py @@ -0,0 +1,94 @@ +import math + +import torch +import torch.nn as nn + +from esm.layers.blocks import UnifiedTransformerBlock +from esm.utils.structure.affine3d import Affine3D + + +class TransformerStack(nn.Module): + """ + A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock, + which can either be geometric attention or standard multi-head attention. + + Args: + d_model (int): The dimensionality of the input and output feature vectors. + n_heads (int): The number of attention heads. + v_heads (int): The number of voting heads. + n_layers (int): The number of transformer blocks in the stack. + n_layers_geom (int, optional): The number of transformer blocks that use geometric attention. + scale_residue (bool, optional): Whether to scale the residue connections in each transformer block. + mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input. + Only applies in the geometric attention blocks, which is conditioned on the structure + """ + + def __init__( + self, + d_model: int, + n_heads: int, + v_heads: int | None, + n_layers: int, + n_layers_geom: int = 1, + scale_residue: bool = True, + mask_and_zero_frameless: bool = False, + bias: bool = False, + qk_layernorm: bool = True, + ffn_type: str = "swiglu", # swiglu | gelu + expansion_ratio: float = 8 / 3, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + UnifiedTransformerBlock( + d_model, + n_heads, + v_heads=v_heads, + use_geom_attn=i < n_layers_geom, + residue_scaling_factor=( + math.sqrt(n_layers / 36) if scale_residue else 1.0 + ), + expansion_ratio=expansion_ratio, + mask_and_zero_frameless=mask_and_zero_frameless, + bias=bias, + qk_layernorm=qk_layernorm, + ffn_type=ffn_type, + ) + for i in range(n_layers) + ] + ) + self.norm = nn.LayerNorm(d_model, bias=False) + + def forward( + self, + x: torch.Tensor, + sequence_id: torch.Tensor | None = None, + affine: Affine3D | None = None, + affine_mask: torch.Tensor | None = None, + chain_id: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the TransformerStack. + + Args: + x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model). + sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length). + affine (Affine3D | None): The affine transformation tensor or None. + affine_mask (torch.Tensor | None): The affine mask tensor or None. + chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length). + Only used in geometric attention. + + Returns: + post_norm: The output tensor of shape (batch_size, sequence_length, d_model). + pre_norm: The embedding of shape (batch_size, sequence_length, d_model). + """ + *batch_dims, _ = x.shape + if sequence_id is None: + sequence_id = torch.ones( + size=batch_dims, dtype=torch.int64, device=x.device + ) + if chain_id is None: + chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device) + for block in self.blocks: + x = block(x, sequence_id, affine, affine_mask, chain_id) + return self.norm(x), x diff --git a/esm/models/esm3.py b/esm/models/esm3.py new file mode 100644 index 0000000..76e26cd --- /dev/null +++ b/esm/models/esm3.py @@ -0,0 +1,799 @@ +from __future__ import annotations + +import contextlib +from functools import partial + +import attr +import einops +import torch +import torch.nn as nn +from attr import dataclass + +from esm.layers.regression_head import RegressionHead +from esm.layers.transformer_stack import TransformerStack +from esm.models.function_decoder import FunctionTokenDecoder +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) +from esm.sdk.api import ( + ESM3InferenceClient, + ESMProtein, + ESMProteinTensor, + ForwardAndSampleOutput, + ForwardConfig, + ForwardOutput, + ForwardTrackData, + GenerationConfig, + ProteinType, + ReturnLogitsConfig, + SamplingConfig, + SamplingTrackConfig, +) +from esm.tokenization import get_model_tokenizers +from esm.utils import encoding +from esm.utils.constants import esm3 as C +from esm.utils.constants.models import ESM3_OPEN_SMALL +from esm.utils.decoding import decode_protein_tensor +from esm.utils.generation import ( + iterative_sampling_raw, + iterative_sampling_tokens, +) +from esm.utils.misc import rbf +from esm.utils.sampling import ( + get_default_sampling_config, + sample_function_logits, + sample_logits, + sample_residue_annotation_logits, +) +from esm.utils.structure.affine3d import ( + build_affine3d_from_coordinates, +) + + +@dataclass +class ESMOutput: + sequence_logits: torch.Tensor + structure_logits: torch.Tensor + secondary_structure_logits: torch.Tensor + sasa_logits: torch.Tensor + function_logits: torch.Tensor + residue_logits: torch.Tensor + embeddings: torch.Tensor + + +class EncodeInputs(nn.Module): + """ + Module for encoding input features in the ESM-3 model. + + Args: + d_model (int): The dimensionality of the model's hidden states. + """ + + def __init__(self, d_model: int): + super().__init__() + + # Sequence + self.sequence_embed = nn.Embedding(64, d_model) + # Mandatory information + self.plddt_projection = nn.Linear(16, d_model) + self.structure_per_res_plddt_projection = nn.Linear(16, d_model) + + # Structure + self.structure_tokens_embed = nn.Embedding(4096 + 5, d_model) + + # "Structural" features + self.ss8_embed = nn.Embedding(8 + 3, d_model) + self.sasa_embed = nn.Embedding(16 + 3, d_model) + + # "Functional" features + self.function_embed = nn.ModuleList( + [nn.Embedding(260, d_model // 8, padding_idx=0) for _ in range(8)] + ) + + self.residue_embed = nn.EmbeddingBag(1478, d_model, mode="sum", padding_idx=0) + + def forward( + self, + sequence_tokens: torch.Tensor, + structure_tokens: torch.Tensor, + average_plddt: torch.Tensor, + per_res_plddt: torch.Tensor, + ss8_tokens: torch.Tensor, + sasa_tokens: torch.Tensor, + function_tokens: torch.Tensor, + residue_annotation_tokens: torch.Tensor, + ) -> torch.Tensor: + sequence_embed = self.sequence_embed(sequence_tokens) + + rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16) + # the `masked_fill(padding_mask.unsqueeze(2), 0)` for the two below is unnecessary + # as pad tokens never even interact with the "real" tokens (due to sequence_id) + plddt_embed = self.plddt_projection(rbf_16_fn(average_plddt)) + structure_per_res_plddt = self.structure_per_res_plddt_projection( + rbf_16_fn(per_res_plddt) + ) + + # Structure + "structural features" embeds + structure_embed = self.structure_tokens_embed(structure_tokens) + ss8_embed = self.ss8_embed(ss8_tokens) + sasa_embed = self.sasa_embed(sasa_tokens) + + # "Functional" features embeds + function_embed = torch.cat( + [ + embed_fn(funcs) + for embed_fn, funcs in zip( + self.function_embed, function_tokens.unbind(-1) + ) + ], + -1, + ) + + # Residue embeds + B, L, N = residue_annotation_tokens.shape + residue_embed = self.residue_embed( + einops.rearrange( + residue_annotation_tokens, "B L N -> (B L) N", B=B, L=L, N=N + ) + ) + residue_embed = einops.rearrange(residue_embed, "(B L) D -> B L D", B=B, L=L) + + return ( + sequence_embed + + plddt_embed + + structure_per_res_plddt + + structure_embed + + ss8_embed + + sasa_embed + + function_embed + + residue_embed + ) + + +class OutputHeads(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.sequence_head = RegressionHead(d_model, 64) + self.structure_head = RegressionHead(d_model, 4096) + self.ss8_head = RegressionHead(d_model, 8 + 3) + self.sasa_head = RegressionHead(d_model, 16 + 3) + self.function_head = RegressionHead(d_model, 260 * 8) + self.residue_head = RegressionHead(d_model, 1478) + + def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: + sequence_logits = self.sequence_head(x) + structure_logits = self.structure_head(x) + secondary_structure_logits = self.ss8_head(x) + sasa_logits = self.sasa_head(x) + function_logits = self.function_head(x) + function_logits = einops.rearrange( + function_logits, + "... (k v) -> ... k v", + k=8, + ) + + residue_logits = self.residue_head(x) + + return ESMOutput( + sequence_logits=sequence_logits, + structure_logits=structure_logits, + secondary_structure_logits=secondary_structure_logits, + sasa_logits=sasa_logits, + function_logits=function_logits, + residue_logits=residue_logits, + embeddings=embed, + ) + + +class ESM3(nn.Module, ESM3InferenceClient): + """ + ESM3 model implementation. + + Args: + d_model (int): The dimensionality of the input and output feature vectors. + n_heads (int): The number of attention heads in the transformer layers. + v_heads (int): The number of attention heads in the variational transformer layers. + n_layers (int): The number of transformer layers. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + v_heads: int, + n_layers: int, + structure_encoder_name: str, + structure_decoder_name: str, + function_decoder_name: str, + ): + super().__init__() + self.encoder = EncodeInputs(d_model) + self.transformer = TransformerStack( + d_model, + n_heads, + v_heads, + n_layers, + mask_and_zero_frameless=True, + ) + self.output_heads = OutputHeads(d_model) + + self.structure_encoder_name = structure_encoder_name + self.structure_decoder_name = structure_decoder_name + self.function_decoder_name = function_decoder_name + + self.structure_encoder: StructureTokenEncoder | None = None # type: ignore + self.structure_decoder: StructureTokenDecoder | None = None # type: ignore + self.function_decoder: FunctionTokenDecoder | None = None # type: ignore + + self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL) + + @classmethod + def from_pretrained( + cls, + model_name: str = ESM3_OPEN_SMALL, + device: torch.device | str = "cpu", + ) -> ESM3: + from esm.pretrained import load_local_model + + if model_name not in [ESM3_OPEN_SMALL]: + raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.") + model: ESM3 = load_local_model(model_name, device=device) # type: ignore + return model + + def get_structure_token_encoder(self) -> StructureTokenEncoder: + if self.structure_encoder is None: + self.structure_encoder = self.load_model(self.structure_encoder_name) # type: ignore + return self.structure_encoder # type: ignore + + def get_structure_token_decoder(self) -> StructureTokenDecoder: + if self.structure_decoder is None: + self.structure_decoder = self.load_model(self.structure_decoder_name) # type: ignore + return self.structure_decoder # type: ignore + + def get_function_token_decoder(self) -> FunctionTokenDecoder: + if self.function_decoder is None: + self.function_decoder = self.load_model(self.function_decoder_name) # type: ignore + return self.function_decoder # type: ignore + + def load_model(self, model_name: str): + # Lazy import from pretrained + from esm.pretrained import load_local_model + + return load_local_model(model_name, device=next(self.parameters()).device) + + def forward( + self, + *, + sequence_tokens: torch.Tensor | None = None, + structure_tokens: torch.Tensor | None = None, + ss8_tokens: torch.Tensor | None = None, + sasa_tokens: torch.Tensor | None = None, + function_tokens: torch.Tensor | None = None, + residue_annotation_tokens: torch.Tensor | None = None, + average_plddt: torch.Tensor | None = None, + per_res_plddt: torch.Tensor | None = None, + structure_coords: torch.Tensor | None = None, + chain_id: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, + ) -> ESMOutput: + """ + Performs forward pass through the ESM3 model. Check utils to see how to tokenize inputs from raw data. + + Args: + sequence_tokens (torch.Tensor, optional): The amino acid tokens. + structure_tokens (torch.Tensor, optional): The structure tokens. + ss8_tokens (torch.Tensor, optional): The secondary structure tokens. + sasa_tokens (torch.Tensor, optional): The solvent accessible surface area tokens. + function_tokens (torch.Tensor, optional): The function tokens. + residue_annotation_tokens (torch.Tensor, optional): The residue annotation tokens. + average_plddt (torch.Tensor, optional): The average plddt across the entire sequence. + per_res_plddt (torch.Tensor, optional): The per residue plddt, if you want to specify exact plddts, use this, + otherwise, use average_plddt. + structure_coords (torch.Tensor, optional): The structure coordinates, in the form of (B, L, 3, 3). + chain_id (torch.Tensor, optional): The chain ID + sequence_id (torch.Tensor, optional): The sequence ID. + + Returns: + ESMOutput: The output of the ESM3 model. + + Raises: + ValueError: If at least one of the inputs is None. + + """ + # Reasonable defaults: + try: + L, device = next( + (x.shape[1], x.device) + for x in [ + sequence_tokens, + structure_tokens, + ss8_tokens, + sasa_tokens, + structure_coords, + function_tokens, + residue_annotation_tokens, + ] + if x is not None + ) + except StopIteration: + raise ValueError("At least one of the inputs must be non-None") + + t = self.tokenizers + defaults = lambda x, tok: ( + torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x + ) + sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id) + ss8_tokens = defaults(ss8_tokens, C.SS8_UNK_TOKEN) + sasa_tokens = defaults(sasa_tokens, C.SASA_UNK_TOKEN) + average_plddt = defaults(average_plddt, 1).float() + per_res_plddt = defaults(per_res_plddt, 0).float() + chain_id = defaults(chain_id, 0) + sequence_id = defaults(sequence_id, 0) + + if residue_annotation_tokens is None: + residue_annotation_tokens = torch.full( + (1, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device + ) + + if function_tokens is None: + function_tokens = torch.full( + (1, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device + ) + + if structure_coords is None: + structure_coords = torch.full( + (1, L, 3, 3), float("nan"), dtype=torch.float, device=device + ) + + structure_coords = structure_coords[ + ..., :3, : + ] # In case we pass in an atom14 or atom37 repr + affine, affine_mask = build_affine3d_from_coordinates(structure_coords) + + if structure_tokens is None: + _, structure_tokens = self.get_structure_token_encoder().encode( + structure_coords + ) + assert structure_tokens is not None + structure_tokens = ( + structure_tokens.masked_fill( + (structure_tokens == -1) | ~affine_mask, C.STRUCTURE_MASK_TOKEN + ) + .masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN) + .masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN) + .masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN) + .masked_fill( + sequence_tokens == C.SEQUENCE_CHAINBREAK_TOKEN, + C.STRUCTURE_CHAINBREAK_TOKEN, + ) + ) + + x = self.encoder( + sequence_tokens, + structure_tokens, + average_plddt, + per_res_plddt, + ss8_tokens, + sasa_tokens, + function_tokens, + residue_annotation_tokens, + ) + x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id) + return self.output_heads(x, embedding) + + # The following methods are for the ESM3InferenceClient interface + def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: + if isinstance(input, ESMProtein): + return iterative_sampling_raw(self, input, config) + elif isinstance(input, ESMProteinTensor): + return iterative_sampling_tokens(self, input, config, self.tokenizers) + else: + raise ValueError("Input must be an ESMProtein or ESMProteinTensor") + + def encode(self, input: ESMProtein) -> ESMProteinTensor: + input = attr.evolve(input) # Make a copy + + sequence_tokens = None + structure_tokens = None + secondary_structure_tokens = None + sasa_tokens = None + function_tokens = None + residue_annotation_tokens = None + + coordinates = None + + if input.sequence is not None: + sequence_tokens = encoding.tokenize_sequence( + input.sequence, self.tokenizers.sequence, add_special_tokens=True + ) + if input.secondary_structure is not None: + secondary_structure_tokens = encoding.tokenize_secondary_structure( + input.secondary_structure, + self.tokenizers.secondary_structure, + add_special_tokens=True, + ) + if input.sasa is not None: + sasa_tokens = encoding.tokenize_sasa( + input.sasa, self.tokenizers.sasa, add_special_tokens=True + ) + + # Infer input length + sequence_length = -1 + if sequence_tokens is not None: + sequence_length = len(sequence_tokens) + elif secondary_structure_tokens is not None: + sequence_length = len(secondary_structure_tokens) + elif sasa_tokens is not None: + sequence_length = len(sasa_tokens) + + # Try to infer input length from structure data + if input.coordinates is not None: + coordinates, _, structure_tokens = encoding.tokenize_structure( + input.coordinates, + self.get_structure_token_encoder(), + structure_tokenizer=self.tokenizers.structure, + reference_sequence=input.sequence or "", + add_special_tokens=True, + ) + if sequence_length == -1: + sequence_length = len(structure_tokens) + + if sequence_length == -1: + raise ValueError( + "Cannot infer input length from input data. Please provide one of: sequence, structure, secondary_structure, sasa.\n" + "To condition on sequence length only, use ESM3LocalInferenceClient.get_default_sequence(sequence_length) to generate a default sequence input." + ) + + # Function and Residue annotations + if input.function_annotations is not None: + if input.sequence is None: + reference_sequence = encoding.get_default_sequence(sequence_length - 2) + else: + reference_sequence = input.sequence + ( + function_tokens, + residue_annotation_tokens, + ) = encoding.tokenize_function_annotations( + input.function_annotations, + reference_sequence=reference_sequence, + function_tokenizer=self.tokenizers.function, + residue_annotation_tokenizer=self.tokenizers.residue_annotations, + add_special_tokens=True, + ) + + return ESMProteinTensor( + sequence=sequence_tokens, + structure=structure_tokens, + secondary_structure=secondary_structure_tokens, + sasa=sasa_tokens, + function=function_tokens, + residue_annotations=residue_annotation_tokens, + coordinates=coordinates, + ).to(next(self.parameters()).device) + + def decode( + self, + input: ESMProteinTensor, + ) -> ESMProtein: + return decode_protein_tensor( + input=input, + tokenizers=self.tokenizers, + structure_token_decoder=self.get_structure_token_decoder(), + function_token_decoder=self.get_function_token_decoder(), + ) + + def _forward( + self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig() + ) -> ForwardOutput: + # Default plddt conditioning for inference. 1s where coordinates are provided. + if input.coordinates is None: + per_res_plddt = None + else: + # 1.0 if all coordinates at specific indices have valid non-nan values. + per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float() + + with torch.no_grad() if self.eval else contextlib.nullcontext(): + output = self.forward( + sequence_tokens=input.sequence, + structure_tokens=input.structure, + ss8_tokens=input.secondary_structure, + sasa_tokens=input.sasa, + function_tokens=input.function, + residue_annotation_tokens=input.residue_annotations, + average_plddt=torch.tensor(1.0, device=input.device), + per_res_plddt=per_res_plddt, + structure_coords=input.coordinates, + chain_id=None, + sequence_id=None, + ) + + if config.return_logits: + logits = ForwardTrackData( + sequence=output.sequence_logits, + structure=output.structure_logits, + secondary_structure=output.secondary_structure_logits, + sasa=output.sasa_logits, + function=output.function_logits, + ) + else: + logits = None + + return ForwardOutput( + logits=logits, + residue_annotation_logits=output.residue_logits, + embeddings=output.embeddings if config.return_embeddings else None, + ) + + def forward_and_sample( + self, input: ESMProteinTensor, sampling_configuration: SamplingConfig + ) -> ForwardAndSampleOutput: + protein_tensor = attr.evolve(input) # Make a copy + + def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: + return x.clone() if x is not None else None + + device = next(self.parameters()).device + + sampling_config = sampling_configuration + if sampling_config is None: + sampling_config = get_default_sampling_config(self.tokenizers) + + # Initialize default values for missing tracks + default_protein_tensor = ESMProteinTensor.empty( + len(input) - 2, tokenizers=self.tokenizers, device=input.device + ) + for track in attr.fields(ESMProteinTensor): + if getattr(protein_tensor, track.name, None) is None: + setattr( + protein_tensor, + track.name, + getattr(default_protein_tensor, track.name, None), + ) + + # Preprocessing + sequence_length: int = -1 + for track in [ + "sequence", + "structure", + "secondary_structure", + "sasa", + "function", + "residue_annotations", + ]: + input_tensor: torch.Tensor | None = getattr(protein_tensor, track, None) + if input_tensor is not None: + # Add batch dimension if necessary + if track in ["sequence", "structure", "secondary_structure", "sasa"]: + if len(input_tensor.size()) == 1: + input_tensor = input_tensor.unsqueeze(0) # (L,) -> (1, L) + elif track in ["function", "residue_annotations"]: + if len(input_tensor.size()) == 2: + input_tensor = input_tensor.unsqueeze(0) # (L, O) -> (1, L, O) + + # Check length consistency + if sequence_length == -1: + sequence_length = input_tensor.size(1) + else: + if input_tensor.size(1) != sequence_length: + raise ValueError( + f"Length mismatch for track {track}. Expected {sequence_length}, got {input_tensor.size(1)}" + ) + + # Move input tensor to model device + input_tensor = input_tensor.to(device) + setattr(protein_tensor, track, input_tensor) + + if protein_tensor.coordinates is not None: + coordinates = protein_tensor.coordinates + if len(coordinates.size()) == 3: + coordinates = coordinates.unsqueeze(0) + protein_tensor.coordinates = coordinates.to(device) + sequence_length = coordinates.size(1) + + if sequence_length == -1: + raise ValueError("No input data provided") + + # Forward pass + forward_output = self._forward( + protein_tensor, + ForwardConfig( + ReturnLogitsConfig( + sequence=True, + structure=True, + secondary_structure=True, + sasa=True, + function=True, + residue_annotations=True, + ), + return_embeddings=True, + ), + ) + + # Sampling + tokens_dir = {} + track_sampling_metadata_dir: dict[str, dict | None] = {} + for track in ["sequence", "structure", "secondary_structure", "sasa"]: + config = getattr(sampling_config, track) + if config is None: + tokens_dir[track] = maybe_clone(getattr(input, track)) + continue + sampling_metadata = self._sample_track( + logits=getattr(forward_output.logits, track)[0, ...], + tokens=getattr(protein_tensor, track)[0, ...], + sampling_track_config=config, + mask_idx=getattr(self.tokenizers, track).mask_token_id, + ) + tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,) + track_sampling_metadata_dir[track] = sampling_metadata + + # Sample function and residue annotations separately + config = getattr(sampling_config, "function") + if config is None: + tokens_dir["function"] = maybe_clone(getattr(input, "function")) + tokens_dir["residue_annotations"] = maybe_clone( + getattr(input, "residue_annotations") + ) + else: + sampling_metadata = self._sample_function_track( + tokens=getattr(protein_tensor, "function")[0, ...], + logits=getattr(forward_output.logits, "function")[0, ...], + sampling_track_config=config, + ) + tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D) + track_sampling_metadata_dir["function"] = sampling_metadata + + sampled_tokens, _ = sample_residue_annotation_logits( + logits=forward_output.residue_annotation_logits[0, ...] # type: ignore + ) + tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R) + + # Format output + forward_and_sample_output_dir = {} + forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir) + for property in [ + "entropy", + "prob", + "logprob", + "top_prob", + "topk_logprob", + "topk_tokens", + ]: + is_all_none = True + forward_track_data_dir = {} + for track in track_sampling_metadata_dir.keys(): + values = track_sampling_metadata_dir[track] + if values is not None and values.get(property, None) is not None: + forward_track_data_dir[track] = values.get(property, None) + is_all_none = False + if not is_all_none: + forward_and_sample_output_dir[property] = ForwardTrackData( + **forward_track_data_dir + ) + else: + forward_and_sample_output_dir[property] = None + + perres_embed = ( + forward_output.embeddings[0] # type: ignore + if sampling_configuration.return_per_residue_embeddings + else None + ) + mean_embedding = ( + forward_output.embeddings[0].mean(1) # type: ignore + if sampling_configuration.return_per_residue_embeddings + else None + ) + + return ForwardAndSampleOutput( + per_residue_embedding=perres_embed, + mean_embedding=mean_embedding, + **forward_and_sample_output_dir, + ) + + def _sample_track( + self, + logits: torch.Tensor, + tokens: torch.Tensor, + sampling_track_config: SamplingTrackConfig, + mask_idx: int, + ) -> dict[str, torch.Tensor]: + # Sample in all positions + temperature = sampling_track_config.temperature + sampled_tokens = sample_logits( + logits, temperature=temperature, top_p=sampling_track_config.top_p + ) + log_probs = logits.log_softmax(-1) + + # Do not sample at BOS and EOS tokens + sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (L, ) + sampling_mask[0] = False + sampling_mask[-1] = False + + # Do not sample at special token positions but allow sampling at mask token + special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx}) + if len(special_minus_mask) > 0: + special_tokens = torch.tensor(special_minus_mask, device=tokens.device) + assert special_tokens.numel() > 0 + sampling_mask = sampling_mask & ( + tokens[..., None] != special_tokens[None, :] + ).all(-1) + + # Keep only samples from masked positions (if specified) + if sampling_track_config.only_sample_masked_tokens: + masked_tokens = tokens == mask_idx + sampling_mask = sampling_mask & masked_tokens + sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) + + return self._compute_track_metadata( + sampled_tokens, + log_probs, + sampling_mask, + top_k=sampling_track_config.topk_logprobs, + ) + + def _sample_function_track( + self, + tokens: torch.Tensor, + logits: torch.Tensor, + sampling_track_config: SamplingTrackConfig, + ) -> dict[str, torch.Tensor]: + # Do not sample at BOS and EOS tokens + sampling_mask = torch.ones_like(tokens, dtype=torch.bool) + sampling_mask[0] = False + sampling_mask[-1] = False + + sampled_tokens, probs = sample_function_logits( + logits, + self.tokenizers.function, + top_p=sampling_track_config.top_p, + temperature=sampling_track_config.temperature, + ) + + if sampling_track_config.only_sample_masked_tokens: + raise ValueError( + "Sampling only masked tokens is undefined for function tokens." + ) + + sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) # (L, D) + + return self._compute_track_metadata( + sampled_tokens, + probs, + sampling_mask, + top_k=sampling_track_config.topk_logprobs, + ) + + @staticmethod + def _compute_track_metadata( + sampled_tokens: torch.Tensor, + log_probs: torch.Tensor, + sampling_mask: torch.Tensor, + top_k: int, + ) -> dict: + probs = torch.exp(log_probs) # (B, L) + entropy = torch.distributions.Categorical(probs=probs).entropy() # (B, L) + + # Only compute probabilities for sampled tokens + sampled_logprob = torch.zeros_like( + sampled_tokens, dtype=torch.float32 + ) # (B, L) + sampled_tokens_valid = sampled_tokens[sampling_mask] + sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid] + sampled_logprob[sampling_mask] = sampled_log_probs_valid + + # Calculate extra metadata + sampled_prob = torch.exp(sampled_logprob) + top_prob = torch.max(probs, dim=-1).values + topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1) + topk_logprobs = None if top_k == 0 else topk_logprobs + topk_tokens = None if top_k == 0 else topk_tokens + + return { + "entropy": entropy, + "sampled_tokens": sampled_tokens, + "prob": sampled_prob, + "logprob": sampled_logprob, + "top_prob": top_prob, + "topk_logprob": topk_logprobs, + "topk_tokens": topk_tokens, + } diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py new file mode 100644 index 0000000..7e6d4f6 --- /dev/null +++ b/esm/models/function_decoder.py @@ -0,0 +1,339 @@ +"""Function Token Decoder.""" + +from collections import defaultdict +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F + +from esm.layers.regression_head import RegressionHead +from esm.layers.transformer_stack import TransformerStack +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.utils.constants import esm3 as C +from esm.utils.misc import merge_ranges +from esm.utils.types import FunctionAnnotation + + +@dataclass(frozen=True) +class FunctionTokenDecoderConfig: + """Configures function token decoder.""" + + # Embedding dimension of decoder. + d_model: int = 1024 + # Number of attention heads of decoder. + n_heads: int = 8 + # Number of layers of decoder. + n_layers: int = 3 + # Number of integer values that function tokens may assume. + function_token_vocab_size: int = 260 + # Number of function tokens at each position. + function_token_depth: int = 8 + # Number of InterPro labels that can be decoded. + num_interpro_classes: int = 29026 + # Number of function keywords that can be decoded. + keyword_vocabulary_size: int = 58641 + # List of supported InterPro ids. + interpro_entry_list: str = field( + default_factory=lambda: str(C.data_root() / C.INTERPRO_ENTRY) + ) + # Path to keywords vocabulary. + keyword_vocabulary_path: str = field( + default_factory=lambda: str(C.data_root() / C.KEYWORDS_VOCABULARY) + ) + # Whether to unpack LSH bits into single-bit tokens. + unpack_lsh_bits: bool = True + # The number of special tokens in the function tokenizer vocabulary which come + # before the LSH tokens. + num_special_tokens: int = 4 + # The number of bits per LSH token in the function tokenizer. + bits_per_token: int = 8 + + +class FunctionTokenDecoder(nn.Module): + def __init__(self, config: FunctionTokenDecoderConfig | None = None): + """Constructs function token decoder.""" + super().__init__() + if config is None: + config = FunctionTokenDecoderConfig() + self.config = config + + # Get the supported set of InterPro ids. + df = pd.read_csv(config.interpro_entry_list, sep="\t") + self.interpro_ids = sorted(df.ENTRY_AC) + self.interpro2index = { + interpro_id: i for i, interpro_id in enumerate(self.interpro_ids) + } + assert len(self.interpro_ids) == config.num_interpro_classes + + with open(config.keyword_vocabulary_path, "r") as f: + self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n")) + assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size + + if config.unpack_lsh_bits: + vocab_size = 2 * config.function_token_depth * config.bits_per_token + else: + # Function-token id's re-use the same token ids at each position along the depth + # dimension, despite distinct meanings. The decoder should take this into + # account so create distinct embeddings for tokens at each position. + vocab_size = ( + self.config.function_token_depth * self.config.function_token_vocab_size + ) + + self.embedding = nn.Embedding( + # Function-token id's re-use the same token ids at each position along the + # depth dimension, despite distinct meanings. The decoder should take this + # into account so create distinct embeddings for tokens at each position. + num_embeddings=(vocab_size), + embedding_dim=config.d_model, + ) + self.decoder = TransformerStack( + d_model=config.d_model, + n_heads=config.n_heads, + v_heads=None, + n_layers=config.n_layers, + n_layers_geom=0, + scale_residue=False, + bias=True, + qk_layernorm=False, + ffn_type="gelu", + expansion_ratio=4, + ) + self.heads = nn.ModuleDict( + { + # Binary classification head predicting which keywords are present. + "keyword_logits": RegressionHead( + d_model=config.d_model, + output_dim=config.keyword_vocabulary_size, + hidden_dim=4 * config.d_model, + ), + # Regresses the TF-IDF value of each present keyword. + "keyword_tfidf": RegressionHead( + d_model=config.d_model, + output_dim=config.keyword_vocabulary_size, + hidden_dim=4 * config.d_model, + ), + # Predicts which InterPro annotations are present. + "interpro_logits": RegressionHead( + d_model=config.d_model, + output_dim=config.num_interpro_classes, + hidden_dim=4 * config.d_model, + ), + } + ) + + def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]: + """Forward pass through function token decoder. + + Args: + token_ids: [batch_size, function_token_depth] batch of function tokens + ids to decode. + Returns: + interpro_logits: binary classification logits tensor of shape + [batch_size, num_interpro_classes] + """ + assert token_ids.ndim == 2 + assert token_ids.shape[1] == self.config.function_token_depth + batch_size, depth = token_ids.shape + + if self.config.unpack_lsh_bits: + # Shift values into [0, 2^bits/token) + lsh_bits = token_ids - self.config.num_special_tokens + # extract each bit. (hob stands for highest-order bit) + bits = torch.concat( + [ + torch.bitwise_and(lsh_bits, 1 << hob).gt(0).to(torch.int32) + for hob in range(self.config.bits_per_token) + ], + dim=1, + ) + assert bits.shape == (batch_size, depth * self.config.bits_per_token) + + # Shift each bit into individual vocabulary ranges, so they get distinct + # embeddings. + vocab_offsets = 2 * torch.arange( + depth * self.config.bits_per_token, device=token_ids.device + ) + inputs = vocab_offsets[None, :] + bits + + # zero-out special tokens, i.e. non LSH tokens. + where_special = token_ids < self.config.num_special_tokens + inputs = torch.where(where_special.any(dim=1, keepdim=True), 0, inputs) + else: + # Apply depth-position offset to use distinct vocabs. See __init__ for + # explaination. + vocab_offsets = self.config.function_token_vocab_size * torch.arange( + self.config.function_token_depth, + device=token_ids.device, + ) + inputs = token_ids + vocab_offsets[None, :] + + embed = self.embedding(inputs) + encoding, _ = self.decoder(embed) + pooled = torch.mean(encoding, dim=1) + + return {name: head(pooled) for name, head in self.heads.items()} + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def decode( + self, + function_token_ids: torch.Tensor, + tokenizer: InterProQuantizedTokenizer, + decode_annotations: bool = True, + annotation_threshold: float = 0.1, + decode_keywords=True, + keywords_threshold: float = 0.5, + annotation_min_length: int | None = 5, + annotation_gap_merge_max: int | None = 3, + ): + """Decodes function tokens into predicted annotations and keywords. + + Args: + function_token_ids: [length, depth] function token ids. NOTE: + without / prefix + tokenizer: function tokenizer. + decode_annotations: whether to decode InterPro annotations. + annotation_threshold: threshold for emitting a function annotation. + decode_keywords: whether to decode function keywords. + keywords_threshold: threshold for emitting a keyword. + annotation_min_length: optional minimum length of predicted annotations for + size filtering. + annotation_gap_merge_max: optional merge adjacent annotation of the same type + Returns: + Decoder outputs: + - "interpro_logits": [length, num_interpro] predicted interpro logits. + - "interpro_preds": [length, num_interpro] predicted intepro labels. + - "interpro_annotations": list[FunctionAnnotation] predicted InterPro + annotations + - "keyword_logits": [length, keyword_vocabulary] binary prediciton + logits for keywrods. + - "function_keywords": list[FunctionAnnotation] predicted function keyword + ranges. + """ + assert function_token_ids.ndim == 2 + assert function_token_ids.shape[1] == tokenizer.depth + assert self.config.function_token_depth == tokenizer.depth + + outputs = {} + + outputs = self(function_token_ids.to(self.device)) + + # Only decode in positions that have function tokens. + where_decode = torch.all( + (function_token_ids != tokenizer.vocab_to_index[""]) + & (function_token_ids != tokenizer.vocab_to_index[""]) + & (function_token_ids != tokenizer.vocab_to_index[""]), + dim=1, + ) + + # Decode InterPro annotations ranges. + interpro_preds = F.sigmoid(outputs["interpro_logits"]) + interpro_preds = interpro_preds >= annotation_threshold + interpro_preds[~where_decode, :] = False + outputs["interpro_preds"] = interpro_preds + if decode_annotations: + annotations: list[FunctionAnnotation] = [] + preds: np.ndarray = interpro_preds.detach().cpu().numpy() + for position_index, class_index in zip(*preds.nonzero()): + interpro_id = self.interpro_ids[class_index] + annotation = FunctionAnnotation( + label=interpro_id, + start=position_index + 1, # zero-index -> one-index inclusive + end=position_index + 1, # zero-index -> one-index inclusive + ) + annotations.append(annotation) + + annotations = _merge_annotations( + annotations, + merge_gap_max=annotation_gap_merge_max, + ) + + # Drop very small annotations. + if annotation_min_length is not None: + annotations = [ + annotation + for annotation in annotations + if annotation.end - annotation.start + 1 >= annotation_min_length + ] + + outputs["interpro_annotations"] = annotations + + # Decode function keyword ranges. + keyword_logits = outputs["keyword_logits"] + keyword_logits[~where_decode, :] = -torch.inf + if decode_keywords: + keyword_preds = F.sigmoid(keyword_logits) >= keywords_threshold + outputs["function_keywords"] = self._preds_to_keywords( + keyword_preds.detach().cpu().numpy() + ) + + return outputs + + def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotation]: + """Converts output log-TFDF to predicted keywords over the sequence. + + Args: + keyword_precs: [length, keyword_vocab] positional predictions of + function keywords from the keyword prediction head. + Returns: + Non-overlapping keyword annotated ranges along the sequence. Note that indices + will index into the *sequence*, not the function token array which has a + prefix. + """ + assert keyword_preds.ndim == 2 + assert keyword_preds.shape[1] == self.config.keyword_vocabulary_size + + keyword_positions: dict[str, list[range]] = defaultdict(list) + for position, keyword_id in zip(*np.nonzero(keyword_preds)): + keyword = self.keywords_vocabulary[keyword_id] + keyword_positions[keyword].append(range(position, position + 1)) + + annotations: list[FunctionAnnotation] = [] + for keyword, ranges in keyword_positions.items(): + for range_ in merge_ranges(ranges): + annotation = FunctionAnnotation( + label=keyword, + start=range_.start + 1, # zero-index -> one-index + end=range_.stop + 1 - 1, # zero-index excl -> one-index incl + ) + annotations.append(annotation) + + return annotations + + +def _merge_annotations( + annotations: list[FunctionAnnotation], + merge_gap_max: int | None = None, +) -> list[FunctionAnnotation]: + """Merges annotations into non-overlapping segments. + + Args: + annotations: annotations to merge. + merge_gap_max: optionally merge neighboring ranges that are separated by a gap + no larger than this size. + Returns: + non-overlapping annotations with gaps merged. + """ + grouped: dict[str, list[range]] = defaultdict(list) + for a in annotations: + # Convert one-indexed inclusive-inclusive, to range() + grouped[a.label].append(range(a.start, a.end + 1)) + + merged = [] + for label, ranges in grouped.items(): + merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max) + for range_ in merged_ranges: + annotation = FunctionAnnotation( + label=label, + start=range_.start + 1, # zero-index -> one-index + end=range_.stop - 1, # zero-index excl -> one-index incl + ) + merged.append(annotation) + return merged diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py new file mode 100644 index 0000000..1e4afa9 --- /dev/null +++ b/esm/models/vqvae.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn + +from esm.layers.blocks import UnifiedTransformerBlock +from esm.layers.codebook import EMACodebook +from esm.layers.structure_proj import Dim6RotStructureHead +from esm.layers.transformer_stack import TransformerStack +from esm.utils.constants import esm3 as C +from esm.utils.misc import knn_graph +from esm.utils.structure.affine3d import ( + Affine3D, + build_affine3d_from_coordinates, +) +from esm.utils.structure.predicted_aligned_error import ( + compute_predicted_aligned_error, + compute_tm, +) + + +class RelativePositionEmbedding(nn.Module): + """ + Embedding layer for relative position embeddings. `bins` is the number of positions relative + to the query position that are considered before clipping. For instance, if `bins=10`, then + the relative position embedding will have 21 positions, [-10, 10]. + """ + + def __init__(self, bins, embedding_dim, init_std=0.02): + super().__init__() + self.bins = bins + + self.embedding = torch.nn.Embedding(2 * bins + 2, embedding_dim) + self.embedding.weight.data.normal_(0, init_std) + + def forward(self, query_residue_index, key_residue_index): + """ + Input: + query_residue_index: (B, ) tensor of source indices (dytpe=torch.long) + key_residue_index: (B, L) tensor of target indices (dytpe=torch.long) + Output: + embeddings: B x L x embedding_dim tensor of embeddings + """ + + assert query_residue_index.dtype == torch.long + assert key_residue_index.dtype == torch.long + assert query_residue_index.ndim == 1 + assert key_residue_index.ndim == 2 + + diff = key_residue_index - query_residue_index.unsqueeze(1) + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # add 1 to adjust for padding index + output = self.embedding(diff) + return output + + +class PairwisePredictionHead(nn.Module): + def __init__( + self, + input_dim: int, + downproject_dim: int, + hidden_dim: int, + n_bins: int, + bias: bool = True, + pairwise_state_dim: int = 0, + ): + super().__init__() + self.downproject = nn.Linear(input_dim, downproject_dim, bias=bias) + self.linear1 = nn.Linear( + downproject_dim + pairwise_state_dim, hidden_dim, bias=bias + ) + self.activation_fn = nn.GELU() + self.norm = nn.LayerNorm(hidden_dim) + self.linear2 = nn.Linear(hidden_dim, n_bins, bias=bias) + + def forward(self, x, pairwise: torch.Tensor | None = None): + """ + Args: + x: [B x L x D] + + Output: + [B x L x L x K] + """ + x = self.downproject(x) + # Let x_i be a vector of size (B, D). + # Input is {x_1, ..., x_L} of size (B, L, D) + # Output is 2D where x_ij = cat([x_i * x_j, x_i - x_j]) + q, k = x.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + x_2d = [ + prod, + diff, + ] + if pairwise is not None: + x_2d.append(pairwise) + x = torch.cat(x_2d, dim=-1) + x = self.linear1(x) + x = self.activation_fn(x) + x = self.norm(x) + x = self.linear2(x) + return x + + +class RegressionHead(nn.Module): + def __init__(self, embed_dim: int, output_dim: int): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = nn.GELU() + self.norm = nn.LayerNorm(embed_dim) + self.output = nn.Linear(embed_dim, output_dim) + + def forward(self, features): + x = self.dense(features) + x = self.activation_fn(x) + x = self.norm(x) + x = self.output(x) + return x + + +class CategoricalMixture: + def __init__(self, param, bins=50, start=0, end=1): + # All tensors are of shape ..., bins. + self.logits = param + bins = torch.linspace( + start, end, bins + 1, device=self.logits.device, dtype=torch.float32 + ) + self.v_bins = (bins[:-1] + bins[1:]) / 2 + + def log_prob(self, true): + # Shapes are: + # self.probs: ... x bins + # true : ... (floating point # for target) + true_index = ( + (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) + ) + nll = self.logits.log_softmax(-1) + return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) + + def mean(self): + return ( + self.logits.to(self.v_bins.dtype).softmax(-1) @ self.v_bins.unsqueeze(1) + ).squeeze(-1) + + def median(self): + return self.v_bins[self.logits.max(-1).indices] + + +class GeometricEncoderStack(TransformerStack): + def __init__(self, d_model, n_heads, v_heads, n_layers): + super().__init__(d_model, n_heads, v_heads, 0) + self.blocks = nn.ModuleList( + [ + UnifiedTransformerBlock( + d_model, + n_heads, + v_heads=v_heads, + use_geom_attn=True, + use_plain_attn=False, + expansion_ratio=4, + bias=True, + ) + for i in range(n_layers) + ] + ) + self.norm = nn.Identity() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: + return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1) + + +class StructureTokenEncoder(nn.Module): + def __init__(self, d_model, n_heads, v_heads, n_layers, d_out, n_codes): + super().__init__() + # We only support fully-geometric structure token encoders for now... + # setting n_layers_geom to something that's not n_layers won't work because + # sequence ID isn't supported fully in this repo for plain-old transformers + self.transformer = GeometricEncoderStack(d_model, n_heads, v_heads, n_layers) + self.pre_vq_proj = nn.Linear(d_model, d_out) + self.codebook = EMACodebook(n_codes, d_out) + self.relative_positional_embedding = RelativePositionEmbedding( + 32, d_model, init_std=0.02 + ) + self.knn = 16 + + def encode_local_structure( + self, + coords: torch.Tensor, + affine: Affine3D, + attention_mask: torch.Tensor, + sequence_id: torch.Tensor | None, + affine_mask: torch.Tensor, + residue_index: torch.Tensor | None = None, + ): + """This function allows for a multi-layered encoder to encode tokens with a local receptive fields. The implementation is as follows: + + 1. Starting with (B, L) frames, we find the KNN in structure space. This now gives us (B, L, K) where the last dimension is the local + neighborhood of all (B, L) residues. + 2. We reshape these frames to (B*L, K) so now we have a large batch of a bunch of local neighborhoods. + 3. Pass the (B*L, K) local neighborhoods through a stack of geometric reasoning blocks, effectively getting all to all communication between + all frames in the local neighborhood. + 4. This gives (B*L, K, d_model) embeddings, from which we need to get a single embedding per local neighborhood. We do this by simply + taking the embedding corresponding to the query node. This gives us (B*L, d_model) embeddings. + 5. Reshape back to (B, L, d_model) embeddings + """ + assert coords.size(-1) == 3 and coords.size(-2) == 3, "need N, CA, C" + with torch.no_grad(): + knn_edges, _ = self.find_knn_edges( + coords, + ~attention_mask, + coord_mask=affine_mask, + sequence_id=sequence_id, + knn=self.knn, + ) + B, L, E = knn_edges.shape + + affine_tensor = affine.tensor # for easier manipulation + T_D = affine_tensor.size(-1) + knn_affine_tensor = node_gather(affine_tensor, knn_edges) + knn_affine_tensor = knn_affine_tensor.view(-1, E, T_D).contiguous() + affine = Affine3D.from_tensor(knn_affine_tensor) + knn_sequence_id = ( + node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E) + if sequence_id is not None + else torch.zeros(L, E, dtype=torch.int64, device=coords.device) + ) + knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view( + -1, E + ) + knn_chain_id = torch.zeros(L, E, dtype=torch.int64, device=coords.device) + + if residue_index is None: + res_idxs = knn_edges.view(-1, E) + else: + res_idxs = node_gather(residue_index.unsqueeze(-1), knn_edges).view( + -1, E + ) + + z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs) + + z, _ = self.transformer.forward( + x=z, + sequence_id=knn_sequence_id, + affine=affine, + affine_mask=knn_affine_mask, + chain_id=knn_chain_id, + ) + + # Unflatten the output and take the query node embedding, which will always be the first one because + # a node has distance 0 with itself and the KNN are sorted. + z = z.view(B, L, E, -1) + z = z[:, :, 0, :] + + return z + + @staticmethod + def find_knn_edges( + coords, + padding_mask, + coord_mask, + sequence_id: torch.Tensor | None = None, + knn: int | None = None, + ) -> tuple: + assert knn is not None, "Must specify a non-null knn to find_knn_edges" + # Coords are N, CA, C + coords = coords.clone() + coords[~coord_mask] = 0 + + if sequence_id is None: + sequence_id = torch.zeros( + (coords.shape[0], coords.shape[1]), device=coords.device + ).long() + + with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore + ca = coords[..., 1, :] + edges, edge_mask = knn_graph( + ca, + coord_mask, + padding_mask, + sequence_id, + no_knn=knn, + ) + + return edges, edge_mask + + def encode( + self, + coords: torch.Tensor, + attention_mask: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, + residue_index: torch.Tensor | None = None, + ): + coords = coords[..., :3, :] + affine, affine_mask = build_affine3d_from_coordinates(coords=coords) + + if attention_mask is None: + attention_mask = torch.ones_like(affine_mask, dtype=torch.bool) + attention_mask = attention_mask.bool() + + if sequence_id is None: + sequence_id = torch.zeros_like(affine_mask, dtype=torch.int64) + + z = self.encode_local_structure( + coords=coords, + affine=affine, + attention_mask=attention_mask, + sequence_id=sequence_id, + affine_mask=affine_mask, + residue_index=residue_index, + ) + + z = z.masked_fill(~affine_mask.unsqueeze(2), 0) + z = self.pre_vq_proj(z) + + z_q, min_encoding_indices, _ = self.codebook(z) + + return z_q, min_encoding_indices + + +class StructureTokenDecoder(nn.Module): + def __init__( + self, + d_model, + n_heads, + n_layers, + ): + super().__init__() + self.decoder_channels = d_model + + self.vqvae_codebook_size = C.VQVAE_CODEBOOK_SIZE + self.special_tokens = C.VQVAE_SPECIAL_TOKENS + self.max_pae_bin = C.VQVAE_MAX_PAE_BIN + + self.embed = nn.Embedding( + self.vqvae_codebook_size + len(self.special_tokens), d_model + ) + self.decoder_stack = TransformerStack( + d_model, n_heads, 1, n_layers, scale_residue=False, n_layers_geom=0 + ) + + self.affine_output_projection = Dim6RotStructureHead( + self.decoder_channels, 10, predict_torsion_angles=False + ) + + direction_loss_bins = C.VQVAE_DIRECTION_LOSS_BINS + pae_bins = C.VQVAE_PAE_BINS + self.pairwise_bins = [ + 64, # distogram + direction_loss_bins * 6, # direction bins + pae_bins, # predicted aligned error + ] + self.pairwise_classification_head = PairwisePredictionHead( + self.decoder_channels, + downproject_dim=128, + hidden_dim=128, + n_bins=sum(self.pairwise_bins), + bias=False, + ) + + plddt_bins = C.VQVAE_PLDDT_BINS + self.plddt_head = RegressionHead( + embed_dim=self.decoder_channels, output_dim=plddt_bins + ) + + def decode( + self, + structure_tokens: torch.Tensor, + attention_mask: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, + ): + if attention_mask is None: + attention_mask = torch.ones_like(structure_tokens, dtype=torch.bool) + + attention_mask = attention_mask.bool() + if sequence_id is None: + sequence_id = torch.zeros_like(structure_tokens, dtype=torch.int64) + # not supported for now + chain_id = torch.zeros_like(structure_tokens, dtype=torch.int64) + + # check that BOS and EOS are set correctly + assert ( + structure_tokens[:, 0].eq(self.special_tokens["BOS"]).all() + ), "First token in structure_tokens must be BOS token" + assert ( + structure_tokens[ + torch.arange(structure_tokens.shape[0]), attention_mask.sum(1) - 1 + ] + .eq(self.special_tokens["EOS"]) + .all() + ), "Last token in structure_tokens must be EOS token" + assert ( + (structure_tokens < 0).sum() == 0 + ), "All structure tokens set to -1 should be replaced with BOS, EOS, PAD, or MASK tokens by now, but that isn't the case!" + + x = self.embed(structure_tokens) + # !!! NOTE: Attention mask is actually unused here so watch out + x, _ = self.decoder_stack.forward( + x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id + ) + + tensor7_affine, bb_pred = self.affine_output_projection( + x, affine=None, affine_mask=torch.zeros_like(attention_mask) + ) + + pae, ptm = None, None + pairwise_logits = self.pairwise_classification_head(x) + _, _, pae_logits = [ + (o if o.numel() > 0 else None) + for o in pairwise_logits.split(self.pairwise_bins, dim=-1) + ] + + special_tokens_mask = structure_tokens >= min(self.special_tokens.values()) + pae = compute_predicted_aligned_error( + pae_logits, # type: ignore + aa_mask=~special_tokens_mask, + sequence_id=sequence_id, + max_bin=self.max_pae_bin, + ) + # This might be broken for chainbreak tokens? We might align to the chainbreak + ptm = compute_tm( + pae_logits, # type: ignore + aa_mask=~special_tokens_mask, + max_bin=self.max_pae_bin, + ) + + plddt_logits = self.plddt_head(x) + plddt_value = CategoricalMixture( + plddt_logits, bins=plddt_logits.shape[-1] + ).mean() + + return dict( + tensor7_affine=tensor7_affine, + bb_pred=bb_pred, + plddt=plddt_value, + ptm=ptm, + predicted_aligned_error=pae, + ) diff --git a/esm/pretrained.py b/esm/pretrained.py new file mode 100644 index 0000000..141ba2f --- /dev/null +++ b/esm/pretrained.py @@ -0,0 +1,95 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from esm.models.esm3 import ESM3 +from esm.models.function_decoder import FunctionTokenDecoder +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) +from esm.utils.constants.esm3 import data_root +from esm.utils.constants.models import ( + ESM3_FUNCTION_DECODER_V0, + ESM3_OPEN_SMALL, + ESM3_STRUCTURE_DECODER_V0, + ESM3_STRUCTURE_ENCODER_V0, +) + +ModelBuilder = Callable[[torch.device | str], nn.Module] + + +def ESM3_sm_open_v0(device: torch.device | str = "cpu"): + model = ( + ESM3( + d_model=1536, + n_heads=24, + v_heads=256, + n_layers=48, + structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0, + structure_decoder_name=ESM3_STRUCTURE_DECODER_V0, + function_decoder_name=ESM3_FUNCTION_DECODER_V0, + ) + .to(device) + .eval() + ) + state_dict = torch.load( + data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device + ) + model.load_state_dict(state_dict) + return model + + +def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"): + model = ( + StructureTokenEncoder( + d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 + ) + .to(device) + .eval() + ) + state_dict = torch.load( + data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device + ) + model.load_state_dict(state_dict) + return model + + +def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"): + model = ( + StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).to(device).eval() + ) + state_dict = torch.load( + data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device + ) + model.load_state_dict(state_dict) + return model + + +def ESM3_function_decoder_v0(device: torch.device | str = "cpu"): + model = FunctionTokenDecoder().to(device).eval() + state_dict = torch.load( + data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device + ) + model.load_state_dict(state_dict) + return model + + +LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = { + ESM3_OPEN_SMALL: ESM3_sm_open_v0, + ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0, + ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0, + ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0, +} + + +def load_local_model(model_name: str, device: torch.device | str = "cpu") -> nn.Module: + if model_name not in LOCAL_MODEL_REGISTRY: + raise ValueError(f"Model {model_name} not found in local model registry.") + return LOCAL_MODEL_REGISTRY[model_name](device) + + +# Register custom versions of ESM3 for use with the local inference API +def register_local_model(model_name: str, model_builder: ModelBuilder) -> None: + LOCAL_MODEL_REGISTRY[model_name] = model_builder diff --git a/esm/sdk/api.py b/esm/sdk/api.py new file mode 100644 index 0000000..98ccea1 --- /dev/null +++ b/esm/sdk/api.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from abc import ABC +from typing import Sequence, TypeVar + +import attr +import torch +from attr import define + +from esm.tokenization import ( + TokenizerCollectionProtocol, + get_model_tokenizers, +) +from esm.utils import encoding +from esm.utils.constants.models import ESM3_OPEN_SMALL +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.types import ( + FunctionAnnotation, + PathLike, + PathOrBuffer, +) + + +## Basic Types +@define +class ESMProtein: + # Tracks + sequence: str | None = None + secondary_structure: str | None = None + sasa: list[float | str | None] | None = None + function_annotations: list[FunctionAnnotation] | None = None + coordinates: torch.Tensor | None = None + # Metrics + plddt: torch.Tensor | None = None + ptm: torch.Tensor | None = None + + def __len__(self): + if self.sequence is not None: + return len(self.sequence) + elif self.secondary_structure is not None: + return len(self.secondary_structure) + elif self.sasa is not None: + return len(self.sasa) + elif self.coordinates is not None: + return self.coordinates.size(0) + else: + raise ValueError("No track to determine length from.") + + @classmethod + def from_pdb( + cls, + path: PathOrBuffer, + chain_id: str = "detect", + id: str | None = None, + is_predicted: bool = False, + ) -> ESMProtein: + protein_chain = ProteinChain.from_pdb( + path=path, chain_id=chain_id, id=id, is_predicted=is_predicted + ) + return cls.from_protein_chain(protein_chain) + + @classmethod + def from_protein_chain( + cls, protein_chain: ProteinChain, with_annotations: bool = False + ) -> ESMProtein: + # By default, we don't annotate with DSSP / SASA, which are expensive. + # If mkdssp is installed, we can annotate with a flag. + if with_annotations: + return ESMProtein( + sequence=protein_chain.sequence, + secondary_structure=protein_chain.dssp().tolist(), + sasa=protein_chain.sasa().tolist(), + function_annotations=None, + coordinates=torch.tensor(protein_chain.atom37_positions), + ) + else: + return ESMProtein( + sequence=protein_chain.sequence, + secondary_structure=None, + sasa=None, + function_annotations=None, + coordinates=torch.tensor(protein_chain.atom37_positions), + ) + + def to_pdb(self, pdb_path: PathLike) -> None: + protein_chain = self.to_protein_chain() + protein_chain.to_pdb(pdb_path) + + def to_pdb_string(self) -> str: + protein_chain = self.to_protein_chain() + return protein_chain.to_pdb_string() + + def to_protein_chain(self) -> ProteinChain: + if self.coordinates is None: + raise ValueError("Coordinates are required to convert to a ProteinChain.") + protein_chain = ProteinChain.from_atom37( + atom37_positions=self.coordinates.to("cpu").numpy(), + id=None, + sequence=self.sequence, + chain_id=None, + entity_id=None, + residue_index=None, + insertion_code=None, + ) + return protein_chain + + +@define +class ESMProteinTensor: + sequence: torch.Tensor | None = None + structure: torch.Tensor | None = None + secondary_structure: torch.Tensor | None = None + sasa: torch.Tensor | None = None + function: torch.Tensor | None = None + residue_annotations: torch.Tensor | None = None + coordinates: torch.Tensor | None = None + + def __len__(self) -> int: + if self.sequence is not None: + return self.sequence.size(0) + elif self.structure is not None: + return self.structure.size(0) + elif self.secondary_structure is not None: + return self.secondary_structure.size(0) + elif self.sasa is not None: + return self.sasa.size(0) + elif self.coordinates is not None: + return self.coordinates.size(0) + else: + raise ValueError("No track to determine length from.") + + @property + def device(self) -> str | torch.device: + device_ = None + + tracks = [f.name for f in attr.fields(ESMProteinTensor)] + + for track in tracks: + current_track: torch.Tensor | None = getattr(self, track) + if current_track is not None: + if device_ is not None and device_ != current_track.device: + raise ValueError(f"Inconsistent devices for track {track}.") + device_ = getattr(self, track).device + + if device_ is None: + raise ValueError("No track to determine device from.") + + return device_ + + def to(self, device: str | torch.device | None) -> ESMProteinTensor: + if device is None: + return self + + device = torch.device(device) + + def _to(name): + v = getattr(self, name) + if v is not None: + setattr(self, name, v.to(device)) + + for n in [ + "sequence", + "structure", + "secondary_structure", + "sasa", + "function", + "residue_annotations", + "coordinates", + ]: + _to(n) + + return self + + @classmethod + def empty( + cls, + length: int, + tokenizers: TokenizerCollectionProtocol | None = None, + device: torch.device | str = "cpu", + ) -> ESMProteinTensor: + if tokenizers is None: + tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL) + + return ESMProteinTensor( + sequence=encoding.get_default_sequence_tokens( + length, tokenizers.sequence + ).to(device), + structure=encoding.get_default_structure_tokens( + length, tokenizers.structure + ).to(device), + secondary_structure=encoding.get_default_secondary_structure_tokens( + length, tokenizers.secondary_structure + ).to(device), + sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device), + function=encoding.get_default_function_tokens( + length, tokenizers.function + ).to(device), + residue_annotations=encoding.get_default_residue_annotation_tokens( + length, tokenizers.residue_annotations + ).to(device), + ) + + +## High Level Endpoint Types +@define +class GenerationConfig: + track: str = "" + invalid_ids: Sequence[int] = [] + schedule: str = "cosine" + num_steps: int = 8 + temperature: float = 1.0 + top_p: float = 1.0 + condition_on_coordinates_only: bool = True + + +## Low Level Endpoint Types +@define +class SamplingTrackConfig: + temperature: float = 1.0 + top_p: float = 1.0 + only_sample_masked_tokens: bool = True + invalid_ids: Sequence[int] = [] + topk_logprobs: int = 0 + + +@define +class SamplingConfig: + sequence: SamplingTrackConfig | None = None + structure: SamplingTrackConfig | None = None + secondary_structure: SamplingTrackConfig | None = None + sasa: SamplingTrackConfig | None = None + function: SamplingTrackConfig | None = None + + return_per_residue_embeddings: bool = False + return_mean_embedding: bool = False + + +@define +class ReturnLogitsConfig: + sequence: bool = False + structure: bool = False + secondary_structure: bool = False + sasa: bool = False + function: bool = False + residue_annotations: bool = False + + +@define +class ForwardConfig: + return_logits: ReturnLogitsConfig = ReturnLogitsConfig() + return_embeddings: bool = False + + +@define +class ForwardTrackData: + sequence: torch.Tensor | None = None + structure: torch.Tensor | None = None + secondary_structure: torch.Tensor | None = None + sasa: torch.Tensor | None = None + function: torch.Tensor | None = None + + +@define +class ForwardOutput: + logits: ForwardTrackData | None = None + embeddings: torch.Tensor | None = None + + # Residue annotations is multi-hot, so deserves special treatment + # It's not a categorical distribution, but instead a bernoulli, so + # softmax across the last dimension is _wrong_ + residue_annotation_logits: torch.Tensor | None = None + + +@define +class ForwardAndSampleOutput(ForwardOutput): + protein_tensor: ESMProteinTensor = ESMProteinTensor() + + entropy: ForwardTrackData | None = None + # Probability of sampled token + prob: ForwardTrackData | None = None + logprob: ForwardTrackData | None = None + # Top probability at this position + top_prob: ForwardTrackData | None = None + topk_logprob: ForwardTrackData | None = None + # Which tokens correspond to top probability + topk_tokens: ForwardTrackData | None = None + + per_residue_embedding: torch.Tensor | None = None + mean_embedding: torch.Tensor | None = None + + +ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein) + + +class ESM3InferenceClient(ABC): + def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: + # This is the easiest and most flexible way to run ESM3. Generate will + # iteratively sample tokens an provide an output with the track specified + # completely filled out, according to the GenerationConfig provided. + # It is a local function wrapping calls for encode -> iterative_sampling -> decode. + # if a ESMProteinTensor is provided, encode and decode are skipped + raise NotImplementedError + + def encode(self, input: ESMProtein) -> ESMProteinTensor: + # Encode allows for encoding RawRepresentation into TokenizedRepresentation. + # This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion + raise NotImplementedError + + def decode(self, input: ESMProteinTensor) -> ESMProtein: + # Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates + raise NotImplementedError + + def _forward( + self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig() + ) -> ForwardOutput: + # Our API generally discourages using raw forwards. + # This is because sending logits can be prohibitively expensive. + # Please use forward_and_sample instead. + raise NotImplementedError + + def forward_and_sample( + self, input: ESMProteinTensor, sampling_configuration: SamplingConfig + ) -> ForwardAndSampleOutput: + # forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`. + # This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput + # inference, as well as arbitrary chain-of-though invocations of ESM3. + raise NotImplementedError diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py new file mode 100644 index 0000000..9d2e5a9 --- /dev/null +++ b/esm/tokenization/__init__.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Protocol + +from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS +from esm.utils.constants.models import ESM3_OPEN_SMALL + +from .function_tokenizer import InterProQuantizedTokenizer +from .residue_tokenizer import ResidueAnnotationsTokenizer +from .sasa_tokenizer import SASADiscretizingTokenizer +from .sequence_tokenizer import EsmSequenceTokenizer +from .ss_tokenizer import SecondaryStructureTokenizer +from .structure_tokenizer import StructureTokenizer +from .tokenizer_base import EsmTokenizerBase + + +class TokenizerCollectionProtocol(Protocol): + sequence: EsmSequenceTokenizer + structure: StructureTokenizer + secondary_structure: SecondaryStructureTokenizer + sasa: SASADiscretizingTokenizer + function: InterProQuantizedTokenizer + residue_annotations: ResidueAnnotationsTokenizer + + +@dataclass +class TokenizerCollection: + sequence: EsmSequenceTokenizer + structure: StructureTokenizer + secondary_structure: SecondaryStructureTokenizer + sasa: SASADiscretizingTokenizer + function: InterProQuantizedTokenizer + residue_annotations: ResidueAnnotationsTokenizer + + +def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: + if model == ESM3_OPEN_SMALL: + return TokenizerCollection( + sequence=EsmSequenceTokenizer(), + structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS), + secondary_structure=SecondaryStructureTokenizer(kind="ss8"), + sasa=SASADiscretizingTokenizer(), + function=InterProQuantizedTokenizer(), + residue_annotations=ResidueAnnotationsTokenizer(), + ) + else: + raise ValueError(f"Unknown model: {model}") + + +def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: + if isinstance(tokenizer, EsmSequenceTokenizer): + return [ + tokenizer.mask_token_id, # type: ignore + tokenizer.pad_token_id, # type: ignore + tokenizer.cls_token_id, # type: ignore + tokenizer.eos_token_id, # type: ignore + ] + else: + return [ + tokenizer.mask_token_id, + tokenizer.pad_token_id, + tokenizer.bos_token_id, + tokenizer.eos_token_id, + ] diff --git a/esm/tokenization/function_tokenizer.py b/esm/tokenization/function_tokenizer.py new file mode 100644 index 0000000..a56e728 --- /dev/null +++ b/esm/tokenization/function_tokenizer.py @@ -0,0 +1,404 @@ +"""Tokenizes annotations of protein function.""" + +import re +import string +from functools import cache, cached_property, partial +from typing import Collection + +import numpy as np +import pandas as pd +import scipy.sparse as sp +import torch +import torch.nn.functional as F + +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C +from esm.utils.function import interpro, lsh, tfidf +from esm.utils.misc import stack_variable_length_tensors +from esm.utils.types import FunctionAnnotation + + +class InterProQuantizedTokenizer(EsmTokenizerBase): + """Tokenizer for functional annotations. + + This tokenizer converts InterPro and/or function keywords into a multi-token + representation by hashing TF-IDF vector representations of the text associated with + the fuction and then applying a locality sensitive hash (LSH). + """ + + def __init__( + self, + depth: int = 8, + lsh_bits_per_token: int = 8, + lsh_path: str | None = None, + keyword_vocabulary_path: str | None = None, + keyword_idf_path: str | None = None, + interpro_entry_path: str | None = None, + interpro2keywords_path: str | None = None, + ): + """Constructs function tokenizer. + + Args: + depth: number of tokens emitted in each position. + lsh_bits_per_token: Number of LSH bits per token. Determines the vocabulary + size. + lsh_path: path to locality sensitive hash (LSH) hyperplanes. + keyword_vocabulary_path: path to csv containing function keyword vocabulary. + keyword_idf_path: path to IDF values for each keyword. + interpro_entry_csv_path: path to list of InterPro entries in CSV format. + interpro2keywords_path: path to CSV mapping InterPro IDs to function keywords. + """ + self.depth = depth + default = lambda x, d: x if x is not None else C.data_root() / d + + self.keyword_vocabulary_path = default( + keyword_vocabulary_path, C.KEYWORDS_VOCABULARY + ) + self.keyword_idf_path = default(keyword_idf_path, C.KEYWORDS_IDF) + + self._interpro2keywords_path = default( + interpro2keywords_path, C.INTERPRO2KEYWORDS + ) + self.interpro_ = interpro.InterPro( + entries_path=default(interpro_entry_path, C.INTERPRO_ENTRY) + ) + + self.lsh_vocab_size = 1 << lsh_bits_per_token + self._lsh = lsh.LSHTokenized( + lsh_bits_per_token, + len(self.keyword_vocabulary), + self.depth, + default(lsh_path, C.LSH_TABLE_PATHS["8bit"]), + ) + + # This is the offset into the vocabulary where LSH tokens start. + self._lsh_token_vocab_offset = len(self.special_tokens) + 1 # +1 for + + @cached_property + def interpro2keywords(self) -> dict[str, list[str]]: + """Mapping from InterPro ID to function keywords.""" + df = pd.read_csv(self._interpro2keywords_path) + assert "interpro_id" in df.columns and "keywords" in df.columns, df.columns + return dict(zip(df.interpro_id, df.keywords.str.split(","))) + + @cached_property + def interpro_labels(self) -> list[str]: + """The set of supported InterPro labels.""" + return sorted(self.interpro2keywords.keys()) + + @cached_property + def interpro_to_index(self) -> dict[str, int]: + """Mapping from InterPro id to index.""" + return {id: i for i, id in enumerate(self.interpro_labels)} + + @property + def keyword_vocabulary(self) -> list[str]: + """Set of supported keywords.""" + return self._tfidf.vocabulary + + @property + def keyword_to_index(self) -> dict[str, int]: + """Mapping from keywords to index.""" + return self._tfidf.vocab_to_index + + @cached_property + def _tfidf(self) -> tfidf.TFIDFModel: + """Creates TF-IDF model for encoding function keywords.""" + return tfidf.TFIDFModel( + vocabulary_path=self.keyword_vocabulary_path, + idf_path=self.keyword_idf_path, + ) + + @cached_property + def special_tokens(self) -> list[str]: + """List of special tokens which come before cluster tokens in vocab.""" + return ["", "", ""] + + @cached_property + def vocab(self) -> list[str]: + """Vocabulary of function tokens.""" + lsh_tokens = [f"" for i in range(self.lsh_vocab_size)] + return self.special_tokens + [""] + lsh_tokens + + @cached_property + def vocab_to_index(self) -> dict[str, int]: + return {token: token_id for token_id, token in enumerate(self.vocab)} + + def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: + """Determines where in the sequence are special tokens.""" + where = encoded < len(self.special_tokens) + assert torch.all(torch.all(where, dim=1) | torch.all(~where, dim=1)) + return where[:, 0] + + def tokenize( + self, + annotations: list[FunctionAnnotation], + seqlen: int, + p_keyword_dropout: float = 0.0, + ) -> list[str]: + """Encodes range-annotations of protein function as tokens. + + Args: + features: Annotated function ranges, either as InterPro ids or keywords. + seqlen: length of sequence. + p_keyword_dropout: Optional probability of dropping out keywords from the + input annotations. + Returns: + Tokenized representation of function annotations as a list of string tokens + of size seqlen. + """ + assert seqlen >= 0 + + if not annotations: + return [""] * seqlen + + # Expand the range annotations into positional annotaiton sets. + positional_labels: list[set[str]] = [set() for _ in range(seqlen)] + for annotation in annotations: + assert 1 <= annotation.start <= annotation.end <= seqlen, ( + f"Invalid annotation range [{annotation.start}, {annotation.end}] for " + f"sequence length {seqlen}." + ) + for i in range(annotation.start - 1, annotation.end): + positional_labels[i].add(annotation.label) + + if p_keyword_dropout > 0: + keyword_mask = ( + np.random.random(len(self._tfidf.vocabulary)) < p_keyword_dropout + ) + else: + keyword_mask = None + + # Annotations tend to be repetitive over the length of the sequence - cache their + # hashes to speed up tokenization. + hash_fn = cache(partial(self._function_text_hash, keyword_mask=keyword_mask)) + + tokens: list[str] = [] + for labels in positional_labels: + if not labels: + token = "" + else: + lsh_hash = hash_fn(frozenset(labels)) + if lsh_hash is not None: + assert len(lsh_hash) == self.depth + token = "" + else: + token = "" + + tokens.append(token) + + return tokens + + def _function_text_hash( + self, + labels: Collection[str], + keyword_mask: np.ndarray | None = None, + ) -> np.ndarray | None: + """Applies a locality sensitive hash (LSH) to function text. + + Args: + labels: InterPro ids and/or keywords. + keyword_mask: optional boolean array shaped (keyword_vocab_size,) indicating + which keywords to drop before hashing. + Returns: + LSH shaped (depth,) or None if there is no text or keywords to hash. + """ + # Split labels into either InterPro ids or keywords. + interpro_ids = [] + keywords = [] + for label in labels: + match = re.match(r"IPR\d+", label) + if match and match.group() in self.interpro_to_index: + interpro_ids.append(match.group()) + elif label in self._tfidf.vocab_to_index: + keywords.append(label) + else: + raise ValueError(f"Unsupported: {label}") + + vec: sp.csr_matrix = self._tfidf.encode(keywords) + + # Perform an element-wise maximum over TF-IDF vectors from distinct tags to + # avoid tags getting "washed out" by eg. 4 very similar tags. Keywords are + # incorporated as another TF-IDF vector + vec: sp.csr_matrix = self._tfidf.encode(keywords) + for interpro_id in interpro_ids: + interpro_keywords = self.interpro2keywords.get(interpro_id, []) + vec_ = self._tfidf.encode(interpro_keywords) + vec = vec.maximum(vec_) + + if keyword_mask is not None: + vec.data *= 1 - np.take(keyword_mask, vec.indices) + + if vec.sum() == 0: + return None + + return self._lsh(vec)[0, :] + + def encode( + self, tokens: list[str], add_special_tokens: bool = True + ) -> torch.Tensor: + """Encodes string tokens as token-id tensor. + + Args: + tokens: list of individual tokens. e.g. ["", ""] + add_special_tokens: whether to add a single pad token at the start and end + of the sequence to act as and tokens. + Returns: + [length, depth] function tokens. Length will be +2 of input tokens + length when add_special_tokens is True. + """ + token_ids = torch.zeros(size=(len(tokens), self.depth), dtype=torch.int64) + for i, token in enumerate(tokens): + token_ids[i, :] = torch.tensor(self._token2ids(token)) + if add_special_tokens: + token_ids = F.pad( + token_ids, (0, 0, 1, 1), value=self.vocab_to_index[""] + ) + return token_ids + + def lookup_annotation_name(self, annotation: FunctionAnnotation) -> str | None: + return self.interpro_.lookup_name(annotation.label) + + def format_annotation(self, annotation: FunctionAnnotation) -> str: + annotation_name = self.lookup_annotation_name(annotation) + if annotation_name is not None: + return f"{annotation_name} ({annotation.label})" + else: + return annotation.label + + def _token2ids(self, token: str) -> list[int]: + """Converts token into token_id set of length depth.""" + if re.match(r"", token): + lsh_ids = [int(lsh_id) for lsh_id in re.findall(r"\d+", token)] + assert ( + len(lsh_ids) == self.depth + ), f"Expected token to have {self.depth} ids found {lsh_ids}" + return [self._lsh_token_vocab_offset + lsh_id for lsh_id in lsh_ids] + elif token == "" or token in self.special_tokens: + return [self.vocab_to_index[token]] * self.depth + else: + raise ValueError(f"Unknown token: {token}") + + def batch_encode( + self, + token_batch: list[list[str]], + add_special_tokens: bool = True, + ) -> torch.Tensor: + """Encodes batch of function tokens. + + Args: + token_batch: batch of function tokens. + add_special_tokens: whether to add special tokens. + Returns: + [batch_size, max_length, depth] batch of encoded tokens. + """ + encoded = [ + self.encode(tokens, add_special_tokens=add_special_tokens) + for tokens in token_batch + ] + return stack_variable_length_tensors( + encoded, + constant_value=self.vocab_to_index[""], + ) + + def decode(self, encoded: torch.Tensor): + raise NotImplementedError( + "Function token decoding should be handled with " + "util.decoding.decode_function_annotations" + ) + + @property + def mask_token(self) -> str: + return "" + + @property + def mask_token_id(self) -> int: + return self.vocab_to_index[self.mask_token] + + @property + def bos_token(self) -> str: + return "" + + @property + def bos_token_id(self) -> int: + return self.vocab_to_index[self.bos_token] + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self) -> int: + return self.vocab_to_index[self.eos_token] + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self) -> int: + return self.vocab_to_index[self.pad_token] + + +def _texts_to_keywords(texts: list[str]) -> list[str]: + """Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}. + + Args: + texts: collection of text descriptions, i.e. InterPro/GO names. + Returns: + Collection of terms/n-grams + """ + keywords = [] + for text in texts: + keywords.extend(_keywords_from_text(text)) + return keywords + + +def _keywords_from_text(text: str) -> list[str]: + """Splits text into unigrams and bigrams.""" + elements = text.split(", ") + + terms = [] + for element in elements: + element = _sanitize(element) + words = element.split() + + # Add 1-mers + terms.extend(words) + + # Add 2-mers + for i in range(len(words) - 1): + bigram = words[i] + " " + words[i + 1] + terms.append(bigram) + + return [term for term in terms if len(term) > 1 and term not in _EXCLUDED_TERMS] + + +def _sanitize(text: str) -> str: + text = text.replace("-", " ") + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + return text + + +# These terms are omitted from textual representations since they are pervasive and +# unspecific to particular protein function. +_EXCLUDED_TERMS = { + "binding domain", + "biological_process", + "biological process", + "biologicalprocess", + "c", + "cellular_component", + "cellular component", + "cellularcomponent", + "cellular_process", + "cellularprocess", + "cellular process", + "cellularprocess", + "like domain", + "molecular function", + "molecular_function", + "molecularfunction", + "n", +} diff --git a/esm/tokenization/residue_tokenizer.py b/esm/tokenization/residue_tokenizer.py new file mode 100644 index 0000000..a48d048 --- /dev/null +++ b/esm/tokenization/residue_tokenizer.py @@ -0,0 +1,224 @@ +from functools import cached_property +from pathlib import Path +from typing import Any + +import pandas as pd +import torch +import torch.nn.functional as F + +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C + +Sample = dict[str, Any] + + +class ResidueAnnotationsTokenizer(EsmTokenizerBase): + def __init__( + self, + csv_path: str | None = None, + max_annotations: int = 16, + ): + if csv_path is None: + csv_path = str(C.data_root() / C.RESID_CSV) + self.csv_path = csv_path + self.max_annotations = max_annotations + + @cached_property + def _description2label(self) -> dict[str, str]: + with Path(self.csv_path).open() as f: # type: ignore + df = pd.read_csv(f) + return dict(zip(df.label, df.label_clean)) + + @cached_property + def _labels(self) -> list[str]: + with Path(self.csv_path).open() as f: # type: ignore + df = pd.read_csv(f) + labels = ( + df.groupby("label_clean")["count"] + .sum() + .sort_values(ascending=False, kind="stable") # type: ignore + .index.tolist() + ) + assert isinstance(labels, list) + return labels # type: ignore + + def _description2id(self, description: str) -> int | None: + label = self._description2label.get(description) + return self._label2id.get(label) # type: ignore + + @cached_property + def _label2id(self) -> dict[str, int]: + offset = len(self.special_tokens) + 1 # +1 for "" + return {label: offset + i for i, label in enumerate(self._labels)} + + @cached_property + def special_tokens(self) -> list[str]: + """List of special tokens which come before cluster toknes in vocab.""" + return ["", "", ""] + + @cached_property + def vocab(self): + annotation_tokens = [f"" for _, id in self._label2id.items()] + return self.special_tokens + [""] + annotation_tokens + + @cached_property + def vocab_to_index(self) -> dict[str, int]: + return {token: token_id for token_id, token in enumerate(self.vocab)} + + @cached_property + def vocabulary(self) -> list[str]: + """Full vocabulary.""" + return [*self.special_tokens, "", *self._labels] + + def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: + """Determines where in the sequence are special tokens.""" + return encoded[:, 0] < len(self.special_tokens) + + def tokenize( + self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False + ) -> list[str]: + """ + # interpro_site_starts + # interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall. + # interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly. + # interpro_site_descriptions + # ASSERT (i.e. drop if bad) + # interpro_site_residues matches the residue at that position + # all these lists ^ above are the same length + """ + seqlen = len(sequence) + assert seqlen >= 0 + # None mean sequence is *not annotated* - so use full + if sample is None: + return [""] * seqlen + + if any( + sample.get(field) is None + for field in [ + "interpro_site_descriptions", + "interpro_site_starts", + "interpro_site_ends", + "interpro_site_residues", + ] + ): + return [""] * seqlen + + num_annotations = len(sample["interpro_site_descriptions"]) + if any( + len(sample[field]) != num_annotations + for field in [ + "interpro_site_starts", + "interpro_site_ends", + "interpro_site_residues", + ] + ): + # mismatched length. + return [""] * seqlen + + positional_ids = [set() for _ in range(seqlen)] + for description, start, end, residues in zip( + sample["interpro_site_descriptions"], + sample["interpro_site_starts"], + sample["interpro_site_ends"], + sample["interpro_site_residues"], + ): + try: + start = int(start) + end = int(end) + except (TypeError, ValueError): + continue + + # Start / End are 1-indexed [inclusive, inclusive]. + if start <= 0 or end > seqlen or start > end: + print(f"invalid start/end: ({start}, {end}), len: {seqlen}") + continue + + if len(residues) != (end - start) + 1: + print(f"bad reference residue: {residues}") + continue + + token_id = self._description2id(description) + if token_id is None: + token_id = self.vocab_to_index[""] + + for i, residue in zip(range(start - 1, end), residues): + # If there are any mismatching residues, skip the entire sample. + if sequence[i] != residue: + if fail_on_mismatch: + raise ValueError( + f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}" + ) + return [""] * seqlen + + positional_ids[i].add(token_id) + + tokens = [] + for token_ids in positional_ids: + if token_ids: + token = "" + else: + token = "" + tokens.append(token) + return tokens + + def _token2ids(self, token: str) -> list[int]: + if token.startswith(""): + return [int(token_id) for token_id in token[4:-1].split(",")] + else: + token_id = self.vocab_to_index[token] + return [token_id] + + def encode( + self, tokens: list[str], add_special_tokens: bool = True + ) -> torch.Tensor: + token_ids = torch.full( + size=(len(tokens), self.max_annotations), + dtype=torch.int64, + fill_value=self.vocab_to_index[""], + ) + for i, token in enumerate(tokens): + ids = self._token2ids(token)[: self.max_annotations] + token_ids[i, : len(ids)] = torch.tensor(ids) + + if add_special_tokens: + token_ids = F.pad( + token_ids, (0, 0, 1, 1), value=self.vocab_to_index[""] + ) + return token_ids + + def decode(self, encoded: torch.Tensor) -> list[str]: + raise NotImplementedError( + "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations" + ) + + @property + def mask_token(self) -> str: + return "" + + @property + def mask_token_id(self) -> int: + return self.vocab_to_index[self.mask_token] + + @property + def bos_token(self) -> str: + return "" + + @property + def bos_token_id(self) -> int: + return self.vocab_to_index[self.bos_token] + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self) -> int: + return self.vocab_to_index[self.eos_token] + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self) -> int: + return self.vocab_to_index[self.pad_token] diff --git a/esm/tokenization/sasa_tokenizer.py b/esm/tokenization/sasa_tokenizer.py new file mode 100644 index 0000000..4d7221b --- /dev/null +++ b/esm/tokenization/sasa_tokenizer.py @@ -0,0 +1,129 @@ +from functools import cached_property + +import torch + +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C + + +class SASADiscretizingTokenizer(EsmTokenizerBase): + """Tokenizer for Solvent Accessible Surface Area (SASA).""" + + def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES): + self._boundaries = sorted(boundaries) + + @cached_property + def special_tokens(self) -> list[str]: + return ["", "", ""] + + @cached_property + def vocab(self) -> list[str]: + """Discrete token vocabulary. + + Returns: + token vocabulary with ranges represented as "". + """ + boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"] + range_tokens = [ + f"<{low}-{high}>" + for low, high in zip(boundary_strs[:-1], boundary_strs[1:]) + ] + return self.special_tokens + range_tokens + + @cached_property + def midpoints(self) -> list[float]: + """Midpoints of the SASA token ranges.""" + boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2] + midpoint_tokens = [ + (float(high) + float(low)) / 2 + for low, high in zip(boundaries[:-1], boundaries[1:]) + ] + midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens + return midpoint_tokens + + @cached_property + def vocab_to_index(self) -> dict[str, int]: + """Constructs token -> token id mapping.""" + return {word: i for i, word in enumerate(self.vocab)} + + def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: + """Determines which positions are special tokens. + + Args: + tokens: [length] + Returns: + [length] tensor, true where special tokens are located in the input. + """ + return tokens < len(self.special_tokens) + + def encode( + self, values: list[float | str], add_special_tokens: bool = True + ) -> torch.Tensor: + """Encodes SASA values as discrete tokens. + + Args: + values: list of either SASA values or individual tokens. For example + [1.2, "", 10.3, , 0.] + Returns: + Token ids as tensor. Adds BOS and EOS special tokens. + """ + ids = [] + if add_special_tokens: + ids.append(self.vocab_to_index[""]) # BOS + for value in values: + if isinstance(value, (float, int)): + bucket = torch.bucketize(value, torch.tensor(self._boundaries)) + token_id = len(self.special_tokens) + bucket + elif isinstance(value, str): + token_id = self.vocab_to_index[value] + else: + raise TypeError(value) + ids.append(token_id) + if add_special_tokens: + ids.append(self.vocab_to_index[""]) # EOS + + return torch.tensor(ids, dtype=torch.int64) + + def decode_float(self, encoded: torch.Tensor) -> list[float]: + """Decodes SASA token ids into float values.""" + return [self.midpoints[token_id] for token_id in encoded] + + def decode(self, encoded: torch.Tensor) -> str: + """Decodes SASA token ids.""" + return ",".join(self.vocab[i] for i in encoded) + + def decode_list(self, encoded: torch.Tensor) -> list[str]: + """Decodes SASA token ids.""" + return [self.vocab[i] for i in encoded] + + @property + def mask_token(self) -> str: + return "" + + @property + def mask_token_id(self) -> int: + return self.vocab_to_index[self.mask_token] + + @property + def bos_token(self) -> str: + return "" + + @property + def bos_token_id(self) -> int: + return self.vocab_to_index[self.bos_token] + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self) -> int: + return self.vocab_to_index[self.eos_token] + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self) -> int: + return self.vocab_to_index[self.pad_token] diff --git a/esm/tokenization/sequence_tokenizer.py b/esm/tokenization/sequence_tokenizer.py new file mode 100644 index 0000000..0926aab --- /dev/null +++ b/esm/tokenization/sequence_tokenizer.py @@ -0,0 +1,68 @@ +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.processors import TemplateProcessing +from transformers import PreTrainedTokenizerFast + +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C + + +class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase): + """ + Constructs an ESM tokenizer. + """ + + model_input_names = ["sequence_tokens", "attention_mask"] + + def __init__( + self, + unk_token="", + cls_token="", + pad_token="", + mask_token="", + eos_token="", + chainbreak_token="|", + **kwargs, + ): + all_tokens = C.SEQUENCE_VOCAB + token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} + + # a character-level tokenizer is the same as BPE with no token merges + bpe = BPE(token_to_id, merges=[], unk_token=unk_token) + tokenizer = Tokenizer(bpe) + special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token] + additional_special_tokens = [chainbreak_token] + + tokenizer.add_special_tokens( + special_tokens, + ) + + # This is where we configure the automatic addition of special tokens when we call + # tokenizer(text, add_special_tokens=True). Note that you can also configure how two + # sequences are merged if you want. + tokenizer.post_processor = TemplateProcessing( # type: ignore + single=" $A ", + special_tokens=[ + ("", tokenizer.token_to_id("")), + ("", tokenizer.token_to_id("")), + ], + ) + super().__init__( + tokenizer_object=tokenizer, + unk_token=unk_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. + @property + def bos_token(self): + return self.cls_token + + @property + def bos_token_id(self): + return self.cls_token_id diff --git a/esm/tokenization/ss_tokenizer.py b/esm/tokenization/ss_tokenizer.py new file mode 100644 index 0000000..c540103 --- /dev/null +++ b/esm/tokenization/ss_tokenizer.py @@ -0,0 +1,109 @@ +from functools import cached_property +from typing import Sequence + +import torch + +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C + + +class SecondaryStructureTokenizer(EsmTokenizerBase): + """Tokenizer for secondary structure strings.""" + + def __init__(self, kind: str = "ss8"): + assert kind in ("ss8", "ss3") + self.kind = kind + + @property + def special_tokens(self) -> list[str]: + return ["", "", ""] + + @cached_property + def vocab(self): + """Tokenzier vocabulary list.""" + match self.kind: + case "ss8": + nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC" + case "ss3": + nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC + case _: + raise ValueError(self.kind) + + # The non-special tokens ids match amino acid tokens ids when possible. + return [*self.special_tokens, *nonspecial_tokens] + + @cached_property + def vocab_to_index(self) -> dict[str, int]: + """Constructs token -> token id mapping.""" + return {word: i for i, word in enumerate(self.vocab)} + + def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: + """Determines which positions are special tokens. + + Args: + tokens: [length] + Returns: + [length] tensor, true where special tokens are located in the input. + """ + return tokens < len(self.special_tokens) + + def encode( + self, sequence: str | Sequence[str], add_special_tokens: bool = True + ) -> torch.Tensor: + """Encode secondary structure string + + Args: + string: secondary structure string e.g. "GHHIT", or as token listk. + Returns: + [sequence_length] token ids representing. Will add /. + """ + ids = [] + if add_special_tokens: + ids.append(self.vocab_to_index[""]) # cls + for char in sequence: + ids.append(self.vocab_to_index[char]) + if add_special_tokens: + ids.append(self.vocab_to_index[""]) # eos + return torch.tensor(ids, dtype=torch.int64) + + def decode(self, encoded: torch.Tensor) -> str: + """Decodes token ids into secondary structure string. + + Args: + encoded: [length] token id array. + Returns + Decoded secondary structure string. + """ + return "".join(self.vocab[i] for i in encoded) + + @property + def mask_token(self) -> str: + return "" + + @property + def mask_token_id(self) -> int: + return self.vocab_to_index[self.mask_token] + + @property + def bos_token(self) -> str: + return "" + + @property + def bos_token_id(self) -> int: + return self.vocab_to_index[self.bos_token] + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self) -> int: + return self.vocab_to_index[self.eos_token] + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self) -> int: + return self.vocab_to_index[self.pad_token] diff --git a/esm/tokenization/structure_tokenizer.py b/esm/tokenization/structure_tokenizer.py new file mode 100644 index 0000000..76b91b2 --- /dev/null +++ b/esm/tokenization/structure_tokenizer.py @@ -0,0 +1,63 @@ +from esm.tokenization.tokenizer_base import EsmTokenizerBase + + +class StructureTokenizer(EsmTokenizerBase): + """A convenince class for accessing special token ids of + the StructureTokenEncoder and StructureTokenDecoder.""" + + def __init__(self, vq_vae_special_tokens: dict[str, int]): + self.vq_vae_special_tokens = vq_vae_special_tokens + + def mask_token(self) -> str: + raise NotImplementedError( + "Structure tokens are defined on 3D coordinates, not strings." + ) + + @property + def mask_token_id(self) -> int: + return self.vq_vae_special_tokens["MASK"] + + def bos_token(self) -> str: + raise NotImplementedError( + "Structure tokens are defined on 3D coordinates, not strings." + ) + + @property + def bos_token_id(self) -> int: + return self.vq_vae_special_tokens["BOS"] + + def eos_token(self) -> str: + raise NotImplementedError( + "Structure tokens are defined on 3D coordinates, not strings." + ) + + @property + def eos_token_id(self) -> int: + return self.vq_vae_special_tokens["EOS"] + + def pad_token(self) -> str: + raise NotImplementedError( + "Structure tokens are defined on 3D coordinates, not strings." + ) + + @property + def pad_token_id(self) -> int: + return self.vq_vae_special_tokens["PAD"] + + @property + def chainbreak_token_id(self) -> int: + return self.vq_vae_special_tokens["CHAINBREAK"] + + def encode(self, *args, **kwargs): + raise NotImplementedError( + "The StructureTokenizer class is provided as a convenience for " + "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" + "Please use them instead." + ) + + def decode(self, *args, **kwargs): + raise NotImplementedError( + "The StructureTokenizer class is provided as a convenience for " + "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" + "Please use them instead." + ) diff --git a/esm/tokenization/tokenizer_base.py b/esm/tokenization/tokenizer_base.py new file mode 100644 index 0000000..7cbce34 --- /dev/null +++ b/esm/tokenization/tokenizer_base.py @@ -0,0 +1,42 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class EsmTokenizerBase(Protocol): + def encode(self, *args, **kwargs): + ... + + def decode(self, *args, **kwargs): + ... + + @property + def mask_token(self) -> str: + ... + + @property + def mask_token_id(self) -> int: + ... + + @property + def bos_token(self) -> str: + ... + + @property + def bos_token_id(self) -> int: + ... + + @property + def eos_token(self) -> str: + ... + + @property + def eos_token_id(self) -> int: + ... + + @property + def pad_token(self) -> str: + ... + + @property + def pad_token_id(self) -> int: + ... diff --git a/esm/utils/constants/esm3.py b/esm/utils/constants/esm3.py new file mode 100644 index 0000000..10411c0 --- /dev/null +++ b/esm/utils/constants/esm3.py @@ -0,0 +1,127 @@ +from functools import cache +from pathlib import Path + +from huggingface_hub import snapshot_download + +SEQUENCE_BOS_TOKEN = 0 +SEQUENCE_PAD_TOKEN = 1 +SEQUENCE_EOS_TOKEN = 2 +SEQUENCE_CHAINBREAK_TOKEN = 31 +SEQUENCE_MASK_TOKEN = 32 + +VQVAE_CODEBOOK_SIZE = 4096 +VQVAE_SPECIAL_TOKENS = { + "MASK": VQVAE_CODEBOOK_SIZE, + "EOS": VQVAE_CODEBOOK_SIZE + 1, + "BOS": VQVAE_CODEBOOK_SIZE + 2, + "PAD": VQVAE_CODEBOOK_SIZE + 3, + "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, +} +VQVAE_DIRECTION_LOSS_BINS = 16 +VQVAE_PAE_BINS = 64 +VQVAE_MAX_PAE_BIN = 31.0 +VQVAE_PLDDT_BINS = 50 + +STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"] +STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"] +STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"] +STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"] +STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"] +STRUCTURE_UNDEFINED_TOKEN = 955 + +SASA_UNK_TOKEN = 2 +SASA_PAD_TOKEN = 0 + +SS8_UNK_TOKEN = 2 +SS8_PAD_TOKEN = 0 + +INTERPRO_PAD_TOKEN = 0 + +RESIDUE_PAD_TOKEN = 0 + +CHAIN_BREAK_STR = "|" + +SEQUENCE_BOS_STR = "" +SEQUENCE_EOS_STR = "" + +MASK_STR_SHORT = "_" +SEQUENCE_MASK_STR = "" +SASA_MASK_STR = "" +SS8_MASK_STR = "" + +# fmt: off +SEQUENCE_VOCAB = [ + "", "", "", "", + "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", + "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", + "O", ".", "-", "|", + "", +] +# fmt: on + +SSE_8CLASS_VOCAB = "GHITEBSC" +SSE_3CLASS_VOCAB = "HEC" +SSE_8CLASS_TO_3CLASS_MAP = { + "G": "H", + "H": "H", + "I": "H", + "T": "C", + "E": "E", + "B": "E", + "S": "C", + "C": "C", +} + +SASA_DISCRETIZATION_BOUNDARIES = [ + 0.8, + 4.0, + 9.6, + 16.4, + 24.5, + 32.9, + 42.0, + 51.5, + 61.2, + 70.9, + 81.6, + 93.3, + 107.2, + 125.4, + 151.4, +] + +MAX_RESIDUE_ANNOTATIONS = 16 + + +TFIDF_VECTOR_SIZE = 58641 + + +@staticmethod +@cache +def data_root(): + # Try a few default directories + for path in [ + "esm/data", + "esm/data", + ]: + if (p := Path(path)).exists(): + return p.parent + # Try to download from hugginface if it doesn't exist + path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1")) + return path + + +INTERPRO_ENTRY = "data/entry_list_safety_29026.list" +INTERPRO_HIERARCHY = "data/ParentChildTreeFile.txt" +INTERPRO2GO = "data/ParentChildTreeFile.txt" +INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json" + +LSH_TABLE_PATHS = { + "8bit": "data/hyperplanes_8bit_58641.npz", +} + +KEYWORDS_VOCABULARY = "data/keyword_vocabulary_safety_filtered_58641.txt" +KEYWORDS_IDF = "data/keyword_idf_safety_filtered_58641.npy" + +RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv" +INTERPRO2KEYWORDS = "data/interpro_29026_to_keywords_58641.csv" diff --git a/esm/utils/constants/models.py b/esm/utils/constants/models.py new file mode 100644 index 0000000..c72b922 --- /dev/null +++ b/esm/utils/constants/models.py @@ -0,0 +1,5 @@ +# Model names +ESM3_OPEN_SMALL = "esm3_sm_open_v1" +ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0" +ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0" +ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0" diff --git a/esm/utils/constants/physics.py b/esm/utils/constants/physics.py new file mode 100644 index 0000000..4130016 --- /dev/null +++ b/esm/utils/constants/physics.py @@ -0,0 +1,5 @@ +BB_COORDINATES = [ + [0.5256, 1.3612, 0.0000], + [0.0000, 0.0000, 0.0000], + [-1.5251, 0.0000, 0.0000], +] diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py new file mode 100644 index 0000000..c65f057 --- /dev/null +++ b/esm/utils/decoding.py @@ -0,0 +1,225 @@ +import warnings + +import attr +import torch + +from esm.models.function_decoder import FunctionTokenDecoder +from esm.models.vqvae import StructureTokenDecoder +from esm.sdk.api import ESMProtein, ESMProteinTensor +from esm.tokenization import TokenizerCollectionProtocol +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) +from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C +from esm.utils.function.encode_decode import ( + decode_function_tokens, + decode_residue_annotation_tokens, +) +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.types import FunctionAnnotation + + +def decode_protein_tensor( + input: ESMProteinTensor, + tokenizers: TokenizerCollectionProtocol, + structure_token_decoder: StructureTokenDecoder, + function_token_decoder: FunctionTokenDecoder, +) -> ESMProtein: + input = attr.evolve(input) # Make a copy + + sequence = None + secondary_structure = None + sasa = None + function_annotations = [] + + coordinates = None + + # If all pad tokens, set to None + for track in attr.fields(ESMProteinTensor): + tokens: torch.Tensor | None = getattr(input, track.name) + if track.name == "coordinates": + continue + if tokens is not None: + tokens = tokens[1:-1] # Remove BOS and EOS tokens + tokens = tokens.flatten() # For multi-track tensors + track_tokenizer = getattr(tokenizers, track.name) + if torch.all(tokens == track_tokenizer.pad_token_id): + setattr(input, track.name, None) + + if input.sequence is not None: + sequence = decode_sequence(input.sequence, tokenizers.sequence) + + plddt, ptm = None, None + if input.structure is not None: + # Note: We give priority to the structure tokens over the coordinates when decoding + coordinates, plddt, ptm = decode_structure( + structure_tokens=input.structure, + structure_decoder=structure_token_decoder, + structure_tokenizer=tokenizers.structure, + sequence=sequence, + ) + elif input.coordinates is not None: + coordinates = input.coordinates[1:-1, ...] + + if input.secondary_structure is not None: + secondary_structure = decode_secondary_structure( + input.secondary_structure, tokenizers.secondary_structure + ) + if input.sasa is not None: + sasa = decode_sasa(input.sasa, tokenizers.sasa) + if input.function is not None: + function_track_annotations = decode_function_annotations( + input.function, + function_token_decoder=function_token_decoder, + function_tokenizer=tokenizers.function, + ) + function_annotations.extend(function_track_annotations) + if input.residue_annotations is not None: + residue_annotations = decode_residue_annotations( + input.residue_annotations, tokenizers.residue_annotations + ) + function_annotations.extend(residue_annotations) + + return ESMProtein( + sequence=sequence, + secondary_structure=secondary_structure, + sasa=sasa, # type: ignore + function_annotations=function_annotations if function_annotations else None, + coordinates=coordinates, + plddt=plddt, + ptm=ptm, + ) + + +def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): + if tensor[0] != tok.bos_token_id: + warnings.warn( + f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}" + ) + if tensor[-1] != tok.eos_token_id: + warnings.warn( + f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}" + ) + + +def decode_sequence( + sequence_tokens: torch.Tensor, + sequence_tokenizer: EsmSequenceTokenizer, + **kwargs, +) -> str: + _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) + sequence = sequence_tokenizer.decode( + sequence_tokens, + **kwargs, + ) + sequence = sequence.replace(" ", "") + sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) + sequence = sequence.replace(sequence_tokenizer.cls_token, "") + sequence = sequence.replace(sequence_tokenizer.eos_token, "") + + return sequence + + +def decode_structure( + structure_tokens: torch.Tensor, + structure_decoder: StructureTokenDecoder, + structure_tokenizer: StructureTokenizer, + sequence: str | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + is_singleton = len(structure_tokens.size()) == 1 + if is_singleton: + structure_tokens = structure_tokens.unsqueeze(0) + else: + raise ValueError( + f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}" + ) + _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer) + + decoder_output = structure_decoder.decode(structure_tokens) + bb_coords: torch.Tensor = decoder_output["bb_pred"][ + 0, 1:-1, ... + ] # Remove BOS and EOS tokens + bb_coords = bb_coords.detach().cpu() + + if "plddt" in decoder_output: + plddt = decoder_output["plddt"][0, 1:-1] + plddt = plddt.detach().cpu() + else: + plddt = None + + if "ptm" in decoder_output: + ptm = decoder_output["ptm"] + else: + ptm = None + + chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) + chain = chain.infer_oxygen() + return torch.tensor(chain.atom37_positions), plddt, ptm + + +def decode_secondary_structure( + secondary_structure_tokens: torch.Tensor, + ss_tokenizer: SecondaryStructureTokenizer, +) -> str: + _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) + secondary_structure_tokens = secondary_structure_tokens[1:-1] + secondary_structure = ss_tokenizer.decode( + secondary_structure_tokens, + ) + return secondary_structure + + +def decode_sasa( + sasa_tokens: torch.Tensor, + sasa_tokenizer: SASADiscretizingTokenizer, +) -> list[float]: + _bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer) + sasa_tokens = sasa_tokens[1:-1] + + return sasa_tokenizer.decode_float(sasa_tokens) + + +def decode_function_annotations( + function_annotation_tokens: torch.Tensor, + function_token_decoder: FunctionTokenDecoder, + function_tokenizer: InterProQuantizedTokenizer, + **kwargs, +) -> list[FunctionAnnotation]: + # No need to check for BOS/EOS as function annotations are not affected + + function_annotations = decode_function_tokens( + function_annotation_tokens, + function_token_decoder=function_token_decoder, + function_tokens_tokenizer=function_tokenizer, + **kwargs, + ) + return function_annotations + + +def decode_residue_annotations( + residue_annotation_tokens: torch.Tensor, + residue_annotation_decoder: ResidueAnnotationsTokenizer, +) -> list[FunctionAnnotation]: + # No need to check for BOS/EOS as function annotations are not affected + + residue_annotations = decode_residue_annotation_tokens( + residue_annotations_token_ids=residue_annotation_tokens, + residue_annotations_tokenizer=residue_annotation_decoder, + ) + return residue_annotations diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py new file mode 100644 index 0000000..97555e0 --- /dev/null +++ b/esm/utils/encoding.py @@ -0,0 +1,241 @@ +from typing import Sequence + +import torch +import torch.nn.functional as F + +from esm.models.vqvae import StructureTokenEncoder +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer as EsmFunctionTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) +from esm.utils.constants import esm3 as C +from esm.utils.function.encode_decode import ( + encode_function_annotations, +) +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.types import FunctionAnnotation + + +# Raw Defaults +def get_default_sequence(sequence_length: int) -> str: + return C.MASK_STR_SHORT * sequence_length + + +def get_default_secondary_structure(sequence_length: int) -> str: + return C.MASK_STR_SHORT * sequence_length + + +def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]: + return [None] * sequence_length + + +# Tokenization +def tokenize_sequence( + sequence: str, + sequence_tokenizer: EsmSequenceTokenizer, + add_special_tokens: bool = True, +) -> torch.Tensor: + sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token) + sequence_tokens = sequence_tokenizer.encode( + sequence, add_special_tokens=add_special_tokens + ) + sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64) + return sequence_tokens + + +def tokenize_structure( + coordinates: torch.Tensor, + structure_encoder: StructureTokenEncoder, + structure_tokenizer: StructureTokenizer, + reference_sequence: str = "", + add_special_tokens: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = next(structure_encoder.parameters()).device + chain = ProteinChain.from_atom37( + coordinates, sequence=reference_sequence if reference_sequence else None + ) + + # Setup padding + if reference_sequence and len(reference_sequence) != coordinates.size(0): + raise ValueError( + f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})" + ) + + left_pad = 0 + right_pad = 0 + + if add_special_tokens: + left_pad += 1 # Add space for BOS token + right_pad += 1 # Add space for EOS token + + coordinates, plddt, residue_index = chain.to_structure_encoder_inputs() + coordinates = coordinates.to(device) # (1, L, 37, 3) + plddt = plddt.to(device) # (1, L) + residue_index = residue_index.to(device) # (1, L) + _, structure_tokens = structure_encoder.encode( + coordinates, residue_index=residue_index + ) + coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore + plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore + structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore + + # Add space for BOS and EOS tokens + if add_special_tokens: + coordinates = F.pad( + coordinates, + (0, 0, 0, 0, left_pad, right_pad), + value=torch.inf, + ) + plddt = F.pad(plddt, (left_pad, right_pad), value=0) + structure_tokens = F.pad( + structure_tokens, + (left_pad, right_pad), + value=structure_tokenizer.pad_token_id, + ) + structure_tokens[0] = structure_tokenizer.bos_token_id + structure_tokens[-1] = structure_tokenizer.eos_token_id + return coordinates, plddt, structure_tokens + + +def tokenize_secondary_structure( + secondary_structure: str | Sequence[str], + secondary_structure_tokenizer: SecondaryStructureTokenizer, + add_special_tokens: bool = True, +) -> torch.Tensor: + if isinstance(secondary_structure, str): + # Ensure only one char per token + secondary_structure = secondary_structure.replace( + secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT + ) + + # Input as list of chars + secondary_structure = [char for char in secondary_structure] + + # Use tokenizer's mask token + secondary_structure = [ + secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char + for char in secondary_structure + ] + + secondary_structure_tokens = secondary_structure_tokenizer.encode( + secondary_structure, add_special_tokens=add_special_tokens + ) + return secondary_structure_tokens + + +def tokenize_sasa( + sasa: Sequence[float | str | None], + sasa_tokenizer: SASADiscretizingTokenizer, + add_special_tokens: bool = True, +): + sasa_tokens = sasa_tokenizer.encode( + [sasa_tokenizer.mask_token if value is None else value for value in sasa], + add_special_tokens=add_special_tokens, + ) + return sasa_tokens + + +def tokenize_function_annotations( + function_annotations: Sequence[FunctionAnnotation], + reference_sequence: str, + function_tokenizer: EsmFunctionTokenizer, + residue_annotation_tokenizer: ResidueAnnotationsTokenizer, + add_special_tokens: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + function_tokens, residue_annotation_tokens = encode_function_annotations( + sequence=reference_sequence, + function_annotations=function_annotations, + function_tokens_tokenizer=function_tokenizer, + residue_annotations_tokenizer=residue_annotation_tokenizer, + add_special_tokens=add_special_tokens, + ) + return function_tokens, residue_annotation_tokens + + +# Tokenized Defaults +def get_default_sequence_tokens( + sequence_length: int, + sequence_tokenizer: EsmSequenceTokenizer, +) -> torch.Tensor: + return tokenize_sequence( + get_default_sequence(sequence_length), + sequence_tokenizer, + add_special_tokens=True, + ) + + +def get_default_structure_tokens( + sequence_length: int, structure_tokenizer: StructureTokenizer +) -> torch.Tensor: + structure_tokens = ( + torch.ones( + (sequence_length + 2,), + dtype=torch.int64, + ) + * structure_tokenizer.pad_token_id + ) + # Always include BOS and EOS tokens + structure_tokens[0] = structure_tokenizer.bos_token_id + structure_tokens[-1] = structure_tokenizer.eos_token_id + return structure_tokens + + +def get_default_secondary_structure_tokens( + sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer +) -> torch.Tensor: + return tokenize_secondary_structure( + get_default_secondary_structure(sequence_length), + secondary_structure_tokenizer, + add_special_tokens=True, + ) + + +def get_default_sasa_tokens( + sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer +) -> torch.Tensor: + return tokenize_sasa( + get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True + ) + + +def get_default_function_tokens( + sequence_length: int, function_tokenizer: EsmFunctionTokenizer +) -> torch.Tensor: + function_tokens = ( + torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64) + * function_tokenizer.pad_token_id + ) + # Always include BOS and EOS tokens + function_tokens[0] = function_tokenizer.bos_token_id + function_tokens[-1] = function_tokenizer.eos_token_id + return function_tokens + + +def get_default_residue_annotation_tokens( + sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer +) -> torch.Tensor: + residue_annotation_tokens = ( + torch.ones( + (sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), + dtype=torch.int64, + ) + * residue_annotation_tokenizer.pad_token_id + ) + # Always include BOS and EOS tokens + residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id + residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id + return residue_annotation_tokens diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py new file mode 100644 index 0000000..f8d57c7 --- /dev/null +++ b/esm/utils/function/encode_decode.py @@ -0,0 +1,187 @@ +import re +from typing import Sequence + +import torch + +from esm.models.function_decoder import ( + FunctionTokenDecoder, + _merge_annotations, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.utils.constants import esm3 as C +from esm.utils.types import FunctionAnnotation + + +def encode_function_annotations( + sequence: str, + function_annotations: Sequence[FunctionAnnotation], + function_tokens_tokenizer: InterProQuantizedTokenizer, + residue_annotations_tokenizer: ResidueAnnotationsTokenizer, + add_special_tokens: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + assert isinstance( + residue_annotations_tokenizer, ResidueAnnotationsTokenizer + ), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer" + + # Split the user's annotations by type + ft_annotations: list[FunctionAnnotation] = [] + ra_annotations: list[FunctionAnnotation] = [] + for fa in function_annotations: + assert ( + 1 <= fa.start <= fa.end <= len(sequence) + ), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]" + + supported_label = False + + # Is it an InterPro label? + if match := re.match(r"IPR\d+", fa.label): + if match.group() in function_tokens_tokenizer.interpro_to_index: + ft_annotations.append(fa) + supported_label = True + + # Is it a function keyword? + if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index: + ft_annotations.append(fa) + supported_label = True + + # Is it a residue annotation? + if fa.label in residue_annotations_tokenizer._labels: + ra_annotations.append(fa) + supported_label = True + + if not supported_label: + raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}") + + # Convert function token FunctionAnnotations -> Tensor + function_tokens = function_tokens_tokenizer.tokenize( + annotations=ft_annotations, + seqlen=len(sequence), + ) + function_token_ids = function_tokens_tokenizer.encode( + function_tokens, add_special_tokens=add_special_tokens + ) + + # Convert residue annotation FunctionAnnotations -> Tensor + if ra_annotations: + descriptions, starts, ends = zip( + *[(anot.label, anot.start, anot.end) for anot in ra_annotations] + ) + else: + descriptions = starts = ends = None + ra_tokens = residue_annotations_tokenizer.tokenize( + { + "interpro_site_descriptions": descriptions, + "interpro_site_starts": starts, + "interpro_site_ends": ends, + }, + sequence=sequence, + fail_on_mismatch=True, + ) + residue_annotation_ids = residue_annotations_tokenizer.encode( + ra_tokens, add_special_tokens=add_special_tokens + ) + + return function_token_ids, residue_annotation_ids + + +def decode_function_tokens( + function_token_ids: torch.Tensor, + function_token_decoder: FunctionTokenDecoder, + function_tokens_tokenizer: InterProQuantizedTokenizer, + decoder_annotation_threshold: float = 0.1, + annotation_min_length: int | None = 5, + annotation_gap_merge_max: int | None = 3, +) -> list[FunctionAnnotation]: + """Decodes model prediction logits into function predictions. + + Merges function token and residue annotation predictions into a single + set of FunctionAnnotation predictions. + + Args: + function_token_ids: Tensor [length, depth] of + function token ids. + residue_annotation_logits: Tensor [length, RA-vocab] of residue + annotation binary classification logits. + function_tokens_tokenizer: InterPro annotation tokenizer. + residue_annotation_threshold: tokenizer of residue annotations. + residue_annotation_threshold: predicted probability threshold for emitting + a predicted residue annotation. + Returns: + Predicted function annotations merged from both predictions. + """ + assert ( + function_token_ids.ndim == 2 + ), "function_token_ids must be of shape (length, depth)" + + annotations: list[FunctionAnnotation] = [] + + # Function Annotations from predicted function tokens. + decoded = function_token_decoder.decode( + function_token_ids, + tokenizer=function_tokens_tokenizer, + annotation_threshold=decoder_annotation_threshold, + annotation_min_length=annotation_min_length, + annotation_gap_merge_max=annotation_gap_merge_max, + ) + + # Convert predicted InterPro annotation to FunctionAnnotation. + annotations.extend(decoded["function_keywords"]) + for annotation in decoded["interpro_annotations"]: + annotation: FunctionAnnotation + label = function_tokens_tokenizer.format_annotation(annotation) + annotations.append( + FunctionAnnotation(label=label, start=annotation.start, end=annotation.end) + ) + + return annotations + + +def decode_residue_annotation_tokens( + residue_annotations_token_ids: torch.Tensor, + residue_annotations_tokenizer: ResidueAnnotationsTokenizer, + annotation_min_length: int | None = 5, + annotation_gap_merge_max: int | None = 3, +) -> list[FunctionAnnotation]: + """Decodes residue annotation tokens into FunctionAnnotations. + + Args: + tokens: Tensor [length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens. + residue_annotations_tokenizer: Tokenizer of residue annotations. + threshold: predicted probability threshold for emitting a predicted residue + annotation. + Returns: + Predicted residue annotations. + """ + assert ( + residue_annotations_token_ids.ndim == 2 + ), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)" + + annotations: list[FunctionAnnotation] = [] + + for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS): + token_ids = residue_annotations_token_ids[:, depth] + for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy(): + label = residue_annotations_tokenizer.vocabulary[vocab_index] + if label not in [*residue_annotations_tokenizer.special_tokens, ""]: + annotation = FunctionAnnotation(label=label, start=loc, end=loc) + annotations.append(annotation) + + annotations = _merge_annotations( + annotations, + merge_gap_max=annotation_gap_merge_max, + ) + + # Drop very small annotations. + if annotation_min_length is not None: + annotations = [ + annotation + for annotation in annotations + if annotation.end - annotation.start + 1 >= annotation_min_length + ] + + return annotations diff --git a/esm/utils/function/interpro.py b/esm/utils/function/interpro.py new file mode 100644 index 0000000..ccfb8e0 --- /dev/null +++ b/esm/utils/function/interpro.py @@ -0,0 +1,245 @@ +"""Utilities for interacting with InterPro.""" + +import itertools +import re +from dataclasses import dataclass +from enum import IntEnum, auto +from functools import cached_property +from pathlib import Path + +import networkx as nx +import numpy as np +import pandas as pd + +from esm.utils.constants import esm3 as C + + +def parse_go_terms(text: str) -> list[str]: + """Parses GO terms from a string. + + Args: + text: String containing GO terms. Example: "GO:0008309, GO:1902267" Note that GO + terms have exactly 7 digits. + Returns: + All GO terms found in the string. Example: ['GO:0008309', 'GO:1902267'] + """ + return re.findall(r"GO:(?:\d{7,})", text) + + +def _parse_interpro2go(path: str) -> dict[str, list[str]]: + """Parses InterPro2GO file into map. + + NOTE: this file has a very strange, non-standard format. + + Args: + path: path to InterPro2GO file from: https://www.ebi.ac.uk/GOA/InterPro2GO + Returns: + Mapping from InterPro to list of associated GO terms. + """ + with Path(path).open("r") as f: + text = f.read() + df = pd.Series(text.split("\n"), name="line").to_frame() + df = df[~df.line.str.startswith("!")] + df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line)) + df["go_ids"] = df.line.apply(parse_go_terms) + df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)] + df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore + + # Group all mappints together into a single map. + df = ( + df.groupby("interpro_id")["go_ids"] # type: ignore + .apply(lambda group: list(itertools.chain.from_iterable(group))) + .reset_index() + ) + return dict(zip(df.interpro_id, df.go_ids)) # type: ignore + + +class InterProEntryType(IntEnum): + """InterPro types and representation counts: + + Family 21,942 + Domain 14,053 + Homologous_superfamily 3,446 + Conserved_site 728 + Repeat 374 + Active_site 133 + Binding_site 75 + PTM 17 + """ + + ACTIVE_SITE = 0 + BINDING_SITE = auto() + CONSERVED_SITE = auto() + DOMAIN = auto() + FAMILY = auto() + HOMOLOGOUS_SUPERFAMILY = auto() + PTM = auto() + REPEAT = auto() + UNKNOWN = auto() + + +@dataclass +class InterProEntry: + """Represents an InterPro entry.""" + + id: str # Example: IPR000006 + type: InterProEntryType + name: str # Example: "Metallothionein, vertebrate" + description: str | None = None + + +@dataclass(frozen=True) +class InterProRangeAnnotation: + """Represents a InterPro annotation along a range of residues in a protein.""" + + interpro_accession: str + start_idx: int + end_idx: int + + +class InterPro: + """Convenience class interacting with InterPro ontology/data.""" + + def __init__( + self, + entries_path: str | None = None, + hierarchy_path: str | None = None, + interpro2go_path: str | None = None, + ): + """Constructs interface to query InterPro entries.""" + default = lambda x, d: x if x is not None else d + self.entries_path = default(entries_path, str(C.data_root() / C.INTERPRO_ENTRY)) + self.hierarchy_graph_path = default( + hierarchy_path, str(C.data_root() / C.INTERPRO_HIERARCHY) + ) + self.interpro2go_path = default( + interpro2go_path, str(C.data_root() / C.INTERPRO2GO) + ) + + @cached_property + def interpro2go(self) -> dict[str, list[str]]: + """Reads the InterPro to GO term mapping.""" + assert self.interpro2go_path is not None + return _parse_interpro2go(self.interpro2go_path) + + @cached_property + def entries_frame(self) -> pd.DataFrame: + """Loads full InterPro entry set as a DataFrame. + + Colums are + - "id": str interpro accession /id as + - "type": InterProEntryType representing the type of annotation. + - "name": Short name of the entry. + """ + with Path(self.entries_path).open("r") as f: + df = pd.read_csv(f, sep="\t") + assert all( + col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"] + ) + df.rename( + columns={ + "ENTRY_AC": "id", + "ENTRY_TYPE": "type", + "ENTRY_NAME": "name", + }, + inplace=True, + ) + df["type"] = df.type.str.upper().apply( + lambda type_name: InterProEntryType[type_name] + ) + return df + + @cached_property + def entries(self) -> dict[str, InterProEntry]: + """Returns all InterPro entries.""" + return { + row.id: InterProEntry( # type: ignore + id=row.id, # type: ignore + type=row.type, # type: ignore + name=row.name, # type: ignore + ) + for row in self.entries_frame.itertuples() + } + + def lookup_name(self, interpro_id: str) -> str | None: + """Short name / title for an interpro id.""" + if interpro_id not in self.entries: + return None + return self.entries[interpro_id].name + + def lookup_entry_type(self, interpro_id: str) -> InterProEntryType: + """Looks up entry-type for an interpro id.""" + if interpro_id in self.entries: + return self.entries[interpro_id].type + else: + return InterProEntryType.UNKNOWN + + @cached_property + def graph(self) -> nx.DiGraph: + """Reads the InterPro hierarchy of InterPro.""" + graph = nx.DiGraph() + with Path(self.hierarchy_graph_path).open("r") as f: + parents = [] + for line in f: + ipr = line.split("::", maxsplit=1)[0] + ipr_strip = ipr.lstrip("-") + level = (len(ipr) - len(ipr_strip)) // 2 + parents = parents[:level] + graph.add_node(ipr_strip) + if parents: + graph.add_edge(ipr_strip, parents[-1]) + parents.append(ipr_strip) + return graph + + +def parse_interpro_features( + interpro_accessions: list[str], + interpro_starts: list[int], + interpro_ends: list[int], +) -> list[InterProRangeAnnotation]: + """Parses raw InterPro ranges. + + Args: + interpro_accessions: list of InterPro accessions + interpro_starts: list of one-indexed inclusive residue locations where the + annotation from `interpro_accesisons` begin. + interpro_ends: list of one-indexed *inclusive* residue locations where the + annotation from `interpro_accesisons` end. + Returns: + Collated InterProRangeAnnotations. NOTE that index conversion will convert range + bounds to zero-indexed [inclusive, exclusive) start/end indices. + """ + assert len(interpro_accessions) == len(interpro_starts) == len(interpro_ends) + + # Residue locations from Uniprot/InterPro are [inclusive, inclusive] and 1-index. + start_idcs = np.array(interpro_starts).astype(int) + end_idcs = np.array(interpro_ends).astype(int) + + # We want to use Python's convention of [inclusive, exclusive) and 0-indexing. + # Interpro residue indices are [inclusive, inclusive] and 1-indexing. + # The conversion ends up being: + # ```python + # end_idcs += 1 # [inclusive, inclusive] -> [inclusive, exclusive) + # start_idcs -= 1 # 1 -> 0 indexing + # end_idcs -= 1 # 1 -> 0 indexing + # ``` + # Which simply results in: + start_idcs -= 1 + + ranges = [] + for interpro_accession, start_idx, end_idx in zip( + interpro_accessions, start_idcs, end_idcs + ): + # NOTE: Skip unintegrated Interpro labels, for now. + if interpro_accession == "-": + continue + + ranges.append( + InterProRangeAnnotation( + interpro_accession=interpro_accession, + start_idx=start_idx, + end_idx=end_idx, + ) + ) + + return ranges diff --git a/esm/utils/function/lsh.py b/esm/utils/function/lsh.py new file mode 100644 index 0000000..60e7e45 --- /dev/null +++ b/esm/utils/function/lsh.py @@ -0,0 +1,103 @@ +from pathlib import Path + +import numpy as np + +from esm.utils.types import PathLike + + +class LSHTable: + def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): + if hyperplanes is None: + hyperplanes = np.random.randn(n_bits, dim) + hyperplanes = hyperplanes / np.linalg.norm( + hyperplanes, axis=-1, keepdims=True + ) + else: + assert hyperplanes.shape == (n_bits, dim), ( + hyperplanes.shape, + (n_bits, dim), + ) + assert hyperplanes is not None + self.hyperplanes: np.ndarray = hyperplanes + self.values = 1 << np.arange(n_bits) + + def __call__(self, array, tokenize: bool = True): + similarity = self.hyperplanes @ array.T + bits = np.where(similarity >= 0, 1, 0) + if tokenize: + tokens = bits.T @ self.values + return tokens + else: + return bits.T + + +class LSHTokenized: + def __init__( + self, + n_bits: int, + dim: int, + num_tables: int = 1, + filepath: PathLike | None = None, + allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes + ): + table_hyperplanes = None + if filepath is not None: + filepath = Path(filepath) + if not filepath.exists(): + raise FileNotFoundError(filepath) + table_hyperplanes = np.load(filepath) # type: ignore + for i in range(num_tables): + assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" + elif not allow_create_hyperplanes: + raise RuntimeError( + "Not allowed to create hyperplanes but no filepath provided" + ) + + self.tables = [ + LSHTable( + n_bits, + dim, + table_hyperplanes[str(i)] if table_hyperplanes is not None else None, + ) + for i in range(num_tables) + ] + + def write_hyperplanes(self, filepath: PathLike): + hyperplanes: dict[str, np.ndarray] = { # type: ignore + str(i): table.hyperplanes for i, table in enumerate(self.tables) + } + np.savez(filepath, **hyperplanes) + + def __call__(self, array): + tokens = np.stack([table(array) for table in self.tables], 1) + return tokens + + +class LSHBitstream: + def __init__( + self, + n_bits: int, + dim: int, + filepath: PathLike | None = None, + allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes + ): + table_hyperplanes = None + if filepath is not None: + filepath = Path(filepath) + if not filepath.exists(): + raise FileNotFoundError(filepath) + table_hyperplanes = np.load(filepath) + elif not allow_create_hyperplanes: + raise RuntimeError( + "Not allowed to create hyperplanes but no filepath provided" + ) + + self.table = LSHTable( + n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None + ) + + def write_hyperplanes(self, filepath: PathLike): + np.save(filepath, self.table.hyperplanes) + + def __call__(self, array): + return self.table(array, tokenize=False) diff --git a/esm/utils/function/tfidf.py b/esm/utils/function/tfidf.py new file mode 100644 index 0000000..bd9282e --- /dev/null +++ b/esm/utils/function/tfidf.py @@ -0,0 +1,56 @@ +"""Term-Frequency / Inverse Document Frequency (TF-IDF) model.""" + +from collections import Counter +from functools import cached_property + +import numpy as np +from scipy import sparse + + +class TFIDFModel: + """Term-Frequency / Inverse Document Frequency (TF-IDF) model. + Mimics sklearn.feature_extraction.text.TfidfVectorizer with sublinear_tf=True + """ + + def __init__(self, vocabulary_path: str, idf_path: str): + with open(vocabulary_path, "r") as f: + self.vocabulary = f.read().strip().split("\n") + + with open(idf_path, "rb") as f: + self.idf_ = np.load(f) + + assert self.idf_.ndim == 1 + assert ( + len(self.idf_) == len(self.vocabulary) + ), f"IDF size must match vocabulary size, got {len(self.idf_)} and {len(self.vocabulary)}" + + @cached_property + def vocab_to_index(self) -> dict[str, int]: + return {term: index for index, term in enumerate(self.vocabulary)} + + def encode(self, terms: list[str]) -> sparse.csr_matrix: + """Encodes terms as TF-IDF vectors. + + Args: + terms: list of terms to encode. + + Returns: + TF-IDF vector encoded as sparse matrix of shape (1, num_terms) + """ + counter = Counter(filter(self.vocabulary.__contains__, terms)) + indices = [self.vocab_to_index[term] for term in counter] + + tf = np.array([count for term, count in counter.items()]) + idf = np.take(self.idf_, indices) + + values = (1 + np.log(tf)) * idf + values /= np.linalg.norm(values) + + return sparse.csr_matrix( + (values, (np.zeros_like(indices), indices)), + shape=(1, len(self.vocabulary)), + ) + + def decode(self, vec: sparse.csr_matrix) -> list[str]: + """Extract terms from TF-IDF.""" + return [self.vocabulary[i] for i in vec.indices] diff --git a/esm/utils/generation.py b/esm/utils/generation.py new file mode 100644 index 0000000..a6f9b5a --- /dev/null +++ b/esm/utils/generation.py @@ -0,0 +1,185 @@ +from typing import Callable + +import attr +import torch +from tqdm import tqdm + +from esm.sdk.api import ( + ESM3InferenceClient, + ESMProtein, + ESMProteinTensor, + GenerationConfig, + SamplingConfig, + SamplingTrackConfig, +) +from esm.tokenization import ( + EsmTokenizerBase, + TokenizerCollectionProtocol, +) +from esm.utils.constants import esm3 as C +from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY + + +def iterative_sampling_raw( + client: ESM3InferenceClient, + input: ESMProtein, + config: GenerationConfig, +): + # Keep structure tokens + input_tokens = client.encode(input) + + output_tokens = client.generate(input_tokens, config) + + raw_protein = client.decode(output_tokens) + + track_to_sample = config.track + + if track_to_sample not in ["function", "residue_annotations"]: + # Function and residue annotation encoding/decoding is lossy + # There is no guarantee that decoding encoded tokens will yield the same input + raw_protein.function_annotations = input.function_annotations + + return raw_protein + + +def iterative_sampling_tokens( + client: ESM3InferenceClient, + input_tokens: ESMProteinTensor, + config: GenerationConfig, + tokenizers: TokenizerCollectionProtocol, +) -> ESMProteinTensor: + track_to_sample = config.track + + # Get all tracks that require sampling + all_tracks = [ + f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name + ] + + sequence_length = len(input_tokens) + device = input_tokens.device + + # Initialize schedule and masks + decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule] + sampled_tokens = attr.evolve(input_tokens) # Make a copy + + if config.condition_on_coordinates_only and input_tokens.coordinates is not None: + sampled_tokens.structure = None + + sampling_mask = torch.ones( + sequence_length, + dtype=torch.bool, + device=device, + ) + sampling_mask[0] = False + sampling_mask[-1] = False + + get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s) + if getattr(sampled_tokens, track_to_sample) is None: + if track_to_sample == "function": + dims = (sequence_length, tokenizers.function.depth) + elif track_to_sample == "residue_annotations": + dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS) + else: + dims = (sequence_length,) + masked_tokens = torch.full( + dims, + get_tokenizer(track_to_sample).mask_token_id, + dtype=torch.long, + device=device, + ) + if track_to_sample == "sequence": + masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore + masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore + else: + masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id + masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id + + setattr( + sampled_tokens, + track_to_sample, + masked_tokens, + ) + else: + is_mask: torch.Tensor = ( + getattr(input_tokens, track_to_sample) + == get_tokenizer(track_to_sample).mask_token_id + ) + if not is_mask.any().item(): + raise ValueError(f"Cannot sample {config.track} when input has no masks.") + sampling_mask = sampling_mask & is_mask + + # Decode + + def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: + return x.clone() if x is not None else None + + L = sequence_length - 2 + positions_sampled = 0 + for t in tqdm(range(config.num_steps)): + # Single step sampling at all positions + track_sample_config = SamplingTrackConfig() + track_sample_config.invalid_ids = config.invalid_ids + track_sample_config.temperature = config.temperature + track_sample_config.top_p = config.top_p + sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore + + forward_and_sample_output = client.forward_and_sample( + sampled_tokens, sampling_config + ) + new_samples = forward_and_sample_output.protein_tensor + + # Calculate number of tokens to sample + perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps)) + num_to_sample = int((1 - perc_masked) * L) - positions_sampled + positions_sampled += num_to_sample + + # Select tokens based on lowest entropy + if track_to_sample in ["function", "residue_annotations"]: + # TODO: Implement iterative decoding for function and residue_annotations + # TODO: Fix encode/decode of interpro tokens (not yet supported) + sampled_tokens.function = maybe_clone(input_tokens.function) + sampled_tokens.residue_annotations = maybe_clone( + input_tokens.residue_annotations + ) + if track_to_sample in track_to_sample: + raise NotImplementedError( + f"Iterative decoding for {track_to_sample} is not supported yet." + ) + continue + + sampling_mask = sampling_mask & ( + getattr(sampled_tokens, track_to_sample) + == get_tokenizer(track_to_sample).mask_token_id + ) + + track_entropy: torch.Tensor = getattr( + forward_and_sample_output.entropy, track_to_sample + ) + track_entropy = track_entropy.masked_fill( + ~sampling_mask, torch.finfo(track_entropy.dtype).max + ) + _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) + is_top_k = ~( + torch.arange(sequence_length, device=device)[:, None] != indices[None, :] + ).all(-1) + tokens_to_sample = sampling_mask & is_top_k + + old_track_samples = getattr(sampled_tokens, track_to_sample) + new_track_samples = getattr(new_samples, track_to_sample) + + new_track_samples = torch.where( + tokens_to_sample, new_track_samples, old_track_samples + ) + + setattr(sampled_tokens, track_to_sample, new_track_samples) + + # Do not update tracks that were not sampled (e.g. keep None instead of masks) + for track in all_tracks: + if track != track_to_sample: + setattr( + sampled_tokens, + track, + maybe_clone(getattr(input_tokens, track)), + ) + + return sampled_tokens diff --git a/esm/utils/misc.py b/esm/utils/misc.py new file mode 100644 index 0000000..5e168c2 --- /dev/null +++ b/esm/utils/misc.py @@ -0,0 +1,256 @@ +import math +from typing import ContextManager, Sequence, TypeVar + +import numpy as np +import torch + +MAX_SUPPORTED_DISTANCE = 1e6 + + +TSequence = TypeVar("TSequence", bound=Sequence) + + +def slice_python_object_as_numpy( + obj: TSequence, idx: int | list[int] | slice | np.ndarray +) -> TSequence: + """ + Slice a python object (like a list, string, or tuple) as if it was a numpy object. + + Example: + >>> obj = "ABCDE" + >>> slice_python_object_as_numpy(obj, [1, 3, 4]) + "BDE" + + >>> obj = [1, 2, 3, 4, 5] + >>> slice_python_object_as_numpy(obj, np.arange(5) < 3) + [1, 2, 3] + """ + if isinstance(idx, int): + idx = [idx] + + if isinstance(idx, np.ndarray) and idx.dtype == bool: + sliced_obj = [obj[i] for i in np.where(idx)[0]] + elif isinstance(idx, slice): + sliced_obj = obj[idx] + else: + sliced_obj = [obj[i] for i in idx] + + match obj, sliced_obj: + case str(), list(): + sliced_obj = "".join(sliced_obj) + case _: + sliced_obj = obj.__class__(sliced_obj) # type: ignore + + return sliced_obj # type: ignore + + +def rbf(values, v_min, v_max, n_bins=16): + """ + Returns RBF encodings in a new dimension at the end. + """ + rbf_centers = torch.linspace( + v_min, v_max, n_bins, device=values.device, dtype=values.dtype + ) + rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) + rbf_std = (v_max - v_min) / n_bins + z = (values.unsqueeze(-1) - rbf_centers) / rbf_std + return torch.exp(-(z**2)) + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: + return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1) + + +def knn_graph( + coords: torch.Tensor, + coord_mask: torch.Tensor, + padding_mask: torch.Tensor, + sequence_id: torch.Tensor, + *, + no_knn: int, +): + L = coords.shape[-2] + num_by_dist = min(no_knn, L) + device = coords.device + + coords = coords.nan_to_num() + coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None]) + padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None] + if sequence_id is not None: + padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze( + sequence_id, 2 + ) + dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1) + arange = torch.arange(L, device=device) + seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs() + # We only support up to a certain distance, above that, we use sequence distance + # instead. This is so that when a large portion of the structure is masked out, + # the edges are built according to sequence distance. + max_dist = MAX_SUPPORTED_DISTANCE + torch._assert_async((dists[~coord_mask] < max_dist).all()) + struct_then_seq_dist = ( + seq_dists.to(dists.dtype) + .mul(1e2) + .add(max_dist) + .where(coord_mask, dists) + .masked_fill(padding_pairwise_mask, torch.inf) + ) + dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False) + # This is a L x L tensor, where we index by rows first, + # and columns are the edges we should pick. + chosen_edges = edges[..., :num_by_dist] + chosen_mask = dists[..., :num_by_dist].isfinite() + return chosen_edges, chosen_mask + + +def stack_variable_length_tensors( + sequences: Sequence[torch.Tensor], + constant_value: int | float = 0, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Automatically stack tensors together, padding variable lengths with the + value in constant_value. Handles an arbitrary number of dimensions. + + Examples: + >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5]) + >>> stack_variable_length_tensors(tensor1, tensor2) + tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones. + + >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3]) + >>> stack_variable_length_tensors(tensor1, tensor2) + tensor of shape [2, 5, 4] + """ + batch_size = len(sequences) + shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() + + if dtype is None: + dtype = sequences[0].dtype + device = sequences[0].device + + array = torch.full(shape, constant_value, dtype=dtype, device=device) + for arr, seq in zip(array, sequences): + arrslice = tuple(slice(dim) for dim in seq.shape) + arr[arrslice] = seq + + return array + + +def unbinpack( + tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float +): + """ + Args: + tensor (Tensor): [B, L, ...] + + Returns: + Tensor: [B_unbinpacked, L_unbinpack, ...] + """ + if sequence_id is None: + return tensor + + unpacked_tensors = [] + num_sequences = sequence_id.max(dim=-1).values + 1 + for batch_idx, (batch_seqid, batch_num_sequences) in enumerate( + zip(sequence_id, num_sequences) + ): + for seqid in range(batch_num_sequences): + mask = batch_seqid == seqid + unpacked = tensor[batch_idx, mask] + unpacked_tensors.append(unpacked) + return stack_variable_length_tensors(unpacked_tensors, pad_value) + + +def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: + """ + Returns an autocast context manager that disables downcasting by AMP. + + Args: + device_type: The device type ('cpu' or 'cuda') + + Returns: + An autocast context manager with the specified behavior. + """ + if device_type == "cpu": + return torch.amp.autocast(device_type, enabled=False) + elif device_type == "cuda": + return torch.amp.autocast(device_type, dtype=torch.float32) + else: + raise ValueError(f"Unsupported device type: {device_type}") + + +def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]: + """Merge overlapping ranges into sorted, non-overlapping segments. + + Args: + ranges: collection of ranges to merge. + merge_gap_max: optionally merge neighboring ranges that are separated by a gap + no larger than this size. + Returns: + non-overlapping ranges merged from the inputs, sorted by position. + """ + ranges = sorted(ranges, key=lambda r: r.start) + merge_gap_max = merge_gap_max if merge_gap_max is not None else 0 + assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}" + + merged = [] + for r in ranges: + if not merged: + merged.append(r) + else: + last = merged[-1] + if last.stop + merge_gap_max >= r.start: + merged[-1] = range(last.start, max(last.stop, r.stop)) + else: + merged.append(r) + return merged + + +def list_nan_to_none(l: list) -> list: + if l is None: + return None # type: ignore + elif isinstance(l, float): + return None if math.isnan(l) else l # type: ignore + elif isinstance(l, list): + return [list_nan_to_none(x) for x in l] + else: + # Don't go into other structures. + return l + + +def list_none_to_nan(l: list) -> list: + if l is None: + return math.nan # type: ignore + elif isinstance(l, list): + return [list_none_to_nan(x) for x in l] + else: + return l + + +def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: + if x is None: + return None + if convert_none_to_nan: + x = list_none_to_nan(x) + return torch.tensor(x) + + +def maybe_list(x, convert_nan_to_none: bool = False) -> list | None: + if x is None: + return None + x = x.tolist() + if convert_nan_to_none: + x = list_nan_to_none(x) + return x diff --git a/esm/utils/noise_schedules.py b/esm/utils/noise_schedules.py new file mode 100644 index 0000000..b02052d --- /dev/null +++ b/esm/utils/noise_schedules.py @@ -0,0 +1,34 @@ +import math + +import torch + + +def cosine_schedule(t: torch.Tensor): + # t is a tensor of size (batch_size,) with values between 0 and 1. This is the + # schedule used in the MaskGIT paper + return torch.cos(t * math.pi * 0.5) + + +def cubic_schedule(t): + return 1 - t**3 + + +def linear_schedule(t): + return 1 - t + + +def square_root_schedule(t): + return 1 - torch.sqrt(t) + + +def square_schedule(t): + return 1 - t**2 + + +NOISE_SCHEDULE_REGISTRY = { + "cosine": cosine_schedule, + "linear": linear_schedule, + "square_root_schedule": square_root_schedule, + "cubic": cubic_schedule, + "square": square_schedule, +} diff --git a/esm/utils/residue_constants.py b/esm/utils/residue_constants.py new file mode 100644 index 0000000..86ea82e --- /dev/null +++ b/esm/utils/residue_constants.py @@ -0,0 +1,81 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py new file mode 100644 index 0000000..8db88c5 --- /dev/null +++ b/esm/utils/sampling.py @@ -0,0 +1,155 @@ +import attr +import torch +import torch.nn.functional as F + +from esm.sdk.api import ( + SamplingConfig, + SamplingTrackConfig, +) +from esm.tokenization import ( + TokenizerCollection, + get_invalid_tokenizer_ids, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS + + +def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig: + tracks = [f.name for f in attr.fields(SamplingConfig)] + sampling_config = SamplingConfig() + for current_track in tracks: + setattr( + sampling_config, + current_track, + SamplingTrackConfig( + invalid_ids=get_invalid_tokenizer_ids( + getattr(tokenizers, current_track) + ), + temperature=1.0, + top_p=1.0, + # TODO: Add different mask and padding tokens for all tracks + # Some tracks have the same pad and mask, which causes ambiguity when sampling + only_sample_masked_tokens=current_track + not in ["secondary_structure", "sasa", "function"], + ), + ) + return sampling_config + + +def sample_logits( + logits: torch.Tensor, + temperature: float | torch.Tensor, + top_p: float | torch.Tensor = 1.0, +): + """Default sampling from logits. + + Args: + logits is shape (..., vocab_size) + temperature is broadcastable to (...) + """ + + if top_p < 1.0: + logits = top_p_logits(logits, top_p=top_p) + + temperature = _tensorize_like(temperature, logits) + + if torch.all(temperature == 0): + ids = logits.argmax(-1) + return ids + + assert not torch.any(temperature == 0), "Partial temperature 0 not supported." + + batch_dims = logits.size()[:-1] + logits = logits.reshape(-1, logits.shape[-1]) + + # Sample from all logits + probs = F.softmax(logits / temperature[..., None], dim=-1) + ids = torch.multinomial(probs, 1).squeeze(1) + + ids = ids.reshape(*batch_dims) + return ids + + +def sample_function_logits( + logits: torch.Tensor, + tokenizer: InterProQuantizedTokenizer, + top_p: float | torch.Tensor = 1.0, + temperature: float | torch.Tensor = 1.0, + p_none_threshold: float = 0.05, +) -> tuple[torch.Tensor, torch.Tensor]: + [L, D, V] = logits.shape + assert D == tokenizer.depth + + if top_p < 1.0: + logits = top_p_logits(logits, top_p=top_p) + + temperature = torch.ones_like(logits[..., 0]) * temperature + + log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V) + + # Choose which positions have no predicted function. + log_p_nones = log_p[..., tokenizer.vocab_to_index[""]] # (L, D) + p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of predictions" + where_none = p_none > p_none_threshold # (L, ) + + # Set probability of to 0 for all not-none positions + none_index = tokenizer.vocab_to_index[""] + log_p[~where_none, :, none_index] = -torch.inf + + ids = torch.argmax(log_p, dim=-1) # (L, D) + ids[where_none, :] = tokenizer.vocab_to_index[""] + + return ids, log_p + + +def sample_residue_annotation_logits( + logits: torch.Tensor, annotation_threshold: float = 0.5 +) -> tuple[torch.Tensor, torch.Tensor]: + # Take top residue annotations + top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[ + ..., :MAX_RESIDUE_ANNOTATIONS + ] # (L, MAX_R) + top_residue_annotations_logprobs = torch.gather( + F.logsigmoid(logits), -1, top_residue_annotations_idx + ) # (L, MAX_R) + top_residue_annotations_probs = top_residue_annotations_logprobs.exp() + # Keep only positive predictions + is_negative = top_residue_annotations_probs < annotation_threshold + top_residue_annotations_idx[is_negative] = 0 + + top_residue_annotations_logprobs = top_residue_annotations_logprobs + + return top_residue_annotations_idx, top_residue_annotations_logprobs + + +def top_p_logits( + logits: torch.Tensor, + top_p: float | torch.Tensor, +) -> torch.Tensor: + top_p = _tensorize_like(top_p, logits) + + batch_dims = logits.size()[:-1] + logits = logits.reshape(-1, logits.shape[-1]) + + # Sort logits in descending order and extract the mask for the top_p + sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) + cumsum_logits = sorted_logits.softmax(-1).cumsum(-1) + top_p_mask = cumsum_logits <= top_p[:, None] + + # Make sure at least one token is sampled + top_p_mask[:, 0] = True + + # Mask out the logits that are not in the top_p + batch_indices_to_mask, _ = torch.where(~top_p_mask) + vocab_indices_to_mask = sorted_indices[~top_p_mask] + logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min + + return logits.reshape(*batch_dims, -1) + + +def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): + if isinstance(value, (float, int)): + value = torch.full_like(logits[..., 0], value, dtype=logits.dtype) + return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1) diff --git a/esm/utils/structure/affine3d.py b/esm/utils/structure/affine3d.py new file mode 100644 index 0000000..3792180 --- /dev/null +++ b/esm/utils/structure/affine3d.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import typing as T +from dataclasses import dataclass + +import torch +from typing_extensions import Self + +from esm.utils.misc import fp32_autocast_context + + +def maybe_compile(func, x: torch.Tensor): + # Sometimes, torch compile seems to give issues for CPU tensors... + return torch.compile(func) if x.device.type == "cuda" else func + + +@T.runtime_checkable +class Rotation(T.Protocol): + @classmethod + def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: + ... + + @classmethod + def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: + ... + + def __getitem__(self, idx: T.Any) -> Self: + ... + + @property + def tensor(self) -> torch.Tensor: + # We claim that this should be zero-cost abstraction that returns the raw tensor backing this + # object. The raw tensor should always have exactly 1 more dim than self.shape, which should be + # implemented using reshaping + ... + + @property + def shape(self) -> torch.Size: + # The "shape" of the rotation, as if it was a torch.tensor object + # This means that 1x4 quaternions are treated as size (1,) for example + ... + + def as_matrix(self) -> RotationMatrix: + ... + + def compose(self, other: Self) -> Self: + # To be safe, we force users to explicitly convert between rotation types. + ... + + def convert_compose(self, other: Self) -> Self: + # This function will automatically convert between types of rotations + ... + + def apply(self, p: torch.Tensor) -> torch.Tensor: + # rotates points by this rotation object + ... + + def invert(self) -> Self: + ... + + @property + def dtype(self) -> torch.dtype: + return self.tensor.dtype + + @property + def device(self) -> torch.device: + return self.tensor.device + + @property + def requires_grad(self) -> bool: + return self.tensor.requires_grad + + @classmethod + def _from_tensor(cls, t: torch.Tensor) -> Self: + # This function exists to simplify the below functions, esp type signatures + # Its implementation is different from Affine3D.from_tensor and does not + # autodetect rotation types. + return cls(t) # type: ignore + + def to(self, **kwargs) -> Self: + return self._from_tensor(self.tensor.to(**kwargs)) + + def detach(self, *args, **kwargs) -> Self: + return self._from_tensor(self.tensor.detach(**kwargs)) + + def tensor_apply(self, func) -> Self: + # Applys a function to the underlying tensor + return self._from_tensor( + torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1) + ) + + +class RotationMatrix(Rotation): + def __init__(self, rots: torch.Tensor): + if rots.shape[-1] == 9: + rots = rots.unflatten(-1, (3, 3)) + assert rots.shape[-1] == 3 + assert rots.shape[-2] == 3 + # Force full precision + self._rots = rots.to(torch.float32) + + @classmethod + def identity(cls, shape, **tensor_kwargs): + rots = torch.eye(3, **tensor_kwargs) + rots = rots.view(*[1 for _ in range(len(shape))], 3, 3) + rots = rots.expand(*shape, -1, -1) + return cls(rots) + + @classmethod + def random(cls, shape, **tensor_kwargs): + v1 = torch.randn((*shape, 3), **tensor_kwargs) + v2 = torch.randn((*shape, 3), **tensor_kwargs) + return cls(_graham_schmidt(v1, v2)) + + def __getitem__(self, idx: T.Any) -> RotationMatrix: + indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) + return RotationMatrix(self._rots[indices + (slice(None), slice(None))]) + + @property + def shape(self) -> torch.Size: + return self._rots.shape[:-2] + + def as_matrix(self) -> RotationMatrix: + return self + + def compose(self, other: RotationMatrix) -> RotationMatrix: + with fp32_autocast_context(self._rots.device.type): + return RotationMatrix(self._rots @ other._rots) + + def convert_compose(self, other: Rotation): + return self.compose(other.as_matrix()) + + def apply(self, p: torch.Tensor) -> torch.Tensor: + with fp32_autocast_context(self.device.type): + if self._rots.shape[-3] == 1: + # This is a slight speedup over einsum for batched rotations + return p @ self._rots.transpose(-1, -2).squeeze(-3) + else: + # einsum way faster than bmm! + return torch.einsum("...ij,...j", self._rots, p) + + def invert(self) -> RotationMatrix: + return RotationMatrix(self._rots.transpose(-1, -2)) + + @property + def tensor(self) -> torch.Tensor: + return self._rots.flatten(-2) + + def to_3x3(self) -> torch.Tensor: + return self._rots + + @staticmethod + def from_graham_schmidt( + x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12 + ) -> RotationMatrix: + # A low eps here is necessary for good stability! + return RotationMatrix( + maybe_compile(_graham_schmidt, x_axis)(x_axis, xy_plane, eps) + ) + + +@dataclass(frozen=True) +class Affine3D: + trans: torch.Tensor + rot: Rotation + + def __post_init__(self): + assert self.trans.shape[:-1] == self.rot.shape + + @staticmethod + def identity( + shape_or_affine: T.Union[tuple[int, ...], "Affine3D"], + rotation_type: T.Type[Rotation] = RotationMatrix, + **tensor_kwargs, + ): + # Creates a new identity Affine3D object with a specified shape + # or the same shape as another Affine3D object. + if isinstance(shape_or_affine, Affine3D): + kwargs = {"dtype": shape_or_affine.dtype, "device": shape_or_affine.device} + kwargs.update(tensor_kwargs) + shape = shape_or_affine.shape + rotation_type = type(shape_or_affine.rot) + else: + kwargs = tensor_kwargs + shape = shape_or_affine + return Affine3D( + torch.zeros((*shape, 3), **kwargs), rotation_type.identity(shape, **kwargs) + ) + + @staticmethod + def random( + shape: tuple[int, ...], + std: float = 1, + rotation_type: T.Type[Rotation] = RotationMatrix, + **tensor_kwargs, + ) -> "Affine3D": + return Affine3D( + trans=torch.randn((*shape, 3), **tensor_kwargs).mul(std), + rot=rotation_type.random(shape, **tensor_kwargs), + ) + + def __getitem__(self, idx: T.Any) -> "Affine3D": + indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) + return Affine3D( + trans=self.trans[indices + (slice(None),)], + rot=self.rot[idx], + ) + + @property + def shape(self) -> torch.Size: + return self.trans.shape[:-1] + + @property + def dtype(self) -> torch.dtype: + return self.trans.dtype + + @property + def device(self) -> torch.device: + return self.trans.device + + @property + def requires_grad(self) -> bool: + return self.trans.requires_grad + + def to(self, **kwargs) -> "Affine3D": + return Affine3D(self.trans.to(**kwargs), self.rot.to(**kwargs)) + + def detach(self, *args, **kwargs) -> "Affine3D": + return Affine3D(self.trans.detach(**kwargs), self.rot.detach(**kwargs)) + + def tensor_apply(self, func) -> "Affine3D": + # Applys a function to the underlying tensor + return self.from_tensor( + torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1) + ) + + def as_matrix(self): + return Affine3D(trans=self.trans, rot=self.rot.as_matrix()) + + def compose(self, other: "Affine3D", autoconvert: bool = False): + rot = self.rot + new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot) + new_trans = rot.apply(other.trans) + self.trans + return Affine3D(trans=new_trans, rot=new_rot) + + def compose_rotation(self, other: Rotation, autoconvert: bool = False): + return Affine3D( + trans=self.trans, + rot=(self.rot.convert_compose if autoconvert else self.rot.compose)(other), + ) + + def scale(self, v: torch.Tensor | float): + return Affine3D(self.trans * v, self.rot) + + def mask(self, mask: torch.Tensor, with_zero=False): + # Returns a transform where True positions in mask is identity + if with_zero: + tensor = self.tensor + return Affine3D.from_tensor( + torch.zeros_like(tensor).where(mask[..., None], tensor) + ) + else: + identity = self.identity( + self.shape, + rotation_type=type(self.rot), + device=self.device, + dtype=self.dtype, + ).tensor + return Affine3D.from_tensor(identity.where(mask[..., None], self.tensor)) + + def apply(self, p: torch.Tensor) -> torch.Tensor: + return self.rot.apply(p) + self.trans + + def invert(self): + inv_rot = self.rot.invert() + return Affine3D(trans=-inv_rot.apply(self.trans), rot=inv_rot) + + @property + def tensor(self) -> torch.Tensor: + return torch.cat([self.rot.tensor, self.trans], dim=-1) + + @staticmethod + def from_tensor(t: torch.Tensor) -> "Affine3D": + match t.shape[-1]: + case 4: + # Assume tensor 4x4 for backward compat with alphafold + trans = t[..., :3, 3] + rot = RotationMatrix(t[..., :3, :3]) + case 12: + trans = t[..., -3:] + rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3))) + case _: + raise RuntimeError( + f"Cannot detect rotation fromat from {t.shape[-1] -3}-d flat vector" + ) + return Affine3D(trans, rot) + + @staticmethod + def from_tensor_pair(t: torch.Tensor, r: torch.Tensor) -> "Affine3D": + return Affine3D(t, RotationMatrix(r)) + + @staticmethod + def from_graham_schmidt( + neg_x_axis: torch.Tensor, + origin: torch.Tensor, + xy_plane: torch.Tensor, + eps: float = 1e-10, + ): + # The arguments of this function is for parity with AlphaFold + x_axis = origin - neg_x_axis + xy_plane = xy_plane - origin + return Affine3D( + trans=origin, rot=RotationMatrix.from_graham_schmidt(x_axis, xy_plane, eps) + ) + + @staticmethod + def cat(affines: list["Affine3D"], dim: int = 0): + if dim < 0: + dim = len(affines[0].shape) + dim + return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim)) + + +def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12): + # A low eps here is necessary for good stability! + with fp32_autocast_context(x_axis.device.type): + e1 = xy_plane + + denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps) + x_axis = x_axis / denom + dot = (x_axis * e1).sum(dim=-1, keepdim=True) + e1 = e1 - x_axis * dot + denom = torch.sqrt((e1**2).sum(dim=-1, keepdim=True) + eps) + e1 = e1 / denom + e2 = torch.cross(x_axis, e1, dim=-1) + + rots = torch.stack([x_axis, e1, e2], dim=-1) + + return rots + + +def build_affine3d_from_coordinates( + coords: torch.Tensor, # (N, CA, C). +) -> tuple[Affine3D, torch.Tensor]: + _MAX_SUPPORTED_DISTANCE = 1e6 + coord_mask = torch.all( + torch.all(torch.isfinite(coords) & (coords < _MAX_SUPPORTED_DISTANCE), dim=-1), + dim=-1, + ) + + def atom3_to_backbone_affine(bb_positions: torch.Tensor) -> Affine3D: + N, CA, C = bb_positions.unbind(dim=-2) + return Affine3D.from_graham_schmidt(C, CA, N) + + coords = coords.clone().float() + coords[~coord_mask] = 0 + + # NOTE(thayes): If you have already normalized the coordinates, then + # the black hole affine translations will be zeros and the rotations will be + # the identity. + average_per_n_ca_c = coords.masked_fill(~coord_mask[..., None, None], 0).sum(1) / ( + coord_mask.sum(-1)[..., None, None] + 1e-8 + ) + affine_from_average = atom3_to_backbone_affine( + average_per_n_ca_c.float() + ).as_matrix() + + B, S, _, _ = coords.shape + assert isinstance(B, int) + assert isinstance(S, int) + affine_rot_mats = affine_from_average.rot.tensor[..., None, :].expand(B, S, 9) + affine_trans = affine_from_average.trans[..., None, :].expand(B, S, 3) + + # We use the identity rotation whereever we have no coordinates. This is + # important because otherwise the rotation matrices will be all zeros, which + # will cause collapse in the distance/direction attention mechanism. + identity_rot = RotationMatrix.identity( + (B, S), dtype=torch.float32, device=coords.device, requires_grad=False + ) + affine_rot_mats = affine_rot_mats.where( + coord_mask.any(-1)[..., None, None], identity_rot.tensor + ) + black_hole_affine = Affine3D(affine_trans, RotationMatrix(affine_rot_mats)) + + affine = atom3_to_backbone_affine(coords.float()) + affine = Affine3D.from_tensor( + affine.tensor.where(coord_mask[..., None], black_hole_affine.tensor) + ) + + return affine, coord_mask diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py new file mode 100644 index 0000000..613b69f --- /dev/null +++ b/esm/utils/structure/aligner.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING + +import numpy as np +import torch + +from esm.utils.structure.protein_structure import ( + compute_affine_and_rmsd, +) + +if TYPE_CHECKING: + from esm.utils.structure.protein_chain import ProteinChain + + +class Aligner: + def __init__( + self, + mobile: ProteinChain, + target: ProteinChain, + only_use_backbone: bool = False, + use_reflection: bool = False, + ): + """ + Aligns a mobile protein chain against a target protein chain. + + Args: + mobile (ProteinChain): Protein chain to be aligned. + target (ProteinChain): Protein chain target. + only_use_backbone (bool): Whether to only use backbone atoms. + use_reflection (bool): Whether to align to target reflection. + """ + # Check proteins must have same number of residues + assert len(mobile) == len(target) + + # Determine overlapping atoms + joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype( + bool + ) + + # Backbone atoms are first sites in atom37 representation + if only_use_backbone: + joint_atom37_mask[:, 3:] = False + + # Extract matching atom positions and convert to batched tensors + mobile_atom_tensor = ( + torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0) + ) + target_atom_tensor = ( + torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0) + ) + joint_atom37_mask = ( + torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0) + ) + + # If using reflection flip target + if use_reflection: + target_atom_tensor = -target_atom_tensor + + # Compute alignment and rmsd + affine3D, rmsd = compute_affine_and_rmsd( + mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask + ) + self._affine3D = affine3D + self._rmsd = rmsd.item() + + @property + def rmsd(self): + return self._rmsd + + def apply(self, mobile: ProteinChain) -> ProteinChain: + """Apply alignment to a protein chain""" + # Extract atom positions and convert to batched tensors + mobile_atom_tensor = ( + torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask]) + .type(torch.float32) + .unsqueeze(0) + ) + + # Transform atom arrays + aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0) + + # Rebuild atom37 positions + aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan) + aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor + + return replace(mobile, atom37_positions=aligned_atom37_positions) diff --git a/esm/utils/structure/lddt.py b/esm/utils/structure/lddt.py new file mode 100644 index 0000000..487e930 --- /dev/null +++ b/esm/utils/structure/lddt.py @@ -0,0 +1,98 @@ +import torch +from einops import rearrange + +from esm.utils import residue_constants as RC + + +def compute_lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + """ + Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically: + Nstates: + all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included. + Natoms: + LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L. + + Args: + all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions + all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions + all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists. + cutoff (float): Max distance to score lddt over. + per_residue (bool): Whether to return per-residue or full-protein lddt. + + Returns: + LDDT Tensor: + if per_residue: + Tensor[float], [(Nstates x) B x (L * Natoms)] + else: + Tensor[float], [(Nstates x) B] + """ + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) + ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * rearrange(all_atom_mask, "... a b -> ... b a") + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def compute_lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = RC.atom_order["CA"] + if all_atom_pred_pos.dim() != 3: + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + return compute_lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) diff --git a/esm/utils/structure/normalize_coordinates.py b/esm/utils/structure/normalize_coordinates.py new file mode 100644 index 0000000..6b8efd6 --- /dev/null +++ b/esm/utils/structure/normalize_coordinates.py @@ -0,0 +1,82 @@ +from typing import TypeVar + +import numpy as np +import torch +from torch import Tensor + +from esm.utils import residue_constants as RC +from esm.utils.structure.affine3d import Affine3D + +ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) + + +def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: + N, CA, C = bb_positions.unbind(dim=-2) + return Affine3D.from_graham_schmidt(C, CA, N) + + +def index_by_atom_name( + atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 +) -> ArrayOrTensor: + squeeze = False + if isinstance(atom_names, str): + atom_names = [atom_names] + squeeze = True + indices = [RC.atom_order[atom_name] for atom_name in atom_names] + dim = dim % atom37.ndim + index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) + result = atom37[index] # type: ignore + if squeeze: + result = result.squeeze(dim) + return result + + +def get_protein_normalization_frame(coords: Tensor) -> Affine3D: + """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. + Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame + using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. + + Args: + coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates + + Returns: + Affine3D: tensor of Affine3D frame + """ + bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) + coord_mask = torch.all( + torch.all(torch.isfinite(bb_coords), dim=-1), + dim=-1, + ) + + average_position_per_n_ca_c = bb_coords.masked_fill( + ~coord_mask[..., None, None], 0 + ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) + frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) + + return frame + + +def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: + """Given a set of coordinates and a single frame, apply the frame to the coordinates. + + Args: + coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates + frame (Affine3D): Affine3D frame + + Returns: + torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates + """ + coords_trans_rot = frame[..., None, None].invert().apply(coords) + + # only transform coordinates with frame that have a valid rotation + valid_frame = frame.trans.norm(dim=-1) > 0 + + is_inf = torch.isinf(coords) + coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) + coords.masked_fill_(is_inf, torch.inf) + + return coords + + +def normalize_coordinates(coords: Tensor) -> Tensor: + return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) diff --git a/esm/utils/structure/predicted_aligned_error.py b/esm/utils/structure/predicted_aligned_error.py new file mode 100644 index 0000000..2b999c1 --- /dev/null +++ b/esm/utils/structure/predicted_aligned_error.py @@ -0,0 +1,108 @@ +import torch +import torch.nn.functional as F + +from esm.utils.structure.affine3d import Affine3D + + +def masked_mean( + mask: torch.Tensor, + value: torch.Tensor, + dim: int | None | tuple[int, ...] = None, + eps=1e-10, +) -> torch.Tensor: + """Compute the mean of `value` where only positions where `mask == true` are + counted. + """ + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def _pae_bins( + max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu") +): + bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device) + step = max_bin / (num_bins - 2) + bin_centers = bins + step / 2 + bin_centers = torch.cat( + [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 + ) + return bin_centers + + +def _compute_pae_masks(mask: torch.Tensor): + square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool() + return square_mask + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + aa_mask: torch.Tensor, + sequence_id: torch.Tensor | None = None, + max_bin: float = 31, +) -> torch.Tensor: + bins = _pae_bins(max_bin, logits.shape[-1], logits.device) + square_mask = _compute_pae_masks(aa_mask) + min_v = torch.finfo(logits.dtype).min + probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) + + return (probs * bins).sum(dim=-1) + + +@torch.no_grad +def compute_tm( + logits: torch.Tensor, + aa_mask: torch.Tensor, + max_bin: float = 31.0, +): + square_mask = _compute_pae_masks(aa_mask) + seqlens = aa_mask.sum(-1, keepdim=True) + bins = _pae_bins(max_bin, logits.shape[-1], logits.device) + d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8 + f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2) + + min_v = torch.finfo(logits.dtype).min + probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) + # This is the sum over bins + ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1) + # This is the mean over residues j + ptm = masked_mean(square_mask, ptm, dim=-1) + # The we do a max over residues i + return ptm.max(dim=-1).values + + +def tm_loss( + logits: torch.Tensor, + pred_affine: torch.Tensor, + targ_affine: torch.Tensor, + targ_mask: torch.Tensor, + tm_mask: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, + max_bin: float = 31, +): + pred = Affine3D.from_tensor(pred_affine) + targ = Affine3D.from_tensor(targ_affine) + + def transform(affine: Affine3D): + pts = affine.trans[..., None, :, :] + return affine.invert()[..., None].apply(pts) + + with torch.no_grad(): + sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1) + + num_bins = logits.shape[-1] + sq_bins = torch.linspace( + 0, max_bin, num_bins - 1, device=logits.device + ).square() + # Gets the bin id by using a sum. + true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long() + + errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none") + square_mask = _compute_pae_masks(targ_mask) + loss = masked_mean(square_mask, errors, dim=(-1, -2)) + + if tm_mask is not None: + loss = masked_mean(tm_mask, loss, dim=None) + else: + loss = loss.mean() + + return loss diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py new file mode 100644 index 0000000..ab5ea6d --- /dev/null +++ b/esm/utils/structure/protein_chain.py @@ -0,0 +1,828 @@ +from __future__ import annotations + +import io +from dataclasses import asdict, dataclass, replace +from functools import cached_property +from pathlib import Path +from typing import Sequence, TypeVar, Union + +import biotite.structure as bs +import brotli +import msgpack +import msgpack_numpy +import numpy as np +import torch +from Bio.Data import PDBData +from biotite.application.dssp import DsspApp +from biotite.database import rcsb +from biotite.structure.io.npz import NpzFile +from biotite.structure.io.pdb import PDBFile +from scipy.spatial.distance import pdist, squareform +from torch import Tensor + +from esm.utils import residue_constants as RC +from esm.utils.constants import esm3 as C +from esm.utils.misc import slice_python_object_as_numpy +from esm.utils.structure.affine3d import Affine3D +from esm.utils.structure.aligner import Aligner +from esm.utils.structure.lddt import compute_lddt_ca +from esm.utils.structure.normalize_coordinates import ( + apply_frame_to_coords, + get_protein_normalization_frame, + normalize_coordinates, +) + +msgpack_numpy.patch() + +CHAIN_ID_CONST = "A" + + +ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) +PathLike = Union[str, Path] +PathOrBuffer = Union[PathLike, io.StringIO] + + +def index_by_atom_name( + atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 +) -> ArrayOrTensor: + squeeze = False + if isinstance(atom_names, str): + atom_names = [atom_names] + squeeze = True + indices = [RC.atom_order[atom_name] for atom_name in atom_names] + dim = dim % atom37.ndim + index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) + result = atom37[index] # type: ignore + if squeeze: + result = result.squeeze(dim) + return result + + +def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143): + """ + Inspired by a util in trDesign: + https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92 + + input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral + output: 4th coord + """ + norm = lambda x: x / np.sqrt(np.square(x).sum(-1, keepdims=True) + 1e-8) + with np.errstate(invalid="ignore"): # inf - inf = nan is ok here + vec_bc = N - Ca + vec_ba = N - C + bc = norm(vec_bc) + n = norm(np.cross(vec_ba, bc)) + m = [bc, np.cross(n, bc), n] + d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] + return Ca + sum([m * d for m, d in zip(m, d)]) + + +class AtomIndexer: + def __init__(self, structure: ProteinChain, property: str, dim: int): + self.structure = structure + self.property = property + self.dim = dim + + def __getitem__(self, atom_names: str | list[str]) -> np.ndarray: + return index_by_atom_name( + getattr(self.structure, self.property), atom_names, self.dim + ) + + +@dataclass +class ProteinChain: + """Dataclass with atom37 representation of a single protein chain.""" + + id: str + sequence: str + chain_id: str # author chain id + entity_id: int | None + residue_index: np.ndarray + insertion_code: np.ndarray + atom37_positions: np.ndarray + atom37_mask: np.ndarray + confidence: np.ndarray + + def __post_init__(self): + self.atom37_mask = self.atom37_mask.astype(bool) + assert self.atom37_positions.shape[0] == len(self.sequence), ( + self.atom37_positions.shape, + len(self.sequence), + ) + assert self.atom37_mask.shape[0] == len(self.sequence), ( + self.atom37_mask.shape, + len(self.sequence), + ) + assert self.residue_index.shape[0] == len(self.sequence), ( + self.residue_index.shape, + len(self.sequence), + ) + assert self.insertion_code.shape[0] == len(self.sequence), ( + self.insertion_code.shape, + len(self.sequence), + ) + assert self.confidence.shape[0] == len(self.sequence), ( + self.confidence.shape, + len(self.sequence), + ) + + @cached_property + def atoms(self) -> AtomIndexer: + return AtomIndexer(self, property="atom37_positions", dim=-2) + + @cached_property + def atom_mask(self) -> AtomIndexer: + return AtomIndexer(self, property="atom37_mask", dim=-1) + + @cached_property + def atom_array(self) -> bs.AtomArray: + atoms = [] + for res_name, res_idx, ins_code, positions, mask, conf in zip( + self.sequence, + self.residue_index, + self.insertion_code, + self.atom37_positions, + self.atom37_mask.astype(bool), + self.confidence, + ): + for i, pos in zip(np.where(mask)[0], positions[mask]): + atom = bs.Atom( + coord=pos, + chain_id="A" if self.chain_id is None else self.chain_id, + res_id=res_idx, + ins_code=ins_code, + res_name=RC.restype_1to3.get(res_name, "UNK"), + hetero=False, + atom_name=RC.atom_types[i], + element=RC.atom_types[i][0], + b_factor=conf, + ) + atoms.append(atom) + return bs.array(atoms) + + @cached_property + def residue_index_no_insertions(self) -> np.ndarray: + return self.residue_index + np.cumsum(self.insertion_code != "") + + @cached_property + def atom_array_no_insertions(self) -> bs.AtomArray: + atoms = [] + for res_idx, (res_name, positions, mask, conf) in enumerate( + zip( + self.sequence, + self.atom37_positions, + self.atom37_mask.astype(bool), + self.confidence, + ) + ): + for i, pos in zip(np.where(mask)[0], positions[mask]): + atom = bs.Atom( + coord=pos, + # hard coded to as we currently only support single chain structures + chain_id=CHAIN_ID_CONST, + res_id=res_idx + 1, + res_name=RC.restype_1to3.get(res_name, "UNK"), + hetero=False, + atom_name=RC.atom_types[i], + element=RC.atom_types[i][0], + b_factor=conf, + ) + atoms.append(atom) + return bs.array(atoms) + + def __getitem__(self, idx: int | list[int] | slice | np.ndarray): + if isinstance(idx, int): + idx = [idx] + + sequence = slice_python_object_as_numpy(self.sequence, idx) + return replace( + self, + sequence=sequence, + residue_index=self.residue_index[..., idx], + insertion_code=self.insertion_code[..., idx], + atom37_positions=self.atom37_positions[..., idx, :, :], + atom37_mask=self.atom37_mask[..., idx, :], + confidence=self.confidence[..., idx], + ) + + def __len__(self): + return len(self.sequence) + + def cbeta_contacts(self, distance_threshold: float = 8.0) -> np.ndarray: + distance = self.pdist_CB + contacts = (distance < distance_threshold).astype(np.int64) + contacts[np.isnan(distance)] = -1 + contacts = squareform(contacts) + np.fill_diagonal(contacts, -1) + return contacts + + def to_npz(self, path: PathOrBuffer): + f = NpzFile() + f.set_structure(self.atom_array) + f.write(path) + + def to_npz_string(self): + f = NpzFile() + f.set_structure(self.atom_array) + buf = io.BytesIO() + f.write(buf) + return buf.getvalue() + + def to_structure_encoder_inputs( + self, + should_normalize_coordinates: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + coords = torch.tensor(self.atom37_positions, dtype=torch.float32) + plddt = torch.tensor(self.confidence, dtype=torch.float32) + residue_index = torch.tensor(self.residue_index, dtype=torch.long) + + if should_normalize_coordinates: + coords = normalize_coordinates(coords) + return coords.unsqueeze(0), plddt.unsqueeze(0), residue_index.unsqueeze(0) + + def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True): + """Dssp works better w/o insertions.""" + f = PDBFile() + if not include_insertions: + f.set_structure(self.atom_array_no_insertions) + else: + f.set_structure(self.atom_array) + f.write(path) + + def to_pdb_string(self, include_insertions: bool = True) -> str: + buf = io.StringIO() + self.to_pdb(buf, include_insertions=include_insertions) + buf.seek(0) + return buf.read() + + def state_dict(self, backbone_only=False): + """This state dict is optimized for storage, so it turns things to fp16 whenever + possible. Note that we also only support int32 residue indices, I'm hoping we don't + need more than 2**32 residues...""" + dct = {k: v for k, v in asdict(self).items()} + for k, v in dct.items(): + if isinstance(v, np.ndarray): + match v.dtype: + case np.int64: + dct[k] = v.astype(np.int32) + case np.float64 | np.float32: + dct[k] = v.astype(np.float16) + case _: + pass + if backbone_only: + dct["atom37_mask"][:, 3:] = False + dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]] + return dct + + def to_blob(self, backbone_only=False) -> bytes: + return brotli.compress(msgpack.dumps(self.state_dict(backbone_only))) + + @classmethod + def from_state_dict(cls, dct): + atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan) + atom37[dct["atom37_mask"]] = dct["atom37_positions"] + dct["atom37_positions"] = atom37 + dct = { + k: (v.astype(np.float32) if k in ["atom37_positions", "confidence"] else v) + for k, v in dct.items() + } + return cls(**dct) + + @classmethod + def from_blob(cls, input: Path | str | io.BytesIO | bytes): + """NOTE: blob + sparse coding + brotli + fp16 reduces memory + of chains from 52G/1M chains to 20G/1M chains, I think this is a good first + shot at compressing and dumping chains to disk. I'm sure there's better ways.""" + match input: + case Path() | str(): + bytes = Path(input).read_bytes() + case io.BytesIO(): + bytes = input.getvalue() + case _: + bytes = input + return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes))) + + def dssp(self): + dssp = DsspApp.annotate_sse(self.atom_array_no_insertions) + full_dssp = np.full(len(self.sequence), "X", dtype=" float | np.ndarray: + """Compute the LDDT between this protein chain and another. + + Arguments: + target (ProteinChain): The other protein chain to compare to. + mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices + target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices + + Returns: + float | np.ndarray: The LDDT score between the two protein chains, either + a single float or per-residue LDDT scores if `per_residue` is True. + """ + + lddt = compute_lddt_ca( + torch.tensor(self.atom37_positions[mobile_inds]).unsqueeze(0), + torch.tensor(target.atom37_positions[target_inds]).unsqueeze(0), + torch.tensor(self.atom37_mask[mobile_inds]).unsqueeze(0), + **kwargs, + ) + return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten() + + @classmethod + def from_atom37( + cls, + atom37_positions: np.ndarray | torch.Tensor, + *, + id: str | None = None, + sequence: str | None = None, + chain_id: str | None = None, + entity_id: int | None = None, + residue_index: np.ndarray | torch.Tensor | None = None, + insertion_code: np.ndarray | None = None, + confidence: np.ndarray | torch.Tensor | None = None, + ): + if isinstance(atom37_positions, torch.Tensor): + atom37_positions = atom37_positions.cpu().numpy() + if atom37_positions.ndim == 4: + if atom37_positions.shape[0] != 1: + raise ValueError( + f"Cannot handle batched inputs, atom37_positions has shape {atom37_positions.shape}" + ) + atom37_positions = atom37_positions[0] + + assert isinstance(atom37_positions, np.ndarray) + seqlen = atom37_positions.shape[0] + + atom_mask = np.isfinite(atom37_positions).all(-1) + + if id is None: + id = "" + + if sequence is None: + sequence = "A" * seqlen + + if chain_id is None: + chain_id = "A" + + if residue_index is None: + residue_index = np.arange(1, seqlen + 1) + elif isinstance(residue_index, torch.Tensor): + residue_index = residue_index.cpu().numpy() + assert isinstance(residue_index, np.ndarray) + if residue_index.ndim == 2: + if residue_index.shape[0] != 1: + raise ValueError( + f"Cannot handle batched inputs, residue_index has shape {residue_index.shape}" + ) + residue_index = residue_index[0] + assert isinstance(residue_index, np.ndarray) + + if insertion_code is None: + insertion_code = np.array(["" for _ in range(seqlen)]) + + if confidence is None: + confidence = np.ones(seqlen, dtype=np.float32) + elif isinstance(confidence, torch.Tensor): + confidence = confidence.cpu().numpy() + assert isinstance(confidence, np.ndarray) + if confidence.ndim == 2: + if confidence.shape[0] != 1: + raise ValueError( + f"Cannot handle batched inputs, confidence has shape {confidence.shape}" + ) + confidence = confidence[0] + assert isinstance(confidence, np.ndarray) + + return cls( + id=id, + sequence=sequence, + chain_id=chain_id, + entity_id=entity_id, + atom37_positions=atom37_positions, + atom37_mask=atom_mask, + residue_index=residue_index, + insertion_code=insertion_code, + confidence=confidence, + ) + + @classmethod + def from_backbone_atom_coordinates( + cls, + backbone_atom_coordinates: np.ndarray | torch.Tensor, + **kwargs, + ): + """Create a ProteinChain from a set of backbone atom coordinates. + + This function simply expands the seqlen x 3 x 3 array of backbone atom + coordinates to a seqlen x 37 x 3 array of all atom coordinates, with the padded + positions set to infinity. This allows us to use from_atom37 to create the + appropriate ProteinChain object with the appropriate atom37_mask. + + This function passes all kwargs to from_atom37. + """ + if isinstance(backbone_atom_coordinates, torch.Tensor): + backbone_atom_coordinates = backbone_atom_coordinates.cpu().numpy() + if backbone_atom_coordinates.ndim == 4: + if backbone_atom_coordinates.shape[0] != 1: + raise ValueError( + f"Cannot handle batched inputs, backbone_atom_coordinates has " + f"shape {backbone_atom_coordinates.shape}" + ) + backbone_atom_coordinates = backbone_atom_coordinates[0] + + assert isinstance(backbone_atom_coordinates, np.ndarray) + assert backbone_atom_coordinates.ndim == 3 + assert backbone_atom_coordinates.shape[-2] == 3 + assert backbone_atom_coordinates.shape[-1] == 3 + + atom37_positions = np.full( + (backbone_atom_coordinates.shape[0], 37, 3), + np.inf, + dtype=backbone_atom_coordinates.dtype, + ) + atom37_positions[:, :3, :] = backbone_atom_coordinates + + return cls.from_atom37( + atom37_positions=atom37_positions, + **kwargs, + ) + + @classmethod + def from_pdb( + cls, + path: PathOrBuffer, + chain_id: str = "detect", + id: str | None = None, + is_predicted: bool = False, + ) -> "ProteinChain": + """Return a ProteinStructure object from an pdb file. + + Args: + path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed. + id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise. + is_predicted (bool): If True, reads b factor as the confidence readout. Default: False. + chain_id (str, optional): Select a chain corresponding to (author) chain id. "detect" uses the + first detected chain + """ + + if id is not None: + file_id = id + else: + match path: + case Path() | str(): + file_id = Path(path).with_suffix("").name + case _: + file_id = "null" + + atom_array = PDBFile.read(path).get_structure( + model=1, extra_fields=["b_factor"] + ) + if chain_id == "detect": + chain_id = atom_array.chain_id[0] + atom_array = atom_array[ + bs.filter_amino_acids(atom_array) + & ~atom_array.hetero + & (atom_array.chain_id == chain_id) + ] + + entity_id = 1 # Not supplied in PDBfiles + + sequence = "".join( + ( + r + if len(r := PDBData.protein_letters_3to1.get(monomer[0].res_name, "X")) + == 1 + else "X" + ) + for monomer in bs.residue_iter(atom_array) + ) + num_res = len(sequence) + + atom_positions = np.full( + [num_res, RC.atom_type_num, 3], + np.nan, + dtype=np.float32, + ) + atom_mask = np.full( + [num_res, RC.atom_type_num], + False, + dtype=bool, + ) + residue_index = np.full([num_res], -1, dtype=np.int64) + insertion_code = np.full([num_res], "", dtype=" "ProteinChain": + """A simple converter from bs.AtomArray -> ProteinChain. + Uses PDB file format as intermediate.""" + pdb_file = bs.io.pdb.PDBFile() # pyright: ignore + pdb_file.set_structure(atom_array) + + buf = io.StringIO() + pdb_file.write(buf) + buf.seek(0) + return cls.from_pdb(buf, id=id) + + def get_normalization_frame(self) -> Affine3D: + """Given a set of coordinates, compute a single frame. + Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. + + Returns: + Affine3D: [] tensor of Affine3D frame + """ + coords = torch.from_numpy(self.atom37_positions) + frame = get_protein_normalization_frame(coords) + + return frame + + def apply_frame(self, frame: Affine3D) -> ProteinChain: + """Given a frame, apply the frame to the protein's coordinates. + + Args: + frame (Affine3D): [] tensor of Affine3D frame + + Returns: + ProteinChain: Transformed protein chain + """ + coords = torch.from_numpy(self.atom37_positions).to(frame.trans.dtype) + coords = apply_frame_to_coords(coords, frame) + atom37_positions = coords.numpy() + return replace(self, atom37_positions=atom37_positions) + + def normalize_coordinates(self) -> ProteinChain: + """Normalize the coordinates of the protein chain.""" + return self.apply_frame(self.get_normalization_frame()) + + def infer_oxygen(self) -> ProteinChain: + """Oxygen position is fixed given N, CA, C atoms. Infer it if not provided.""" + O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32) + N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1) + N = torch.roll(N, -3) + N[..., -1, :] = torch.nan + + # Get the frame defined by the CA-C-N atom + frames = Affine3D.from_graham_schmidt(CA, C, N) + O = frames.apply(O_vector) + atom37_positions = self.atom37_positions.copy() + atom37_mask = self.atom37_mask.copy() + + atom37_positions[:, RC.atom_order["O"]] = O.numpy() + atom37_mask[:, RC.atom_order["O"]] = ~np.isnan( + atom37_positions[:, RC.atom_order["O"]] + ).any(-1) + new_chain = replace( + self, atom37_positions=atom37_positions, atom37_mask=atom37_mask + ) + return new_chain + + @cached_property + def inferred_cbeta(self) -> np.ndarray: + """Infer cbeta positions based on N, C, CA.""" + N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0) + # See usage in trDesign codebase. + # https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140 + CB = infer_CB(C, N, CA, 1.522, 1.927, -2.143) + return CB + + def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinChain: + """Return a new chain with inferred CB atoms at all residues except GLY. + + Args: + infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine + residues, even though that residue doesn't have one. Default off. + + NOTE: The reason for having this switch in the first place + is that sometimes we want a (inferred) CB coordinate for every residue, + for example for making a pairwise distance matrix, or doing an RMSD + calculation between two designs for a given structural template, w/ + CB atoms. + """ + atom37_positions = self.atom37_positions.copy() + atom37_mask = self.atom37_mask.copy() + + inferred_cbeta_positions = self.inferred_cbeta + if not infer_cbeta_for_glycine: + inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.NAN + + atom37_positions[:, RC.atom_order["CB"]] = inferred_cbeta_positions + atom37_mask[:, RC.atom_order["CB"]] = ~np.isnan( + atom37_positions[:, RC.atom_order["CB"]] + ).any(-1) + new_chain = replace( + self, atom37_positions=atom37_positions, atom37_mask=atom37_mask + ) + return new_chain + + @cached_property + def pdist_CA(self) -> np.ndarray: + CA = self.atoms["CA"] + pdist_CA = squareform(pdist(CA)) + return pdist_CA + + @cached_property + def pdist_CB(self) -> np.ndarray: + pdist_CB = squareform(pdist(self.inferred_cbeta)) + return pdist_CB + + @classmethod + def as_complex(cls, chains: Sequence[ProteinChain]): + raise RuntimeError( + ".as_complex() has been deprecated in favor of .concat(). " + ".concat() will eventually be deprecated in favor of ProteinComplex..." + ) + + @classmethod + def concat(cls, chains: Sequence[ProteinChain]): + def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray): + full_array = [] + for array in arrays: + full_array.append(array) + full_array.append(sep) + full_array = full_array[:-1] + return np.concatenate(full_array, 0) + + sep_tokens = { + "residue_index": np.array([-1]), + "insertion_code": np.array([""]), + "atom37_positions": np.full([1, 37, 3], np.inf), + "atom37_mask": np.zeros([1, 37]), + "confidence": np.array([0]), + } + + array_args: dict[str, np.ndarray] = { + name: join_arrays([getattr(chain, name) for chain in chains], sep) + for name, sep in sep_tokens.items() + } + + return cls( + id=chains[0].id, + sequence=C.CHAIN_BREAK_STR.join(chain.sequence for chain in chains), + chain_id="A", + entity_id=None, + **array_args, + ) + + def select_residue_indices( + self, indices: list[int | str], ignore_x_mismatch: bool = False + ) -> ProteinChain: + numeric_indices = [ + idx if isinstance(idx, int) else int(idx[1:]) for idx in indices + ] + mask = np.isin(self.residue_index, numeric_indices) + new = self[mask] + mismatches = [] + for aa, idx in zip(new.sequence, indices): + if isinstance(idx, int): + continue + if aa == "X" and ignore_x_mismatch: + continue + if aa != idx[0]: + mismatches.append((aa, idx)) + if mismatches: + mismatch_str = "; ".join( + f"Position {idx[1:]}, Expected: {idx[0]}, Received: {aa}" + for aa, idx in mismatches + ) + raise RuntimeError(mismatch_str) + + return new diff --git a/esm/utils/structure/protein_structure.py b/esm/utils/structure/protein_structure.py new file mode 100644 index 0000000..6bea8ad --- /dev/null +++ b/esm/utils/structure/protein_structure.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +from typing import Tuple, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import autocast # type: ignore + +from esm.utils import residue_constants +from esm.utils.misc import unbinpack +from esm.utils.structure.affine3d import Affine3D + +ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) + + +def index_by_atom_name( + atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 +) -> ArrayOrTensor: + squeeze = False + if isinstance(atom_names, str): + atom_names = [atom_names] + squeeze = True + indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names] + dim = dim % atom37.ndim + index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) + result = atom37[index] # type: ignore + if squeeze: + result = result.squeeze(dim) + return result + + +def infer_cbeta_from_atom37( + atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143 +): + """ + Inspired by a util in trDesign: + https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92 + + input: atom37, (L)ength, (A)ngle, and (D)ihedral + output: 4th coord + """ + N = index_by_atom_name(atom37, "N", dim=-2) + CA = index_by_atom_name(atom37, "CA", dim=-2) + C = index_by_atom_name(atom37, "C", dim=-2) + + if isinstance(atom37, np.ndarray): + + def normalize(x: ArrayOrTensor): + return x / np.linalg.norm(x, axis=-1, keepdims=True) + + cross = np.cross + else: + normalize = F.normalize # type: ignore + cross = torch.cross + + with np.errstate(invalid="ignore"): # inf - inf = nan is ok here + vec_nca = N - CA + vec_nc = N - C + nca = normalize(vec_nca) + n = normalize(cross(vec_nc, nca)) # type: ignore + m = [nca, cross(n, nca), n] + d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] + return CA + sum([m * d for m, d in zip(m, d)]) + + +@torch.no_grad() +@autocast(enabled=False) +def compute_alignment_tensors( + mobile: torch.Tensor, + target: torch.Tensor, + atom_exists_mask: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, +): + """ + Align two batches of structures with support for masking invalid atoms using PyTorch. + + Args: + - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) + - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) + - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) + - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. + + Returns: + - centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3) + - centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3) + - centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3) + - centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3) + - rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3) + - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) + """ + + # Ensure both batches have the same number of structures, atoms, and dimensions + if sequence_id is not None: + mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan) + target = unbinpack(target, sequence_id, pad_value=torch.nan) + if atom_exists_mask is not None: + atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0) + else: + atom_exists_mask = torch.isfinite(target).all(-1) + + assert mobile.shape == target.shape, "Batch structure shapes do not match!" + + # Number of structures in the batch + batch_size = mobile.shape[0] + + # if [B, Nres, Natom, 3], resize + if mobile.dim() == 4: + mobile = mobile.view(batch_size, -1, 3) + if target.dim() == 4: + target = target.view(batch_size, -1, 3) + if atom_exists_mask is not None and atom_exists_mask.dim() == 3: + atom_exists_mask = atom_exists_mask.view(batch_size, -1) + + # Number of atoms + num_atoms = mobile.shape[1] + + # Apply masks if provided + if atom_exists_mask is not None: + mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) + target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) + else: + atom_exists_mask = torch.ones( + batch_size, num_atoms, dtype=torch.bool, device=mobile.device + ) + + num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True) + # Compute centroids for each batch + centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) + centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) + + # Handle potential division by zero if all atoms are invalid in a structure + centroid_mobile[num_valid_atoms == 0] = 0 + centroid_target[num_valid_atoms == 0] = 0 + + # Center structures by subtracting centroids + centered_mobile = mobile - centroid_mobile + centered_target = target - centroid_target + + centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) + centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) + + # Compute covariance matrix for each batch + covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target) + + # Singular Value Decomposition for each batch + u, _, v = torch.svd(covariance_matrix) + + # Calculate rotation matrices for each batch + rotation_matrix = torch.matmul(u, v.transpose(1, 2)) + + return ( + centered_mobile, + centroid_mobile, + centered_target, + centroid_target, + rotation_matrix, + num_valid_atoms, + ) + + +@torch.no_grad() +@autocast(enabled=False) +def compute_rmsd_no_alignment( + aligned: torch.Tensor, + target: torch.Tensor, + num_valid_atoms: torch.Tensor, + reduction: str = "batch", +) -> torch.Tensor: + """ + Compute RMSD between two batches of structures without alignment. + + Args: + - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) + - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) + - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) + - reduction (str): One of "batch", "per_sample", "per_residue". + + Returns: + + If reduction == "batch": + (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch + If reduction == "per_sample": + (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch + If reduction == "per_residue": + (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch + """ + if reduction not in ("per_residue", "per_sample", "batch"): + raise ValueError("Unrecognized reduction: '{reduction}'") + # Compute RMSD for each batch + diff = aligned - target + if reduction == "per_residue": + mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1) + else: + mean_squared_error = diff.square().sum(dim=(1, 2)) / ( + num_valid_atoms.squeeze(-1) * 3 + ) + + rmsd = torch.sqrt(mean_squared_error) + if reduction in ("per_sample", "per_residue"): + return rmsd + elif reduction == "batch": + avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / ( + (num_valid_atoms > 0).sum() + 1e-8 + ) + return avg_rmsd + else: + raise ValueError(reduction) + + +@torch.no_grad() +@autocast(enabled=False) +def compute_affine_and_rmsd( + mobile: torch.Tensor, + target: torch.Tensor, + atom_exists_mask: torch.Tensor | None = None, + sequence_id: torch.Tensor | None = None, +) -> Tuple[Affine3D, torch.Tensor]: + """ + Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch. + + Args: + - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) + - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) + - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) + - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. + + Returns: + - affine (Affine3D): Transformation between mobile and target structure + - avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch + """ + + ( + centered_mobile, + centroid_mobile, + centered_target, + centroid_target, + rotation_matrix, + num_valid_atoms, + ) = compute_alignment_tensors( + mobile=mobile, + target=target, + atom_exists_mask=atom_exists_mask, + sequence_id=sequence_id, + ) + + # Apply rotation to mobile centroid + translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target + affine = Affine3D.from_tensor_pair( + translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1) + ) + + # Apply transformation to centered structure to compute rmsd + rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) + avg_rmsd = compute_rmsd_no_alignment( + rotated_mobile, + centered_target, + num_valid_atoms, + reduction="batch", + ) + + return affine, avg_rmsd diff --git a/esm/utils/types.py b/esm/utils/types.py new file mode 100644 index 0000000..96ba329 --- /dev/null +++ b/esm/utils/types.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import io +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +PathLike = Union[str, Path] +PathOrBuffer = Union[PathLike, io.StringIO] + + +@dataclass +class FunctionAnnotation: + """Represents an annotation of a protein's function over a range of residues. + + Fields: + label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs + start (int): Start index of this annotation. 1-indexed, inclusive. + end (int): End index of this annotation. 1-indexed, inclusive. + """ + + label: str + start: int + end: int + + def to_tuple(self) -> tuple[str, int, int]: + return self.label, self.start, self.end diff --git a/examples/generate.ipynb b/examples/generate.ipynb new file mode 100644 index 0000000..399f2f7 --- /dev/null +++ b/examples/generate.ipynb @@ -0,0 +1,554 @@ +{ + "cells": [ + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ESM3\n", + "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", + "\n", + "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n", + "\n", + "\n", + "\n", + "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", + "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%set_env TOKENIZERS_PARALLELISM=false\n", + "!pip install esm\n", + "import numpy as np\n", + "import torch\n", + "!pip install py3Dmol\n", + "import py3Dmol\n", + "from huggingface_hub import login\n", + "\n", + "from esm.utils.structure.protein_chain import ProteinChain\n", + "from esm.models.esm3 import ESM3\n", + "from esm.sdk.api import (\n", + " ESMProtein,\n", + " GenerationConfig,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load `esm-open-small` on GPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "login(token=\"YOUR_TOKEN_HERE\")\n", + "model = ESM3.from_pretrained(\"esm3_sm_open_v1\", device=torch.device(\"cuda\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n", + "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n", + "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", + "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(renal_dipep_chain.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein. \n", + "\n", + "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n", + "print(renal_dipep_chain.atom37_positions[:3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can visualize the protein chain using the `py3Dmol` library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# First we can create a `py3Dmol` view object\n", + "view = py3Dmol.view(width=500, height=500)\n", + "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n", + "pdb_str = renal_dipep_chain.to_pdb_string()\n", + "# Load the PDB string into the `py3Dmol` view object\n", + "view.addModel(pdb_str, \"pdb\")\n", + "# Set the style of the protein chain\n", + "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", + "# Zoom in on the protein chain\n", + "view.zoomTo()\n", + "# Display the protein chain\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "motif_inds = np.arange(123, 146)\n", + "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n", + "motif_sequence = renal_dipep_chain[motif_inds].sequence\n", + "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n", + "print(\"Motif sequence: \", motif_sequence)\n", + "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(pdb_str, \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "motif_res_inds = (motif_inds + 1).tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt_length = 200\n", + "# First, we can construct a sequence prompt of all masks\n", + "sequence_prompt = [\"_\"]*prompt_length\n", + "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n", + "sequence_prompt[72:72+len(motif_sequence)] = list(motif_sequence)\n", + "sequence_prompt = \"\".join(sequence_prompt)\n", + "print(\"Sequence prompt: \", sequence_prompt)\n", + "print(\"Length of sequence prompt: \", len(sequence_prompt))\n", + "\n", + "# Next, we can construct a structure prompt of all nan coordinates\n", + "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n", + "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n", + "structure_prompt[72:72+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)\n", + "print(\"Structure prompt shape: \", structure_prompt.shape)\n", + "print(\"Indices with structure conditioning: \", torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist())\n", + "\n", + "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n", + "sequence_generation_config = GenerationConfig(\n", + " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", + " num_steps=sequence_prompt.count(\"_\") // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n", + " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n", + ")\n", + "\n", + "# Now, we can use the `generate` method of the model to decode the sequence\n", + "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n", + "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n", + "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prediction_config = GenerationConfig(\n", + " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", + " num_steps=len(sequence_generation) // 8,\n", + " temperature=0.7, \n", + ")\n", + "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n", + "structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert the generated structure to a back into a ProteinChain object\n", + "structure_prediction_chain = structure_prediction.to_protein_chain()\n", + "# Align the generated structure to the original structure using the motif residues\n", + "motif_inds_in_generation = np.arange(72, 72+len(motif_sequence))\n", + "structure_prediction_chain.align(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n", + "crmsd = structure_prediction_chain.rmsd(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n", + "print(\"cRMSD of the motif in the generated structure vs the original structure: \", crmsd)\n", + "\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n", + "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n", + "view.addStyle({\"resi\": (motif_inds_in_generation+1).tolist()}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 1))\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Secondary Structure Editing Example: Helix Shortening" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "helix_region = np.arange(38, 111) # zero-indexed\n", + "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}})\n", + "view.zoomTo()\n", + "view.show()\n", + "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", + "print(\"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\", helix_shortening_ss8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shortened_region_length = 45\n", + "\n", + "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n", + "sequence_prompt = helix_shortening_chain.sequence[:helix_region[0]] + \"_\" * shortened_region_length + helix_shortening_chain.sequence[helix_region[-1] + 1:]\n", + "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n", + "\n", + "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n", + "ss8_prompt = helix_shortening_ss8[:helix_region[0]] + (((shortened_region_length - 3) // 2) * \"H\" + \"C\"*3 + ((shortened_region_length - 3) // 2) * \"H\") + helix_shortening_ss8[helix_region[-1] + 1:]\n", + "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n", + "print(\"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + ss8_prompt[helix_region[0]:helix_region[0]+45])\n", + "\n", + "print(\"\")\n", + "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n", + "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n", + "print(\"Original SS8 for helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + helix_shortening_ss8[helix_region[0]:helix_region[-1]+1])\n", + "\n", + "\n", + "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Generating protein sequence...\")\n", + "sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=protein_prompt.sequence.count(\"_\") // 2, temperature=0.5))\n", + "print(\"Folding protein...\")\n", + "structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_chain = structure_prediction.to_protein_chain()\n", + "predicted_chain = predicted_chain.align(helix_shortening_chain, mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)), target_inds=np.arange(len(helix_shortening_chain) - 120, len(helix_shortening_chain)))\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}},viewer=(0, 0))\n", + "view.addStyle({\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()}, {\"cartoon\": {\"color\":\"pink\"}},viewer=(0, 1))\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SASA Editing Example: Exposing a buried helix" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n", + "span_start = 105\n", + "span_end = 116\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\":\"red\"}})\n", + "view.zoomTo()\n", + "view.show()\n", + "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n", + "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n", + "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n", + "structure_prompt[span_start:span_end] = torch.tensor(lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32) \n", + "\n", + "sasa_prompt = [None]*len(lipase_chain)\n", + "sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start)\n", + "\n", + "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n", + "\n", + "protein_prompt = ESMProtein(sequence=\"_\"*len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 32 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generated_proteins = []\n", + "N_SAMPLES = 32\n", + "for i in range(N_SAMPLES):\n", + " print(\"Generating protein sequence...\")\n", + " sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=len(protein_prompt) // 8, temperature=0.7))\n", + " print(\"Folding protein...\")\n", + " structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 32))\n", + " generated_proteins.append(structure_prediction)\n", + "\n", + "# Sort generations by ptm\n", + "generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_SAMPLES_TO_SHOW = 4\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW+1))\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "for i in range(N_SAMPLES_TO_SHOW):\n", + " print(\"PTM of generated protein {}: {:.2f}\".format(i+1, generated_proteins[i].ptm.item()))\n", + " view.addModel(generated_proteins[i].to_protein_chain().to_pdb_string(), \"pdb\", viewer=(0, i+1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\": \"red\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/invfold.py b/examples/invfold.py new file mode 100644 index 0000000..6fa90d0 --- /dev/null +++ b/examples/invfold.py @@ -0,0 +1,37 @@ +import torch +import torch.nn.functional as F + +from esm.pretrained import ( + ESM3_sm_open_v0, + ESM3_structure_encoder_v0, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.utils.structure.protein_chain import ProteinChain + +if __name__ == "__main__": + tokenizer = EsmSequenceTokenizer() + encoder = ESM3_structure_encoder_v0("cuda") + model = ESM3_sm_open_v0("cuda") + + chain = ProteinChain.from_pdb("esm/data/1utn.pdb") + coords, plddt, residue_index = chain.to_structure_encoder_inputs() + coords = coords.cuda() + plddt = plddt.cuda() + residue_index = residue_index.cuda() + _, structure_tokens = encoder.encode(coords, residue_index=residue_index) + + # Add BOS/EOS padding + coords = F.pad(coords, (0, 0, 0, 0, 1, 1), value=torch.inf) + plddt = F.pad(plddt, (1, 1), value=0) + structure_tokens = F.pad(structure_tokens, (1, 1), value=0) + structure_tokens[:, 0] = 4098 + structure_tokens[:, -1] = 4097 + + output = model.forward( + structure_coords=coords, per_res_plddt=plddt, structure_tokens=structure_tokens + ) + sequence_tokens = torch.argmax(output.sequence_logits, dim=-1) + sequence = tokenizer.decode(sequence_tokens[0]) + print(sequence) diff --git a/examples/local_client.py b/examples/local_client.py new file mode 100644 index 0000000..e963317 --- /dev/null +++ b/examples/local_client.py @@ -0,0 +1,76 @@ +from esm.models.esm3 import ESM3 +from esm.sdk.api import ( + ESM3InferenceClient, + ESMProtein, + GenerationConfig, + SamplingConfig, + SamplingTrackConfig, +) +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.types import FunctionAnnotation + + +def get_sample_protein() -> ESMProtein: + protein = ProteinChain.from_rcsb("1utn") + protein = ESMProtein.from_protein_chain(protein) + protein.function_annotations = [ + # Peptidase S1A, chymotrypsin family: https://www.ebi.ac.uk/interpro/structure/PDB/1utn/ + FunctionAnnotation(label="peptidase", start=100, end=114), + FunctionAnnotation(label="chymotrypsin", start=190, end=202), + ] + return protein + + +def main(client: ESM3InferenceClient): + # Single step decoding + protein = get_sample_protein() + protein.function_annotations = None + protein = client.encode(protein) + single_step_protein = client.forward_and_sample( + protein, + SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)), + ) + single_step_protein.protein_tensor.sequence = protein.sequence + single_step_protein = client.decode(single_step_protein.protein_tensor) + + # Folding + protein = get_sample_protein() + sequence_length = len(protein.sequence) # type: ignore + num_steps = int(sequence_length / 16) + protein.coordinates = None + protein.function_annotations = None + protein.sasa = None + folded_protein = client.generate( + protein, + GenerationConfig(track="structure", schedule="cosine", num_steps=num_steps), + ) + folded_protein.to_pdb("./sample_folded.pdb") + + # Inverse Folding + protein = get_sample_protein() + protein.sequence = None + protein.sasa = None + protein.function_annotations = None + inv_folded_protein = client.generate( + protein, + GenerationConfig(track="sequence", schedule="cosine", num_steps=num_steps), + ) + inv_folded_protein.to_pdb("./sample_inv_folded.pdb") + + # Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence) + cot_protein = get_sample_protein() + cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore + cot_protein.coordinates = None + cot_protein.sasa = None + cot_protein_tensor = client.encode(cot_protein) + for cot_track in ["secondary_structure", "structure", "sequence"]: + cot_protein_tensor = client.generate( + cot_protein_tensor, + GenerationConfig(track=cot_track, schedule="cosine", num_steps=10), + ) + cot_protein = client.decode(cot_protein_tensor) + cot_protein.to_pdb("./sample_cot.pdb") + + +if __name__ == "__main__": + main(ESM3.from_pretrained("esm3_sm_open_v1")) diff --git a/examples/seqfun_struct.py b/examples/seqfun_struct.py new file mode 100644 index 0000000..90b8bac --- /dev/null +++ b/examples/seqfun_struct.py @@ -0,0 +1,121 @@ +import random + +import torch +import torch.nn.functional as F + +from esm.pretrained import ( + ESM3_function_decoder_v0, + ESM3_sm_open_v0, + ESM3_structure_decoder_v0, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer as EsmFunctionTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.utils.constants.esm3 import ( + SEQUENCE_MASK_TOKEN, +) +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.types import FunctionAnnotation + + +@torch.no_grad() +def main(): + tokenizer = EsmSequenceTokenizer() + function_tokenizer = EsmFunctionTokenizer() + + model = ESM3_sm_open_v0("cuda") + + # PDB 1UTN + sequence = "MKTFIFLALLGAAVAFPVDDDDKIVGGYTCGANTVPYQVSLNSGYHFCGGSLINSQWVVSAAHCYKSGIQVRLGEDNINVVEGNEQFISASKSIVHPSYNSNTLNNDIMLIKLKSAASLNSRVASISLPTSCASAGTQCLISGWGNTKSSGTSYPDVLKCLKAPILSDSSCKSAYPGQITSNMFCAGYLEGGKDSCQGDSGGPVVCSGKLQGIVSWGSGCAQKNKPGVYTKVCNYVSWIKQTIASN" + tokens = tokenizer.encode(sequence) + + # Calculate the number of tokens to replace, excluding the first and last token + num_to_replace = int((len(tokens) - 2) * 0.75) + + # Randomly select indices to replace, excluding the first and last index + indices_to_replace = random.sample(range(1, len(tokens) - 1), num_to_replace) + + # Replace selected indices with 32 + for idx in indices_to_replace: + tokens[idx] = SEQUENCE_MASK_TOKEN + sequence_tokens = torch.tensor(tokens, dtype=torch.int64) + + function_annotations = [ + # Peptidase S1A, chymotrypsin family + FunctionAnnotation(label="peptidase", start=100, end=114), + FunctionAnnotation(label="chymotrypsin", start=190, end=202), + ] + function_tokens = function_tokenizer.tokenize(function_annotations, len(sequence)) + function_tokens = function_tokenizer.encode(function_tokens) + + function_tokens = function_tokens.cuda().unsqueeze(0) + sequence_tokens = sequence_tokens.cuda().unsqueeze(0) + + output = model.forward( + sequence_tokens=sequence_tokens, function_tokens=function_tokens + ) + return sequence, output, sequence_tokens + + +@torch.no_grad() +def decode(sequence, output, sequence_tokens): + # To save on VRAM, we load these in separate functions + decoder = ESM3_structure_decoder_v0("cuda") + function_decoder = ESM3_function_decoder_v0("cuda") + function_tokenizer = EsmFunctionTokenizer() + + structure_tokens = torch.argmax(output.structure_logits, dim=-1) + structure_tokens = ( + structure_tokens.where(sequence_tokens != 0, 4098) # BOS + .where(sequence_tokens != 2, 4097) # EOS + .where(sequence_tokens != 31, 4100) # Chainbreak + ) + + bb_coords = ( + decoder.decode( + structure_tokens, + torch.ones_like(sequence_tokens), + torch.zeros_like(sequence_tokens), + )["bb_pred"] + .detach() + .cpu() + ) + + chain = ProteinChain.from_backbone_atom_coordinates( + bb_coords, sequence="X" + sequence + "X" + ) + chain.infer_oxygen().to_pdb("hello.pdb") + + # Function prediction + p_none_threshold = 0.05 + log_p = F.log_softmax(output.function_logits[:, 1:-1, :], dim=3).squeeze(0) + + # Choose which positions have no predicted function. + log_p_nones = log_p[:, :, function_tokenizer.vocab_to_index[""]] + p_none = torch.exp(log_p_nones).mean(dim=1) # "Ensemble of predictions" + where_none = p_none > p_none_threshold # (length,) + + log_p[~where_none, :, function_tokenizer.vocab_to_index[""]] = -torch.inf + function_token_ids = torch.argmax(log_p, dim=2) + function_token_ids[where_none, :] = function_tokenizer.vocab_to_index[""] + + predicted_function = function_decoder.decode( + function_token_ids, + tokenizer=function_tokenizer, + annotation_threshold=0.1, + annotation_min_length=5, + annotation_gap_merge_max=3, + ) + + print("function prediction:") + print(predicted_function["interpro_preds"].nonzero()) + print(predicted_function["function_keywords"]) + + +if __name__ == "__main__": + sequence, output, sequence_tokens = main() + torch.cuda.empty_cache() + decode(sequence, output, sequence_tokens) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b9b4c67 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "esm" +version = "3.0.0" +description = "EvolutionaryScale open model repository" +readme = "README.md" +requires-python = ">=3.10" +license = {file = "LICENSE.txt"} + +authors = [ + {name = "EvolutionaryScale Team"} +] + +maintainers = [ + {name = "Zeming Lin", email = "zeming+esm@evolutionaryscale.ai" } +] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Programming Language :: Python :: 3.10", +] + +dependencies = [ + "torch>=2.0.0", + "torchvision", + "torchtext", + "transformers", + "ipython", + "einops", + "biotite", + "msgpack-numpy", + "biopython", + "scikit-learn", + "brotli", + "attrs", + "pandas", +] + +[tool.setuptools] +packages = ["esm"]