Skip to content

Commit

Permalink
feat: fix event enumeration modules to allow post-enumeration alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
soymintc committed Dec 10, 2024
1 parent 91c9bff commit c325118
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
27 changes: 17 additions & 10 deletions ontmont/bundle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pandas as pd
import numpy as np
import tqdm

from .utils import (remove_duplicates_from_tra_table,
filter_sv_with_breakpoint_at_contig_ends, filter_breakpoints_at_contig_ends, enumerate_breakpoints)
Expand Down Expand Up @@ -30,10 +31,11 @@ def make_seg_table(bundle, seg_supports, segment_score_cutoff=5, label_irs=False
data = []

for brks in bundle:
for ix, seg in enumerate(brks.segs):
for seg in brks.segs:
assert seg.brk1.chrom == seg.brk2.chrom
chrom = seg.brk1.chrom
if chrom not in chrom_order: continue #
if chrom not in chrom_order:
continue
pos1 = seg.brk1.pos
pos2 = seg.brk2.pos
ori1 = seg.brk1.ori
Expand All @@ -50,12 +52,14 @@ def make_seg_table(bundle, seg_supports, segment_score_cutoff=5, label_irs=False
if coord not in seg_supports:
continue
support = seg_supports[coord]
field = [*coord, support]

if seg.aln_segment:
_segment_score = seg.aln_segment.score
if _segment_score >= segment_score_cutoff:
segment_score = _segment_score
segment_pvalue = seg.aln_segment.pvalue
field = [*coord, support]

if label_irs:
field += [segment_score, segment_pvalue]
data.append(field)
Expand Down Expand Up @@ -90,7 +94,7 @@ def make_brk_table(bundle, brk_supports,
]
chrom_order = ['chr'+str(c) for c in range(1, 22)] + ['chrX', 'chrY'] + vectors
brk_saved = set()

data = []
for brks in bundle:
for brk in brks:
if brk.chrom not in chrom_order: continue
Expand Down Expand Up @@ -139,28 +143,30 @@ def make_brk_table(bundle, brk_supports,
return brk_df


def make_aligned_brks_bundle(reads_df, genome=None, sw_palindrome=None, sw_holliday=None, margins=[15, 30, 60]):
def make_aligned_brks_bundle(reads_df, genome=None, sw_palindrome=None, sw_holliday=None, margins=(15, 30, 60), track_progress=False):
"""Make a list of ``BreapointChain`` based on alignment table, genome, and alignment parameters
Args:
reads_df (pandas.DataFrame): Table of read alignment statistics
genome (pyfaidx.Fasta): Genome fasta
sw_palindrome (swalign.LocalAlignment): Parameters for detecting IR
sw_holliday (swalign.LocalAlignment): Parameters for detecting homology
margins (list, optional): Bases to slice from breakpoints. Defaults to [15, 30, 60].
margins (list, optional): Bases to slice from breakpoints. Defaults to (15, 30, 60).
Returns:
list: List of ``BreakpointChain``
"""
bundle = []
margin_max = max(margins)
for qname, qdf in reads_df.groupby('qname'):
iterable = reads_df.groupby('qname')
iterable = tqdm.tqdm(iterable) if track_progress else iterable
for qname, qdf in iterable:
brks = enumerate_breakpoints(qdf)
brks.qname = qname
brks.get_transitions()
brks.get_segments()

flag_align = sw_palinedrome and sw_holliday
flag_align = sw_palindrome and sw_holliday
if flag_align: # find IRs through SW alignment
for brk in brks:
brk.get_breakpoint_seqs(margin=margin_max, genome=genome)
Expand Down Expand Up @@ -216,7 +222,7 @@ def make_tra_table(bundle, tra_supports, label_irs=False):
tra_saved = set()
data = []
for brks in bundle:
for ix, tra in enumerate(brks.tras):
for tra in brks.tras:
coord1 = (tra.brk1.chrom, tra.brk1.pos, tra.brk1.ori)
coord2 = (tra.brk2.chrom, tra.brk2.pos, tra.brk2.ori)
coord_pair = (coord1, coord2)
Expand Down Expand Up @@ -245,6 +251,7 @@ def make_tra_table(bundle, tra_supports, label_irs=False):
tra_cols += [f'{key}_pvalue' for key in key_pairs]
tra_df = pd.DataFrame(data, columns=tra_cols)
tra_df = filter_sv_with_breakpoint_at_contig_ends(tra_df)
tra_df = remove_duplicates_from_tra_table(tra_df)
if label_irs:
tra_df = remove_duplicates_from_tra_table(tra_df)
tra_df.replace([-np.inf, np.inf], np.nan, inplace=True)
return tra_df
6 changes: 3 additions & 3 deletions ontmont/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def extract_split_alignments(reads, max_reads=500):
alignments.append(alignment)
return alignments

def extract_read_data(bam:pysam.AlignmentFile, contig:str, start=None, end=None) -> pd.DataFrame:
def extract_read_data(bam:pysam.AlignmentFile, contig:str, start=None, end=None, max_reads=500) -> pd.DataFrame:
"""Extract alignment tables per read and concatenate
Args:
bam (pysam.AlignmentFile): BAM file
contig (str): Contig to extract reads from
start (int, optional): 1-based start position
start (int, optional): 1-based end position
bam (pysam.AlignmentFile): BAM file
max_reads (int, optional): Maximum number of reads to extract. Defaults to 500.
Returns:
pd.DataFrame: Dataframe of alignment data concatenated across all reads in the region
"""
if start is not None:
start -= 1
reads = bam.fetch(contig=contig, start=start, end=end) # convert to 0-based pos
alignments = extract_split_alignments(reads)
alignments = extract_split_alignments(reads, max_reads=max_reads)
df = make_split_read_table(alignments)
return df

Expand Down
4 changes: 3 additions & 1 deletion ontmont/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def get_list(self):
n_fragments = qdf.shape[0]
if n_fragments <= 2: continue # don't count segments when none
for rix, row in qdf.iterrows():
if rix == 0 or rix == n_fragments-1: continue
if rix == 0 or rix == n_fragments-1:
continue
chrom, start, end = row['chrom'], int(row['start']), int(row['end'])
segment = (chrom, start, end)
self.list.append(segment)
Expand Down Expand Up @@ -210,6 +211,7 @@ class BreakpointPair:
def __init__(self, brk1, brk2):
self.brk1 = brk1
self.brk2 = brk2
self.aln_segment = False

def __repr__(self):
return f'{self.brk1.chrom}:{self.brk1.pos}:{self.brk1.ori}-{self.brk2.chrom}:{self.brk2.pos}:{self.brk2.ori}'
Expand Down
5 changes: 3 additions & 2 deletions ontmont/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .datatypes import Breakpoint, BreakpointChain, SplitAlignment

def remove_duplicates_from_tra_table(tra_df):
df = pd.DataFrame(columns = tra_df.columns)
data = []

vectors = [
'DelPBEF1NeoTransposon',
Expand Down Expand Up @@ -42,7 +42,8 @@ def remove_duplicates_from_tra_table(tra_df):
field = [chrom1, pos1, ori1, chrom2, pos2, ori2, support,
u1_u2_score, d1_d2_score, u1_d2r_score, d1_u2r_score,
u1_u2_pvalue, d1_d2_pvalue, u1_d2r_pvalue, d1_u2r_pvalue]
df.loc[df.shape[0]] = field
data.append(field)
df = pd.DataFrame(data, columns=tra_df.columns)

ix_cols = ['chrom1', 'pos1', 'ori1', 'chrom2', 'pos2', 'ori2']
df = df.groupby(ix_cols).agg({
Expand Down

0 comments on commit c325118

Please sign in to comment.