1
1
import os
2
2
import random
3
3
4
+ import numpy as np
4
5
import pytorch_lightning as L
5
6
import torch
7
+ from esm .utils .constants import esm3 as C
8
+ from pytorch_lightning .utilities .types import TRAIN_DATALOADERS
6
9
from torch .utils .data import DataLoader , Dataset
10
+ from torch .utils .data .sampler import SubsetRandomSampler
7
11
8
12
import ioutils
9
13
10
14
15
+ class MyDataLoader (DataLoader ):
16
+ def __init__ (self , ds , step_ds , * args , ** kwargs ):
17
+ super ().__init__ (ds , * args , ** kwargs )
18
+ self .ds = ds
19
+ self .epoch = 0
20
+ self .step_ds = step_ds
21
+
22
+ def step (self ):
23
+ if self .step_ds :
24
+ self .ds .step ()
25
+
26
+ def __iter__ (self ):
27
+ self .epoch += 1
28
+ if self .step_ds :
29
+ self .ds .step ()
30
+ self .ds .ifaug = True
31
+ else :
32
+ self .ds .ifaug = False
33
+ print ("now epoch " , self .epoch )
34
+ return super ().__iter__ ()
35
+
36
+
11
37
def readVirusSequences (pos = None , trunc = 1498 , sample = 300 , seed = 1509 ):
12
38
random .seed (seed )
13
39
print ("read positive samples" )
@@ -38,7 +64,7 @@ def readVirusSequences(pos=None, trunc=1498, sample=300, seed=1509):
38
64
39
65
print ("read negative samples" )
40
66
gen = ioutils .readFasta (
41
- "/home/tyfei/datasets/ion_channel/Interprot/Negative_sample/decoy_1m_new .fasta" ,
67
+ "/home/tyfei/datasets/ion_channel/Interprot/Negative_sample/old/decoy_1m_new_rmdup .fasta" ,
42
68
truclength = trunc ,
43
69
)
44
70
seqs ["neg" ] = [i for i in gen ]
@@ -49,15 +75,292 @@ def readVirusSequences(pos=None, trunc=1498, sample=300, seed=1509):
49
75
print ("read virus sequences" )
50
76
allvirus = []
51
77
for i in os .listdir ("/home/tyfei/datasets/NCBI_virus/genbank_csv/" ):
52
- allvirus .extend (
53
- ioutils .readNCBICsv (
54
- "/home/tyfei/datasets/NCBI_virus/genbank_csv/" + i , truclength = trunc
78
+ try :
79
+ allvirus .extend (
80
+ ioutils .readNCBICsv (
81
+ "/home/tyfei/datasets/NCBI_virus/genbank_csv/" + i , truclength = trunc
82
+ )
55
83
)
56
- )
84
+ except Exception :
85
+ pass
57
86
58
87
return sequences , labels , allvirus
59
88
60
89
90
+ MIN_LENGTH = 50
91
+
92
+
93
+ class DataAugmentation :
94
+ def __init__ (
95
+ self , step_points : list , maskp : list , crop : list , croprange : list
96
+ ) -> None :
97
+ assert len (step_points ) == len (maskp )
98
+ assert len (maskp ) == len (crop )
99
+ self .step_points = step_points
100
+ self .maskp = maskp
101
+ self .crop = crop
102
+ self .croprange = croprange
103
+
104
+ def _getSettings (self , step ):
105
+ maskp = (- 1.0 , - 1.0 )
106
+ crop = - 1.0
107
+ for i in range (len (self .step_points )):
108
+ if step > self .step_points [i ]:
109
+ maskp = self .maskp [i ]
110
+ crop = self .crop [i ]
111
+ return maskp , crop
112
+
113
+ def getAugmentation (self , seqlen , step ):
114
+ maskp , crop = self ._getSettings (step )
115
+ if crop > 0 :
116
+ t = random .random ()
117
+ if t < crop :
118
+ sampledlen = random .sample (self .croprange , 1 )[0 ]
119
+ sampledlen = int (sampledlen * np .random .uniform (0.8 , 1.2 ))
120
+ sampledlen = MIN_LENGTH if sampledlen < MIN_LENGTH else sampledlen
121
+ sampledlen = min (sampledlen , seqlen - 2 )
122
+ return maskp , sampledlen
123
+ return maskp , - 1
124
+
125
+
126
+ class ESM3BaseDataset (Dataset ):
127
+ def __init__ (self , tracks = ["seq_t" , "structure_t" , "sasa_t" , "second_t" ]) -> None :
128
+ assert len (tracks ) > 0
129
+ self .tracks = tracks
130
+ self .step_cnt = 0
131
+
132
+ def step (self ):
133
+ self .step_cnt += 1
134
+
135
+ def resetCnt (self ):
136
+ self .step_cnt = 0
137
+
138
+ def getToken (self , track , token ):
139
+ # assert token in ["start", "end", "mask"]
140
+ match token :
141
+ case "start" :
142
+ match track :
143
+ case "seq_t" :
144
+ return C .SEQUENCE_BOS_TOKEN
145
+ case "structure_t" :
146
+ return C .STRUCTURE_BOS_TOKEN
147
+ case "sasa_t" :
148
+ return 0
149
+ case "second_t" :
150
+ return 0
151
+ case _:
152
+ raise ValueError
153
+ case "end" :
154
+ match track :
155
+ case "seq_t" :
156
+ return C .SEQUENCE_EOS_TOKEN
157
+ case "structure_t" :
158
+ return C .STRUCTURE_EOS_TOKEN
159
+ case "sasa_t" :
160
+ return 0
161
+ case "second_t" :
162
+ return 0
163
+ case _:
164
+ raise ValueError
165
+ case "mask" :
166
+ match track :
167
+ case "seq_t" :
168
+ return C .SEQUENCE_MASK_TOKEN
169
+ case "structure_t" :
170
+ return C .STRUCTURE_MASK_TOKEN
171
+ case "sasa_t" :
172
+ return C .SASA_UNK_TOKEN
173
+ case "second_t" :
174
+ return C .SS8_UNK_TOKEN
175
+ case _:
176
+ raise ValueError
177
+ case "pad" :
178
+ match track :
179
+ case "seq_t" :
180
+ return C .SEQUENCE_PAD_TOKEN
181
+ case "structure_t" :
182
+ return C .STRUCTURE_PAD_TOKEN
183
+ case "sasa_t" :
184
+ return C .SASA_PAD_TOKEN
185
+ case "second_t" :
186
+ return C .SS8_PAD_TOKEN
187
+ case _:
188
+ raise ValueError
189
+ case _:
190
+ raise ValueError
191
+
192
+ def _maskSequence (self , sample , pos ):
193
+ for i in sample :
194
+ sample [i ][pos ] = self .getToken (i , "mask" )
195
+
196
+ return sample
197
+
198
+ def _generateMaskingPos (self , num , length , method = "point" ):
199
+ assert length > num + 5
200
+ if method == "point" :
201
+ a = np .array (random .sample (range (length - 2 ), num )) + 1
202
+ return a
203
+ elif method == "block" :
204
+ s = random .randint (1 , length - num )
205
+ a = np .array (range (s , s + num ))
206
+ return a
207
+ else :
208
+ raise NotImplementedError
209
+
210
+ def _cropSequence (self , sample , start , end ):
211
+ for i in sample :
212
+ t = torch .zeros ((end - start + 2 ), dtype = torch .long )
213
+ t [1 :- 1 ] = torch .tensor (sample [i ][start :end ])
214
+ t [0 ] = self .getToken (i , "start" )
215
+ t [- 1 ] = self .getToken (i , "end" )
216
+ sample [i ] = t
217
+ return sample
218
+
219
+ def _augmentsample (self , sample , maskp , crop ):
220
+ samplelen = len (sample [self .tracks [0 ]])
221
+ if crop > 50 :
222
+ s = random .randint (1 , samplelen - crop - 1 )
223
+ sample = self ._cropSequence (sample , s , s + crop )
224
+ samplelen = crop + 2
225
+ if maskp [0 ] > 0 :
226
+ num = np .random .binomial (samplelen - 2 , maskp [0 ])
227
+ pos = self ._generateMaskingPos (num , samplelen )
228
+ if len (pos ) > 0 :
229
+ sample = self ._maskSequence (sample , pos )
230
+ if maskp [1 ] > 0 :
231
+ num = np .random .binomial (samplelen - 2 , maskp [0 ])
232
+ pos = self ._generateMaskingPos (num , samplelen , "block" )
233
+ if len (pos ) > 0 :
234
+ sample = self ._maskSequence (sample , pos )
235
+ return sample
236
+
237
+
238
+ class ESM3MultiTrackDataset (ESM3BaseDataset ):
239
+ def __init__ (
240
+ self ,
241
+ data1 ,
242
+ data2 ,
243
+ label ,
244
+ augment : DataAugmentation = None ,
245
+ tracks = ["seq_t" , "structure_t" , "sasa_t" , "second_t" ],
246
+ ) -> None :
247
+ super ().__init__ (tracks = tracks )
248
+ self .data1 = data1
249
+ self .data2 = data2
250
+ self .label = label
251
+ self .aug = augment
252
+ self .iters = 0
253
+ self .data2order = np .arange (len (data2 ))
254
+ random .shuffle (self .data2order )
255
+ self .ifaug = False
256
+ # self.tracks = tracks
257
+
258
+ def __len__ (self ):
259
+ return len (self .data1 )
260
+
261
+ def step (self ):
262
+ random .shuffle (self .data2order )
263
+ super ().step ()
264
+
265
+ def __getitem__ (self , idx ):
266
+ x1 = {}
267
+ x2 = {}
268
+ for i in self .tracks :
269
+ x1 [i ] = self .data1 [idx ][i ]
270
+ x2 [i ] = self .data2 [self .data2order [idx % len (self .data2 )]][i ]
271
+ if self .aug is not None and self .ifaug :
272
+ maskp , crop = self .aug .getAugmentation (
273
+ len (x1 [self .tracks [0 ]]), self .step_cnt
274
+ )
275
+ x1 = self ._augmentsample (x1 , maskp , crop )
276
+ return x1 , torch .tensor ([self .label [idx ]]), x2
277
+
278
+
279
+ class ESM3MultiTrackDatasetTEST (ESM3BaseDataset ):
280
+ def __init__ (
281
+ self ,
282
+ data1 ,
283
+ augment : DataAugmentation = None ,
284
+ tracks = ["seq_t" , "structure_t" , "sasa_t" , "second_t" ],
285
+ ) -> None :
286
+ super ().__init__ (tracks = tracks )
287
+ self .data1 = data1
288
+ self .aug = augment
289
+ # self.tracks = tracks
290
+
291
+ def __len__ (self ):
292
+ return len (self .data1 )
293
+
294
+ def step (self ):
295
+ super ().step ()
296
+
297
+ def __getitem__ (self , idx ):
298
+ x1 = {}
299
+ for i in self .tracks :
300
+ x1 [i ] = self .data1 [idx ][i ]
301
+ if self .aug is not None :
302
+ maskp , crop = self .aug .getAugmentation (
303
+ len (x1 [self .tracks [0 ]]), self .step_cnt
304
+ )
305
+ x1 = self ._augmentsample (x1 , maskp , crop )
306
+ return x1
307
+
308
+
309
+ class ESM3datamodule (L .LightningDataModule ):
310
+ def __init__ (
311
+ self ,
312
+ ds1 : ESM3BaseDataset ,
313
+ ds2 : ESM3BaseDataset ,
314
+ batch_size = 1 ,
315
+ train_test_split = [0.85 , 0.15 ],
316
+ seed = 1509 ,
317
+ ):
318
+ super ().__init__ ()
319
+ self .value = 0
320
+ # self.ds1 = ds1
321
+ # self.ds2 = ds2
322
+ self .batch_size = batch_size
323
+ self .seed = seed
324
+ torch .manual_seed (self .seed )
325
+ # train_set, val_set = torch.utils.data.random_split(ds1, train_test_split)
326
+ all_indices = np .arange (len (ds1 ))
327
+
328
+ self .trainval_set = ds1
329
+ self .train_indices = all_indices [: int (len (all_indices ) * train_test_split [0 ])]
330
+ self .val_indices = all_indices [int (len (all_indices ) * train_test_split [0 ]) :]
331
+ self .testset = ds2
332
+
333
+ def train_dataloader (self ):
334
+ self .value += 1
335
+ self .trainval_set .resetCnt ()
336
+ print ("get train loader" )
337
+ return MyDataLoader (
338
+ self .trainval_set ,
339
+ True ,
340
+ batch_size = self .batch_size ,
341
+ sampler = SubsetRandomSampler (self .train_indices ),
342
+ num_workers = 4 ,
343
+ )
344
+
345
+ def val_dataloader (self ):
346
+ self .value += 1
347
+ print ("get val loader" )
348
+ return MyDataLoader (
349
+ self .trainval_set ,
350
+ False ,
351
+ batch_size = self .batch_size ,
352
+ sampler = SubsetRandomSampler (self .val_indices ),
353
+ num_workers = 4 ,
354
+ )
355
+
356
+ def predict_dataloader (self ):
357
+ self .value += 1
358
+ print ("get predict loader" )
359
+ return MyDataLoader (
360
+ self .testset , False , batch_size = self .batch_size , shuffle = True , num_workers = 4
361
+ )
362
+
363
+
61
364
class SeqDataset2 (Dataset ):
62
365
def __init__ (self , seq , label , seqtest ):
63
366
@@ -184,7 +487,6 @@ def setup(self, stage):
184
487
if stage == "test" :
185
488
raise NotImplementedError
186
489
187
-
188
490
def val_dataloader (self ):
189
491
return DataLoader (
190
492
self .val_set , batch_size = self .batch_size , shuffle = False , num_workers = 4
0 commit comments