Skip to content

Commit

Permalink
added miner2-causalinference
Browse files Browse the repository at this point in the history
  • Loading branch information
weiju committed Dec 5, 2019
1 parent 27079af commit 72a710c
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 30 deletions.
109 changes: 109 additions & 0 deletions bin/miner2-causalinference
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logging

from miner2 import coexpression, preprocess, mechanistic_inference as mechinf, miner
from miner2 import GIT_SHA
from miner2 import __version__ as MINER_VERSION
from miner2 import util

MIN_REGULON_GENES = 5
DESCRIPTION = """miner-causalinference - MINER causal inference.
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
GIT_SHA.replace('$Id: ', '').replace(' $', ''))

if __name__ == '__main__':
LOG_FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
datefmt='%Y-%m-%d %H:%M:%S \t')

parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=DESCRIPTION)
parser.add_argument('expfile', help="input matrix")
parser.add_argument('mapfile', help="identifier mapping file")
parser.add_argument('coreg', help="coregulationModules.json file from miner-mechinf")
parser.add_argument('coher', help="coherentMembers.csv file from miner-bcmembers")

parser.add_argument('cmfile', help="common mutations file")
parser.add_argument('tlfile', help="translocations file")
parser.add_argument('cgfile', help="cytogenetics file")
parser.add_argument('outdir', help="output directory")

args = parser.parse_args()

if not os.path.exists(args.expfile):
sys.exit("expression file not found")
if not os.path.exists(args.mapfile):
sys.exit("identifier mapping file not found")
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)

exp_data, conv_table = preprocess.main(args.expfile, args.mapfile)
with open(args.coreg) as infile:
coregulation_modules = json.load(infile)
#regulon_df = pd.read_csv(args.regulondf, index_col=0, header=0)
regulons = mechinf.get_regulons(coregulation_modules,
min_number_genes=MIN_REGULON_GENES,
freq_threshold=0.333)
regulon_modules, regulon_df = mechinf.get_regulon_dictionary(regulons)
coherent_samples_matrix = pd.read_csv(args.coher, index_col=0, header=0)

common_mutations = pd.read_csv(args.cmfile, index_col=0, header=0)
translocations = pd.read_csv(args.tlfile, index_col=0, header=0)
cytogenetics = pd.read_csv(args.cgfile, index_col=0, header=0)


eigengenes = miner.getEigengenes(regulon_modules, exp_data,
regulon_dict=None, saveFolder=None)
eigen_scale = np.percentile(exp_data,95) / np.percentile(eigengenes, 95)
eigengenes = eigen_scale * eigengenes
eigengenes.index = np.array(eigengenes.index).astype(str)

# Perform causal analysis for each mutation matrix
result_dir = os.path.join(args.outdir, "causal_analysis")
miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
expression_matrix=exp_data,
reference_matrix=eigengenes,
mutation_matrix=common_mutations,
resultsDirectory=result_dir,
minRegulons=1,
significance_threshold=0.05,
causalFolder="causal_results_common_mutations")

miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
expression_matrix=exp_data,
reference_matrix=eigengenes,
mutation_matrix=translocations,
resultsDirectory=result_dir,
minRegulons=1,
significance_threshold=0.05,
causalFolder="causal_results_translocations")

miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
expression_matrix=exp_data,
reference_matrix=eigengenes,
mutation_matrix=cytogenetics,
resultsDirectory=result_dir,
minRegulons=1,
significance_threshold=0.05,
causalFolder="causal_results_cytogenetics")

# compile all causal results
causal_results = miner.readCausalFiles(result_dir)
causal_results.to_csv(os.path.join(args.outdir, "completeCausalResults.csv"))

wire_diagram_out = os.path.join(args.outdir, 'wiring_diagram.csv')
wire_diagram = miner.wiringDiagram(causal_results, regulon_modules,
coherent_samples_matrix,
include_genes=False,
savefile=wire_diagram_out)
55 changes: 31 additions & 24 deletions bin/miner2-coexpr
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,35 @@ DESCRIPTION = """miner-coexpr - MINER cluster expression data.
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
GIT_SHA.replace('$Id: ', '').replace(' $', ''))


def plot_expression_stats(exp_data, outdir):
plt.figure()
ind_exp_data = [exp_data.iloc[:,i] for i in range(50)]
_ = plt.boxplot(ind_exp_data)
plt.title("Patient expression profiles",FontSize=14)
plt.ylabel("Relative expression",FontSize=14)
plt.xlabel("Sample ID",FontSize=14)
plt.savefig(os.path.join(outdir, "patient_expression_profiles.pdf"),
bbox_inches="tight")

plt.figure()
_ = plt.hist(exp_data.iloc[0,:],bins=100,alpha=0.75)
plt.title("Expression of single gene",FontSize=14)
plt.ylabel("Frequency",FontSize=14)
plt.xlabel("Relative expression",FontSize=14)
plt.savefig(os.path.join(outdir, "expression_single_gene.pdf"),
bbox_inches="tight")

plt.figure()
_ = plt.hist(exp_data.iloc[:,0],bins=200,color=[0,0.4,0.8],alpha=0.75)
plt.ylim(0,350)
plt.title("Expression of single patient sample",FontSize=14)
plt.ylabel("Frequency",FontSize=14)
plt.xlabel("Relative expression",FontSize=14)
plt.savefig(os.path.join(outdir, "expression_single_patient.pdf"),
bbox_inches="tight")


if __name__ == '__main__':
LOG_FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
Expand Down Expand Up @@ -55,6 +84,8 @@ if __name__ == '__main__':
util.write_dependency_infos(outfile)

exp_data, conv_table = preprocess.main(args.expfile, args.mapfile)
plot_expression_stats(exp_data, args.outdir)

t1 = time.time()
init_clusters = coexpression.cluster(exp_data,
min_number_genes=args.mingenes,
Expand Down Expand Up @@ -99,27 +130,3 @@ if __name__ == '__main__':

t2 = time.time()
logging.info("Completed clustering module in {:.2f} minutes".format((t2-t1)/60.))

"""
# visualize first 10 clusters
plt.figure(figsize=(8,8))
plt.imshow(exp_data.loc[np.hstack([revised_clusters[i] for i in range(10)]),:],
aspect="auto", cmap="viridis", vmin=-1, vmax=1)
plt.grid(False)
plt.ylabel("Genes", FontSize=20)
plt.xlabel("Samples", FontSize=20)
plt.title("First 10 clusters", FontSize=20)
# report coverage
#logging.info("Number of genes clustered: {:d}".format(len(set(np.hstack(initialClusters)))))
#logging.info("Number of unique clusters: {:d}".format(len(revisedClusters)))
# plot histogram of the cluster size distribution
counts_ = plt.hist([len(revised_clusters[key]) for key in revised_clusters.keys()],
bins=100)
plt.xlabel("Number of genes in cluster", FontSize=14)
plt.ylabel("Number of clusters", FontSize=14)
plt.savefig(os.path.join(args.outdir, "cluster_size_distribution.pdf"),
bbox_inches="tight")
"""
10 changes: 5 additions & 5 deletions miner2/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,15 @@ def readExpressionFromGZipFiles(directory):
expressionData = pd.concat(sample_dfs,axis=1)
return expressionData

def readCausalFiles(directory):

rootDir = directory
def readCausalFiles(rootDir):
sample_dfs = []
for dirName, subdirList, fileList in os.walk(rootDir):
#print("dn: %s, sdl: %s, fl: %s" % (dirName, str(subdirList), str(fileList)))
for fname in fileList:
#print('\t%s' % fname)
extension = fname.split(".")[-1]
if extension == 'csv':
path = os.path.join(rootDir,dirName,fname)
path = os.path.join(dirName, fname)
df = pd.read_csv(path, index_col=0,header=0)
df.index = np.array(df.index).astype(str)
sample_dfs.append(df)
Expand All @@ -209,6 +208,7 @@ def readCausalFiles(directory):
causalData.Regulon = renamed
return causalData


def entropy(vector):

data = np.array(vector)
Expand Down Expand Up @@ -2869,7 +2869,7 @@ def parallelCausalNetworkAnalysis(regulon_matrix,expression_matrix,reference_mat
def wiringDiagram(causal_results,regulonModules,coherent_samples_matrix,include_genes=False,savefile=None):
cytoscape_output = []
for regulon in list(set(causal_results.index)):

# regulon is of type 'str', coherent_samples_matrix indexes are 'int'
genes = regulonModules[regulon]
samples = coherent_samples_matrix.columns[coherent_samples_matrix.loc[int(regulon),:]==1]
condensed_genes = (";").join(genes)
Expand Down
62 changes: 62 additions & 0 deletions miner2/survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,53 @@ def survival_membership_analysis(task):
return coxResults


def survival_median_analysis(task):
start, stop = task[0]
referenceDictionary,expressionDf,SurvivalDf = task[1]

overlapPatients = list(set(expressionDf.columns)&set(SurvivalDf.index))
Survival = SurvivalDf.loc[overlapPatients,SurvivalDf.columns[0:2]]

coxResults = {}
keys = referenceDictionary.keys()[start:stop]
ct=0
for key in keys:
ct+=1
if ct%10==0:
print(ct)
try:
geneset = referenceDictionary[key]
cluster = expressionDf.loc[geneset,overlapPatients]
nz = np.count_nonzero(cluster+4.01,axis=0)

medians = []
for i in range(cluster.shape[1]):
if nz[i] >= 3:
median = np.median(cluster.iloc[:,i][cluster.iloc[:,i]>-4.01])
elif nz[i] < 3:
median = np.median(cluster.iloc[:,i])
medians.append(median)

medianDf = pd.DataFrame(medians)
medianDf.index = overlapPatients
medianDf.columns = ["median"]
Survival = pd.concat([Survival,medianDf],axis=1)
Survival.sort_values(by=Survival.columns[0],inplace=True)

cph = CoxPHFitter()
cph.fit(Survival, duration_col=Survival.columns[0], event_col=Survival.columns[1])

tmpcph = cph.summary

cox_hr = tmpcph.loc[key,"z"]
cox_p = tmpcph.loc[key,"p"]
coxResults[key] = (cox_hr, cox_p)
except:
coxResults[key] = (0, 1)

return coxResults


def parallel_member_survival_analysis(membershipDf, numCores=5, survivalPath=None,
survivalData=None):

Expand All @@ -148,6 +195,21 @@ def parallel_member_survival_analysis(membershipDf, numCores=5, survivalPath=Non
return util.condense_output(coxOutput)


def parallel_median_survival_analysis(referenceDictionary,
expressionDf,
numCores=5,
survivalPath=None,
survivalData=None):

if survivalData is None:
survivalData = pd.read_csv(survivalPath,index_col=0,header=0)
taskSplit = util.split_for_multiprocessing(referenceDictionary.keys(),numCores)
taskData = (referenceDictionary,expressionDf,survivalData)
tasks = [[taskSplit[i],taskData] for i in range(len(taskSplit))]
coxOutput = util.multiprocess(survival_median_analysis,tasks)
return util.condense_output(coxOutput)


def combined_states(groups, ranked_groups, survivalDf, minSamples=4, maxStates=7):
high_risk_indices = []
for i in range(1, len(ranked_groups) + 1):
Expand Down
2 changes: 1 addition & 1 deletion rp_input.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"regulons": "prout_mechinf/regulons.json",
"regulons": "out-mechinf/regulons.json",
"primary_survival": "MATTDATA/survival/survivalIA12.csv",
"test_survival": "MATTDATA/survival/globalClinTraining.csv",
"datasets": [
Expand Down

0 comments on commit 72a710c

Please sign in to comment.