Skip to content

Commit

Permalink
fix bug regarding overwrite in tools, replace MEME tomtom with tanger…
Browse files Browse the repository at this point in the history
…meme
  • Loading branch information
ruochiz committed Nov 25, 2024
1 parent a0910f0 commit 7edd525
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 222 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ dependencies = [
"gffutils",
"wandb",
"logomaker",
"shap"
"shap",
"tangermeme"
]
[project.scripts]
seq2print_train = "scprinter.seq.scripts.seq2print_lora_train:main"
seq2print_attr = "scprinter.seq.scripts.evaluation_model:main"
seq2print_tfbs = "scprinter.seq.scripts.generate_TFBS_bigwig:main"
seq2print_delta = "scprinter.seq.scripts.motif_delta_effects:main"
seq2print_modisco = "scprinter.seq.scripts.modisco_custom:main"

[tool.black]
line-length = 100
Expand Down
53 changes: 30 additions & 23 deletions scprinter/chromvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def _compute_deviations(motif_match, count, expectation_obs, expectation_var, de
return out


def bag_deviations(adata=None, ranked_df=None, cor=0.7, motif_corr_matrix=None):
def bag_deviations(
adata=None, ranked_df=None, cutoff=0.7, motif_motif_matrix=None, collapse_greater=True
):
"""
This function performs a bagging operation on transcription factors (TFs) based on their correlation with each other.
It selects a representative TF (sentinel TF) for each group of TFs that have a correlation coefficient greater than or equal to the specified threshold.
Expand All @@ -319,13 +321,14 @@ def bag_deviations(adata=None, ranked_df=None, cor=0.7, motif_corr_matrix=None):
An AnnData object containing the count matrix and TF information. If provided, TF variability will be computed from this object.
ranked_df : DataFrame, optional
A DataFrame containing TF names and their variability scores. If provided, TF variability will be computed from this object.
cor : float, optional
The correlation coefficient threshold. TFs with a correlation coefficient greater than or equal to this threshold will be grouped together.
motif_corr_matrix : str, optional
The path to the motif correlation matrix file. The common options are `scp.datasets.FigR_motifs_bagging_mouse` and `scp.datasets.FigR_motifs_bagging_human`.
use_name : bool, optional
A flag indicating whether to use TF names from the correlation matrix file. If True, TF names will be extracted from the correlation matrix file.
cutoff : float, optional
The coefficient threshold. TFs with a correlation coefficient greater (if collapse_higher=True, otherwise smaller) than or equal to this threshold will be grouped together.
motif_motif_matrix : str | pd.DataFrame
The path to the motif motif matrix file or a dataframe. The common options are `scp.datasets.FigR_motifs_bagging_mouse` and `scp.datasets.FigR_motifs_bagging_human`.
collapse_greater : bool, optional
Whether to collapse TFs with coefficient greater than or equal to the threshold.
If False, TFs with coefficient smaller than or equal to the threshold will be collapsed.
For instance, if it's correlation based method, set as True, if it's tomtom or testing based metod, set as False
Returns
-------
DataFrame
Expand All @@ -336,7 +339,7 @@ def bag_deviations(adata=None, ranked_df=None, cor=0.7, motif_corr_matrix=None):
If a DataFrame object is provided, a list containing the TF groups.
"""

assert motif_corr_matrix is not None, "Motif correlation matrix must be provided"
assert motif_motif_matrix is not None, "Motif correlation matrix must be provided"
assert adata is not None or ranked_df is not None, "Either adata or ranking_df must be provided"

# Compute variability and get transcription factors (TFs)
Expand All @@ -352,36 +355,40 @@ def bag_deviations(adata=None, ranked_df=None, cor=0.7, motif_corr_matrix=None):
TFnames = ranked_df.index.tolist()
TFnames_to_rank = {tf: i for i, tf in enumerate(TFnames)}
# Import correlation based on PWMs for the organism
if type(motif_corr_matrix) is pd.DataFrame:
cormat = motif_corr_matrix
if type(motif_motif_matrix) is pd.DataFrame:
cormat = motif_motif_matrix
else:
cormat = pd.read_csv(
motif_corr_matrix, sep="\t"
motif_motif_matrix, sep="\t"
) # Assuming the RDS file contains one object

# Historical code, kept for future references
# if use_name:
# tf1 = [xx.split("_")[2] for xx in cormat["TF1"]]
# tf2 = [xx.split("_")[2] for xx in cormat["TF2"]]
# cormat["TF1"] = tf1
# cormat["TF2"] = tf2
tf1 = cormat.iloc[:, 0]
tf2 = cormat.iloc[:, 1]

assert set(TFnames).issubset(
set(cormat["TF1"]).union(set(cormat["TF2"]))
set(tf1).union(set(tf2))
), "All TF names must be in the correlation matrix"
cormat = cormat[(cormat["TF1"].isin(TFnames)) & (cormat["TF2"].isin(TFnames))]
cormat = cormat[(tf1.isin(TFnames)) & (tf2.isin(TFnames))]

tf1 = cormat.iloc[:, 0]
tf2 = cormat.iloc[:, 1]
factor = np.array(cormat.iloc[:, 2])
if not collapse_greater:
factor = factor * -1
cutoff = -1 * cutoff

i = 1
TFgroups = []
while len(TFnames) != 0:
tfcur = TFnames[0]
boo = ((cormat["TF1"] == tfcur) | (cormat["TF2"] == tfcur)) & (cormat["Pearson"] >= cor)
boo = ((tf1 == tfcur) | (tf2 == tfcur)) & (factor >= cutoff)
hits = cormat[boo]
tfhits = list(set(list(np.unique(hits[["TF1", "TF2"]])) + [tfcur]))
tfhits = list(set(list(np.unique(hits.iloc[:, :2])) + [tfcur]))

# Update lists
TFnames = [tf for tf in TFnames if tf not in tfhits]
TFgroups.append(tfhits)
cormat = cormat[cormat["TF1"].isin(TFnames) & cormat["TF2"].isin(TFnames)]
cormat = cormat[tf1.isin(TFnames) & tf2.isin(TFnames)]
i += 1

sentinalTFs = []
Expand Down
34 changes: 34 additions & 0 deletions scprinter/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,37 @@ def fetch_bias_bw(self):
)

mm10 = GRCm38


mm39_genome = Genome(
name="mm39",
chrom_sizes={
"chr1": 195154279,
"chr2": 181755017,
"chrX": 169476592,
"chr3": 159745316,
"chr4": 156860686,
"chr5": 151758149,
"chr6": 149588044,
"chr7": 144995196,
"chr10": 130530862,
"chr8": 130127694,
"chr14": 125139656,
"chr9": 124359700,
"chr11": 121973369,
"chr13": 120883175,
"chr12": 120092757,
"chr15": 104073951,
"chr16": 98008968,
"chr17": 95294699,
"chrY": 91455967,
"chr18": 90720763,
"chr19": 61420004,
},
gff_file="gencode_vM30_GRCm39.gff3.gz",
fa_file="gencode_vM30_GRCm39.fa.gz",
bias_file="/data/rzhang/mm39/mm39Tn5Bias.h5",
blacklist_file="/data/rzhang/mm39/mm39.excluderanges.bed",
bg=(0.29149763779592625, 0.2083275235867118, 0.20834346947899296, 0.291831369138369),
splits=mm10_splits,
)
105 changes: 105 additions & 0 deletions scprinter/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from itertools import product
from pathlib import Path

import h5py
import MOODS
import MOODS.parsers
import MOODS.scan
Expand All @@ -22,13 +23,117 @@
from Bio import motifs
from pyfaidx import Fasta
from scipy.sparse import SparseEfficiencyWarning, csr_matrix, diags, hstack, vstack
from statsmodels.stats.multitest import fdrcorrection
from tangermeme.io import read_meme
from tangermeme.tools.tomtom import tomtom
from tqdm.auto import tqdm, trange
from typing_extensions import Literal

from . import genome
from .datasets import CisBP_Human, FigR_motifs_human, FigR_motifs_mouse, JASPAR2022_core
from .utils import regionparser

"""
Most of the code are from modiscolite, and adapted to include the delta effects.
Chose to copy code because modiscolite somehow changes the backend of matplotlib and inline version
"""


def read_pfms(motif, trim_threshold=0.3, prefix=""):
if isinstance(motif, Motifs):
dict1 = {}
for mtx in motif.all_motifs:
name = mtx.name
c = mtx.counts
mtx = np.array([c["A"], c["C"], c["G"], c["T"]]).T
mtx = mtx / np.sum(mtx, axis=1)[:, None]
dict1[prefix + name] = mtx.T
elif isinstance(motif, dict):
dict1 = motif
else:
if ".h5" in motif:
# modisco format
modisco_motifs = {}
with h5py.File(motif, "r") as modisco_results:
for contribution_dir_name in modisco_results.keys():
metacluster = modisco_results[contribution_dir_name]
key = lambda x: int(x[0].split("_")[-1])

for idx, (key_, pattern) in enumerate(sorted(metacluster.items(), key=key)):
ppm = np.array(pattern["sequence"][:])
cwm = np.array(pattern["contrib_scores"][:])
pattern_name = f"{contribution_dir_name}.pattern_{idx}"
score = np.sum(np.abs(cwm), axis=1)
trim_thresh = (
np.max(score) * trim_threshold
) # Cut off anything less than 30% of max score
pass_inds = np.where(score >= trim_thresh)[0]
trimmed = ppm[np.min(pass_inds) : np.max(pass_inds) + 1]
modisco_motifs[prefix + pattern_name] = trimmed.T
dict1 = modisco_motifs
else:
dict1 = read_meme(motif)
dict1 = {prefix + k: v.detach().numpy() for k, v in dict1.items()}
return dict1


def tomtom_motif_motif_matrix(
motifs_1, prefix_1="", motifs_2=None, prefix_2="", trim_threshold=0.3
):
"""
create a motif motif matrix using tomtom
Parameters
----------
motif: list[str, Path, motifs.Motifs] | str, Path, motifs.Motifs
Returns
-------
p values: np.ndarray
"""

if not isinstance(motifs_1, list):
motifs_1 = [motifs_1]
if not isinstance(prefix_1, list):
prefix_1 = [prefix_1]
if motifs_2 is None:
motifs_2 = motifs_1
prefix_2 = prefix_1
if not isinstance(motifs_2, list):
motifs_2 = [motifs_2]
if not isinstance(prefix_2, list):
prefix_2 = [prefix_2]

all_names_1 = []
all_pfms_1 = []

for m, p in zip(motifs_1, prefix_1):
m = read_pfms(m, trim_threshold=trim_threshold, prefix=p)
all_names_1 += list(m.keys())
all_pfms_1 += list(m.values())

all_names_2 = []
all_pfms_2 = []

for m, p in zip(motifs_2, prefix_2):
m = read_pfms(m, trim_threshold=trim_threshold, prefix=p)
all_names_2 += list(m.keys())
all_pfms_2 += list(m.values())

p, scores, offsets, overlaps, strands = tomtom(all_pfms_1, all_pfms_2)
df = pd.DataFrame(
{
"Query_ID": np.repeat(list(all_names_1), len(all_names_2)),
"Target_ID": list(all_names_2) * len(all_names_1),
"Optimal_offset": offsets.reshape(-1),
"p-value": p.reshape(-1),
}
)
df["E-value"] = df["p-value"] * len(all_names_2)
df["q-value"] = fdrcorrection(df["p-value"])[1]
df["Overlap"] = overlaps.flatten()
df["Orientation"] = ["+" if x == 0 else "-" for x in strands.reshape((-1))]
return df


def consecutive(data, stepsize=1):
"""
Expand Down
Loading

0 comments on commit 7edd525

Please sign in to comment.