Skip to content

Commit a7a9799

Browse files
committed
before adding step to data augmentation and parameter freezing
1 parent b33e6b6 commit a7a9799

17 files changed

+9024
-4463
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@ tb_logs/
77
*.yaml
88
*.pt
99
*.txt
10-
ckpts/
10+
ckpts/
11+
ckptsbaseline/
12+
ckptsesm3/
13+
structure_temp/

VirusDataset.py

+308-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,39 @@
11
import os
22
import random
33

4+
import numpy as np
45
import pytorch_lightning as L
56
import torch
7+
from esm.utils.constants import esm3 as C
8+
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
69
from torch.utils.data import DataLoader, Dataset
10+
from torch.utils.data.sampler import SubsetRandomSampler
711

812
import ioutils
913

1014

15+
class MyDataLoader(DataLoader):
16+
def __init__(self, ds, step_ds, *args, **kwargs):
17+
super().__init__(ds, *args, **kwargs)
18+
self.ds = ds
19+
self.epoch = 0
20+
self.step_ds = step_ds
21+
22+
def step(self):
23+
if self.step_ds:
24+
self.ds.step()
25+
26+
def __iter__(self):
27+
self.epoch += 1
28+
if self.step_ds:
29+
self.ds.step()
30+
self.ds.ifaug = True
31+
else:
32+
self.ds.ifaug = False
33+
print("now epoch ", self.epoch)
34+
return super().__iter__()
35+
36+
1137
def readVirusSequences(pos=None, trunc=1498, sample=300, seed=1509):
1238
random.seed(seed)
1339
print("read positive samples")
@@ -38,7 +64,7 @@ def readVirusSequences(pos=None, trunc=1498, sample=300, seed=1509):
3864

3965
print("read negative samples")
4066
gen = ioutils.readFasta(
41-
"/home/tyfei/datasets/ion_channel/Interprot/Negative_sample/decoy_1m_new.fasta",
67+
"/home/tyfei/datasets/ion_channel/Interprot/Negative_sample/old/decoy_1m_new_rmdup.fasta",
4268
truclength=trunc,
4369
)
4470
seqs["neg"] = [i for i in gen]
@@ -49,15 +75,292 @@ def readVirusSequences(pos=None, trunc=1498, sample=300, seed=1509):
4975
print("read virus sequences")
5076
allvirus = []
5177
for i in os.listdir("/home/tyfei/datasets/NCBI_virus/genbank_csv/"):
52-
allvirus.extend(
53-
ioutils.readNCBICsv(
54-
"/home/tyfei/datasets/NCBI_virus/genbank_csv/" + i, truclength=trunc
78+
try:
79+
allvirus.extend(
80+
ioutils.readNCBICsv(
81+
"/home/tyfei/datasets/NCBI_virus/genbank_csv/" + i, truclength=trunc
82+
)
5583
)
56-
)
84+
except Exception:
85+
pass
5786

5887
return sequences, labels, allvirus
5988

6089

90+
MIN_LENGTH = 50
91+
92+
93+
class DataAugmentation:
94+
def __init__(
95+
self, step_points: list, maskp: list, crop: list, croprange: list
96+
) -> None:
97+
assert len(step_points) == len(maskp)
98+
assert len(maskp) == len(crop)
99+
self.step_points = step_points
100+
self.maskp = maskp
101+
self.crop = crop
102+
self.croprange = croprange
103+
104+
def _getSettings(self, step):
105+
maskp = (-1.0, -1.0)
106+
crop = -1.0
107+
for i in range(len(self.step_points)):
108+
if step > self.step_points[i]:
109+
maskp = self.maskp[i]
110+
crop = self.crop[i]
111+
return maskp, crop
112+
113+
def getAugmentation(self, seqlen, step):
114+
maskp, crop = self._getSettings(step)
115+
if crop > 0:
116+
t = random.random()
117+
if t < crop:
118+
sampledlen = random.sample(self.croprange, 1)[0]
119+
sampledlen = int(sampledlen * np.random.uniform(0.8, 1.2))
120+
sampledlen = MIN_LENGTH if sampledlen < MIN_LENGTH else sampledlen
121+
sampledlen = min(sampledlen, seqlen - 2)
122+
return maskp, sampledlen
123+
return maskp, -1
124+
125+
126+
class ESM3BaseDataset(Dataset):
127+
def __init__(self, tracks=["seq_t", "structure_t", "sasa_t", "second_t"]) -> None:
128+
assert len(tracks) > 0
129+
self.tracks = tracks
130+
self.step_cnt = 0
131+
132+
def step(self):
133+
self.step_cnt += 1
134+
135+
def resetCnt(self):
136+
self.step_cnt = 0
137+
138+
def getToken(self, track, token):
139+
# assert token in ["start", "end", "mask"]
140+
match token:
141+
case "start":
142+
match track:
143+
case "seq_t":
144+
return C.SEQUENCE_BOS_TOKEN
145+
case "structure_t":
146+
return C.STRUCTURE_BOS_TOKEN
147+
case "sasa_t":
148+
return 0
149+
case "second_t":
150+
return 0
151+
case _:
152+
raise ValueError
153+
case "end":
154+
match track:
155+
case "seq_t":
156+
return C.SEQUENCE_EOS_TOKEN
157+
case "structure_t":
158+
return C.STRUCTURE_EOS_TOKEN
159+
case "sasa_t":
160+
return 0
161+
case "second_t":
162+
return 0
163+
case _:
164+
raise ValueError
165+
case "mask":
166+
match track:
167+
case "seq_t":
168+
return C.SEQUENCE_MASK_TOKEN
169+
case "structure_t":
170+
return C.STRUCTURE_MASK_TOKEN
171+
case "sasa_t":
172+
return C.SASA_UNK_TOKEN
173+
case "second_t":
174+
return C.SS8_UNK_TOKEN
175+
case _:
176+
raise ValueError
177+
case "pad":
178+
match track:
179+
case "seq_t":
180+
return C.SEQUENCE_PAD_TOKEN
181+
case "structure_t":
182+
return C.STRUCTURE_PAD_TOKEN
183+
case "sasa_t":
184+
return C.SASA_PAD_TOKEN
185+
case "second_t":
186+
return C.SS8_PAD_TOKEN
187+
case _:
188+
raise ValueError
189+
case _:
190+
raise ValueError
191+
192+
def _maskSequence(self, sample, pos):
193+
for i in sample:
194+
sample[i][pos] = self.getToken(i, "mask")
195+
196+
return sample
197+
198+
def _generateMaskingPos(self, num, length, method="point"):
199+
assert length > num + 5
200+
if method == "point":
201+
a = np.array(random.sample(range(length - 2), num)) + 1
202+
return a
203+
elif method == "block":
204+
s = random.randint(1, length - num)
205+
a = np.array(range(s, s + num))
206+
return a
207+
else:
208+
raise NotImplementedError
209+
210+
def _cropSequence(self, sample, start, end):
211+
for i in sample:
212+
t = torch.zeros((end - start + 2), dtype=torch.long)
213+
t[1:-1] = torch.tensor(sample[i][start:end])
214+
t[0] = self.getToken(i, "start")
215+
t[-1] = self.getToken(i, "end")
216+
sample[i] = t
217+
return sample
218+
219+
def _augmentsample(self, sample, maskp, crop):
220+
samplelen = len(sample[self.tracks[0]])
221+
if crop > 50:
222+
s = random.randint(1, samplelen - crop - 1)
223+
sample = self._cropSequence(sample, s, s + crop)
224+
samplelen = crop + 2
225+
if maskp[0] > 0:
226+
num = np.random.binomial(samplelen - 2, maskp[0])
227+
pos = self._generateMaskingPos(num, samplelen)
228+
if len(pos) > 0:
229+
sample = self._maskSequence(sample, pos)
230+
if maskp[1] > 0:
231+
num = np.random.binomial(samplelen - 2, maskp[0])
232+
pos = self._generateMaskingPos(num, samplelen, "block")
233+
if len(pos) > 0:
234+
sample = self._maskSequence(sample, pos)
235+
return sample
236+
237+
238+
class ESM3MultiTrackDataset(ESM3BaseDataset):
239+
def __init__(
240+
self,
241+
data1,
242+
data2,
243+
label,
244+
augment: DataAugmentation = None,
245+
tracks=["seq_t", "structure_t", "sasa_t", "second_t"],
246+
) -> None:
247+
super().__init__(tracks=tracks)
248+
self.data1 = data1
249+
self.data2 = data2
250+
self.label = label
251+
self.aug = augment
252+
self.iters = 0
253+
self.data2order = np.arange(len(data2))
254+
random.shuffle(self.data2order)
255+
self.ifaug = False
256+
# self.tracks = tracks
257+
258+
def __len__(self):
259+
return len(self.data1)
260+
261+
def step(self):
262+
random.shuffle(self.data2order)
263+
super().step()
264+
265+
def __getitem__(self, idx):
266+
x1 = {}
267+
x2 = {}
268+
for i in self.tracks:
269+
x1[i] = self.data1[idx][i]
270+
x2[i] = self.data2[self.data2order[idx % len(self.data2)]][i]
271+
if self.aug is not None and self.ifaug:
272+
maskp, crop = self.aug.getAugmentation(
273+
len(x1[self.tracks[0]]), self.step_cnt
274+
)
275+
x1 = self._augmentsample(x1, maskp, crop)
276+
return x1, torch.tensor([self.label[idx]]), x2
277+
278+
279+
class ESM3MultiTrackDatasetTEST(ESM3BaseDataset):
280+
def __init__(
281+
self,
282+
data1,
283+
augment: DataAugmentation = None,
284+
tracks=["seq_t", "structure_t", "sasa_t", "second_t"],
285+
) -> None:
286+
super().__init__(tracks=tracks)
287+
self.data1 = data1
288+
self.aug = augment
289+
# self.tracks = tracks
290+
291+
def __len__(self):
292+
return len(self.data1)
293+
294+
def step(self):
295+
super().step()
296+
297+
def __getitem__(self, idx):
298+
x1 = {}
299+
for i in self.tracks:
300+
x1[i] = self.data1[idx][i]
301+
if self.aug is not None:
302+
maskp, crop = self.aug.getAugmentation(
303+
len(x1[self.tracks[0]]), self.step_cnt
304+
)
305+
x1 = self._augmentsample(x1, maskp, crop)
306+
return x1
307+
308+
309+
class ESM3datamodule(L.LightningDataModule):
310+
def __init__(
311+
self,
312+
ds1: ESM3BaseDataset,
313+
ds2: ESM3BaseDataset,
314+
batch_size=1,
315+
train_test_split=[0.85, 0.15],
316+
seed=1509,
317+
):
318+
super().__init__()
319+
self.value = 0
320+
# self.ds1 = ds1
321+
# self.ds2 = ds2
322+
self.batch_size = batch_size
323+
self.seed = seed
324+
torch.manual_seed(self.seed)
325+
# train_set, val_set = torch.utils.data.random_split(ds1, train_test_split)
326+
all_indices = np.arange(len(ds1))
327+
328+
self.trainval_set = ds1
329+
self.train_indices = all_indices[: int(len(all_indices) * train_test_split[0])]
330+
self.val_indices = all_indices[int(len(all_indices) * train_test_split[0]) :]
331+
self.testset = ds2
332+
333+
def train_dataloader(self):
334+
self.value += 1
335+
self.trainval_set.resetCnt()
336+
print("get train loader")
337+
return MyDataLoader(
338+
self.trainval_set,
339+
True,
340+
batch_size=self.batch_size,
341+
sampler=SubsetRandomSampler(self.train_indices),
342+
num_workers=4,
343+
)
344+
345+
def val_dataloader(self):
346+
self.value += 1
347+
print("get val loader")
348+
return MyDataLoader(
349+
self.trainval_set,
350+
False,
351+
batch_size=self.batch_size,
352+
sampler=SubsetRandomSampler(self.val_indices),
353+
num_workers=4,
354+
)
355+
356+
def predict_dataloader(self):
357+
self.value += 1
358+
print("get predict loader")
359+
return MyDataLoader(
360+
self.testset, False, batch_size=self.batch_size, shuffle=True, num_workers=4
361+
)
362+
363+
61364
class SeqDataset2(Dataset):
62365
def __init__(self, seq, label, seqtest):
63366

@@ -184,7 +487,6 @@ def setup(self, stage):
184487
if stage == "test":
185488
raise NotImplementedError
186489

187-
188490
def val_dataloader(self):
189491
return DataLoader(
190492
self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=4

0 commit comments

Comments
 (0)