Skip to content

Commit

Permalink
release 0.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 5, 2021
1 parent 71644ed commit 5598a24
Show file tree
Hide file tree
Showing 9 changed files with 606 additions and 11 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# pytorch models
.pt

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include tr_rosetta_pytorch models.tar.gz
386 changes: 386 additions & 0 deletions T1001.a3m

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
setup(
name = 'tr-rosetta-pytorch',
packages = find_packages(),
include_package_data = True,
entry_points={
'console_scripts': [
'tr_rosetta = tr_rosetta_pytorch.cli:predict',
],
},
version = '0.0.1',
license='MIT',
description = 'trRosetta - Pytorch',
Expand All @@ -15,8 +21,10 @@
'protein design'
],
install_requires=[
'torch>=1.6',
'einops>=0.3'
'einops>=0.3',
'fire',
'numpy',
'torch>=1.6'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
2 changes: 1 addition & 1 deletion tr_rosetta_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosetta, trDesign
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork
53 changes: 53 additions & 0 deletions tr_rosetta_pytorch/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import fire
import torch
import tarfile
import numpy as np
from pathlib import Path

from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork
from tr_rosetta_pytorch.utils import preprocess, d

# paths

CURRENT_PATH = Path(__file__).parent
DEFAULT_MODEL_PATH = CURRENT_PATH / 'models'
MODEL_PATH = DEFAULT_MODEL_PATH / 'models.tar.gz'
MODEL_FILES = [*Path(DEFAULT_MODEL_PATH).glob('*.pt')]

# extract model files if not extracted

if len(MODEL_FILES) == 0:
tar = tarfile.open(str(MODEL_PATH))
tar.extractall(DEFAULT_MODEL_PATH)
tar.close()

# prediction function

@torch.no_grad()
def get_ensembled_predictions(input_file, output_file=None, model_dir=DEFAULT_MODEL_PATH):
net = trRosettaNetwork()
i = preprocess(input_file)

if output_file is None:
input_path = Path(input_file)
output_file = f'{input_path.parents[0] / input_path.stem}.npz'

outputs = []
model_files = [*Path(model_dir).glob('*.pt')]

if len(model_files) == 0:
raise 'No model files can be found'

for model_file in model_files:
net.load_state_dict(torch.load(model_file, map_location=torch.device(d())))
net.to(d()).eval()
output = net(i)
outputs.append(output)

averaged_outputs = [torch.stack(model_output).mean(dim=0).cpu().numpy() for model_output in zip(*outputs)]
output_dict = dict(zip(['dist', 'omega', 'theta', 'phi'], averaged_outputs))
np.savez_compressed(output_file, **output_dict)
print(f'predictions for {input_file} saved to {output_file}')

def predict():
fire.Fire(get_ensembled_predictions)
Binary file added tr_rosetta_pytorch/models/models.tar.gz
Binary file not shown.
64 changes: 56 additions & 8 deletions tr_rosetta_pytorch/tr_rosetta_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,64 @@
from torch import nn, einsum
import torch.nn.functional as F

class trDesign(nn.Module):
def __init__(self):
super().__init__()
def elu():
return nn.ELU(inplace=True)

def forward(self, x):
return x
def instance_norm(filters, eps=1e-6, **kwargs):
return nn.InstanceNorm2d(filters, affine=True, eps=eps, **kwargs)

def conv2d(in_chan, out_chan, kernel_size, dilation=1, **kwargs):
padding = dilation * (kernel_size - 1) // 2
return nn.Conv2d(in_chan, out_chan, kernel_size, padding=padding, dilation=dilation, **kwargs)

class trRosetta(nn.Module):
def __init__(self):
class trRosettaNetwork(nn.Module):
def __init__(self, filters=64, kernel=3, num_layers=61):
super().__init__()
self.filters = filters
self.kernel = kernel
self.num_layers = num_layers

self.first_block = nn.Sequential(
conv2d(442 + 2 * 42, filters, 1),
instance_norm(filters),
elu()
)

# stack of residual blocks with dilations
cycle_dilations = [1, 2, 4, 8, 16]
dilations = [cycle_dilations[i % len(cycle_dilations)] for i in range(num_layers)]

self.layers = nn.ModuleList([nn.Sequential(
conv2d(filters, filters, kernel, dilation=dilation),
instance_norm(filters),
elu(),
nn.Dropout(p=0.15),
conv2d(filters, filters, kernel, dilation=dilation),
instance_norm(filters)
) for dilation in dilations])

self.activate = elu()

# conv to anglegrams and distograms
self.to_prob_theta = nn.Sequential(conv2d(filters, 25, 1), nn.Softmax(dim=-1))
self.to_prob_phi = nn.Sequential(conv2d(filters, 13, 1), nn.Softmax(dim=-1))
self.to_distance = nn.Sequential(conv2d(filters, 37, 1), nn.Softmax(dim=-1))
self.to_prob_bb = nn.Sequential(conv2d(filters, 3, 1), nn.Softmax(dim=-1))
self.to_prob_omega = nn.Sequential(conv2d(filters, 25, 1), nn.Softmax(dim=-1))

def forward(self, x):
return x
x = self.first_block(x)

for layer in self.layers:
x = self.activate(x + layer(x))

prob_theta = self.to_prob_theta(x) # anglegrams for theta
prob_phi = self.to_prob_phi(x) # anglegrams for phi

x = 0.5 * (x + x.permute((0,1,3,2))) # symmetrize

prob_distance = self.to_distance(x) # distograms
prob_bb = self.to_prob_bb(x) # beta-strand pairings (not used)
prob_omega = self.to_prob_omega(x) # anglegrams for omega

return prob_theta, prob_phi, prob_distance, prob_omega
96 changes: 96 additions & 0 deletions tr_rosetta_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import string

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

def d(tensor=None):
if tensor is None:
return 'cuda' if torch.cuda.is_available() else 'cpu'
return 'cuda' if tensor.is_cuda else 'cpu'

# preprocessing fn

# read A3M and convert letters into
# integers in the 0..20 range
def parse_a3m(filename):
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
seqs = [line.strip().translate(table) for line in open(filename, 'r') if line[0] != '>']
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)

# convert letters into numbers
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i

# treat all unknown characters as gaps
msa[msa > 20] = 20
return msa

# 1-hot MSA to PSSM
def msa2pssm(msa1hot, w):
beff = w.sum()
f_i = (w[:, None, None] * msa1hot).sum(dim=0) / beff + 1e-9
h_i = (-f_i * torch.log(f_i)).sum(dim=1)
return torch.cat((f_i, h_i[:, None]), dim=1)

# reweight MSA based on cutoff
def reweight(msa1hot, cutoff):
id_min = msa1hot.shape[1] * cutoff
id_mtx = torch.einsum('ikl,jkl->ij', msa1hot, msa1hot)
id_mask = id_mtx > id_min
w = 1. / id_mask.float().sum(dim=-1)
return w

# shrunk covariance inversion
def fast_dca(msa1hot, weights, penalty = 4.5):
device = msa1hot.device
nr, nc, ns = msa1hot.shape
x = msa1hot.view(nr, -1)
num_points = weights.sum() - torch.sqrt(weights.mean())

mean = (x * weights[:, None]).sum(dim=0, keepdims=True) / num_points
x = (x - mean) * torch.sqrt(weights[:, None])

cov = (x.t() @ x) / num_points
cov_reg = cov + torch.eye(nc * ns).to(device) * penalty / torch.sqrt(weights.sum())

inv_cov = torch.inverse(cov_reg)
x1 = inv_cov.view(nc, ns, nc, ns)
x2 = x1.transpose(1, 2).contiguous()
features = x2.reshape(nc, nc, ns * ns)

x3 = torch.sqrt((x1[:, :-1, :, :-1] ** 2).sum(dim=(1, 3))) * (1 - torch.eye(nc).to(device))
apc = x3.sum(dim=0, keepdims=True) * x3.sum(dim=1, keepdims=True) / x3.sum()
contacts = (x3 - apc) * (1 - torch.eye(nc).to(device))
return torch.cat((features, contacts[:, :, None]), dim=2)

def preprocess(msa_file, wmin=0.8, ns=21):
a3m = torch.from_numpy(parse_a3m(msa_file)).long()
nrow, ncol = a3m.shape

msa1hot = F.one_hot(a3m, ns).float().to(d())
w = reweight(msa1hot, wmin).float().to(d())

# 1d sequence

f1d_seq = msa1hot[0, :, :20].float()
f1d_pssm = msa2pssm(msa1hot, w)

f1d = torch.cat((f1d_seq, f1d_pssm), dim=1)
f1d = f1d[None, :, :].reshape((1, ncol, 42))

# 2d sequence

f2d_dca = fast_dca(msa1hot, w) if nrow > 1 else torch.zeros((ncol, ncol, 442)).float()
f2d_dca = f2d_dca[None, :, :, :]

f2d = torch.cat((
f1d[:, :, None, :].repeat(1, 1, ncol, 1),
f1d[:, None, :, :].repeat(1, ncol, 1, 1),
f2d_dca
), dim=-1)

f2d = f2d.view(1, ncol, ncol, 442 + 2*42)
return f2d.permute((0, 3, 2, 1))

0 comments on commit 5598a24

Please sign in to comment.