diff --git a/lhotse/audio/utils.py b/lhotse/audio/utils.py index c4a604234..b5f7debd5 100644 --- a/lhotse/audio/utils.py +++ b/lhotse/audio/utils.py @@ -125,6 +125,7 @@ def suppress_audio_loading_errors(enabled: bool = True): AudioLoadingError, DurationMismatchError, NonPositiveEnergyError, + ConnectionResetError, # when reading from object stores / network sources enabled=enabled, ): yield @@ -141,6 +142,7 @@ def suppress_video_loading_errors(enabled: bool = True): AudioLoadingError, DurationMismatchError, NonPositiveEnergyError, + ConnectionResetError, # when reading from object stores / network sources enabled=enabled, ): yield diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index 6545b1671..26adda779 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -2,7 +2,7 @@ import os import warnings from abc import ABCMeta, abstractmethod -from bisect import bisect_right +from bisect import bisect_left from copy import deepcopy from dataclasses import asdict, dataclass from math import isclose @@ -424,7 +424,7 @@ def select_bucket( ), f"select_bucket requires either example= or example_len= as the input (we received {example=} and {example_len=})." if example_len is None: example_len = self.measure_length(example) - return bisect_right(buckets, example_len) + return bisect_left(buckets, example_len) def copy(self) -> "SamplingConstraint": """Return a shallow copy of this constraint.""" diff --git a/lhotse/dataset/sampling/stateless.py b/lhotse/dataset/sampling/stateless.py index 6667242b6..91f9395f1 100644 --- a/lhotse/dataset/sampling/stateless.py +++ b/lhotse/dataset/sampling/stateless.py @@ -1,7 +1,17 @@ import logging import random from pathlib import Path -from typing import Callable, Dict, Generator, Iterable, Optional, Sequence, Tuple, Union +from typing import ( + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) import torch from cytoolz import compose_left @@ -89,6 +99,8 @@ class StatelessSampler(torch.utils.data.Sampler, Dillable): :param max_duration: Maximum total number of audio seconds in a mini-batch (dynamic batch size). :param max_cuts: Maximum number of examples in a mini-batch (static batch size). :param num_buckets: If set, enables bucketing (each mini-batch has examples of a similar duration). + :param duration_bins: A list of floats (seconds); when provided, we'll skip the initial + estimation of bucket duration bins (useful to speed-up the launching of experiments). :param quadratic_duration: If set, adds a penalty term for longer duration cuts. Works well with models that have quadratic time complexity to keep GPU utilization similar when using bucketing. Suggested values are between 30 and 45. @@ -102,6 +114,7 @@ def __init__( max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, num_buckets: Optional[int] = None, + duration_bins: List[Seconds] = None, quadratic_duration: Optional[Seconds] = None, ) -> None: super().__init__(data_source=None) @@ -146,6 +159,7 @@ def __init__( self.max_duration = max_duration self.max_cuts = max_cuts self.num_buckets = num_buckets + self.duration_bins = duration_bins self.quadratic_duration = quadratic_duration self.base_seed = base_seed assert any( @@ -216,12 +230,13 @@ def _inner(): yield cut n += 1 - if self.num_buckets is not None and self.num_buckets > 1: + if self.num_buckets is not None or self.duration_bins is not None: inner_sampler = DynamicBucketingSampler( _inner(), max_duration=self.max_duration, max_cuts=self.max_cuts, num_buckets=self.num_buckets, + duration_bins=self.duration_bins, shuffle=False, drop_last=False, quadratic_duration=self.quadratic_duration, diff --git a/test/audio/test_audio_reads.py b/test/audio/test_audio_reads.py index 2a4b9c7ed..d841ca9f9 100644 --- a/test/audio/test_audio_reads.py +++ b/test/audio/test_audio_reads.py @@ -2,6 +2,7 @@ from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory +from unittest.mock import Mock import numpy as np import pytest @@ -10,6 +11,7 @@ import lhotse from lhotse import AudioSource, Recording +from lhotse.audio import suppress_audio_loading_errors from lhotse.audio.backend import ( info, read_opus_ffmpeg, @@ -260,3 +262,27 @@ def test_set_audio_backend(): ) audio2 = recording.load_audio() np.testing.assert_array_almost_equal(audio1, audio2) + + +def test_fault_tolerant_audio_network_exception(): + def _mock_load_audio(*args, **kwargs): + raise ConnectionResetError() + + source = Mock() + source.load_audio = _mock_load_audio + source.has_video = False + + recording = Recording( + id="irrelevant", + sources=[source], + sampling_rate=16000, + num_samples=16000, + duration=1.0, + channel_ids=[0], + ) + + with pytest.raises(ConnectionResetError): + recording.load_audio() # does raise + + with suppress_audio_loading_errors(True): + recording.load_audio() # is silently caught diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index e7d2db019..94bb40cc7 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -53,7 +53,7 @@ def test_dynamic_bucketing_drop_last_false(): rng = random.Random(0) sampler = DynamicBucketer( - cuts, duration_bins=[2], max_duration=5, rng=rng, world_size=1 + cuts, duration_bins=[1.5], max_duration=5, rng=rng, world_size=1 ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -90,7 +90,7 @@ def test_dynamic_bucketing_drop_last_true(): rng = random.Random(0) sampler = DynamicBucketer( - cuts, duration_bins=[2], max_duration=5, rng=rng, drop_last=True, world_size=1 + cuts, duration_bins=[1.5], max_duration=5, rng=rng, drop_last=True, world_size=1 ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -125,7 +125,7 @@ def test_dynamic_bucketing_sampler(concurrent): c.duration = 2 sampler = DynamicBucketingSampler( - cuts, max_duration=5, num_buckets=2, seed=0, concurrent=concurrent + cuts, max_duration=5, duration_bins=[1.5], seed=0, concurrent=concurrent ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -231,7 +231,9 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled(): c.duration = 2 # 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets - sampler = DynamicBucketingSampler(cuts, max_duration=100, num_buckets=2, seed=0) + sampler = DynamicBucketingSampler( + cuts, max_duration=100, duration_bins=[1.5], seed=0 + ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -249,6 +251,35 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled(): assert len(b) == 5 +def test_dynamic_bucketing_sampler_much_less_data_than_ddp_ranks(): + world_size = 128 + orig_cut = dummy_cut(0) + cuts = CutSet([orig_cut]) + samplers = [ + DynamicBucketingSampler( + cuts, + max_duration=2000.0, + duration_bins=[1.5, 3.7, 15.2, 27.9, 40.0], + drop_last=False, + concurrent=False, + world_size=world_size, + rank=i, + ) + for i in range(world_size) + ] + # None of the ranks drops anything, all of them return the one cut we have. + for sampler in samplers: + (batch,) = [b for b in sampler] + assert len(batch) == 1 + (sampled_cut,) = batch + assert ( + sampled_cut.id[: len(orig_cut.id)] == orig_cut.id + ) # same stem, possibly added '_dupX' suffix + # otherwise the cuts are identical + sampled_cut.id = orig_cut.id + assert sampled_cut == orig_cut + + def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches(): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): @@ -337,7 +368,9 @@ def test_dynamic_bucketing_sampler_cut_pairs(): else: c.duration = 2 - sampler = DynamicBucketingSampler(cuts, cuts, max_duration=5, num_buckets=2, seed=0) + sampler = DynamicBucketingSampler( + cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0 + ) batches = [b for b in sampler] sampled_cut_pairs = [cut_pair for b in batches for cut_pair in zip(*b)] source_cuts = [sc for sc, tc in sampled_cut_pairs] @@ -473,7 +506,7 @@ def test_dynamic_bucketing_sampler_cut_triplets(): c.duration = 2 sampler = DynamicBucketingSampler( - cuts, cuts, cuts, max_duration=5, num_buckets=2, seed=0 + cuts, cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0 ) batches = [b for b in sampler] sampled_cut_triplets = [cut_triplet for b in batches for cut_triplet in zip(*b)] @@ -542,7 +575,7 @@ def test_dynamic_bucketing_quadratic_duration(): # quadratic_duration=30 sampler = DynamicBucketingSampler( - cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=30 + cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=30 ) batches = [b for b in sampler] assert len(batches) == 6 @@ -556,7 +589,7 @@ def test_dynamic_bucketing_quadratic_duration(): # quadratic_duration=None (disabled) sampler = DynamicBucketingSampler( - cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=None + cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=None ) batches = [b for b in sampler] assert len(batches) == 4 @@ -731,3 +764,31 @@ def test_dynamic_bucketing_sampler_fixed_batch_constraint(): assert len(batches[7]) == 1 assert sum(c.duration for c in batches[7]) == 1 + + +def test_select_bucket_includes_upper_bound_in_bin(): + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=[2.0, 4.0], batch_sizes=[2, 1] + ) + + # within bounds + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=1.0) == 0 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=2.0) == 0 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=3.0) == 1 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=4.0) == 1 + ) + constraint.add(dummy_cut(0, duration=4.0)) # can add max duration without exception + + # out of bounds + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=5.0) == 2 + ) + with pytest.raises(AssertionError): + constraint.add(dummy_cut(0, duration=5.0)) diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 74736794e..686b472b9 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -1232,3 +1232,27 @@ def test_sampler_map(): b = batches[1] assert len(b) == 1 assert b[0].duration == 5.0 + + +def test_sampler_much_less_data_than_ddp_ranks(): + world_size = 128 + orig_cut = dummy_cut(0) + cuts = CutSet([orig_cut]) + + samplers = [ + DynamicCutSampler( + cuts, max_cuts=256, drop_last=False, world_size=world_size, rank=i + ) + for i in range(world_size) + ] + # None of the ranks drops anything, all of them return the one cut we have. + for sampler in samplers: + (batch,) = [b for b in sampler] + assert len(batch) == 1 + (sampled_cut,) = batch + assert ( + sampled_cut.id[: len(orig_cut.id)] == orig_cut.id + ) # same stem, possibly added '_dupX' suffix + # otherwise the cuts are identical + sampled_cut.id = orig_cut.id + assert sampled_cut == orig_cut diff --git a/test/dataset/sampling/test_stateless_sampler.py b/test/dataset/sampling/test_stateless_sampler.py index 416e7e2b2..f0e32cbc3 100644 --- a/test/dataset/sampling/test_stateless_sampler.py +++ b/test/dataset/sampling/test_stateless_sampler.py @@ -189,7 +189,11 @@ def test_stateless_sampler_in_dataloader_with_iterable_dataset( def test_stateless_sampler_bucketing(cuts_files: Tuple[Path]): index_path = cuts_files[0].parent / "cuts.idx" sampler = StatelessSampler( - cuts_files, index_path=index_path, num_buckets=2, max_duration=4, base_seed=0 + cuts_files, + index_path=index_path, + duration_bins=[1.5], + max_duration=4, + base_seed=0, ) for idx, batch in enumerate(sampler):