forked from IdoSpringer/ERGO
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new plots (and relevant code). using parallel GPUs
- Loading branch information
1 parent
9f1e6e5
commit 62b8b50
Showing
14 changed files
with
1,250 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pickle | ||
|
||
|
||
def extract_pep_tcr_files(pep): | ||
dir = 'comp_logos' | ||
train_data_file = '_'.join(['mcpas', 'train.pickle']) | ||
test_data_file = '_'.join(['mcpas', 'test.pickle']) | ||
# Read train data | ||
with open(train_data_file, "rb") as file: | ||
train = pickle.load(file) | ||
# Read test data | ||
with open(test_data_file, "rb") as file: | ||
test = pickle.load(file) | ||
count = 0 | ||
with open('_'.join([pep, 'pos']), 'w') as pos: | ||
for (tcr, p, sign) in train + test: | ||
if p[0] == pep and sign == 'p' and len(tcr) == 13: | ||
pos.write(tcr + '\n') | ||
count += 1 | ||
with open('_'.join([pep, 'neg']), 'w') as neg: | ||
for (tcr, p, sign) in train + test: | ||
if p[0] == pep and sign == 'n' and count and len(tcr) == 13: | ||
neg.write(tcr + '\n') | ||
count -= 1 | ||
|
||
|
||
peptides = ['GLCTLVAML', 'NLVPMVATV', 'GILGFVFTL'] | ||
for pep in peptides: | ||
extract_pep_tcr_files(pep) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from ae_utils import * | ||
|
||
|
||
# tcrs must have 21 one-hot, not 22. padding index in pep must be 0. | ||
def convert_data(tcrs, peps, tcr_atox, pep_atox, max_len, mis): | ||
for i in range(len(tcrs)): | ||
tcrs[i] = pad_tcr(tcrs[i], tcr_atox, max_len, mis) | ||
convert_peps(peps, pep_atox) | ||
|
||
|
||
def pad_tcr(tcr, amino_to_ix, max_length, mis): | ||
padding = torch.zeros(max_length, 20 + 1) | ||
tcr = tcr + 'X' | ||
for i in range(len(tcr)): | ||
if i == mis: | ||
continue | ||
amino = tcr[i] | ||
padding[i][amino_to_ix[amino]] = 1 | ||
return padding | ||
|
||
|
||
def get_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_length, mis): | ||
""" | ||
Get batches from the data | ||
""" | ||
# Initialization | ||
batches = [] | ||
index = 0 | ||
convert_data(tcrs, peps, tcr_atox, pep_atox, max_length, mis) | ||
# Go over all data | ||
while index < len(tcrs) // batch_size * batch_size: | ||
# Get batch sequences and math tags | ||
# Add batch to list | ||
batch_tcrs = tcrs[index:index + batch_size] | ||
tcr_tensor = torch.zeros((batch_size, max_length, 21)) | ||
for i in range(batch_size): | ||
tcr_tensor[i] = batch_tcrs[i] | ||
batch_peps = peps[index:index + batch_size] | ||
batch_signs = signs[index:index + batch_size] | ||
padded_peps, pep_lens = pad_batch(batch_peps) | ||
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs)) | ||
# Update index | ||
index += batch_size | ||
# Return list of all batches | ||
return batches | ||
|
||
|
||
def get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_length, mis): | ||
""" | ||
Get batches from the data, including last with padding | ||
""" | ||
# Initialization | ||
batches = [] | ||
index = 0 | ||
convert_data(tcrs, peps, tcr_atox, pep_atox, max_length, mis) | ||
# Go over all data | ||
while index < len(tcrs) // batch_size * batch_size: | ||
# Get batch sequences and math tags | ||
# Add batch to list | ||
batch_tcrs = tcrs[index:index + batch_size] | ||
tcr_tensor = torch.zeros((batch_size, max_length, 21)) | ||
for i in range(batch_size): | ||
tcr_tensor[i] = batch_tcrs[i] | ||
batch_peps = peps[index:index + batch_size] | ||
batch_signs = signs[index:index + batch_size] | ||
padded_peps, pep_lens = pad_batch(batch_peps) | ||
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs)) | ||
# Update index | ||
index += batch_size | ||
# pad data in last batch | ||
missing = batch_size - len(tcrs) + index | ||
if missing < batch_size: | ||
padding_tcrs = ['X'] * missing | ||
padding_peps = ['A' * (batch_size - missing)] * missing | ||
convert_data(padding_tcrs, padding_peps, tcr_atox, pep_atox, max_length, mis) | ||
batch_tcrs = tcrs[index:] + padding_tcrs | ||
tcr_tensor = torch.zeros((batch_size, max_length, 21)) | ||
for i in range(batch_size): | ||
tcr_tensor[i] = batch_tcrs[i] | ||
batch_peps = peps[index:] + padding_peps | ||
padded_peps, pep_lens = pad_batch(batch_peps) | ||
batch_signs = [0.0] * batch_size | ||
batches.append((tcr_tensor, padded_peps, pep_lens, batch_signs)) | ||
# Update index | ||
index += batch_size | ||
# Return list of all batches | ||
return batches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import random | ||
import numpy as np | ||
import csv | ||
import os | ||
import sklearn.model_selection as skl | ||
|
||
# todo count how many TCRs and peps there are in each set (for TN FN TP FP tables) | ||
|
||
|
||
def read_data(csv_file, file_key, _protein=False, _hla=False): | ||
with open(csv_file, 'r', encoding='unicode_escape') as file: | ||
file.readline() | ||
if file_key == 'mcpas': | ||
reader = csv.reader(file) | ||
elif file_key == 'vdjdb': | ||
reader = csv.reader(file, delimiter='\t') | ||
tcrs = set() | ||
peps = set() | ||
all_pairs = [] | ||
for line in reader: | ||
if file_key == 'mcpas': | ||
if _protein: | ||
protein = line[9] | ||
if protein == 'NA': | ||
continue | ||
if _hla: | ||
hla = line[13] | ||
if hla == 'NA': | ||
continue | ||
if line[2] != 'Human': | ||
continue | ||
tcr, pep = line[1], line[11] | ||
elif file_key == 'vdjdb': | ||
if _protein: | ||
protein = line[10] | ||
if protein == 'NA': | ||
continue | ||
if _hla: | ||
hla = line[6] | ||
if hla == 'NA': | ||
continue | ||
if line[5] != 'HomoSapiens': | ||
continue | ||
tcr, pep = line[2], line[9] | ||
if line[1] != 'TRB': | ||
continue | ||
# Proper tcr and peptides | ||
if any(att == 'NA' or att == "" for att in [tcr, pep]): | ||
continue | ||
if any(key in tcr + pep for key in ['#', '*', 'b', 'f', 'y', '~', 'O', '/']): | ||
continue | ||
tcrs.add(tcr) | ||
pep_data = [pep] | ||
if _protein: | ||
pep_data.append(protein) | ||
if _hla: | ||
pep_data.append(hla) | ||
peps.add(tuple(pep_data)) | ||
all_pairs.append((tcr, pep_data)) | ||
train_pairs, test_pairs = train_test_split(all_pairs) | ||
return all_pairs, train_pairs, test_pairs | ||
|
||
|
||
def train_test_split(all_pairs): | ||
''' | ||
Splitting the TCR-PEP pairs | ||
''' | ||
train_pairs = [] | ||
test_pairs = [] | ||
for pair in all_pairs: | ||
# 80% train, 20% test | ||
p = np.random.binomial(1, 0.8) | ||
if p == 1: | ||
train_pairs.append(pair) | ||
else: | ||
test_pairs.append(pair) | ||
return train_pairs, test_pairs | ||
|
||
|
||
def positive_examples(pairs): | ||
examples = [] | ||
for pair in pairs: | ||
tcr, pep_data = pair | ||
examples.append((tcr, pep_data, 'p')) | ||
return examples | ||
|
||
|
||
def negative_examples(pairs, all_pairs, size, _protein=False): | ||
''' | ||
Randomly creating intentional negative examples from the same pairs dataset. | ||
''' | ||
examples = [] | ||
i = 0 | ||
# Get tcr and peps lists | ||
tcrs = [tcr for (tcr, pep_data) in pairs] | ||
peps = [pep_data for (tcr, pep_data) in pairs] | ||
while i < size: | ||
pep_data = random.choice(peps) | ||
for j in range(5): | ||
tcr = random.choice(tcrs) | ||
if _protein: | ||
tcr_pos_pairs = [pair for pair in all_pairs if pair[0] == tcr] | ||
tcr_proteins = [pep[1] for (tcr, pep) in tcr_pos_pairs] | ||
protein = pep_data[1] | ||
attach = protein in tcr_proteins | ||
else: | ||
attach = (tcr, pep_data) in all_pairs | ||
if attach is False: | ||
if (tcr, pep_data, 'n') not in examples: | ||
examples.append((tcr, pep_data, 'n')) | ||
i += 1 | ||
return examples | ||
|
||
|
||
def read_naive_negs(tcrgp_dir, benny_chain_dir): | ||
neg_tcrs = [] | ||
for file in os.listdir(tcrgp_dir): | ||
filename = os.fsdecode(file) | ||
if filename.endswith(".csv"): | ||
with open(tcrgp_dir + '/' + filename, 'r') as csv_file: | ||
csv_file.readline() | ||
csv_ = csv.reader(csv_file) | ||
for row in csv_: | ||
if row[1] == 'control': | ||
tcr = row[-1] | ||
neg_tcrs.append(tcr) | ||
for file in os.listdir(benny_chain_dir): | ||
filename = os.fsdecode(file) | ||
is_naive = 'naive' in filename | ||
if filename.endswith(".cdr3") and 'beta' in filename and is_naive: | ||
with open(benny_chain_dir + '/' + filename, 'r') as file: | ||
for row in file: | ||
row = row.strip().split(',') | ||
tcr = row[0] | ||
neg_tcrs.append(tcr) | ||
train, test, _, _ = skl.train_test_split(neg_tcrs, neg_tcrs, test_size=0.2) | ||
return train, test | ||
|
||
|
||
def read_memory_negs(dir): | ||
neg_tcrs = [] | ||
for file in os.listdir(dir): | ||
filename = os.fsdecode(file) | ||
is_memory = 'CM' in filename or 'EM' in filename | ||
if filename.endswith(".cdr3") and 'beta' in filename and is_memory: | ||
with open(dir + '/' + filename, 'r') as file: | ||
for row in file: | ||
row = row.strip().split(',') | ||
tcr = row[0] | ||
neg_tcrs.append(tcr) | ||
train, test, _, _ = skl.train_test_split(neg_tcrs, neg_tcrs, test_size=0.2) | ||
return train, test | ||
|
||
|
||
def negative_external_examples(pairs, all_pairs, size, negs, _protein=False): | ||
examples = [] | ||
i = 0 | ||
# Get tcr and peps lists | ||
peps = [pep_data for (tcr, pep_data) in pairs] | ||
while i < size: | ||
pep_data = random.choice(peps) | ||
for j in range(5): | ||
tcr = random.choice(negs) | ||
if _protein: | ||
tcr_pos_pairs = [pair for pair in all_pairs if pair[0] == tcr] | ||
tcr_proteins = [pep[1] for (tcr, pep) in tcr_pos_pairs] | ||
protein = pep_data[1] | ||
attach = protein in tcr_proteins | ||
else: | ||
attach = (tcr, pep_data) in all_pairs | ||
if attach is False: | ||
if (tcr, pep_data, 'n') not in examples: | ||
examples.append((tcr, pep_data, 'n')) | ||
i += 1 | ||
return examples | ||
|
||
|
||
def get_examples(pairs_file, key, sampling, _protein=False, _hla=False): | ||
all_pairs, train_pairs, test_pairs = read_data(pairs_file, key, _protein=_protein, _hla=_hla) | ||
train_pos = positive_examples(train_pairs) | ||
test_pos = positive_examples(test_pairs) | ||
if sampling == 'naive': | ||
neg_train, neg_test = read_naive_negs('tcrgp_training_data', 'benny_chain') | ||
train_neg = negative_external_examples(train_pairs, all_pairs, len(train_pos), neg_train, _protein=_protein) | ||
test_neg = negative_external_examples(test_pairs, all_pairs, len(test_pos), neg_test, _protein=_protein) # fixed to neg_test, was neg_train before | ||
elif sampling == 'memory': | ||
neg_train, neg_test = read_memory_negs('benny_chain') | ||
train_neg = negative_external_examples(train_pairs, all_pairs, len(train_pos), neg_train, _protein=_protein) | ||
test_neg = negative_external_examples(test_pairs, all_pairs, len(test_pos), neg_test, _protein=_protein) # fixed to neg_test, was neg_train before | ||
elif sampling == 'specific': | ||
train_neg = negative_examples(train_pairs, all_pairs, 5 * len(train_pos), _protein=_protein) | ||
test_neg = negative_examples(test_pairs, all_pairs, 5 * len(test_pos), _protein=_protein) | ||
return train_pos, train_neg, test_pos, test_neg | ||
|
||
|
||
def load_data(pairs_file, key, sampling, _protein=False, _hla=False): | ||
train_pos, train_neg, test_pos, test_neg = get_examples(pairs_file, key, sampling, _protein=_protein, _hla=_hla) | ||
train = train_pos + train_neg | ||
random.shuffle(train) | ||
test = test_pos + test_neg | ||
random.shuffle(test) | ||
return train, test | ||
|
||
|
||
def check(file, key, sampling, _protein, _hla): | ||
train, test = load_data(file, key, sampling, _protein, _hla) | ||
print(train) | ||
print(test) | ||
print(len(train)) | ||
print(len(test)) | ||
|
||
# check() |
Oops, something went wrong.