Skip to content

Commit

Permalink
Add support for multiple input fastqs
Browse files Browse the repository at this point in the history
Sequencing a single cell library across multiple lanes is common practice and cite-seq-count should
be able to handle that and not expect the user to merge them downstream. To reduce the amount of
new code added to handle this, the current preprocessing is run on each input file and then merged
together before barcode correction.
  • Loading branch information
arkal committed Sep 16, 2019
1 parent 66734b0 commit a785fa3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 83 deletions.
194 changes: 114 additions & 80 deletions cite_seq_count/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from argparse import RawTextHelpFormatter
from collections import OrderedDict
from collections import Counter
from collections import defaultdict

from multiprocess import cpu_count
from multiprocess import Pool
Expand All @@ -30,17 +31,19 @@ def get_args():
"""
parser = ArgumentParser(
prog='CITE-seq-Count', formatter_class=RawTextHelpFormatter,
description=("This script counts matching antibody tags from two fastq "
description=("This script counts matching antibody tags from paired fastq "
"files. Version {}".format(version)),
)

# REQUIRED INPUTS group.
inputs = parser.add_argument_group('Inputs',
description="Required input files.")
inputs.add_argument('-R1', '--read1', dest='read1_path', required=True,
help="The path of Read1 in gz format.")
help=("The path of Read1 in gz format, or a comma-separated list of paths to all Read1 files in"
" gz format (E.g. A1.fq.gz,B1,fq,gz,..."))
inputs.add_argument('-R2', '--read2', dest='read2_path', required=True,
help="The path of Read2 in gz format.")
help=("The path of Read2 in gz format, or a comma-separated list of paths to all Read2 files in"
" gz format (E.g. A2.fq.gz,B2,fq,gz,..."))
inputs.add_argument(
'-t', '--tags', dest='tags', required=True,
help=("The path to the csv file containing the antibody\n"
Expand Down Expand Up @@ -244,86 +247,117 @@ def main():
# Load TAGs/ABs.
ab_map = preprocessing.parse_tags_csv(args.tags)
ab_map = preprocessing.check_tags(ab_map, args.max_error)
# Get reads length. So far, there is no validation for Read2.
read1_length = preprocessing.get_read_length(args.read1_path)
read2_length = preprocessing.get_read_length(args.read2_path)
# Check Read1 length against CELL and UMI barcodes length.
(
barcode_slice,
umi_slice,
barcode_umi_length
) = preprocessing.check_barcodes_lengths(
read1_length,
args.cb_first,
args.cb_last,
args.umi_first, args.umi_last)

if args.first_n:
n_lines = args.first_n*4
else:
n_lines = preprocessing.get_n_lines(args.read1_path)
n_reads = int(n_lines/4)
n_threads = args.n_threads
print('Started mapping')
print('Processing {:,} reads'.format(n_reads))
#Run with one process
if n_threads <= 1 or n_reads < 1000001:
print('CITE-seq-Count is running with one core.')
(
final_results,
merged_no_match) = processing.map_reads(
read1_path=args.read1_path,
read2_path=args.read2_path,
tags=ab_map,
barcode_slice=barcode_slice,
umi_slice=umi_slice,
indexes=[0,n_reads],
whitelist=whitelist,
debug=args.debug,
start_trim=args.start_trim,
maximum_distance=args.max_error,
sliding_window=args.sliding_window)
print('Mapping done')
umis_per_cell = Counter()
reads_per_cell = Counter()
for cell_barcode,counts in final_results.items():
umis_per_cell[cell_barcode] = sum([len(counts[UMI]) for UMI in counts])
reads_per_cell[cell_barcode] = sum([sum(counts[UMI].values()) for UMI in counts])
else:
# Run with multiple processes
print('CITE-seq-Count is running with {} cores.'.format(n_threads))
p = Pool(processes=n_threads)
chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads)
parallel_results = []

for indexes in chunk_indexes:
p.apply_async(processing.map_reads,
args=(
args.read1_path,
args.read2_path,
ab_map,
barcode_slice,
umi_slice,
indexes,
whitelist,
args.debug,
args.start_trim,
args.max_error,
args.sliding_window),
callback=parallel_results.append,
error_callback=sys.stderr)
p.close()
p.join()
print('Mapping done')
print('Merging results')
# Identify input file(s)
read1_paths, read2_paths = preprocessing.get_read_paths(args.read1_path, args.read2_path)

# preprocessing and processing occur in separate loops so the program can crash earlier if
# one of the inputs is not valid.
read1_lengths = []
read2_lengths = []
for read1_path, read2_path in zip(read1_paths, read2_paths):
# Get reads length. So far, there is no validation for Read2.
read1_lengths.append(preprocessing.get_read_length(read1_path))
read2_lengths.append(preprocessing.get_read_length(read2_path))
# Check Read1 length against CELL and UMI barcodes length.
(
final_results,
umis_per_cell,
reads_per_cell,
merged_no_match
) = processing.merge_results(parallel_results=parallel_results)
del(parallel_results)
barcode_slice,
umi_slice,
barcode_umi_length
) = preprocessing.check_barcodes_lengths(
read1_lengths[-1],
args.cb_first,
args.cb_last,
args.umi_first, args.umi_last)
# Ensure all files have the same input length
#if len(set(read1_lengths)) != 1:
# sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting')

# Initialize the counts dicts that will be generated from each input fastq pair
final_results = defaultdict(Counter)
umis_per_cell = Counter()
reads_per_cell = Counter()
merged_no_match = Counter()
for read1_path, read2_path in zip(read1_paths, read2_paths):
if args.first_n:
n_lines = args.first_n*4
else:
n_lines = preprocessing.get_n_lines(read1_path)
n_reads = int(n_lines/4)
n_threads = args.n_threads
print('Started mapping')
print('Processing {:,} reads'.format(n_reads))
#Run with one process
if n_threads <= 1 or n_reads < 1000001:
print('CITE-seq-Count is running with one core.')
(
_final_results,
_merged_no_match) = processing.map_reads(
read1_path=read1_path,
read2_path=read2_path,
tags=ab_map,
barcode_slice=barcode_slice,
umi_slice=umi_slice,
indexes=[0,n_reads],
whitelist=whitelist,
debug=args.debug,
start_trim=args.start_trim,
maximum_distance=args.max_error,
sliding_window=args.sliding_window)
print('Mapping done')
_umis_per_cell = Counter()
_reads_per_cell = Counter()
for cell_barcode, counts in _final_results.items():
_umis_per_cell[cell_barcode] = sum([len(counts[UMI]) for UMI in counts])
_reads_per_cell[cell_barcode] = sum([sum(counts[UMI].values()) for UMI in counts])
else:
# Run with multiple processes
print('CITE-seq-Count is running with {} cores.'.format(n_threads))
p = Pool(processes=n_threads)
chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads)
parallel_results = []

for indexes in chunk_indexes:
p.apply_async(processing.map_reads,
args=(
read1_path,
read2_path,
ab_map,
barcode_slice,
umi_slice,
indexes,
whitelist,
args.debug,
args.start_trim,
args.max_error,
args.sliding_window),
callback=parallel_results.append,
error_callback=sys.stderr)
p.close()
p.join()
print('Mapping done')
print('Merging results')

(
_final_results,
_umis_per_cell,
_reads_per_cell,
_merged_no_match
) = processing.merge_results(parallel_results=parallel_results)
del(parallel_results)

# Update the overall counts dicts
umis_per_cell.update(_umis_per_cell)
reads_per_cell.update(_reads_per_cell)
merged_no_match.update(_merged_no_match)
for cell_barcode in _final_results:
for tag in _final_results[cell_barcode]:
if tag in final_results[cell_barcode]:
# Counter + Counter = Counter
final_results[cell_barcode][tag] += _final_results[cell_barcode][tag]
else:
# Explicitly save the counter to that tag
final_results[cell_barcode][tag] = _final_results[cell_barcode][tag]

ordered_tags_map = OrderedDict()
for i,tag in enumerate(ab_map.values()):
Expand Down
22 changes: 21 additions & 1 deletion cite_seq_count/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ def get_n_lines(file_path):
with gzip.open(file_path, "rt",encoding="utf-8",errors='ignore') as f:
n_lines = sum(bl.count("\n") for bl in blocks(f))
if n_lines %4 !=0:
sys.exit('{}\'s number of lines is not a multiple of 4. The file might be corrupted.\n Exiting')
sys.exit('{}\'s number of lines is not a multiple of 4. The file '
'might be corrupted.\n Exiting'.format(file_path))
return(n_lines)


def get_read_paths(read1_path, read2_path):
"""
Splits up 2 comma-separated strings of input files into list of files
to process. Ensures both lists are equal in length.
Args:
read1_path (string): Comma-separated paths to read1.fq
read2_path (string): Comma-separated paths to read2.fq
Returns:
_read1_path (list(string)): list of paths to read1.fq
_read2_path (list(string)): list of paths to read2.fq
"""
_read1_path = read1_path.split(',')
_read2_path = read2_path.split(',')
if len(read1_path) != len(read2_path):
sys.exit('Unequal number of read1 ({}) and read2({}) files provided'
'\n Exiting'.format(len(read1_path),len(read2_path)))
return _read1_path, _read2_path
20 changes: 18 additions & 2 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ def data():
pytest.correct_R2_path = 'tests/test_data/fastq/correct_R2.fastq.gz'
pytest.corrupt_R1_path = 'tests/test_data/fastq/corrupted_R1.fastq.gz'
pytest.corrupt_R2_path = 'tests/test_data/fastq/corrupted_R2.fastq.gz'


pytest.correct_R1_multipath = 'path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz'
pytest.correct_R2_multipath = 'path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz'
pytest.incorrect_R2_multipath = 'path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz'

pytest.correct_multipath_result = (['path/to/R1_1.fastq.gz', 'path/to/R1_2.fastq.gz'],
['path/to/R2_1.fastq.gz', 'path/to/R2_2.fastq.gz'])

# Create some variables to compare to
pytest.correct_whitelist = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT'])
pytest.correct_tags = {
Expand Down Expand Up @@ -70,4 +77,13 @@ def test_get_n_lines(data):
@pytest.mark.dependency(depends=['test_get_n_lines'])
def test_get_n_lines_not_multiple_of_4(data):
with pytest.raises(SystemExit):
preprocessing.get_n_lines(pytest.corrupt_R1_path)
preprocessing.get_n_lines(pytest.corrupt_R1_path)

@pytest.mark.dependency()
def test_corrrect_multipath(data):
assert preprocessing.get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) == pytest.correct_multipath_result

@pytest.mark.dependency(depends=['test_get_n_lines'])
def test_incorrrect_multipath(data):
with pytest.raises(SystemExit):
preprocessing.get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath)

0 comments on commit a785fa3

Please sign in to comment.