Skip to content

Commit

Permalink
Fixes and tests dcpg_data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cangermueller committed Oct 28, 2016
1 parent 54071d6 commit 5b749a8
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 33 deletions.
36 changes: 23 additions & 13 deletions scripts/dcpg_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,21 @@ def entropy(x, axis=1):


def diff(x, axis=1):
return np.array(x.min(axis=axis) != x.max(axis=axis), dtype=np.int8)
diff = x.min(axis=axis) != x.max(axis=axis)
return diff


def disp(x, axis=1):
mean = x.mean(axis=1)
return x.var(axis=1) - mean * (1 - mean)


def mode(x, axis=1):
mode = x.mean(axis=axis).astype(np.int8)
assert np.all((mode == 0) | (mode == 1))
return mode


def output_stats_meta_by_name(names):
funs = dict()
for name in names:
Expand All @@ -172,9 +179,11 @@ def output_stats_meta_by_name(names):
elif name == 'entropy':
fun = (entropy, np.float32)
elif name == 'diff':
fun = (mean, np.int8)
fun = (diff, np.int8)
elif name == 'disp':
fun = (disp, np.float32)
elif name == 'mode':
fun = (mode, np.int8)
else:
raise ValueError('Invalid statistic "%s"!' % name)
funs[name] = fun
Expand Down Expand Up @@ -213,7 +222,7 @@ def create_parser(self, name):
default=501,
help='DNA window length')
p.add_argument(
'--sc_profiles',
'--cpg_profiles',
nargs='+',
help='BED files with single-cell methylation profiles')
p.add_argument(
Expand All @@ -233,10 +242,10 @@ def create_parser(self, name):
help='Filter sites by CpG coverage. Number of observations per '
'site, or percentage if smaller than one.')
p.add_argument(
'--sc_stats',
'--cpg_stats',
help='Output statistics derived from single-cell profiles',
nargs='+',
choices=['mean', 'var', 'entropy', 'diff', 'disp'])
choices=['mean', 'var', 'entropy', 'diff', 'disp', 'mode'])
p.add_argument(
'--chromos',
nargs='+',
Expand Down Expand Up @@ -274,7 +283,7 @@ def main(self, name, opts):
log.debug(opts)

# Check input arguments
if not (opts.sc_profiles or opts.bulk_profiles):
if not (opts.cpg_profiles or opts.bulk_profiles):
if not (opts.pos_file or opts.dna_db):
raise ValueError('Position table and DNA database expected!')

Expand All @@ -284,22 +293,23 @@ def main(self, name, opts):
raise '--cpg_wlen must be even!'

# Parse functions for computing output statistics
cpg_stats_meta = output_stats_meta_by_name(opts.sc_stats)
cpg_stats_meta = output_stats_meta_by_name(opts.cpg_stats)

outputs = OrderedDict()

# Read single-cell profiles if provied
if opts.sc_profiles:
# Read single-cell profiles if provided
if opts.cpg_profiles:
log.info('Reading single-cell profiles ...')
outputs['cpg'] = read_cpg_profiles(opts.sc_profiles,
outputs['cpg'] = read_cpg_profiles(opts.cpg_profiles,
chromos=opts.chromos,
nrows=opts.nb_sample)

if opts.bulk_profiles:
log.info('Reading bulk profiles ...')
outputs['bulk'] = read_cpg_profiles(opts.bulk_profiles,
chromos=opts.chromos,
nrows=opts.nb_sample)
nrows=opts.nb_sample,
round=False)

# Create table with unique positions
if opts.pos_file:
Expand Down Expand Up @@ -390,7 +400,7 @@ def main(self, name, opts):
if len(chunk_outputs):
out_group = chunk_file.create_group('outputs')

# Write sc profiles
# Write cpg profiles
if 'cpg' in chunk_outputs:
for name, value in chunk_outputs['cpg'].items():
assert len(value) == len(chunk_pos)
Expand All @@ -403,7 +413,7 @@ def main(self, name, opts):
cpg_mat = np.ma.masked_values(chunk_outputs['cpg_mat'],
dat.CPG_NAN)
for name, fun in cpg_stats_meta.items():
stat = fun[0](cpg_mat)
stat = fun[0](cpg_mat).data.astype(fun[1])
assert len(stat) == len(chunk_pos)
out_group.create_dataset('stats/%s' % name,
data=stat,
Expand Down
4 changes: 2 additions & 2 deletions scripts/dcpg_data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(self, name, opts):
log.setLevel(logging.INFO)
log.debug(opts)

output_names = dat.h5_ls(opts.data_files[0], 'outputs',
opts.output_names)
output_names = dat.get_output_names(opts.data_files[0],
regex=opts.output_names)
stats = dat.get_output_stats(opts.data_files, output_names)
tmp = []
for key, value in stats.items():
Expand Down
23 changes: 14 additions & 9 deletions scripts/dcpg_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

LOG_PRECISION = 4

CLA_METRICS = [met.acc, met.f1, met.mcc, met.tpr, met.tnr]

REG_METRICS = [met.mse, met.mad]


def get_output_weights(output_names, weight_patterns):
regex_weights = dict()
Expand Down Expand Up @@ -67,7 +71,7 @@ def get_objectives(output_names):
objective = 'binary_crossentropy'
elif output_name.startswith('bulk'):
objective = 'mean_squared_error'
elif output_name == 'stats/diff':
elif output_name in ['stats/diff', 'stats/mode']:
objective = 'binary_crossentropy'
elif output_name in ['stats/mean', 'stats/var']:
objective = 'mean_squared_error'
Expand All @@ -78,12 +82,14 @@ def get_objectives(output_names):


def get_metrics(output_name):
if output_name.startswith('cpg') or output_name == 'stats/diff':
metrics = [met.acc, met.tpr, met.tnr, met.f1, met.mcc]
if output_name.startswith('cpg'):
metrics = CLA_METRICS
elif output_name.startswith('bulk'):
metrics = [met.mse, met.mae]
metrics = REG_METRICS
elif output_name in ['stats/diff', 'stats/mode']:
metrics = CLA_METRICS
elif output_name in ['stats/mean', 'stats/var']:
metrics = [met.mse, met.mae]
metrics = REG_METRICS
else:
raise ValueError('Invalid output name "%s"!' % output_name)
return metrics
Expand Down Expand Up @@ -422,10 +428,9 @@ def main(self, name, opts):
l2_decay=opts.l2_decay,
init=opts.initialization)

inputs = model_builder.inputs()
stem = model_builder(inputs)
outputs = mod.add_output_layers(stem, output_names)
model = Model(input=inputs, output=outputs, name=opts.model_name)
stem = model_builder()
outputs = mod.add_output_layers(stem.outputs, output_names)
model = Model(input=stem.inputs, output=outputs, name=_model_name)
model.summary()

mod.save_model(model, os.path.join(opts.out_dir, 'model.json'))
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/data/dcpg_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@


out_dir="./data"
cpg_files=$(ls ./cpg_files/*bed)

check=1
function run {
cmd=$@
echo
echo "#################################"
echo $cmd
echo "#################################"
eval $cmd
if [ $check -ne 0 -a $? -ne 0 ]; then
if [[ $check -ne 0 && $? -ne 0 ]]; then
1>&2 echo "Command failed!"
exit 1
fi
Expand All @@ -21,9 +18,12 @@ function run {
cmd="rm -rf $out_dir && mkdir -p $out_dir"
cmd="$cmd && dcpg_data.py
--dna_db ./dna_db
--cpg_files $(ls ./cpg_files/BS27_4_SER.bed ./cpg_files/BS28_2_SER.bed)
--cpg_profiles $(ls ./cpg_files/BS27_4_SER.bed ./cpg_files/BS28_2_SER.bed)
--bulk_profiles $(ls ./cpg_files/BS9N_2I.bed ./cpg_files/BS9N_SER.bed)
--cpg_stats mean var diff mode
--dna_wlen 501
--cpg_wlen 50
--chunk_size 5000
--chromos 18 19
--out_dir $out_dir"
run $cmd
67 changes: 63 additions & 4 deletions tests/integration_tests/data/test_dcpg_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
from numpy import testing as npt

from deepcpg.data import h5_read, CPG_NAN
from deepcpg.data.fasta import read_chromo
Expand All @@ -26,8 +27,14 @@ def setup_class(self):
'/inputs/cpg/BS27_4_SER/state',
'/inputs/cpg/BS28_2_SER/dist',
'/inputs/cpg/BS28_2_SER/state',
'/outputs/cpg_BS27_4_SER',
'/outputs/cpg_BS28_2_SER'
'/outputs/cpg/BS27_4_SER',
'/outputs/cpg/BS28_2_SER',
'/outputs/stats/mean',
'/outputs/stats/var',
'/outputs/stats/diff',
'/outputs/stats/mode',
'/outputs/bulk/BS9N_2I',
'/outputs/bulk/BS9N_SER'
]
self.data = h5_read(self.data_files, names)
self.chromo = self.data['chromo']
Expand All @@ -49,7 +56,7 @@ def test_outputs(self):
('19', 4442494, 0.0),
('19', 4447847, 1.0)
]
self._test_outputs('cpg_BS27_4_SER', expected)
self._test_outputs('cpg/BS27_4_SER', expected)

expected = [('18', 3000092, 1.0),
('18', 3010064, 0.0),
Expand All @@ -60,7 +67,7 @@ def test_outputs(self):
('19', 4192788, 0.0),
('19', 4202077, 0.0)
]
self._test_outputs('cpg_BS28_2_SER', expected)
self._test_outputs('cpg/BS28_2_SER', expected)

def _test_dna(self, chromo):
pos = self.pos[self.chromo == chromo.encode()]
Expand Down Expand Up @@ -170,3 +177,55 @@ def test_cpg_neighbors(self):
)
]
self._test_cpg_neighbors(name, expected)

def _test_stats(self, chromo, pos, stat, value):
idx = (self.chromo == chromo.encode()) & (self.pos == pos)
stat = self.data['/outputs/stats/%s' % stat][idx]
assert stat == value

def test_stats(self):
self._test_stats('18', 3010417, 'mean', 1.0)
self._test_stats('18', 3010417, 'var', 0.0)
self._test_stats('18', 3010417, 'diff', 0)
self._test_stats('18', 3010417, 'mode', 1)
self._test_stats('18', 3012173, 'mean', 0.0)
self._test_stats('18', 3012173, 'var', 0.0)
self._test_stats('18', 3012173, 'diff', 0)
self._test_stats('18', 3012173, 'mode', 0)
self._test_stats('18', 3052129, 'mean', 1.0)
self._test_stats('18', 3052129, 'var', 0.0)
self._test_stats('18', 3052129, 'diff', 0)
self._test_stats('18', 3052129, 'mode', 1)
self._test_stats('18', 3071630, 'mean', 0.5)
self._test_stats('18', 3071630, 'var', 0.25)
self._test_stats('18', 3071630, 'diff', 1)
self._test_stats('18', 3071630, 'mode', 0)
self._test_stats('19', 4201704, 'mean', 0.0)
self._test_stats('19', 4201704, 'var', 0.0)
self._test_stats('19', 4201704, 'diff', 0)
self._test_stats('19', 4201704, 'mode', 0)
self._test_stats('19', 4190571, 'mean', 0.5)
self._test_stats('19', 4190571, 'var', 0.25)
self._test_stats('19', 4190571, 'diff', 1)
self._test_stats('19', 4190571, 'mode', 0)
self._test_stats('19', 4190700, 'mean', 0.0)
self._test_stats('19', 4190700, 'var', 0.0)
self._test_stats('19', 4190700, 'diff', 0)
self._test_stats('19', 4190700, 'mode', 0)

def _test_bulk(self, chromo, pos, name, expected):
idx = (self.chromo == chromo.encode()) & (self.pos == pos)
actual = float(self.data['/outputs/bulk/%s' % name][idx])
npt.assert_almost_equal(actual, expected, 2)

def test_bulk(self):
self._test_bulk('18', 3000023, 'BS9N_2I', 0.0)
self._test_bulk('18', 3000023, 'BS9N_SER', 0.75)
self._test_bulk('18', 3000086, 'BS9N_2I', 0.0)
self._test_bulk('18', 3000086, 'BS9N_SER', 0.666)
self._test_bulk('18', 3004868, 'BS9N_2I', 0.042)
self._test_bulk('18', 3004868, 'BS9N_SER', 0.1636)
self._test_bulk('18', 3013979, 'BS9N_2I', -1)
self._test_bulk('18', 3013979, 'BS9N_SER', 1.0)
self._test_bulk('19', 4438754, 'BS9N_2I', -1)
self._test_bulk('19', 4438754, 'BS9N_SER', 0.333)

0 comments on commit 5b749a8

Please sign in to comment.