forked from pnpnpn/dna2vec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request pnpnpn#1 from pnpnpn/gensim
training library
- Loading branch information
Showing
27 changed files
with
12,956 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
inputs/hg38/ | ||
results/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,50 @@ | ||
# dna2vec | ||
|
||
<https://arxiv.org/abs/1701.06279> | ||
**Dna2vec** is an open-source library to train distributed representations | ||
of variable-length k-mers. | ||
|
||
For more information, please refer to the paper: [dna2vec: Consistent vector representations of variable-length k-mers](https://arxiv.org/abs/1701.06279) | ||
|
||
Installation | ||
--- | ||
|
||
Note that this implementation has only been tested on Python 3.5.3, but we welcome any | ||
contributions or bug reporting to make it more accessible. | ||
|
||
1. Clone the `dna2vec` repository: `git clone https://github.com/pnpnpn/dna2vec` | ||
2. Install Python dependencies: `pip3 install -r requirements.txt` | ||
3. Test the installation: `python3 ./scripts/train_dna2vec.py -c configs/small_example.yml` | ||
|
||
Training dna2vec embeddings | ||
--- | ||
|
||
1. Download `hg38` from <http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.chromFa.tar.gz>. | ||
This will take a while as it's 938MB. | ||
2. Untar with `tar -zxvf hg38.chromFa.tar.gz`. You should see FASTA files for | ||
chromosome 1 to 22: `chr1.fa`, `chr2.fa`, ..., `chr22.fa`. | ||
3. Move the 22 FASTA files to folder `inputs/hg38/` | ||
4. Start the training with: `python3 ./scripts/train_dna2vec.py -c configs/hg38-20161219-0153.yml` | ||
5. Wait for a couple of days ... | ||
6. You should see a `*.w2v` and a corresponding `*.txt` file in your `results/` directory. | ||
|
||
Reading pretrained dna2vec | ||
--- | ||
|
||
You can read pretrained dna2vec vectors `pretrained/dna2vec-*.w2v` using | ||
the class `MultiKModel` in `dna2vec/multi_k_model.py`. For example: | ||
|
||
``` | ||
from dna2vec.multi_k_model import MultiKModel | ||
filepath = 'pretrained/dna2vec-20161219-0153-k3to8-100d-10c-29320Mbp-sliding-Xat.w2v' | ||
mk_model = MultiKModel(filepath) | ||
``` | ||
|
||
Contribute | ||
--- | ||
I would love for you to fork and send me pull request for this project. | ||
Please contribute. | ||
|
||
License | ||
--- | ||
This software is licensed under the [MIT license](http://en.wikipedia.org/wiki/MIT_License) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import numpy as np | ||
import os | ||
from Bio.Seq import Seq | ||
|
||
NUCLEOTIDES = 'ACGT' | ||
|
||
class Tuple4: | ||
def __init__(self, pos1, pos2, neg1, neg2): | ||
self.pos1 = pos1 | ||
self.pos2 = pos2 | ||
self.neg1 = neg1 | ||
self.neg2 = neg2 | ||
|
||
def determine_out_filename(output_dir, fileroot, mode, extension='txt'): | ||
return os.path.join(output_dir, '{}.{}.{}'.format(fileroot, mode, extension)) | ||
|
||
def create_tuple4(kmer1, kmer2, kmer1_neg, kmer2_neg): | ||
""" | ||
all inputs are list of single nucleotides, e.g. ['A', 'A', 'C'] | ||
""" | ||
return Tuple4( | ||
''.join(kmer1), | ||
''.join(kmer2), | ||
''.join(kmer1_neg), | ||
''.join(kmer2_neg)) | ||
|
||
def insert_snippet(seq, snippet, idx): | ||
""" | ||
idx: 0 <= idx <= len(seq)] | ||
""" | ||
split1 = seq[:idx] | ||
split2 = seq[idx:] | ||
return split1 + snippet + split2 | ||
|
||
def pairwise_key(v1, v2): | ||
return '{}:{}'.format(v1, v2) | ||
|
||
def rand_kmer(rng, k_low, k_high=None): | ||
""" | ||
k_low and k_high are inclusive | ||
""" | ||
if k_high is None: | ||
k_high = k_low | ||
k_len = rng.randint(k_low, k_high + 1) | ||
return ''.join([NUCLEOTIDES[x] for x in rng.randint(4, size=k_len)]) | ||
|
||
def rand_nt(rng): | ||
return NUCLEOTIDES[rng.randint(4)] | ||
|
||
def generate_revcompl_pair(k_low, k_high=None, rng=None): | ||
# TODO make params k_high and rng be required | ||
if k_high is None: | ||
k_high = k_low | ||
if rng is None: | ||
rng = np.random | ||
kmer = rand_kmer(rng, k_low, k_high) | ||
return (kmer, revcompl(kmer)) | ||
|
||
def revcompl(kmer): | ||
return str(Seq(kmer).reverse_complement()) | ||
|
||
def generate_1nt_mutation_4tuple(rng, k_len): | ||
kmer1 = list(rand_kmer(rng, k_len, k_len)) | ||
kmer2 = list(rand_kmer(rng, k_len, k_len)) | ||
|
||
idx = rng.randint(len(kmer1)) | ||
original_nt = kmer1[idx] | ||
mutate_nt = rand_nt(rng) | ||
|
||
kmer1_neg = list(kmer1) | ||
kmer1_neg[idx] = mutate_nt | ||
|
||
kmer2[idx] = mutate_nt | ||
kmer2_neg = list(kmer2) | ||
kmer2_neg[idx] = original_nt | ||
|
||
return create_tuple4(kmer1, kmer2, kmer1_neg, kmer2_neg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import sys | ||
|
||
class Tee(object): | ||
def __init__(self, fptr): | ||
self.file = fptr | ||
|
||
def __enter__(self): | ||
self.stdout = sys.stdout | ||
sys.stdout = self | ||
|
||
def __exit__(self, exception_type, exception_value, traceback): | ||
sys.stdout = self.stdout | ||
|
||
def write(self, data): | ||
self.file.write(data) | ||
self.stdout.write(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import time | ||
import logbook | ||
|
||
class Benchmark(): | ||
def __init__(self): | ||
self.time_wall_start = time.time() | ||
self.time_cpu_start = time.clock() | ||
self.logger = logbook.Logger(self.__class__.__name__) | ||
|
||
def diff_time_wall_secs(self): | ||
return (time.time() - self.time_wall_start) | ||
|
||
def print_time(self, label=''): | ||
self.logger.info("%s wall=%.3fm cpu=%.3fm" % ( | ||
label, | ||
self.diff_time_wall_secs() / 60.0, | ||
(time.clock() - self.time_cpu_start) / 60.0, | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import random | ||
import string | ||
import resource | ||
import logbook | ||
import arrow | ||
import numpy as np | ||
import os | ||
|
||
def split_Xy(df, y_colname='label'): | ||
X = df.drop([y_colname], axis=1) | ||
y = df[y_colname] | ||
return (X, y) | ||
|
||
def shuffle_tuple(tup, rng): | ||
lst = list(tup) | ||
rng.shuffle(lst) | ||
return tuple(lst) | ||
|
||
def shuffle_dataframe(df, rng): | ||
""" | ||
this does NOT do in-place shuffling | ||
""" | ||
return df.reindex(rng.permutation(df.index)) | ||
|
||
def random_str(N): | ||
return ''.join(random.SystemRandom().choice(string.ascii_lowercase + string.ascii_uppercase + string.digits) for _ in range(N)) | ||
|
||
def memory_usage(): | ||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E6 | ||
|
||
def estimate_bytes(filenames): | ||
return sum([os.stat(f).st_size for f in filenames]) | ||
|
||
def get_output_fileroot(dirpath, name, postfix): | ||
return '{}/{}-{}-{}-{}'.format( | ||
dirpath, | ||
name, | ||
arrow.utcnow().format('YYYYMMDD-HHmm'), | ||
postfix, | ||
random_str(3)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
inputs: inputs/hg38/chr*.fa | ||
k-low: 3 | ||
k-high: 8 | ||
vec-dim: 100 | ||
epoch: 10 | ||
context: 10 | ||
out-dir: results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
inputs: example_inputs/chrUn_KI27075*.fa | ||
debug: true | ||
k-low: 3 | ||
k-high: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import logbook | ||
import re | ||
from Bio import SeqIO | ||
from attic_util import util | ||
from itertools import islice | ||
import numpy as np | ||
|
||
def remove_empty(str_list): | ||
return filter(bool, str_list) # fastest way to remove empty string | ||
|
||
class SeqFragmenter: | ||
""" | ||
Split a sequence into small sequences based on some criteria, e.g. 'N' characters | ||
""" | ||
def __init__(self): | ||
pass | ||
|
||
def get_acgt_seqs(self, seq): | ||
return remove_empty(re.split(r'[^ACGTacgt]+', str(seq))) | ||
|
||
class SlidingKmerFragmenter: | ||
""" | ||
Slide only a single nucleotide | ||
""" | ||
def __init__(self, k_low, k_high): | ||
self.k_low = k_low | ||
self.k_high = k_high | ||
|
||
def apply(self, rng, seq): | ||
return [seq[i: i + rng.randint(self.k_low, self.k_high + 1)] for i in range(len(seq) - self.k_high + 1)] | ||
|
||
class DisjointKmerFragmenter: | ||
""" | ||
Split a sequence into kmers | ||
""" | ||
def __init__(self, k_low, k_high): | ||
self.k_low = k_low | ||
self.k_high = k_high | ||
|
||
@staticmethod | ||
def random_chunks(rng, li, min_chunk, max_chunk): | ||
""" | ||
Both min_chunk and max_chunk are inclusive | ||
""" | ||
it = iter(li) | ||
while True: | ||
head_it = islice(it, rng.randint(min_chunk, max_chunk + 1)) | ||
nxt = '' . join(head_it) | ||
|
||
# throw out chunks that are not within the kmer range | ||
if len(nxt) >= min_chunk: | ||
yield nxt | ||
else: | ||
break | ||
|
||
def apply(self, rng, seq): | ||
seq = seq[rng.randint(self.k_low):] # randomly offset the beginning to create more variations | ||
return list(DisjointKmerFragmenter.random_chunks(rng, seq, self.k_low, self.k_high)) | ||
|
||
class SeqMapper: | ||
def __init__(self, use_revcomp=True): | ||
self.use_revcomp = use_revcomp | ||
|
||
def apply(self, rng, seq): | ||
seq = seq.upper() | ||
if self.use_revcomp and rng.rand() < 0.5: | ||
return seq.reverse_complement() | ||
else: | ||
return seq | ||
|
||
class SeqGenerator: | ||
def __init__(self, filenames, nb_epochs, seqlen_ulim=5000): | ||
self.filenames = filenames | ||
self.nb_epochs = nb_epochs | ||
self.seqlen_ulim = seqlen_ulim | ||
self.logger = logbook.Logger(self.__class__.__name__) | ||
self.logger.info('Number of epochs: {}'.format(nb_epochs)) | ||
|
||
def filehandle_generator(self): | ||
for curr_epoch in range(self.nb_epochs): | ||
for filename in self.filenames: | ||
with open(filename) as file: | ||
self.logger.info('Opened file: {}'.format(filename)) | ||
self.logger.info('Memory usage: {} MB'.format(util.memory_usage())) | ||
self.logger.info('Current epoch: {} / {}'.format(curr_epoch + 1, self.nb_epochs)) | ||
yield file | ||
|
||
def generator(self, rng): | ||
for fh in self.filehandle_generator(): | ||
# SeqIO takes twice as much memory than even simple fh.readlines() | ||
for seq_record in SeqIO.parse(fh, "fasta"): | ||
whole_seq = seq_record.seq | ||
self.logger.info('Whole fasta seqlen: {}'.format(len(whole_seq))) | ||
curr_left = 0 | ||
while curr_left < len(whole_seq): | ||
seqlen = rng.randint(self.seqlen_ulim // 2, self.seqlen_ulim) | ||
segment = seq_record.seq[curr_left: seqlen + curr_left] | ||
curr_left += seqlen | ||
self.logger.debug('input seq len: {}'.format(len(segment))) | ||
yield segment | ||
|
||
class KmerSeqIterable: | ||
def __init__(self, rand_seed, seq_generator, mapper, seq_fragmenter, kmer_fragmenter, histogram): | ||
self.logger = logbook.Logger(self.__class__.__name__) | ||
self.seq_generator = seq_generator | ||
self.mapper = mapper | ||
self.kmer_fragmenter = kmer_fragmenter | ||
self.seq_fragmenter = seq_fragmenter | ||
self.histogram = histogram | ||
self.rand_seed = rand_seed | ||
self.iter_count = 0 | ||
|
||
def __iter__(self): | ||
self.iter_count += 1 | ||
rng = np.random.RandomState(self.rand_seed) | ||
for seq in self.seq_generator.generator(rng): | ||
seq = self.mapper.apply(rng, seq) | ||
acgt_seq_splits = list(self.seq_fragmenter.get_acgt_seqs(seq)) | ||
self.logger.debug('Splits of len={} to: {}'.format(len(seq), [len(f) for f in acgt_seq_splits])) | ||
|
||
for acgt_seq in acgt_seq_splits: | ||
kmer_seq = self.kmer_fragmenter.apply(rng, acgt_seq) # list of strings | ||
if len(kmer_seq) > 0: | ||
if self.iter_count == 1: | ||
# only collect stats on the first call | ||
self.histogram.add(kmer_seq) | ||
yield kmer_seq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from collections import Counter | ||
import logbook | ||
|
||
class Histogram: | ||
def __init__(self): | ||
self.kmer_len_counter = Counter() | ||
self.nb_kmers = 0 | ||
self.logger = logbook.Logger(self.__class__.__name__) | ||
|
||
def add(self, seq): | ||
""" | ||
seq - array of k-mer string | ||
""" | ||
for kmer in seq: | ||
self.kmer_len_counter[len(kmer)] += 1 | ||
self.nb_kmers += 1 | ||
|
||
def print_stat(self, fptr): | ||
for kmer_len in sorted(self.kmer_len_counter.keys()): | ||
self.logger.info('Percent of {:2d}-mers: {:3.1f}% ({})'.format( | ||
kmer_len, | ||
100.0 * self.kmer_len_counter[kmer_len] / self.nb_kmers, | ||
self.kmer_len_counter[kmer_len], | ||
)) | ||
|
||
total_bps = sum([l * c for l, c in self.kmer_len_counter.items()]) | ||
self.logger.info('Number of base-pairs: {}'.format(total_bps)) |
Oops, something went wrong.