From 6042649b394ec82dc4f86b92137bc0bd9e8f80e1 Mon Sep 17 00:00:00 2001 From: Guoguo Chen Date: Wed, 9 Sep 2015 00:27:42 -0400 Subject: [PATCH] Adding initial code for rnnlm lattice rescoring (with Mikolov's tool). --- egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh | 95 + src/latbin/Makefile | 2 +- src/latbin/lattice-lmrescore-rnnlm.cc | 142 ++ src/lm/Makefile | 3 +- src/lm/kaldi-rnnlm.cc | 140 ++ src/lm/kaldi-rnnlm.h | 104 ++ src/lm/mikolov-rnnlm-lib.cc | 2187 +++++++++++++++++++++++ src/lm/mikolov-rnnlm-lib.h | 363 ++++ 8 files changed, 3034 insertions(+), 2 deletions(-) create mode 100755 egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh create mode 100644 src/latbin/lattice-lmrescore-rnnlm.cc create mode 100644 src/lm/kaldi-rnnlm.cc create mode 100644 src/lm/kaldi-rnnlm.h create mode 100644 src/lm/mikolov-rnnlm-lib.cc create mode 100644 src/lm/mikolov-rnnlm-lib.h diff --git a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh new file mode 100755 index 00000000000..46186544b88 --- /dev/null +++ b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# Copyright 2015 Guoguo Chen +# Apache 2.0 + +# This script rescores lattices with RNNLM. + +# Begin configuration section. +cmd=run.pl +skip_scoring=false +max_ngram_order=4 +N=10 +inv_acwt=12 +weight=1.0 # Interpolation weight for RNNLM. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +. ./utils/parse_options.sh + +if [ $# != 5 ]; then + echo "Does language model rescoring of lattices (remove old LM, add new LM)" + echo "with RNNLM." + echo "" + echo "Usage: $0 [options] \\" + echo " " + echo " e.g.: $0 rnnlm data/lang_tg data/test \\" + echo " exp/tri3/test_tg exp/tri3/test_rnnlm" + echo "options: [--cmd (run.pl|queue.pl [queue opts])]" + exit 1; +fi + +[ -f path.sh ] && . ./path.sh; + +rnnlm_dir=$1 +oldlang=$2 +data=$3 +indir=$4 +outdir=$5 + +oldlm=$oldlang/G.fst +if [ -f $oldlang/G.carpa ]; then + oldlm=$oldlang/G.carpa +elif [ ! -f $oldlm ]; then + echo "$0: expecting either $oldlang/G.fst or $oldlang/G.carpa to exist" &&\ + exit 1; +fi + +[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1; +[ ! -f $rnnlm_dir/rnnlm ] && echo "$0: Missing file $rnnlm_dir/rnnlm" && exit 1; +[ ! -f $rnnlm_dir/unk.probs ] &&\ + echo "$0: Missing file $rnnlm_dir/unk.probs" && exit 1; +[ ! -f $oldlang/words.txt ] &&\ + echo "$0: Missing file $oldlang/words.txt" && exit 1; +[ $(echo "$weight < 0 || $weight > 1" | bc) -eq 1 ] &&\ + echo "$0: Interpolation weight should be in the range of [0, 1]" && exit 1; +! ls $indir/lat.*.gz >/dev/null &&\ + echo "$0: No lattices input directory $indir" && exit 1; + +oldlm_command="fstproject --project_output=true $oldlm |" + +acwt=`perl -e "print (1.0/$inv_acwt);"` + +mkdir -p $outdir/log +nj=`cat $indir/num_jobs` || exit 1; +cp $indir/num_jobs $outdir + +oldlm_weight=$(echo "-1 * $weight" | bc -l) +if [ "$oldlm" == "$oldlang/G.fst" ]; then + $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ + lattice-lmrescore --lm-scale=$oldlm_weight \ + "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ + lattice-lmrescore-rnnlm --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ + $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; +else + $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ + lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \ + "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ + lattice-lmrescore-rnnlm --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ + $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; +fi + +if ! $skip_scoring ; then + err_msg="Not scoring because local/score.sh does not exist or not executable." + [ ! -x local/score.sh ] && echo $err_msg && exit 1; + local/score.sh --cmd "$cmd" $data $newlang $outdir +else + echo "Not scoring because requested so..." +fi + +exit 0; diff --git a/src/latbin/Makefile b/src/latbin/Makefile index a3f05621a68..f1633978fbf 100644 --- a/src/latbin/Makefile +++ b/src/latbin/Makefile @@ -20,7 +20,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \ lattice-minimize lattice-limit-depth lattice-depth-per-frame \ lattice-confidence lattice-determinize-phone-pruned \ lattice-determinize-phone-pruned-parallel lattice-expand-ngram \ - lattice-lmrescore-const-arpa nbest-to-prons + lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons OBJFILES = diff --git a/src/latbin/lattice-lmrescore-rnnlm.cc b/src/latbin/lattice-lmrescore-rnnlm.cc new file mode 100644 index 00000000000..92a9b014297 --- /dev/null +++ b/src/latbin/lattice-lmrescore-rnnlm.cc @@ -0,0 +1,142 @@ +// latbin/lattice-lmrescore-rnnlm.cc + +// Copyright 2015 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "fstext/fstext-lib.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" +#include "lm/kaldi-rnnlm.h" +#include "lm/mikolov-rnnlm-lib.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Rescores lattice with rnnlm. The LM will be wrapped into the\n" + "DeterministicOnDemandFst interface and the rescoring is done by\n" + "composing with the wrapped LM using a special type of composition\n" + "algorithm. Determinization will be applied on the composed lattice.\n" + "\n" + "Usage: lattice-lmrescore-rnnlm [options] [unk_prob_rspecifier] \\\n" + " \\\n" + " \n" + " e.g.: lattice-lmrescore-rnnlm --lm-scale=-1.0 words.txt \\\n" + " ark:in.lats rnnlm ark:out.lats\n"; + + ParseOptions po(usage); + int32 max_ngram_order = 3; + BaseFloat lm_scale = 1.0; + + po.Register("lm-scale", &lm_scale, "Scaling factor for language model " + "costs; frequently 1.0 or -1.0"); + po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the " + "rnnlm context to the given number, -1 means we are not going " + "to limit it."); + + KaldiRnnlmWrapperOpts opts; + opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 4 && po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string lats_rspecifier, unk_prob_rspecifier, + word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier; + if (po.NumArgs() == 4) { + unk_prob_rspecifier = ""; + word_symbols_rxfilename = po.GetArg(1); + lats_rspecifier = po.GetArg(2); + rnnlm_rxfilename = po.GetArg(3); + lats_wspecifier = po.GetArg(4); + } else if (po.NumArgs() == 5) { + unk_prob_rspecifier = po.GetArg(1); + word_symbols_rxfilename = po.GetArg(2); + lats_rspecifier = po.GetArg(3); + rnnlm_rxfilename = po.GetArg(4); + lats_wspecifier = po.GetArg(5); + } + + // Reads the language model. + KaldiRnnlmWrapper rnnlm(opts, unk_prob_rspecifier, + word_symbols_rxfilename, rnnlm_rxfilename); + + // Reads and writes as compact lattice. + SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); + CompactLatticeWriter compact_lattice_writer(lats_wspecifier); + + int32 n_done = 0, n_fail = 0; + for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { + std::string key = compact_lattice_reader.Key(); + CompactLattice clat = compact_lattice_reader.Value(); + compact_lattice_reader.FreeCurrent(); + + if (lm_scale != 0.0) { + // Before composing with the LM FST, we scale the lattice weights + // by the inverse of "lm_scale". We'll later scale by "lm_scale". + // We do it this way so we can determinize and it will give the + // right effect (taking the "best path" through the LM) regardless + // of the sign of lm_scale. + fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat); + ArcSort(&clat, fst::OLabelCompare()); + + // Wraps the rnnlm into FST. We re-create it for each lattice to prevent + // memory usage increasing with time. + RnnlmDeterministicFst rnnlm_fst(max_ngram_order, &rnnlm); + + // Composes lattice with language model. + CompactLattice composed_clat; + ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat); + + // Determinizes the composed lattice. + Lattice composed_lat; + ConvertLattice(composed_clat, &composed_lat); + Invert(&composed_lat); + CompactLattice determinized_clat; + DeterminizeLattice(composed_lat, &determinized_clat); + fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat); + if (determinized_clat.Start() == fst::kNoStateId) { + KALDI_WARN << "Empty lattice for utterance " << key + << " (incompatible LM?)"; + n_fail++; + } else { + compact_lattice_writer.Write(key, determinized_clat); + n_done++; + } + } else { + // Zero scale so nothing to do. + n_done++; + compact_lattice_writer.Write(key, clat); + } + } + + KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail; + return (n_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/lm/Makefile b/src/lm/Makefile index 5edc55a563a..ddda9576557 100644 --- a/src/lm/Makefile +++ b/src/lm/Makefile @@ -12,7 +12,8 @@ include ../kaldi.mk TESTFILES = lm-lib-test -OBJFILES = const-arpa-lm.o kaldi-lmtable.o kaldi-lm.o +OBJFILES = const-arpa-lm.o kaldi-lmtable.o kaldi-lm.o kaldi-rnnlm.o \ + mikolov-rnnlm-lib.o TESTOUTPUTS = composed.fst output.fst output1.fst output2.fst diff --git a/src/lm/kaldi-rnnlm.cc b/src/lm/kaldi-rnnlm.cc new file mode 100644 index 00000000000..e1fbcbdc08b --- /dev/null +++ b/src/lm/kaldi-rnnlm.cc @@ -0,0 +1,140 @@ +// lm/kaldi-rnnlm.cc + +// Copyright 2015 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "lm/kaldi-rnnlm.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" + +namespace kaldi { + +KaldiRnnlmWrapper::KaldiRnnlmWrapper( + const KaldiRnnlmWrapperOpts &opts, + const std::string &unk_prob_rspecifier, + const std::string &word_symbol_table_rxfilename, + const std::string &rnnlm_rxfilename) { + rnnlm_.setRnnLMFile(rnnlm_rxfilename); + rnnlm_.setRandSeed(1); + rnnlm_.setUnkSym(opts.unk_symbol); + rnnlm_.setUnkPenalty(unk_prob_rspecifier); + rnnlm_.restoreNet(); + + // Reads symbol table. + fst::SymbolTable *word_symbols = NULL; + if (!(word_symbols = + fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) { + KALDI_ERR << "Could not read symbol table from file " + << word_symbol_table_rxfilename; + } + label_to_word_.resize(word_symbols->NumSymbols() + 1); + for (int32 i = 0; i < label_to_word_.size() - 1; ++i) { + label_to_word_[i] = word_symbols->Find(i); + if (label_to_word_[i] == "") { + KALDI_ERR << "Could not find word for integer " << i << "in the word " + << "symbol table, mismatched symbol table or you have discoutinuous " + << "integers in your symbol table?"; + } + } + label_to_word_[label_to_word_.size() - 1] = opts.eos_symbol; + eos_ = label_to_word_.size() - 1; +} + +BaseFloat KaldiRnnlmWrapper::GetLogProb( + int32 word, const std::vector &wseq, + const std::vector &context_in, + std::vector *context_out) { + + std::vector wseq_symbols(wseq.size()); + for (int32 i = 0; i < wseq_symbols.size(); ++i) { + KALDI_ASSERT(wseq[i] < label_to_word_.size()); + wseq_symbols[i] = label_to_word_[wseq[i]]; + } + + return rnnlm_.computeConditionalLogprob(label_to_word_[word], wseq_symbols, + context_in, context_out); +} + +RnnlmDeterministicFst::RnnlmDeterministicFst(int32 max_ngram_order, + KaldiRnnlmWrapper *rnnlm) { + KALDI_ASSERT(rnnlm != NULL); + max_ngram_order_ = max_ngram_order; + rnnlm_ = rnnlm; + + // Uses empty history for . + std::vector") {} + + void Register(OptionsItf *opts) { + opts->Register("unk-symbol", &unk_symbol, "Symbol for out-of-vocabulary " + "words in rnnlm."); + opts->Register("eos-symbol", &eos_symbol, "End of setence symbol in " + "rnnlm."); + } +}; + +class KaldiRnnlmWrapper { + public: + KaldiRnnlmWrapper(const KaldiRnnlmWrapperOpts &opts, + const std::string &unk_prob_rspecifier, + const std::string &word_symbol_table_rxfilename, + const std::string &rnnlm_rxfilename); + + int32 GetHiddenLayerSize() const { return rnnlm_.getHiddenLayerSize(); } + + int32 GetEos() const { return eos_; } + + BaseFloat GetLogProb(int32 word, const std::vector &wseq, + const std::vector &context_in, + std::vector *context_out); + + private: + CRnnLM rnnlm_; + std::vector label_to_word_; + int32 eos_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiRnnlmWrapper); +}; + +class RnnlmDeterministicFst + : public fst::DeterministicOnDemandFst { + public: + typedef fst::StdArc::Weight Weight; + typedef fst::StdArc::StateId StateId; + typedef fst::StdArc::Label Label; + + // Does not take ownership. + RnnlmDeterministicFst(int32 max_ngram_order, KaldiRnnlmWrapper *rnnlm); + + // We cannot use "const" because the pure virtual function in the interface is + // not const. + virtual StateId Start() { return start_state_; } + + // We cannot use "const" because the pure virtual function in the interface is + // not const. + virtual Weight Final(StateId s); + + virtual bool GetArc(StateId s, Label ilabel, fst::StdArc* oarc); + + private: + typedef unordered_map, + StateId, VectorHasher