Skip to content

Commit

Permalink
Support controllable CutSet.mux weights in multiprocess dataloading (
Browse files Browse the repository at this point in the history
…#1266)

* Simplify the implementation of DurationBatcher. Avoids caching cuts for future re-use.

* Enable leveraging shared memory for updating mux weights in dataloading subprocesses

* Initial partial support for infinite_mux

* Support most `CutSet` operations without dill; fix tests; infinite_mux works

* Fixes

* Fix meeting simulation test

* make py3.8 happy

* fix
  • Loading branch information
pzelasko authored Jan 23, 2024
1 parent 69ab31d commit c678849
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 72 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Lhotse uses several environment variables to customize it's behavior. They are a
- `LHOTSE_AUDIO_DURATION_MISMATCH_TOLERANCE` - used when we load audio from a file and receive a different number of samples than declared in `Recording.num_samples`. This is sometimes necessary because different codecs (or even different versions of the same codec) may use different padding when decoding compressed audio. Typically values up to 0.1, or even 0.3 (second) are still reasonable, and anything beyond that indicates a serious issue.
- `LHOTSE_AUDIO_BACKEND` - may be set to any of the values returned from CLI `lhotse list-audio-backends` to override the default behavior of trial-and-error and always use a specific audio backend.
- `LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE` - when set to `1` we'll emit full exception stack traces when every available audio backend fails to load a given file (they might be very large).
- `LHOTSE_DILL_ENABLED` - when it's set to `1|True|true|yes`, we will enable `dill`-based serialization of `CutSet` and `Sampler` across processes (it's disabled by default even when `dill` is installed).
- `LHOTSE_PREPARING_RELEASE` - used internally by developers when releasing a new version of Lhotse.
- `TORCHAUDIO_USE_BACKEND_DISPATCHER` - when set to `1` and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio.
- `RANK`, `WORLD_SIZE`, `WORKER`, and `NUM_WORKERS` are internally used to inform Lhotse Shar dataloading subprocesses.
Expand Down
2 changes: 2 additions & 0 deletions docs/getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ Lhotse uses several environment variables to customize it's behavior. They are a

* ``LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE`` - when set to 1 we'll emit full exception stack traces when every available audio backend fails to load a given file (they might be very large).

* ``LHOTSE_DILL_ENABLED`` - when it's set to ``1|True|true|yes``, we will enable ``dill``-based serialization of ``CutSet`` and ``Sampler`` across processes (it's disabled by default even when ``dill`` is installed).

* ``LHOTSE_PREPARING_RELEASE`` - used internally by developers when releasing a new version of Lhotse.

* ``TORCHAUDIO_USE_BACKEND_DISPATCHER`` - when set to 1 and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio.
Expand Down
1 change: 1 addition & 0 deletions lhotse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .cut import CutSet, MonoCut, MultiCut, create_cut_set_eager, create_cut_set_lazy
from .features import *
from .kaldi import load_kaldi_data_dir
from .lazy import dill_enabled, is_dill_enabled, set_dill_enabled
from .manipulation import combine, split_parallelize_combine, to_manifest
from .qa import fix_manifests, validate, validate_recordings_and_supervisions
from .serialization import load_manifest, load_manifest_lazy, store_manifest
Expand Down
134 changes: 109 additions & 25 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def filter_supervisions(
:param predicate: A callable that accepts `SupervisionSegment` and returns bool
:return: a CutSet with filtered supervisions
"""
return self.map(lambda cut: cut.filter_supervisions(predicate))
return self.map(partial(_filter_supervisions, predicate=predicate))

def merge_supervisions(
self,
Expand All @@ -982,8 +982,10 @@ def merge_supervisions(
``custom_merge_fn(custom_key, [s.custom[custom_key] for s in sups])``
"""
return self.map(
lambda cut: cut.merge_supervisions(
merge_policy=merge_policy, custom_merge_fn=custom_merge_fn
partial(
_merge_supervisions,
merge_policy=merge_policy,
custom_merge_fn=custom_merge_fn,
)
)

Expand Down Expand Up @@ -1341,7 +1343,8 @@ def pad(
duration = max(cut.duration for cut in self)

return self.map(
lambda cut: cut.pad(
partial(
_pad,
duration=duration,
num_frames=num_frames,
num_samples=num_samples,
Expand Down Expand Up @@ -1422,7 +1425,8 @@ def extend_by(
:return: a new CutSet instance.
"""
return self.map(
lambda cut: cut.extend_by(
partial(
_extend_by,
duration=duration,
direction=direction,
preserve_id=preserve_id,
Expand Down Expand Up @@ -1535,7 +1539,9 @@ def resample(self, sampling_rate: int, affix_id: bool = False) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.resample(sampling_rate, affix_id=affix_id))
return self.map(
partial(_resample, sampling_rate=sampling_rate, affix_id=affix_id)
)

def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1550,7 +1556,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.perturb_speed(factor=factor, affix_id=affix_id))
return self.map(partial(_perturb_speed, factor=factor, affix_id=affix_id))

def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1568,7 +1574,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(lambda cut: cut.perturb_tempo(factor=factor, affix_id=affix_id))
return self.map(partial(_perturb_tempo, factor=factor, affix_id=affix_id))

def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
Expand All @@ -1582,9 +1588,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet":
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(
lambda cut: cut.perturb_volume(factor=factor, affix_id=affix_id)
)
return self.map(partial(_perturb_volume, factor=factor, affix_id=affix_id))

def normalize_loudness(
self, target: float, mix_first: bool = True, affix_id: bool = True
Expand All @@ -1599,8 +1603,11 @@ def normalize_loudness(
:return: a modified copy of the current ``CutSet``.
"""
return self.map(
lambda cut: cut.normalize_loudness(
target=target, mix_first=mix_first, affix_id=affix_id
partial(
_normalize_loudness,
target=target,
mix_first=mix_first,
affix_id=affix_id,
)
)

Expand All @@ -1612,7 +1619,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "CutSet":
by affixing it with "_wpe".
:return: a modified copy of the current ``CutSet``.
"""
return self.map(lambda cut: cut.dereverb_wpe(affix_id=affix_id))
return self.map(partial(_dereverb_wpe, affix_id=affix_id))

def reverb_rir(
self,
Expand Down Expand Up @@ -1643,7 +1650,8 @@ def reverb_rir(
"""
rir_recordings = list(rir_recordings) if rir_recordings else None
return self.map(
lambda cut: cut.reverb_rir(
partial(
_reverb_rir,
rir_recording=random.choice(rir_recordings) if rir_recordings else None,
normalize_output=normalize_output,
early_only=early_only,
Expand Down Expand Up @@ -1713,25 +1721,25 @@ def drop_features(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its extracted features.
"""
return self.map(lambda cut: cut.drop_features())
return self.map(_drop_features)

def drop_recordings(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its recordings.
"""
return self.map(lambda cut: cut.drop_recording())
return self.map(_drop_recordings)

def drop_supervisions(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its supervisions.
"""
return self.map(lambda cut: cut.drop_supervisions())
return self.map(_drop_supervisions)

def drop_alignments(self) -> "CutSet":
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from the alignments present in its supervisions.
"""
return self.map(lambda cut: cut.drop_alignments())
return self.map(_drop_alignments)

def compute_and_store_features(
self,
Expand Down Expand Up @@ -2439,7 +2447,7 @@ def modify_ids(self, transform_fn: Callable[[str], str]) -> "CutSet":
a new string (new cut ID).
:return: a new ``CutSet`` with cuts with modified IDs.
"""
return self.map(lambda cut: cut.with_id(transform_fn(cut.id)))
return self.map(partial(_with_id, transform_fn=transform_fn))

def fill_supervisions(
self, add_empty: bool = True, shrink_ok: bool = False
Expand All @@ -2461,7 +2469,7 @@ def fill_supervisions(
of calling this method.
"""
return self.map(
lambda cut: cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok)
partial(_fill_supervision, add_empty=add_empty, shrink_ok=shrink_ok)
)

def map_supervisions(
Expand All @@ -2473,7 +2481,7 @@ def map_supervisions(
:param transform_fn: a function that modifies a supervision as an argument.
:return: a new, modified CutSet.
"""
return self.map(lambda cut: cut.map_supervisions(transform_fn))
return self.map(partial(_map_supervisions, transform_fn=transform_fn))

def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet":
"""
Expand All @@ -2483,7 +2491,9 @@ def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet":
:param transform_fn: a function that accepts a string and returns a string.
:return: a new, modified CutSet.
"""
return self.map_supervisions(lambda s: s.transform_text(transform_fn))
return self.map_supervisions(
partial(_transform_text, transform_fn=transform_fn)
)

def __repr__(self) -> str:
try:
Expand Down Expand Up @@ -3265,8 +3275,82 @@ def _add_features_path_prefix_single(cut, path):
return cut.with_features_path_prefix(path)


def _call(obj, member_fn: str, *args, **kwargs) -> Callable:
return getattr(obj, member_fn)(*args, **kwargs)
def _with_id(cut, transform_fn):
return cut.with_id(transform_fn(cut.id))


def _fill_supervision(cut, add_empty, shrink_ok):
return cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok)


def _map_supervisions(cut, transform_fn):
return cut.map_supervisions(transform_fn)


def _transform_text(sup, transform_fn):
return sup.transform_text(transform_fn)


def _filter_supervisions(cut, predicate):
return cut.filter_supervisions(predicate)


def _merge_supervisions(cut, merge_policy, custom_merge_fn):
return cut.merge_supervisions(
merge_policy=merge_policy, custom_merge_fn=custom_merge_fn
)


def _pad(cut, *args, **kwargs):
return cut.pad(*args, **kwargs)


def _extend_by(cut, *args, **kwargs):
return cut.extend_by(*args, **kwargs)


def _resample(cut, *args, **kwargs):
return cut.resample(*args, **kwargs)


def _perturb_speed(cut, *args, **kwargs):
return cut.perturb_speed(*args, **kwargs)


def _perturb_tempo(cut, *args, **kwargs):
return cut.perturb_tempo(*args, **kwargs)


def _perturb_volume(cut, *args, **kwargs):
return cut.perturb_volume(*args, **kwargs)


def _reverb_rir(cut, *args, **kwargs):
return cut.reverb_rir(*args, **kwargs)


def _normalize_loudness(cut, *args, **kwargs):
return cut.normalize_loudness(*args, **kwargs)


def _dereverb_wpe(cut, *args, **kwargs):
return cut.dereverb_wpe(*args, **kwargs)


def _drop_features(cut, *args, **kwargs):
return cut.drop_features(*args, **kwargs)


def _drop_recordings(cut, *args, **kwargs):
return cut.drop_recording(*args, **kwargs)


def _drop_alignments(cut, *args, **kwargs):
return cut.drop_alignments(*args, **kwargs)


def _drop_supervisions(cut, *args, **kwargs):
return cut.drop_supervisions(*args, **kwargs)


def _export_to_shar_single(
Expand Down
5 changes: 2 additions & 3 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,9 @@ def close_to_exceeding(self) -> bool:
if self.max_cuts is not None and self.num_cuts >= self.max_cuts:
return True

thresh = self.longest_seen

if self.max_duration is not None:
return self.current + thresh >= self.max_duration - 1e-3 # float precision
effective_duration = (self.num_cuts + 1) * self.longest_seen
return effective_duration > self.max_duration
return False

def reset(self) -> None:
Expand Down
32 changes: 10 additions & 22 deletions lhotse/dataset/sampling/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,8 @@ def detuplify(
while True:
# Check that we have not reached the end of the dataset.
try:
if self.reuse_cuts_buffer:
next_cut_or_tpl = self.reuse_cuts_buffer.popleft()
else:
# If this doesn't raise (typical case), it's not the end: keep processing.
next_cut_or_tpl = next(self.cuts_iter)
# If this doesn't raise (typical case), it's not the end: keep processing.
next_cut_or_tpl = next(self.cuts_iter)
except StopIteration:
# No more cuts to sample from: if we have a partial batch,
# we may output it, unless the user requested to drop it.
Expand All @@ -315,32 +312,23 @@ def detuplify(
raise StopIteration()

# Track the duration/frames/etc. constraints.
cuts.append(next_cut_or_tpl)
self.time_constraint.add(
next_cut_or_tpl[0]
if isinstance(next_cut_or_tpl, tuple)
else next_cut_or_tpl
)

# Did we exceed the max_frames and max_cuts constraints?
if not self.time_constraint.exceeded():
# No - add the next cut to the batch, and keep trying.
cuts.append(next_cut_or_tpl)
else:
# Yes. Do we have at least one cut in the batch?
if cuts:
# Yes. Return the batch, but keep the currently drawn cut for later.
self.reuse_cuts_buffer.append(next_cut_or_tpl)
break
else:
# No. We'll warn the user that the constrains might be too tight,
# and return the cut anyway.
if self.time_constraint.close_to_exceeding():
# Yes. Finish sampling this batch.
if self.time_constraint.exceeded():
warnings.warn(
"The first cut drawn in batch collection violates "
"the max_frames, max_cuts, or max_duration constraints - "
"we'll return it anyway. "
"Consider increasing max_frames/max_cuts/max_duration."
"We have exceeded the max_duration constraint during sampling. "
"This is likely because max_duration was set to a very low value ~10s, "
"or you're using a CutSet with very long cuts (e.g. 100s of seconds long)."
)
cuts.append(next_cut_or_tpl)
break

return detuplify(cuts)

Expand Down
Loading

0 comments on commit c678849

Please sign in to comment.