Skip to content

Commit

Permalink
use separate thread for saving enhanced files
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 committed Oct 30, 2022
1 parent 66de275 commit ad1b237
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 101 deletions.
54 changes: 45 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
in which the mask estimation is guided by a diarizer output. The original method was proposed
for the CHiME-5 challenge in [this paper](http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_boeddecker.pdf) by Boeddeker et al.

It is a kind of target-speaker extraction method. The inputs to the model are:

1. A multi-channel recording, e.g., from an array microphone, of a long, unsegmented,
multi-talker session (possibly with overlapping speech)
2. An RTTM file containing speaker segment boundaries

The system produces enhanced audio for each of the segments in the RTTM, removing the background
speech and noise and "extracting" only the target speaker in the segment.

This repository contains a GPU implementation of this method in Python, along with CLI binaries
to run the enhancement from shell. We also provide several example "recipes" for using the
method.
Expand All @@ -22,6 +31,8 @@ examples in the `recipes` directory for how to use the `gss` module for several
are currently aiming to support LibriCSS, AMI, and AliMeeting.
* The inference can be done on multi-node GPU environment. This makes it several times faster than the
original CPU implementation.
* We have implemented batch processing of segments (see [this issue](https://github.com/desh2608/gss/issues/12) for details)
to maximize GPU memory usage and provide additional speed-up.
* We provide both Python modules and CLI for using the enhancement functions, which can be
easily included in recipes from Kaldi, Icefall, ESPNet, etc.

Expand Down Expand Up @@ -67,7 +78,8 @@ RTTM file denoting speaker segments, run the following:
export CUDA_VISIBLE_DEVICES=0
gss enhance recording \
/path/to/sessionA.wav /path/to/rttm exp/enhanced_segs \
--recording-id sessionA --min-segment-length 0.2 --max-segment-length 10.0
--recording-id sessionA --min-segment-length 0.1 --max-segment-length 10.0 \
--max-batch-duration 20.0 --num-buckets 2 -o exp/segments.jsonl.gz
```

### Enhancing a corpus
Expand All @@ -86,7 +98,7 @@ will be used to get speaker activities.
4. Trim the recording-level cut set into segment-level cuts. These are the segments that will
actually be enhanced.

5. Split the segments into as many parts as the number of GPU jobs you want to run. In the
5. (Optional) Split the segments into as many parts as the number of GPU jobs you want to run. In the
recipes, we submit the jobs through `qsub` , similar to Kaldi or ESPNet recipes. You can
use the parallelization in those toolkits to additionally use a different scheduler such as
SLURM.
Expand All @@ -97,26 +109,50 @@ SLURM.

* `--bss-iteration`: Number of iterations of the CACGMM inference.

* `--context-duration`: Context (in seconds) to include on both sides of the segment.

* `--min-segment-length`: Any segment shorter than this value will be removed. This is
particularly useful when using segments from a diarizer output since they often contain
very small segments which are not relevant for ASR. A recommended setting is 0.2s.
very small segments which are not relevant for ASR. A recommended setting is 0.1s.

* `--max-segment-length`: Segments longer than this value will be chunked up. This is
to prevent OOM errors since the segment STFTs are loaded onto the GPU. We use a setting
of 15s in most cases.

Internally, we also have a fallback option to chunk up segments into increasingly smaller
* `--max-batch-duration`: Segments from the same speaker will be batched together to increase
GPU efficiency. We used 20s batches for enhancement on GPUs with 12G memory. For GPUs with
larger memory, this value can be increased.

* `--max-batch-cuts`: This sets an upper limit on the maximum number of cuts in a batch. To
simulate segment-wise enhancement, set this to 1.

* `--num-workers`: Number of workers to use for data-loading (default = 1). Use more if you
increase the `max-batch-duration` .

* `--num-buckets`: Number of buckets to use for sampling. Batches are drawn from the same
bucket (see Lhotse's [ `DynamicBucketingSampler` ](https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/sampling/dynamic_bucketing.py) for details).

* `--enhanced-manifest/-o`: Path to manifest file to write the enhanced cut manifest. This
is useful for cases when the supervisions need to be propagated to the enhanced segments,
for downstream ASR tasks, for example.

## Other details

Internally, we also have a fallback option to chunk up batches into increasingly smaller
parts in case OOM error is encountered (see `gss.core.enhancer.py` ).

The enhanced wav files will be written to `$EXP_DIR/enhanced` . The wav files are named
as *recoid-spkid-start_end.wav*, i.e., 1 wav file is generated for each segment in the RTTM.
The "start" and "end" are padded to 6 digits, for example: 21.18 seconds is encoded as
`002118` . This convention should be fine if your audio duration is under ~2.75 h (9999s),
otherwise, you should change the padding in `gss/core/enhancer.py` .
The enhanced wav files are named as *recoid-spkid-start_end.wav*, i.e., 1 wav file is
generated for each segment in the RTTM. The "start" and "end" are padded to 6 digits,
for example: 21.18 seconds is encoded as `002118` . This convention should be fine if
your audio duration is under ~2.75 h (9999s), otherwise, you should change the
padding in `gss/core/enhancer.py` .

For examples of how to generate RTTMs for guiding the separation, please refer to my
[diarizer](https://github.com/desh2608/diarizer) toolkit.

**Additional parameters:** We have only made the most important parameters available in the
top-level CLI. To play with other parameters, check out the `gss.enhancer.get_enhancer()` function.

## Contributing

Contributions for core improvements or new recipes are welcome. Please run the following
Expand Down
77 changes: 59 additions & 18 deletions gss/bin/modes/enhance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import functools
import logging
import time
from pathlib import Path
from typing import Tuple

import click
from lhotse import CutSet, Recording, SupervisionSet, load_manifest
from lhotse import Recording, SupervisionSet, load_manifest
from lhotse.cut import CutSet, MixedCut, MonoCut
from lhotse.utils import fastcopy

from gss.bin.modes.cli_base import cli
Expand Down Expand Up @@ -39,6 +42,19 @@ def common_options(func):
help="Number of iterations for BSS",
show_default=True,
)
@click.option(
"--context-duration",
type=float,
default=15.0,
help="Context duration in seconds for CACGMM",
show_default=True,
)
@click.option(
"--use-garbage-class/--no-garbage-class",
default=False,
help="Whether to use the additional noise class for CACGMM",
show_default=True,
)
@click.option(
"--min-segment-length",
type=float,
Expand All @@ -53,13 +69,6 @@ def common_options(func):
help="Chunk up longer segments to avoid OOM issues",
show_default=True,
)
@click.option(
"--context-duration",
type=float,
default=15.0,
help="Context duration in seconds for CACGMM",
show_default=True,
)
@click.option(
"--max-batch-duration",
type=float,
Expand All @@ -74,13 +83,27 @@ def common_options(func):
help="Maximum number of cuts in a batch",
show_default=True,
)
@click.option(
"--num-workers",
type=int,
default=1,
help="Number of workers for parallel processing",
show_default=True,
)
@click.option(
"--num-buckets",
type=int,
default=2,
help="Number of buckets per speaker for batching (use larger values if you set higer max-segment-length)",
show_default=True,
)
@click.option(
"--enhanced-manifest",
"-o",
type=click.Path(),
default=None,
help="Path to the output manifest containing details of the enhanced segments.",
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
Expand Down Expand Up @@ -108,12 +131,15 @@ def cuts_(
enhanced_dir,
num_channels,
bss_iterations,
context_duration,
use_garbage_class,
min_segment_length,
max_segment_length,
context_duration,
max_batch_duration,
max_batch_cuts,
num_workers,
num_buckets,
enhanced_manifest,
):
"""
Enhance segments (represented by cuts).
Expand Down Expand Up @@ -143,18 +169,24 @@ def cuts_(
logger.info("Initializing GSS enhancer")
enhancer = get_enhancer(
cuts=cuts,
error_handling="keep_original",
activity_garbage_class=False,
bss_iterations=bss_iterations,
context_duration=context_duration,
activity_garbage_class=use_garbage_class,
max_batch_duration=max_batch_duration,
max_batch_cuts=max_batch_cuts,
num_workers=num_workers,
num_buckets=num_buckets,
)

logger.info(f"Enhancing {len(frozenset(c.id for c in cuts_per_segment))} segments")
num_errors = enhancer.enhance_cuts(cuts_per_segment, enhanced_dir)
logger.info(f"Finished with {num_errors} errors")
begin = time.time()
num_errors, out_cuts = enhancer.enhance_cuts(cuts_per_segment, enhanced_dir)
end = time.time()
logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors")

if enhanced_manifest is not None:
logger.info(f"Saving enhanced cuts manifest to {enhanced_manifest}")
out_cuts.to_file(enhanced_manifest)


@enhance.command(name="recording")
Expand Down Expand Up @@ -184,12 +216,15 @@ def recording_(
recording_id,
num_channels,
bss_iterations,
context_duration,
use_garbage_class,
min_segment_length,
max_segment_length,
context_duration,
max_batch_duration,
max_batch_cuts,
num_workers,
num_buckets,
enhanced_manifest,
):
"""
Enhance a single recording using an RTTM file.
Expand Down Expand Up @@ -231,15 +266,21 @@ def recording_(
logger.info("Initializing GSS enhancer")
enhancer = get_enhancer(
cuts=cuts,
error_handling="keep_original",
activity_garbage_class=False,
bss_iterations=bss_iterations,
context_duration=context_duration,
activity_garbage_class=use_garbage_class,
max_batch_duration=max_batch_duration,
max_batch_cuts=max_batch_cuts,
num_workers=num_workers,
num_buckets=num_buckets,
)

logger.info(f"Enhancing {len(frozenset(c.id for c in cuts_per_segment))} segments")
num_errors = enhancer.enhance_cuts(cuts_per_segment, enhanced_dir)
logger.info(f"Finished with {num_errors} errors")
begin = time.time()
num_errors, out_cuts = enhancer.enhance_cuts(cuts_per_segment, enhanced_dir)
end = time.time()
logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors")

if enhanced_manifest is not None:
logger.info(f"Saving enhanced cuts manifest to {enhanced_manifest}")
out_cuts.to_file(enhanced_manifest)
Loading

0 comments on commit ad1b237

Please sign in to comment.