Skip to content

Commit

Permalink
new plots (and relevant code). using parallel GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
IdoSpringer committed Dec 31, 2019
1 parent 9f1e6e5 commit 62b8b50
Show file tree
Hide file tree
Showing 14 changed files with 1,250 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ERGO_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, embedding_dim, device, max_len, input_dim, encoding_dim, batc
self.batch_size = batch_size
# TCR Autoencoder
self.autoencoder = PaddingAutoencoder(max_len, input_dim, encoding_dim)
checkpoint = torch.load(ae_file)
checkpoint = torch.load(ae_file, map_location=device)
self.autoencoder.load_state_dict(checkpoint['model_state_dict'])
if train_ae is False:
for param in self.autoencoder.parameters():
Expand Down
29 changes: 29 additions & 0 deletions comp_logos/extract_pep_tcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pickle


def extract_pep_tcr_files(pep):
dir = 'comp_logos'
train_data_file = '_'.join(['mcpas', 'train.pickle'])
test_data_file = '_'.join(['mcpas', 'test.pickle'])
# Read train data
with open(train_data_file, "rb") as file:
train = pickle.load(file)
# Read test data
with open(test_data_file, "rb") as file:
test = pickle.load(file)
count = 0
with open('_'.join([pep, 'pos']), 'w') as pos:
for (tcr, p, sign) in train + test:
if p[0] == pep and sign == 'p' and len(tcr) == 13:
pos.write(tcr + '\n')
count += 1
with open('_'.join([pep, 'neg']), 'w') as neg:
for (tcr, p, sign) in train + test:
if p[0] == pep and sign == 'n' and count and len(tcr) == 13:
neg.write(tcr + '\n')
count -= 1


peptides = ['GLCTLVAML', 'NLVPMVATV', 'GILGFVFTL']
for pep in peptides:
extract_pep_tcr_files(pep)
6 changes: 1 addition & 5 deletions evaluation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ def new_peps_score(args, model, test_data, new_tcrs, new_peps):
signs = [signs_to_prob[p[2]] for p in test_data if p[0] in new_tcrs and p[1][0] in new_peps]
return evaluate(args, model, tcrs, peps, signs)


# todo fix code. remove repeating parts
# todo hyperparameters tuning with NNI. only 4 models - AE/LSTM * McPAS/VDJdb
# todo average of 10 models. Not cross-validation, but close enough


if __name__ == '__main__':
Expand Down Expand Up @@ -286,8 +283,7 @@ def new_peps_score(args, model, test_data, new_tcrs, new_peps):
print(pep + '\t' + str(single_peptide_score(args, model, test_data, pep, None)[0]))
except ValueError:
print(pep + '\t' + 'none')

if args.function == 'load':
elif args.function == 'load':
model, data = load_model_and_data(args)
train_data, test_data = data
pass
Expand Down
Binary file added figures/MIS_POS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added mis_pos/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions mis_pos/ae_mis_pos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from ae_utils import *


# tcrs must have 21 one-hot, not 22. padding index in pep must be 0.
def convert_data(tcrs, peps, tcr_atox, pep_atox, max_len, mis):
for i in range(len(tcrs)):
tcrs[i] = pad_tcr(tcrs[i], tcr_atox, max_len, mis)
convert_peps(peps, pep_atox)


def pad_tcr(tcr, amino_to_ix, max_length, mis):
padding = torch.zeros(max_length, 20 + 1)
tcr = tcr + 'X'
for i in range(len(tcr)):
if i == mis:
continue
amino = tcr[i]
padding[i][amino_to_ix[amino]] = 1
return padding


def get_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_length, mis):
"""
Get batches from the data
"""
# Initialization
batches = []
index = 0
convert_data(tcrs, peps, tcr_atox, pep_atox, max_length, mis)
# Go over all data
while index < len(tcrs) // batch_size * batch_size:
# Get batch sequences and math tags
# Add batch to list
batch_tcrs = tcrs[index:index + batch_size]
tcr_tensor = torch.zeros((batch_size, max_length, 21))
for i in range(batch_size):
tcr_tensor[i] = batch_tcrs[i]
batch_peps = peps[index:index + batch_size]
batch_signs = signs[index:index + batch_size]
padded_peps, pep_lens = pad_batch(batch_peps)
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs))
# Update index
index += batch_size
# Return list of all batches
return batches


def get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_length, mis):
"""
Get batches from the data, including last with padding
"""
# Initialization
batches = []
index = 0
convert_data(tcrs, peps, tcr_atox, pep_atox, max_length, mis)
# Go over all data
while index < len(tcrs) // batch_size * batch_size:
# Get batch sequences and math tags
# Add batch to list
batch_tcrs = tcrs[index:index + batch_size]
tcr_tensor = torch.zeros((batch_size, max_length, 21))
for i in range(batch_size):
tcr_tensor[i] = batch_tcrs[i]
batch_peps = peps[index:index + batch_size]
batch_signs = signs[index:index + batch_size]
padded_peps, pep_lens = pad_batch(batch_peps)
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs))
# Update index
index += batch_size
# pad data in last batch
missing = batch_size - len(tcrs) + index
if missing < batch_size:
padding_tcrs = ['X'] * missing
padding_peps = ['A' * (batch_size - missing)] * missing
convert_data(padding_tcrs, padding_peps, tcr_atox, pep_atox, max_length, mis)
batch_tcrs = tcrs[index:] + padding_tcrs
tcr_tensor = torch.zeros((batch_size, max_length, 21))
for i in range(batch_size):
tcr_tensor[i] = batch_tcrs[i]
batch_peps = peps[index:] + padding_peps
padded_peps, pep_lens = pad_batch(batch_peps)
batch_signs = [0.0] * batch_size
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs))
# Update index
index += batch_size
# Return list of all batches
return batches
212 changes: 212 additions & 0 deletions mis_pos/ergo_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import random
import numpy as np
import csv
import os
import sklearn.model_selection as skl

# todo count how many TCRs and peps there are in each set (for TN FN TP FP tables)


def read_data(csv_file, file_key, _protein=False, _hla=False):
with open(csv_file, 'r', encoding='unicode_escape') as file:
file.readline()
if file_key == 'mcpas':
reader = csv.reader(file)
elif file_key == 'vdjdb':
reader = csv.reader(file, delimiter='\t')
tcrs = set()
peps = set()
all_pairs = []
for line in reader:
if file_key == 'mcpas':
if _protein:
protein = line[9]
if protein == 'NA':
continue
if _hla:
hla = line[13]
if hla == 'NA':
continue
if line[2] != 'Human':
continue
tcr, pep = line[1], line[11]
elif file_key == 'vdjdb':
if _protein:
protein = line[10]
if protein == 'NA':
continue
if _hla:
hla = line[6]
if hla == 'NA':
continue
if line[5] != 'HomoSapiens':
continue
tcr, pep = line[2], line[9]
if line[1] != 'TRB':
continue
# Proper tcr and peptides
if any(att == 'NA' or att == "" for att in [tcr, pep]):
continue
if any(key in tcr + pep for key in ['#', '*', 'b', 'f', 'y', '~', 'O', '/']):
continue
tcrs.add(tcr)
pep_data = [pep]
if _protein:
pep_data.append(protein)
if _hla:
pep_data.append(hla)
peps.add(tuple(pep_data))
all_pairs.append((tcr, pep_data))
train_pairs, test_pairs = train_test_split(all_pairs)
return all_pairs, train_pairs, test_pairs


def train_test_split(all_pairs):
'''
Splitting the TCR-PEP pairs
'''
train_pairs = []
test_pairs = []
for pair in all_pairs:
# 80% train, 20% test
p = np.random.binomial(1, 0.8)
if p == 1:
train_pairs.append(pair)
else:
test_pairs.append(pair)
return train_pairs, test_pairs


def positive_examples(pairs):
examples = []
for pair in pairs:
tcr, pep_data = pair
examples.append((tcr, pep_data, 'p'))
return examples


def negative_examples(pairs, all_pairs, size, _protein=False):
'''
Randomly creating intentional negative examples from the same pairs dataset.
'''
examples = []
i = 0
# Get tcr and peps lists
tcrs = [tcr for (tcr, pep_data) in pairs]
peps = [pep_data for (tcr, pep_data) in pairs]
while i < size:
pep_data = random.choice(peps)
for j in range(5):
tcr = random.choice(tcrs)
if _protein:
tcr_pos_pairs = [pair for pair in all_pairs if pair[0] == tcr]
tcr_proteins = [pep[1] for (tcr, pep) in tcr_pos_pairs]
protein = pep_data[1]
attach = protein in tcr_proteins
else:
attach = (tcr, pep_data) in all_pairs
if attach is False:
if (tcr, pep_data, 'n') not in examples:
examples.append((tcr, pep_data, 'n'))
i += 1
return examples


def read_naive_negs(tcrgp_dir, benny_chain_dir):
neg_tcrs = []
for file in os.listdir(tcrgp_dir):
filename = os.fsdecode(file)
if filename.endswith(".csv"):
with open(tcrgp_dir + '/' + filename, 'r') as csv_file:
csv_file.readline()
csv_ = csv.reader(csv_file)
for row in csv_:
if row[1] == 'control':
tcr = row[-1]
neg_tcrs.append(tcr)
for file in os.listdir(benny_chain_dir):
filename = os.fsdecode(file)
is_naive = 'naive' in filename
if filename.endswith(".cdr3") and 'beta' in filename and is_naive:
with open(benny_chain_dir + '/' + filename, 'r') as file:
for row in file:
row = row.strip().split(',')
tcr = row[0]
neg_tcrs.append(tcr)
train, test, _, _ = skl.train_test_split(neg_tcrs, neg_tcrs, test_size=0.2)
return train, test


def read_memory_negs(dir):
neg_tcrs = []
for file in os.listdir(dir):
filename = os.fsdecode(file)
is_memory = 'CM' in filename or 'EM' in filename
if filename.endswith(".cdr3") and 'beta' in filename and is_memory:
with open(dir + '/' + filename, 'r') as file:
for row in file:
row = row.strip().split(',')
tcr = row[0]
neg_tcrs.append(tcr)
train, test, _, _ = skl.train_test_split(neg_tcrs, neg_tcrs, test_size=0.2)
return train, test


def negative_external_examples(pairs, all_pairs, size, negs, _protein=False):
examples = []
i = 0
# Get tcr and peps lists
peps = [pep_data for (tcr, pep_data) in pairs]
while i < size:
pep_data = random.choice(peps)
for j in range(5):
tcr = random.choice(negs)
if _protein:
tcr_pos_pairs = [pair for pair in all_pairs if pair[0] == tcr]
tcr_proteins = [pep[1] for (tcr, pep) in tcr_pos_pairs]
protein = pep_data[1]
attach = protein in tcr_proteins
else:
attach = (tcr, pep_data) in all_pairs
if attach is False:
if (tcr, pep_data, 'n') not in examples:
examples.append((tcr, pep_data, 'n'))
i += 1
return examples


def get_examples(pairs_file, key, sampling, _protein=False, _hla=False):
all_pairs, train_pairs, test_pairs = read_data(pairs_file, key, _protein=_protein, _hla=_hla)
train_pos = positive_examples(train_pairs)
test_pos = positive_examples(test_pairs)
if sampling == 'naive':
neg_train, neg_test = read_naive_negs('tcrgp_training_data', 'benny_chain')
train_neg = negative_external_examples(train_pairs, all_pairs, len(train_pos), neg_train, _protein=_protein)
test_neg = negative_external_examples(test_pairs, all_pairs, len(test_pos), neg_test, _protein=_protein) # fixed to neg_test, was neg_train before
elif sampling == 'memory':
neg_train, neg_test = read_memory_negs('benny_chain')
train_neg = negative_external_examples(train_pairs, all_pairs, len(train_pos), neg_train, _protein=_protein)
test_neg = negative_external_examples(test_pairs, all_pairs, len(test_pos), neg_test, _protein=_protein) # fixed to neg_test, was neg_train before
elif sampling == 'specific':
train_neg = negative_examples(train_pairs, all_pairs, 5 * len(train_pos), _protein=_protein)
test_neg = negative_examples(test_pairs, all_pairs, 5 * len(test_pos), _protein=_protein)
return train_pos, train_neg, test_pos, test_neg


def load_data(pairs_file, key, sampling, _protein=False, _hla=False):
train_pos, train_neg, test_pos, test_neg = get_examples(pairs_file, key, sampling, _protein=_protein, _hla=_hla)
train = train_pos + train_neg
random.shuffle(train)
test = test_pos + test_neg
random.shuffle(test)
return train, test


def check(file, key, sampling, _protein, _hla):
train, test = load_data(file, key, sampling, _protein, _hla)
print(train)
print(test)
print(len(train))
print(len(test))

# check()
Loading

0 comments on commit 62b8b50

Please sign in to comment.