Skip to content

Commit ba3834a

Browse files
committed
before merging
1 parent a7a3af4 commit ba3834a

File tree

5 files changed

+495
-2
lines changed

5 files changed

+495
-2
lines changed

VirusDataset.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import os
2+
3+
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
4+
import ioutils
5+
import random
6+
7+
from torch.utils.data import DataLoader
8+
from torch.utils.data import Dataset
9+
import torch
10+
import pytorch_lightning as L
11+
12+
def readVirusSequences(pos=None, trunc=1498, sample = 300, seed=1509):
13+
random.seed(seed)
14+
print("read positive samples")
15+
seqs = {}
16+
if pos is None:
17+
pos = os.listdir("/home/tyfei/datasets/ion_channel/Interprot/ion_channel/0.99")
18+
for i in pos:
19+
# print(i)
20+
try:
21+
if i.endswith(".fas"):
22+
# print(i, i[:i.find(".")] in df["Accession"].values)
23+
gen = ioutils.readFasta("/home/tyfei/datasets/ion_channel/Interprot/ion_channel/0.99/"+i, truclength=trunc)
24+
seqs[i[:i.find(".")]] = [i for i in gen]
25+
# print(i, "success")
26+
except:
27+
pass
28+
# print(i, "failed")
29+
30+
sequences = []
31+
labels = []
32+
for i in seqs:
33+
sampled = random.sample(seqs[i], min(sample, len(seqs[i])))
34+
sequences.extend(sampled)
35+
labels.extend([1]*len(sampled))
36+
37+
print("read negative samples")
38+
gen = ioutils.readFasta("/home/tyfei/datasets/ion_channel/Interprot/Negative_sample/decoy_1m_new.fasta", truclength=trunc)
39+
seqs["neg"] = [i for i in gen]
40+
sampled = random.sample(seqs["neg"], min(len(labels), len(seqs["neg"])))
41+
sequences.extend(sampled)
42+
labels.extend([0]*len(sampled))
43+
44+
print("read virus sequences")
45+
allvirus = []
46+
for i in os.listdir("/home/tyfei/datasets/NCBI_virus/genbank_csv/"):
47+
allvirus.extend(ioutils.readNCBICsv("/home/tyfei/datasets/NCBI_virus/genbank_csv/"+i, truclength=trunc))
48+
49+
return sequences, labels, allvirus
50+
51+
52+
class SeqDataset2(Dataset):
53+
def __init__(self, seq, label, seqtest):
54+
55+
if not isinstance(seq, torch.Tensor):
56+
seq = torch.tensor(seq).long()
57+
self.seq = seq
58+
59+
if not isinstance(label, torch.Tensor):
60+
label = torch.tensor(label).long()
61+
self.label = label
62+
63+
if not isinstance(seqtest, torch.Tensor):
64+
seqtest = torch.tensor(seqtest).long()
65+
self.seqtest = seqtest
66+
67+
self.seqlen = seq.shape[0]
68+
self.seqtestlen = seqtest.shape[0]
69+
70+
def __len__(self):
71+
return max(self.seqlen, self.seqtestlen)
72+
73+
def __getitem__(self, idx):
74+
return self.seq[idx%self.seqlen], self.label[idx%self.seqlen], self.seqtest[idx%self.seqtestlen]
75+
76+
class TestDataset(Dataset):
77+
def __init__(self, seq):
78+
if not isinstance(seq, torch.Tensor):
79+
seq = torch.tensor(seq).long()
80+
self.seq = seq
81+
82+
def __len__(self):
83+
return self.seq.shape[0]
84+
85+
def __getitem__(self, idx):
86+
87+
return self.seq[idx]
88+
89+
90+
class SeqDataset(Dataset):
91+
def __init__(self, seq, label):
92+
if not isinstance(seq, torch.Tensor):
93+
seq = torch.tensor(seq).long()
94+
self.seq = seq
95+
96+
if not isinstance(label, torch.Tensor):
97+
label = torch.tensor(label).long()
98+
self.label = label
99+
100+
def __len__(self):
101+
return self.seq.shape[0]
102+
103+
def __getitem__(self, idx):
104+
105+
return self.seq[idx], self.label[idx]
106+
107+
class SeqdataModule(L.LightningDataModule):
108+
def __init__(self, trainset=None, testset=None, path="/home/tyfei/datasets/pts/virus", batch_size = 12, train_test_split=[0.8, 0.2], seed = 1509) -> None:
109+
super().__init__()
110+
111+
self.train_test_split = train_test_split
112+
self.batch_size = batch_size
113+
self.path = path
114+
self.seed = seed
115+
116+
if isinstance(testset, str):
117+
self.test_set = torch.load(testset)
118+
else:
119+
self.test_set = testset
120+
121+
if isinstance(trainset, str):
122+
self.trainset = torch.load(trainset)
123+
else:
124+
self.trainset = trainset
125+
126+
# if self.trainset is not None:
127+
# train_set, val_set = torch.utils.data.random_split(trainset, train_test_split)
128+
129+
# self.train_set = train_set
130+
# self.val_set = val_set
131+
132+
def saveDataset(self):
133+
torch.save(self.trainset, os.path.join(self.path, "train.pt"))
134+
torch.save(self.test_set, os.path.join(self.path, "test.pt"), self.test_set)
135+
136+
137+
def setup(self, stage):
138+
if stage == "fit" or stage == "validate":
139+
if self.trainset is None:
140+
if os.path.exists(os.path.join(self.path, "train.pt")):
141+
self.trainset = torch.load(os.path.join(self.path, "train.pt"))
142+
else:
143+
raise FileExistsError
144+
145+
if not hasattr(self, "train_set"):
146+
torch.manual_seed(self.seed)
147+
train_set, val_set = torch.utils.data.random_split(self.trainset, self.train_test_split)
148+
self.train_set = train_set
149+
self.val_set = val_set
150+
151+
if stage == "predict":
152+
if self.test_set is None:
153+
if os.path.exists(os.path.join(self.path, "test.pt")):
154+
self.test_set = torch.load(os.path.join(self.path, "test.pt"))
155+
else:
156+
raise FileExistsError
157+
158+
if stage == "test":
159+
raise NotImplementedError
160+
161+
162+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)
163+
164+
def val_dataloader(self):
165+
return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
166+
167+
def predict_dataloader(self):
168+
return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
169+
170+
def train_dataloader(self):
171+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)
172+

datautils.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import pytorch_lightning as L
1313
tqdm.pandas()
1414

15+
16+
1517
def parseline(line):
1618
try:
1719
line = line.decode("utf-8").strip()

ionchannel.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import torch
2+
import esm
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import pytorch_lightning as L
6+
import numpy as np
7+
8+
# from torchmetrics import Metric
9+
10+
import torchmetrics
11+
12+
from torch.autograd import Function
13+
14+
class GradientR(Function):
15+
@staticmethod
16+
def forward(ctx, x, alpha):
17+
ctx.save_for_backward(x, alpha)
18+
return x
19+
20+
@staticmethod
21+
def backward(ctx, grad_output):
22+
grad_input = None
23+
_, alpha = ctx.saved_tensors
24+
if ctx.needs_input_grad[0]:
25+
grad_input = - alpha*grad_output
26+
return grad_input, None
27+
28+
29+
class GradientReversal(nn.Module):
30+
def __init__(self, alpha):
31+
super().__init__()
32+
self.alpha = torch.tensor(alpha, requires_grad=False)
33+
34+
def forward(self, x):
35+
return GradientR.apply(x, self.alpha)
36+
37+
class ionclf(L.LightningModule):
38+
def __init__(self, esm_model, unfix = None, addadversial=True, lamb=0.1, lr=5e-4) -> None:
39+
super().__init__()
40+
self.num_layers = esm_model.num_layers
41+
self.embed_dim = esm_model.embed_dim
42+
self.attention_heads = esm_model.attention_heads
43+
self.alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
44+
self.alphabet_size = len(self.alphabet)
45+
self.addadversial = addadversial
46+
self.lamb = lamb
47+
self.lr = lr
48+
49+
self.esm_model = esm_model
50+
51+
self.cls = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim // 2),
52+
nn.LayerNorm(self.embed_dim // 2),
53+
nn.GELU(),
54+
nn.Linear(self.embed_dim // 2, self.embed_dim // 4),
55+
nn.LayerNorm(self.embed_dim // 4),
56+
nn.GELU(),
57+
nn.Linear(self.embed_dim // 4, 1)
58+
)
59+
60+
self.dis = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim // 2),
61+
nn.LayerNorm(self.embed_dim // 2),
62+
nn.GELU(),
63+
nn.Linear(self.embed_dim // 2, self.embed_dim // 4),
64+
nn.LayerNorm(self.embed_dim // 4),
65+
nn.GELU(),
66+
nn.Linear(self.embed_dim // 4, 1)
67+
)
68+
69+
self.reverse = GradientReversal(1)
70+
71+
if unfix is None:
72+
self.fixParameters()
73+
else:
74+
self.fixParameters(unfix)
75+
76+
self.acc = torchmetrics.Accuracy(task="binary")
77+
78+
self.training_step_outputs = []
79+
self.validation_step_outputs = []
80+
81+
82+
83+
def fixParameters(self, unfix=["9", "10", "11"]):
84+
for i, j in self.named_parameters():
85+
flag = 1
86+
if "esm_model" not in i:
87+
flag = 0
88+
for k in unfix:
89+
if k in i:
90+
flag = 0
91+
92+
if flag == 1:
93+
j.requires_grad = False
94+
else:
95+
j.requires_grad = True
96+
97+
def forward(self, x):
98+
representations = self.esm_model(x, repr_layers=[self.num_layers])
99+
100+
x = representations["representations"][self.num_layers][:, 0]
101+
x1 = self.reverse(x)
102+
pre = self.cls(x)
103+
pre = F.sigmoid(pre)
104+
105+
y = self.dis(x1)
106+
y = F.sigmoid(y)
107+
108+
return pre, y
109+
110+
def _common_training_step(self, batch):
111+
X1, y, X2 = batch
112+
y_pre, dis_pre_x1 = self(X1)
113+
_y, dis_pre_x2 = self(X2)
114+
115+
loss1 = F.binary_cross_entropy(y_pre.squeeze(), y.float())
116+
loss2 = F.binary_cross_entropy(dis_pre_x1, torch.zeros_like(dis_pre_x1)) + \
117+
F.binary_cross_entropy(dis_pre_x2, torch.ones_like(dis_pre_x1))
118+
119+
if self.addadversial:
120+
loss = loss1+loss2*self.lamb
121+
else:
122+
loss = loss1
123+
124+
return loss, loss1, loss2, y_pre, y
125+
126+
127+
def training_step(self, batch, batch_idx):
128+
129+
loss, loss1, loss2, y_pre, y = self._common_training_step(batch)
130+
131+
acc = self.acc(y_pre.squeeze(), y)
132+
133+
self.log_dict({"predict loss":loss1.item(), "adversial loss":loss2.item(), "acc":acc}, prog_bar=True, on_step=True)
134+
self.training_step_outputs.append({"loss":loss.detach().cpu(), "y":y_pre.detach().squeeze().cpu(), "true_label":y.cpu()})
135+
136+
return loss
137+
138+
def _common_epoch_end(self, outputs):
139+
140+
loss = torch.stack([x["loss"] for x in outputs]).mean()
141+
scores = torch.concatenate([x["y"] for x in outputs])
142+
y = torch.concatenate([x["true_label"] for x in outputs])
143+
144+
outputs.clear()
145+
return loss, self.acc(scores, y)
146+
147+
def on_training_epoch_end(self):
148+
149+
loss, acc = self._common_epoch_end(self.training_step_outputs)
150+
151+
# print("finish training epoch, loss %f, acc %f"%(loss, acc))
152+
self.log_dict(
153+
{
154+
"mean_loss":loss,
155+
"train_acc": acc,
156+
},
157+
on_step=False,
158+
on_epoch=True,
159+
prog_bar=False,
160+
)
161+
162+
def validation_step(self, batch, batch_idx):
163+
164+
loss, loss1, loss2, y_pre, y = self._common_training_step(batch)
165+
166+
acc = self.acc(y_pre.squeeze(), y)
167+
168+
self.log_dict({"predict loss":loss1.item(), "adversial loss":loss2.item(), "acc":acc}, prog_bar=True, on_step=True)
169+
170+
self.validation_step_outputs.append({"loss":loss.cpu(), "y":y_pre.squeeze().cpu(), "true_label":y.cpu()})
171+
172+
return loss
173+
174+
def on_validation_epoch_end(self):
175+
loss, acc = self._common_epoch_end(self.validation_step_outputs)
176+
# print("finish validating, loss %f, acc %f"%(loss, acc))
177+
self.log_dict(
178+
{
179+
"loss":loss,
180+
"validate_acc": acc,
181+
},
182+
on_step=False,
183+
on_epoch=True,
184+
prog_bar=False,
185+
)
186+
187+
def test_step(self, batch, batch_idx):
188+
x = batch
189+
y_pre, _ = self(x)
190+
return y_pre
191+
192+
def predict_step(self, batch, batch_idx):
193+
if isinstance(batch, tuple):
194+
if len(batch) == 3:
195+
X1, y, X2 = batch
196+
elif len(batch) == 2:
197+
X1, y = batch
198+
else:
199+
raise ValueError
200+
else:
201+
X1 = batch
202+
pre, _ = self(X1)
203+
204+
pre = pre.squeeze()
205+
return pre
206+
207+
def configure_optimizers(self):
208+
optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, self.parameters()), lr=self.lr)
209+
210+
return optimizer
211+

0 commit comments

Comments
 (0)