Skip to content

Commit

Permalink
[src] Issue error if file sample rate differs from feature .conf (kal…
Browse files Browse the repository at this point in the history
…di-asr#4648)

Change rolled to all feature extractors.
  • Loading branch information
jtrmal authored Dec 1, 2021
1 parent 5cd9c1e commit a92babf
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/cudafeatbin/compute-fbank-feats-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "feat/wave-reader.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"


int main(int argc, char *argv[]) {
try {
using namespace kaldi;
Expand Down Expand Up @@ -66,7 +68,7 @@ int main(int argc, char *argv[]) {
po.PrintUsage();
exit(1);
}

g_cuda_allocator.SetOptions(g_allocator_options);
CuDevice::Instantiate().SelectGpuId("yes");
CuDevice::Instantiate().AllowMultithreading();
Expand All @@ -76,7 +78,7 @@ int main(int argc, char *argv[]) {

std::string output_wspecifier = po.GetArg(2);

// Fbank is implemented via the MFCC code path
// Fbank is implemented via the MFCC code path.
CudaSpectralFeatures fbank(fbank_opts);

SequentialTableReader<WaveHolder> reader(wav_rspecifier);
Expand All @@ -88,7 +90,7 @@ int main(int argc, char *argv[]) {
"needed if the vtln-map option is used.");
RandomAccessBaseFloatReaderMapped vtln_map_reader(vtln_map_rspecifier,
utt2spk_rspecifier);

if (output_format == "kaldi") {
if (!kaldi_writer.Open(output_wspecifier))
KALDI_ERR << "Could not initialize output with wspecifier "
Expand All @@ -106,6 +108,11 @@ int main(int argc, char *argv[]) {
num_utts++;
std::string utt = reader.Key();
const WaveData &wave_data = reader.Value();
if (wave_data.SampFreq() != fbank_opts.frame_opts.samp_freq) {
KALDI_ERR << "File: " << utt << " has an mismatched sampling "
<< "rate (config= " << fbank_opts.frame_opts.samp_freq
<< " vs file=" << wave_data.SampFreq() << ".";
}
if (wave_data.Duration() < min_duration) {
KALDI_WARN << "File: " << utt << " is too short ("
<< wave_data.Duration() << " sec): producing no output.";
Expand Down
8 changes: 7 additions & 1 deletion src/cudafeatbin/compute-fbank-online-batched-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace kaldi;

// This class stores data for input and output for this binary.
// We will read/write slices of this input/output in an online
// fasion.
// fashion.
struct UtteranceDataHandle {
std::string utt;
WaveData wave_data_in;
Expand Down Expand Up @@ -186,6 +186,12 @@ int main(int argc, char *argv[]) {
for (; !reader.Done(); reader.Next()) {
std::string utt = reader.Key();
WaveData &wave_data = reader.Value();
if (wave_data.SampFreq() != feature_opts.frame_opts.samp_freq) {
KALDI_ERR << "File: " << utt << " has an mismatched sampling "
<< "rate (config= " << feature_opts.frame_opts.samp_freq
<< " vs file=" << wave_data.SampFreq() << ".";
}

duration += wave_data.Duration();
data_handles.emplace_back(utt, wave_data, frame_opts, feat_dim);
}
Expand Down
11 changes: 9 additions & 2 deletions src/cudafeatbin/compute-mfcc-feats-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "feat/wave-reader.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"


int main(int argc, char *argv[]) {
try {
using namespace kaldi;
Expand Down Expand Up @@ -66,7 +68,7 @@ int main(int argc, char *argv[]) {
po.PrintUsage();
exit(1);
}

g_cuda_allocator.SetOptions(g_allocator_options);
CuDevice::Instantiate().SelectGpuId("yes");
CuDevice::Instantiate().AllowMultithreading();
Expand All @@ -87,7 +89,7 @@ int main(int argc, char *argv[]) {
"needed if the vtln-map option is used.");
RandomAccessBaseFloatReaderMapped vtln_map_reader(vtln_map_rspecifier,
utt2spk_rspecifier);

if (output_format == "kaldi") {
if (!kaldi_writer.Open(output_wspecifier))
KALDI_ERR << "Could not initialize output with wspecifier "
Expand All @@ -105,6 +107,11 @@ int main(int argc, char *argv[]) {
num_utts++;
std::string utt = reader.Key();
const WaveData &wave_data = reader.Value();
if (wave_data.SampFreq() != mfcc_opts.frame_opts.samp_freq) {
KALDI_ERR << "File: " << utt << " has an mismatched sampling "
<< "rate (config= " << mfcc_opts.frame_opts.samp_freq
<< " vs file=" << wave_data.SampFreq() << ".";
}
if (wave_data.Duration() < min_duration) {
KALDI_WARN << "File: " << utt << " is too short ("
<< wave_data.Duration() << " sec): producing no output.";
Expand Down
9 changes: 8 additions & 1 deletion src/cudafeatbin/compute-mfcc-online-batched-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace kaldi;

// This class stores data for input and output for this binary.
// We will read/write slices of this input/output in an online
// fasion.
// fashion.
struct UtteranceDataHandle {
std::string utt;
WaveData wave_data_in;
Expand Down Expand Up @@ -186,6 +186,12 @@ int main(int argc, char *argv[]) {
for (; !reader.Done(); reader.Next()) {
std::string utt = reader.Key();
WaveData &wave_data = reader.Value();
if (wave_data.SampFreq() != feature_opts.frame_opts.samp_freq) {
KALDI_ERR << "File: " << utt << " has an mismatched sampling "
<< "rate (config= " << feature_opts.frame_opts.samp_freq
<< " vs file=" << wave_data.SampFreq() << ".";
}

duration += wave_data.Duration();
data_handles.emplace_back(utt, wave_data, frame_opts, feat_dim);
}
Expand Down Expand Up @@ -373,6 +379,7 @@ int main(int argc, char *argv[]) {
#if HAVE_CUDA == 1
cudaProfilerStop();
#endif

return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
Expand Down
4 changes: 2 additions & 2 deletions src/cudafeatbin/compute-online-feats-batched-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace kaldi;

// This class stores data for input and output for this binary.
// We will read/write slices of this input/output in an online
// fasion.
// fashion.
struct UtteranceDataHandle {
std::string utt;
WaveData wave_data_in;
Expand Down Expand Up @@ -205,7 +205,7 @@ int main(int argc, char *argv[]) {

// This binary is pipelined to allow concurrent memory copies and compute.
// State exists for each pipeline and successive chunks go to different
// pipelines in a modular fasion. The calling thread will synchronize with
// pipelines in a modular fashion. The calling thread will synchronize with
// a pipeline prior to launching work in that pipeline. 2 should be enough
// to get concurrency on current hardware.
const int num_pipelines = 3;
Expand Down

0 comments on commit a92babf

Please sign in to comment.