Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoliver authored Aug 6, 2024
1 parent dedc353 commit 8c0e910
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 127 deletions.
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.1"
__version__ = "3.0.2"
44 changes: 23 additions & 21 deletions esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
ESMProtein,
ESMProteinTensor,
ForwardAndSampleOutput,
ForwardConfig,
ForwardOutput,
ForwardTrackData,
GenerationConfig,
LogitsConfig,
LogitsOutput,
ProteinType,
SamplingConfig,
)
Expand All @@ -47,6 +47,7 @@
from esm.utils.sampling import (
_BatchedESMProteinTensor,
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
Expand Down Expand Up @@ -512,9 +513,9 @@ def decode(
function_token_decoder=self.get_function_decoder(),
)

def _forward(
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
) -> ForwardOutput:
def logits(
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
) -> LogitsOutput:
device = torch.device(input.device)
# Default plddt conditioning for inference. 1s where coordinates are provided.
if input.coordinates is None:
Expand Down Expand Up @@ -547,26 +548,27 @@ def _forward(
**{k: v.to(device).to(torch.float32) for k, v in vars(output).items()}
)

if config.return_logits:
logits = ForwardTrackData(
sequence=output.sequence_logits,
structure=output.structure_logits,
secondary_structure=output.secondary_structure_logits,
sasa=output.sasa_logits,
function=output.function_logits,
)
else:
logits = None

return ForwardOutput(
logits=logits,
residue_annotation_logits=output.residue_logits,
return LogitsOutput(
logits=ForwardTrackData(
sequence=output.sequence_logits if config.sequence else None,
structure=output.structure_logits if config.structure else None,
secondary_structure=output.secondary_structure_logits
if config.secondary_structure
else None,
sasa=output.sasa_logits if config.sasa else None,
function=output.function_logits if config.function else None,
),
residue_annotation_logits=output.residue_logits
if config.residue_annotations
else None,
embeddings=output.embeddings if config.return_embeddings else None,
)

def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput:
validate_sampling_config(sampling_configuration, on_invalid="warn")

protein_tensor = attr.evolve(input) # Make a copy

device = next(self.parameters()).device
Expand Down Expand Up @@ -594,10 +596,10 @@ def forward_and_sample(
batched_protein = _BatchedESMProteinTensor.from_protein_tensor(protein_tensor)
batched_protein.to(device)

forward_output: ForwardOutput = _batch_forward(self, batched_protein)
logits_output: LogitsOutput = _batch_forward(self, batched_protein)
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
batched_protein,
forward_output,
logits_output,
sampling_config,
self.tokenizers,
)
Expand Down
35 changes: 2 additions & 33 deletions esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_ranges
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation


Expand Down Expand Up @@ -252,7 +252,7 @@ def decode(
)
annotations.append(annotation)

annotations = _merge_annotations(
annotations = merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
)
Expand Down Expand Up @@ -308,34 +308,3 @@ def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotati
annotations.append(annotation)

return annotations


def _merge_annotations(
annotations: list[FunctionAnnotation],
merge_gap_max: int | None = None,
) -> list[FunctionAnnotation]:
"""Merges annotations into non-overlapping segments.
Args:
annotations: annotations to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping annotations with gaps merged.
"""
grouped: dict[str, list[range]] = defaultdict(list)
for a in annotations:
# Convert one-indexed inclusive-inclusive, to range()
grouped[a.label].append(range(a.start, a.end + 1))

merged = []
for label, ranges in grouped.items():
merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
for range_ in merged_ranges:
annotation = FunctionAnnotation(
label=label,
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
)
merged.append(annotation)
return merged
79 changes: 53 additions & 26 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from attr import asdict, define

import esm.utils.constants.api as C
from esm.tokenization import (
TokenizerCollectionProtocol,
get_model_tokenizers,
Expand Down Expand Up @@ -34,10 +35,18 @@ class ESMProtein(ProteinType):
sasa: list[int | float | None] | None = None
function_annotations: list[FunctionAnnotation] | None = None
coordinates: torch.Tensor | None = None

# Metrics
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None


# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
# Reach out if interested in using this.
potential_sequence_of_concern: bool = False

def __len__(self):
if self.sequence is not None:
return len(self.sequence)
Expand Down Expand Up @@ -122,8 +131,18 @@ class ESMProteinTensor(ProteinType):
residue_annotations: torch.Tensor | None = None
coordinates: torch.Tensor | None = None

# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
# Reach out if interested in using this.
potential_sequence_of_concern: bool = False

def _detect_attribute(self, func, msg):
mapped = {k: func(k, v) for k, v in asdict(self).items() if v is not None}
mapped = {
k: func(k, v)
for k, v in asdict(self).items()
if isinstance(v, torch.Tensor)
}
s = set(mapped.values())
if len(s) <= 0:
return None
Expand All @@ -144,7 +163,7 @@ def device(self) -> str | torch.device:
def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor:
def _to(name):
v = getattr(self, name)
if v is not None:
if v is not None and isinstance(v, torch.Tensor):
setattr(self, name, v.to(device_or_dtype))

for n in attr.fields(ESMProteinTensor):
Expand Down Expand Up @@ -213,43 +232,51 @@ class SamplingTrackConfig:

@define
class SamplingConfig:
sequence: SamplingTrackConfig | None = None
structure: SamplingTrackConfig | None = None
secondary_structure: SamplingTrackConfig | None = None
sasa: SamplingTrackConfig | None = None
function: SamplingTrackConfig | None = None
sequence: SamplingTrackConfig | None = attr.field(
default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE}
)
structure: SamplingTrackConfig | None = attr.field(
default=None, metadata={"max_topk": C.MAX_TOPK_STRUCTURE}
)
secondary_structure: SamplingTrackConfig | None = attr.field(
default=None, metadata={"max_topk": C.MAX_TOPK_SECONDARY_STRUCTURE}
)
sasa: SamplingTrackConfig | None = attr.field(
default=None, metadata={"max_topk": C.MAX_TOPK_SASA}
)
function: SamplingTrackConfig | None = attr.field(
default=None, metadata={"max_topk": C.MAX_TOPK_FUNCTION}
)

return_per_residue_embeddings: bool = False
return_mean_embedding: bool = False


@define
class ReturnLogitsConfig:
class ForwardTrackData:
sequence: torch.Tensor | None = None
structure: torch.Tensor | None = None
secondary_structure: torch.Tensor | None = None
sasa: torch.Tensor | None = None
function: torch.Tensor | None = None


@define
class LogitsConfig:
# Logits.
sequence: bool = False
structure: bool = False
secondary_structure: bool = False
sasa: bool = False
function: bool = False
residue_annotations: bool = False


@define
class ForwardConfig:
return_logits: ReturnLogitsConfig = ReturnLogitsConfig()
# Embeddings.
return_embeddings: bool = False


@define
class ForwardTrackData:
sequence: torch.Tensor | None = None
structure: torch.Tensor | None = None
secondary_structure: torch.Tensor | None = None
sasa: torch.Tensor | None = None
function: torch.Tensor | None = None


@define
class ForwardOutput:
class LogitsOutput:
logits: ForwardTrackData | None = None
embeddings: torch.Tensor | None = None

Expand All @@ -260,7 +287,7 @@ class ForwardOutput:


@define
class ForwardAndSampleOutput(ForwardOutput):
class ForwardAndSampleOutput(LogitsOutput):
protein_tensor: ESMProteinTensor = ESMProteinTensor()

entropy: ForwardTrackData | None = None
Expand Down Expand Up @@ -302,9 +329,9 @@ def decode(self, input: ESMProteinTensor) -> ESMProtein:
# Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates
raise NotImplementedError

def _forward(
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
) -> ForwardOutput:
def logits(
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
) -> LogitsOutput:
# Our API generally discourages using raw forwards.
# This is because sending logits can be prohibitively expensive.
# Please use forward_and_sample instead.
Expand Down
Loading

0 comments on commit 8c0e910

Please sign in to comment.