-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
71644ed
commit 5598a24
Showing
9 changed files
with
606 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# pytorch models | ||
.pt | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
recursive-include tr_rosetta_pytorch models.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosetta, trDesign | ||
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import fire | ||
import torch | ||
import tarfile | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork | ||
from tr_rosetta_pytorch.utils import preprocess, d | ||
|
||
# paths | ||
|
||
CURRENT_PATH = Path(__file__).parent | ||
DEFAULT_MODEL_PATH = CURRENT_PATH / 'models' | ||
MODEL_PATH = DEFAULT_MODEL_PATH / 'models.tar.gz' | ||
MODEL_FILES = [*Path(DEFAULT_MODEL_PATH).glob('*.pt')] | ||
|
||
# extract model files if not extracted | ||
|
||
if len(MODEL_FILES) == 0: | ||
tar = tarfile.open(str(MODEL_PATH)) | ||
tar.extractall(DEFAULT_MODEL_PATH) | ||
tar.close() | ||
|
||
# prediction function | ||
|
||
@torch.no_grad() | ||
def get_ensembled_predictions(input_file, output_file=None, model_dir=DEFAULT_MODEL_PATH): | ||
net = trRosettaNetwork() | ||
i = preprocess(input_file) | ||
|
||
if output_file is None: | ||
input_path = Path(input_file) | ||
output_file = f'{input_path.parents[0] / input_path.stem}.npz' | ||
|
||
outputs = [] | ||
model_files = [*Path(model_dir).glob('*.pt')] | ||
|
||
if len(model_files) == 0: | ||
raise 'No model files can be found' | ||
|
||
for model_file in model_files: | ||
net.load_state_dict(torch.load(model_file, map_location=torch.device(d()))) | ||
net.to(d()).eval() | ||
output = net(i) | ||
outputs.append(output) | ||
|
||
averaged_outputs = [torch.stack(model_output).mean(dim=0).cpu().numpy() for model_output in zip(*outputs)] | ||
output_dict = dict(zip(['dist', 'omega', 'theta', 'phi'], averaged_outputs)) | ||
np.savez_compressed(output_file, **output_dict) | ||
print(f'predictions for {input_file} saved to {output_file}') | ||
|
||
def predict(): | ||
fire.Fire(get_ensembled_predictions) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import string | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
def d(tensor=None): | ||
if tensor is None: | ||
return 'cuda' if torch.cuda.is_available() else 'cpu' | ||
return 'cuda' if tensor.is_cuda else 'cpu' | ||
|
||
# preprocessing fn | ||
|
||
# read A3M and convert letters into | ||
# integers in the 0..20 range | ||
def parse_a3m(filename): | ||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) | ||
seqs = [line.strip().translate(table) for line in open(filename, 'r') if line[0] != '>'] | ||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) | ||
msa = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8) | ||
|
||
# convert letters into numbers | ||
for i in range(alphabet.shape[0]): | ||
msa[msa == alphabet[i]] = i | ||
|
||
# treat all unknown characters as gaps | ||
msa[msa > 20] = 20 | ||
return msa | ||
|
||
# 1-hot MSA to PSSM | ||
def msa2pssm(msa1hot, w): | ||
beff = w.sum() | ||
f_i = (w[:, None, None] * msa1hot).sum(dim=0) / beff + 1e-9 | ||
h_i = (-f_i * torch.log(f_i)).sum(dim=1) | ||
return torch.cat((f_i, h_i[:, None]), dim=1) | ||
|
||
# reweight MSA based on cutoff | ||
def reweight(msa1hot, cutoff): | ||
id_min = msa1hot.shape[1] * cutoff | ||
id_mtx = torch.einsum('ikl,jkl->ij', msa1hot, msa1hot) | ||
id_mask = id_mtx > id_min | ||
w = 1. / id_mask.float().sum(dim=-1) | ||
return w | ||
|
||
# shrunk covariance inversion | ||
def fast_dca(msa1hot, weights, penalty = 4.5): | ||
device = msa1hot.device | ||
nr, nc, ns = msa1hot.shape | ||
x = msa1hot.view(nr, -1) | ||
num_points = weights.sum() - torch.sqrt(weights.mean()) | ||
|
||
mean = (x * weights[:, None]).sum(dim=0, keepdims=True) / num_points | ||
x = (x - mean) * torch.sqrt(weights[:, None]) | ||
|
||
cov = (x.t() @ x) / num_points | ||
cov_reg = cov + torch.eye(nc * ns).to(device) * penalty / torch.sqrt(weights.sum()) | ||
|
||
inv_cov = torch.inverse(cov_reg) | ||
x1 = inv_cov.view(nc, ns, nc, ns) | ||
x2 = x1.transpose(1, 2).contiguous() | ||
features = x2.reshape(nc, nc, ns * ns) | ||
|
||
x3 = torch.sqrt((x1[:, :-1, :, :-1] ** 2).sum(dim=(1, 3))) * (1 - torch.eye(nc).to(device)) | ||
apc = x3.sum(dim=0, keepdims=True) * x3.sum(dim=1, keepdims=True) / x3.sum() | ||
contacts = (x3 - apc) * (1 - torch.eye(nc).to(device)) | ||
return torch.cat((features, contacts[:, :, None]), dim=2) | ||
|
||
def preprocess(msa_file, wmin=0.8, ns=21): | ||
a3m = torch.from_numpy(parse_a3m(msa_file)).long() | ||
nrow, ncol = a3m.shape | ||
|
||
msa1hot = F.one_hot(a3m, ns).float().to(d()) | ||
w = reweight(msa1hot, wmin).float().to(d()) | ||
|
||
# 1d sequence | ||
|
||
f1d_seq = msa1hot[0, :, :20].float() | ||
f1d_pssm = msa2pssm(msa1hot, w) | ||
|
||
f1d = torch.cat((f1d_seq, f1d_pssm), dim=1) | ||
f1d = f1d[None, :, :].reshape((1, ncol, 42)) | ||
|
||
# 2d sequence | ||
|
||
f2d_dca = fast_dca(msa1hot, w) if nrow > 1 else torch.zeros((ncol, ncol, 442)).float() | ||
f2d_dca = f2d_dca[None, :, :, :] | ||
|
||
f2d = torch.cat(( | ||
f1d[:, :, None, :].repeat(1, 1, ncol, 1), | ||
f1d[:, None, :, :].repeat(1, ncol, 1, 1), | ||
f2d_dca | ||
), dim=-1) | ||
|
||
f2d = f2d.view(1, ncol, ncol, 442 + 2*42) | ||
return f2d.permute((0, 3, 2, 1)) |