Skip to content

Commit

Permalink
fix: optimize some function names and add test for read table extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
soymintc committed Dec 2, 2024
1 parent aee011f commit 47e6067
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 50 deletions.
108 changes: 61 additions & 47 deletions ontmont/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .irs import (get_best_onesided_ir, get_best_ir_within_breakpoints,
get_best_holliday_junctions, get_best_ir_within_segment)

def make_seg_table(bundle, seg_supports, segment_score_cutoff=5):
def make_seg_table(bundle, seg_supports, segment_score_cutoff=5, label_irs=False):
"""Create a dataframe based on a ``BreakpointChain`` bundle and supports dict
Args:
Expand All @@ -17,9 +17,6 @@ def make_seg_table(bundle, seg_supports, segment_score_cutoff=5):
Returns:
pandas.DataFrame: table of segments coordinate with supports and IR statistics
"""
seg_cols = ['chrom', 'pos1', 'pos2', 'support']
seg_cols += ['segment_score', 'segment_pvalue']
seg_df = pd.DataFrame(columns=seg_cols)
seg_saved = set()
vectors = [
'DelPBEF1NeoTransposon',
Expand All @@ -30,6 +27,7 @@ def make_seg_table(bundle, seg_supports, segment_score_cutoff=5):
'puro-GFP-PGBD5_seq'
]
chrom_order = ['chr'+str(c) for c in range(1, 22)] + ['chrX', 'chrY'] + vectors
data = []

for brks in bundle:
for ix, seg in enumerate(brks.segs):
Expand Down Expand Up @@ -57,14 +55,20 @@ def make_seg_table(bundle, seg_supports, segment_score_cutoff=5):
if _segment_score >= segment_score_cutoff:
segment_score = _segment_score
segment_pvalue = seg.aln_segment.pvalue
field = [*coord, support, segment_score, segment_pvalue]
seg_df.loc[seg_df.shape[0]] = field
field = [*coord, support]
if label_irs:
field += [segment_score, segment_pvalue]
data.append(field)
seg_cols = ['chrom', 'pos1', 'pos2', 'support']
if label_irs:
seg_cols += ['segment_score', 'segment_pvalue']
seg_df = pd.DataFrame(data, columns=seg_cols)
seg_df.replace([-np.inf, np.inf], np.nan, inplace=True)
return seg_df


def make_brk_table(bundle, brk_supports,
unilateral_score_cutoff=5, bilateral_score_cutoff=8):
unilateral_score_cutoff=5, bilateral_score_cutoff=8, label_irs=False):
"""Create a dataframe of breakpoints
Args:
Expand All @@ -85,10 +89,6 @@ def make_brk_table(bundle, brk_supports,
'puro-GFP-PGBD5_seq'
]
chrom_order = ['chr'+str(c) for c in range(1, 22)] + ['chrX', 'chrY'] + vectors
brk_cols = ['chrom', 'pos', 'ori', 'support']
brk_cols += ['upstream_score', 'downstream_score', 'breakpoint_score']
brk_cols += ['upstream_pvalue', 'downstream_pvalue', 'breakpoint_pvalue']
brk_df = pd.DataFrame(columns=brk_cols)
brk_saved = set()

for brks in bundle:
Expand Down Expand Up @@ -122,18 +122,24 @@ def make_brk_table(bundle, brk_supports,
breakpoint_pvalue = brk.aln_breakpoint.pvalue

# if max([upstream_score, downstream_score, breakpoint_score]) > 0:
field = [*coord, support,
upstream_score, downstream_score, breakpoint_score,
upstream_pvalue, downstream_pvalue, breakpoint_pvalue]
brk_df.loc[brk_df.shape[0]] = field
field = [*coord, support]
if label_irs:
field += [upstream_score, downstream_score, breakpoint_score,
upstream_pvalue, downstream_pvalue, breakpoint_pvalue]
data.append(field)

brk_cols = ['chrom', 'pos', 'ori', 'support']
if label_irs:
brk_cols += ['upstream_score', 'downstream_score', 'breakpoint_score']
brk_cols += ['upstream_pvalue', 'downstream_pvalue', 'breakpoint_pvalue']
brk_df = pd.DataFrame(data, columns=brk_cols)
brk_df = filter_breakpoints_at_contig_ends(brk_df)
brk_df.replace([-np.inf, np.inf], np.nan, inplace=True)

return brk_df


def make_brks_bundle(reads_df, genome, sw_palindrome, sw_holliday, margins=[15, 30, 60]):
def make_aligned_brks_bundle(reads_df, genome=None, sw_palindrome=None, sw_holliday=None, margins=[15, 30, 60]):
"""Make a list of ``BreapointChain`` based on alignment table, genome, and alignment parameters
Args:
Expand All @@ -150,23 +156,28 @@ def make_brks_bundle(reads_df, genome, sw_palindrome, sw_holliday, margins=[15,
margin_max = max(margins)
for qname, qdf in reads_df.groupby('qname'):
brks = enumerate_breakpoints(qdf)
brks.qname = qname
brks.get_transitions()
brks.get_segments()
for brk in brks:
brk.get_breakpoint_seqs(margin=margin_max, genome=genome)
seq1 = brk.upstream
seq2 = brk.downstream
direction1 = 'up'
direction2 = 'down'
brk.aln_upstream = get_best_onesided_ir(seq1, direction1, sw_palindrome, dist_cutoff=2, margins=margins)
brk.aln_downstream = get_best_onesided_ir(seq2, direction2, sw_palindrome, dist_cutoff=2, margins=margins)
brk.aln_breakpoint = get_best_ir_within_breakpoints(seq1, seq2, sw_palindrome, dist_cutoff1=100, dist_cutoff2=100, margins=margins)
for ix, seg in enumerate(brks.segs):
brks.segs[ix] = get_best_ir_within_segment(
seg, sw_palindrome, genome, dist_cutoff1=2, dist_cutoff2=5, margins=margins)
for ix, tra in enumerate(brks.tras):
brks.tras[ix] = get_best_holliday_junctions(
tra, sw_holliday, genome, score_cutoff=4, dist_cutoff1=2, dist_cutoff2=5, margins=margins)

flag_align = sw_palinedrome and sw_holliday
if flag_align: # find IRs through SW alignment
for brk in brks:
brk.get_breakpoint_seqs(margin=margin_max, genome=genome)
seq1 = brk.upstream
seq2 = brk.downstream
direction1 = 'up'
direction2 = 'down'
brk.aln_upstream = get_best_onesided_ir(seq1, direction1, sw_palindrome, dist_cutoff=2, margins=margins)
brk.aln_downstream = get_best_onesided_ir(seq2, direction2, sw_palindrome, dist_cutoff=2, margins=margins)
brk.aln_breakpoint = get_best_ir_within_breakpoints(seq1, seq2, sw_palindrome, dist_cutoff1=100, dist_cutoff2=100, margins=margins)
for ix, seg in enumerate(brks.segs):
brks.segs[ix] = get_best_ir_within_segment(
seg, sw_palindrome, genome, dist_cutoff1=2, dist_cutoff2=5, margins=margins)
for ix, tra in enumerate(brks.tras):
brks.tras[ix] = get_best_holliday_junctions(
tra, sw_holliday, genome, score_cutoff=4, dist_cutoff1=2, dist_cutoff2=5, margins=margins)

bundle.append(brks)
return bundle

Expand All @@ -190,7 +201,7 @@ def make_brk_supports(bundle):
return brk_supports


def make_tra_table(bundle, tra_supports):
def make_tra_table(bundle, tra_supports, label_irs=False):
"""Make a table of SVs based on bundle and number of supports
Args:
Expand All @@ -202,11 +213,8 @@ def make_tra_table(bundle, tra_supports):
"""
holliday_score_cutoff = 5
key_pairs = ['u1_u2', 'd1_d2', 'u1_d2r', 'd1_u2r']
tra_cols = ['chrom1', 'pos1', 'ori1', 'chrom2', 'pos2', 'ori2', 'support']
tra_cols += [f'{key}_score' for key in key_pairs]
tra_cols += [f'{key}_pvalue' for key in key_pairs]
tra_df = pd.DataFrame(columns=tra_cols)
tra_saved = set()
data = []
for brks in bundle:
for ix, tra in enumerate(brks.tras):
coord1 = (tra.brk1.chrom, tra.brk1.pos, tra.brk1.ori)
Expand All @@ -218,19 +226,25 @@ def make_tra_table(bundle, tra_supports):
continue
support = tra_supports[coord_pair]
field = [*coord1, *coord2, support]
field += ([0] * 4)
field += ([np.nan] * 4)

for kx, key in enumerate(key_pairs):
if tra.alns[key]:
holliday_score = tra.alns[key].score
if holliday_score >= holliday_score_cutoff:
field[7+kx] = holliday_score
field[11+kx] = tra.alns[key].pvalue
if label_irs:
field += ([0] * 4) # add aln scores
field += ([np.nan] * 4) # add p values

for kx, key in enumerate(key_pairs):
if tra.alns[key]:
holliday_score = tra.alns[key].score
if holliday_score >= holliday_score_cutoff:
field[7+kx] = holliday_score
field[11+kx] = tra.alns[key].pvalue
# if max(field[7:]) > 0:
tra_df.loc[tra_df.shape[0]] = field
data.append(field)

tra_cols = ['chrom1', 'pos1', 'ori1', 'chrom2', 'pos2', 'ori2', 'support']
if label_irs:
tra_cols += [f'{key}_score' for key in key_pairs]
tra_cols += [f'{key}_pvalue' for key in key_pairs]
tra_df = pd.DataFrame(data, columns=tra_cols)
tra_df = filter_sv_with_breakpoint_at_contig_ends(tra_df)
tra_df = remove_duplicates_from_tra_table(tra_df)
tra_df.replace([-np.inf, np.inf], np.nan, inplace=True)
return tra_df
return tra_df
21 changes: 21 additions & 0 deletions ontmont/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ def extract_split_alignments(reads, max_reads=500):
alignments.append(alignment)
return alignments

def extract_read_data(bam:pysam.AlignmentFile, contig:str, start=None, end=None) -> pd.DataFrame:
"""Extract alignment tables per read and concatenate
Args:
bam (pysam.AlignmentFile): BAM file
contig (str): Contig to extract reads from
start (int, optional): 1-based start position
start (int, optional): 1-based end position
bam (pysam.AlignmentFile): BAM file
Returns:
pd.DataFrame: Dataframe of alignment data concatenated across all reads in the region
"""
if start is not None:
start -= 1
reads = bam.fetch(contig=contig, start=start, end=end) # convert to 0-based pos
alignments = extract_split_alignments(reads)
df = make_split_read_table(alignments)
return df


def pull_breakpoints_from_reads_in_sv_regions(bam, tra, get_read_table=False, min_n_breakpoint=2, margin=10):
"""Extract and append ``BreakpointChain`` objects from a bam file and a table of SVs
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ontmont"
version = "0.0.7"
version = "0.0.8"
#dynamic = ["version"]
authors = [
{ name="Seongmin Choi", email="[email protected]" },
Expand Down
Binary file added tests/data/test.bam
Binary file not shown.
Binary file added tests/data/test.bam.bai
Binary file not shown.
36 changes: 34 additions & 2 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import pytest

import pysam
import numpy as np
import pandas as pd

from ontmont.datatypes import Breakpoint, BreakpointPair, BreakpointChain

from ontmont.collect import (map_similar_coordinate_to_higher_rank, fix_lower_support_coordinates,
get_breakpoint_support_from_bundle, normalize_sv_table, pull_sv_supporting_reads_from_bundle, find_presence_of_matching_sv)
get_breakpoint_support_from_bundle, normalize_sv_table, pull_sv_supporting_reads_from_bundle, find_presence_of_matching_sv,
extract_read_data)

@pytest.fixture
def bundle():
Expand Down Expand Up @@ -105,4 +110,31 @@ def test_find_presence_of_matching_sv():
['chr3', 3000, '+', 'chr4', 4000, '+', False],
], columns=ix_cols+['match'])
sv1['match'] = find_presence_of_matching_sv(sv1, sv2, margin=50)
assert (sv1 != expected).sum().sum() == 0, sv1
assert (sv1 != expected).sum().sum() == 0, sv1

def test_extract_read_data_1():
bam_path = 'tests/data/test.bam'
assert os.path.exists(bam_path), f'{bam_path} does not exist.'
bam = pysam.AlignmentFile(bam_path)
df = extract_read_data(bam, contig='PBEF1NeoTransposon', start=1, end=2)
assert df.shape[0] == 0, df

def test_extract_read_data_2():
bam_path = 'tests/data/test.bam'
assert os.path.exists(bam_path), f'{bam_path} does not exist.'
bam = pysam.AlignmentFile(bam_path)
df = extract_read_data(bam, contig='PBEF1NeoTransposon', start=1470, end=1477)
assert df.shape[0] == 0, df

def test_extract_read_data_3():
expected = np.array([['02ce28f5-83e5-53a4-a7ed-96f331c6b305', 'chr10', 51339923,
51340887, '+', 35, 960, 3763, 35],
['02ce28f5-83e5-53a4-a7ed-96f331c6b305', 'PBEF1NeoTransposon',
1478, 4996, '-', 288, 3480, 990, 990],
['02ce28f5-83e5-53a4-a7ed-96f331c6b305', 'chr10', 51340883,
51341166, '+', 4466, 281, 11, 4466]])
bam_path = 'tests/data/test.bam'
assert os.path.exists(bam_path), f'{bam_path} does not exist.'
bam = pysam.AlignmentFile(bam_path)
df = extract_read_data(bam, contig='PBEF1NeoTransposon', start=1477, end=1478)
assert np.all(df.to_numpy().astype(str) == expected.astype(str))

0 comments on commit 47e6067

Please sign in to comment.