Skip to content

Commit

Permalink
Adding initial code for rnnlm lattice rescoring (with Mikolov's tool).
Browse files Browse the repository at this point in the history
  • Loading branch information
chenguoguo committed Sep 9, 2015
1 parent 054e3ff commit 6042649
Show file tree
Hide file tree
Showing 8 changed files with 3,034 additions and 2 deletions.
95 changes: 95 additions & 0 deletions egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/bin/bash

# Copyright 2015 Guoguo Chen
# Apache 2.0

# This script rescores lattices with RNNLM.

# Begin configuration section.
cmd=run.pl
skip_scoring=false
max_ngram_order=4
N=10
inv_acwt=12
weight=1.0 # Interpolation weight for RNNLM.
# End configuration section.

echo "$0 $@" # Print the command line for logging

. ./utils/parse_options.sh

if [ $# != 5 ]; then
echo "Does language model rescoring of lattices (remove old LM, add new LM)"
echo "with RNNLM."
echo ""
echo "Usage: $0 [options] <rnnlm-dir> <old-lang-dir> \\"
echo " <data-dir> <input-decode-dir> <output-decode-dir>"
echo " e.g.: $0 rnnlm data/lang_tg data/test \\"
echo " exp/tri3/test_tg exp/tri3/test_rnnlm"
echo "options: [--cmd (run.pl|queue.pl [queue opts])]"
exit 1;
fi

[ -f path.sh ] && . ./path.sh;

rnnlm_dir=$1
oldlang=$2
data=$3
indir=$4
outdir=$5

oldlm=$oldlang/G.fst
if [ -f $oldlang/G.carpa ]; then
oldlm=$oldlang/G.carpa
elif [ ! -f $oldlm ]; then
echo "$0: expecting either $oldlang/G.fst or $oldlang/G.carpa to exist" &&\
exit 1;
fi

[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1;
[ ! -f $rnnlm_dir/rnnlm ] && echo "$0: Missing file $rnnlm_dir/rnnlm" && exit 1;
[ ! -f $rnnlm_dir/unk.probs ] &&\
echo "$0: Missing file $rnnlm_dir/unk.probs" && exit 1;
[ ! -f $oldlang/words.txt ] &&\
echo "$0: Missing file $oldlang/words.txt" && exit 1;
[ $(echo "$weight < 0 || $weight > 1" | bc) -eq 1 ] &&\
echo "$0: Interpolation weight should be in the range of [0, 1]" && exit 1;
! ls $indir/lat.*.gz >/dev/null &&\
echo "$0: No lattices input directory $indir" && exit 1;

oldlm_command="fstproject --project_output=true $oldlm |"

acwt=`perl -e "print (1.0/$inv_acwt);"`

mkdir -p $outdir/log
nj=`cat $indir/num_jobs` || exit 1;
cp $indir/num_jobs $outdir

oldlm_weight=$(echo "-1 * $weight" | bc -l)
if [ "$oldlm" == "$oldlang/G.fst" ]; then
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
lattice-lmrescore --lm-scale=$oldlm_weight \
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \
lattice-lmrescore-rnnlm --lm-scale=$weight \
--max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \
$oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
"ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
else
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \
lattice-lmrescore-rnnlm --lm-scale=$weight \
--max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \
$oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
"ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
fi

if ! $skip_scoring ; then
err_msg="Not scoring because local/score.sh does not exist or not executable."
[ ! -x local/score.sh ] && echo $err_msg && exit 1;
local/score.sh --cmd "$cmd" $data $newlang $outdir
else
echo "Not scoring because requested so..."
fi

exit 0;
2 changes: 1 addition & 1 deletion src/latbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
lattice-minimize lattice-limit-depth lattice-depth-per-frame \
lattice-confidence lattice-determinize-phone-pruned \
lattice-determinize-phone-pruned-parallel lattice-expand-ngram \
lattice-lmrescore-const-arpa nbest-to-prons
lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons

OBJFILES =

Expand Down
142 changes: 142 additions & 0 deletions src/latbin/lattice-lmrescore-rnnlm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// latbin/lattice-lmrescore-rnnlm.cc

// Copyright 2015 Guoguo Chen

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.


#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "lm/kaldi-rnnlm.h"
#include "lm/mikolov-rnnlm-lib.h"
#include "util/common-utils.h"

int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;

const char *usage =
"Rescores lattice with rnnlm. The LM will be wrapped into the\n"
"DeterministicOnDemandFst interface and the rescoring is done by\n"
"composing with the wrapped LM using a special type of composition\n"
"algorithm. Determinization will be applied on the composed lattice.\n"
"\n"
"Usage: lattice-lmrescore-rnnlm [options] [unk_prob_rspecifier] \\\n"
" <word-symbol-table-rxfilename> <lattice-rspecifier> \\\n"
" <rnnlm-rxfilename> <lattice-wspecifier>\n"
" e.g.: lattice-lmrescore-rnnlm --lm-scale=-1.0 words.txt \\\n"
" ark:in.lats rnnlm ark:out.lats\n";

ParseOptions po(usage);
int32 max_ngram_order = 3;
BaseFloat lm_scale = 1.0;

po.Register("lm-scale", &lm_scale, "Scaling factor for language model "
"costs; frequently 1.0 or -1.0");
po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the "
"rnnlm context to the given number, -1 means we are not going "
"to limit it.");

KaldiRnnlmWrapperOpts opts;
opts.Register(&po);

po.Read(argc, argv);

if (po.NumArgs() != 4 && po.NumArgs() != 5) {
po.PrintUsage();
exit(1);
}

std::string lats_rspecifier, unk_prob_rspecifier,
word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier;
if (po.NumArgs() == 4) {
unk_prob_rspecifier = "";
word_symbols_rxfilename = po.GetArg(1);
lats_rspecifier = po.GetArg(2);
rnnlm_rxfilename = po.GetArg(3);
lats_wspecifier = po.GetArg(4);
} else if (po.NumArgs() == 5) {
unk_prob_rspecifier = po.GetArg(1);
word_symbols_rxfilename = po.GetArg(2);
lats_rspecifier = po.GetArg(3);
rnnlm_rxfilename = po.GetArg(4);
lats_wspecifier = po.GetArg(5);
}

// Reads the language model.
KaldiRnnlmWrapper rnnlm(opts, unk_prob_rspecifier,
word_symbols_rxfilename, rnnlm_rxfilename);

// Reads and writes as compact lattice.
SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
CompactLatticeWriter compact_lattice_writer(lats_wspecifier);

int32 n_done = 0, n_fail = 0;
for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
std::string key = compact_lattice_reader.Key();
CompactLattice clat = compact_lattice_reader.Value();
compact_lattice_reader.FreeCurrent();

if (lm_scale != 0.0) {
// Before composing with the LM FST, we scale the lattice weights
// by the inverse of "lm_scale". We'll later scale by "lm_scale".
// We do it this way so we can determinize and it will give the
// right effect (taking the "best path" through the LM) regardless
// of the sign of lm_scale.
fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat);
ArcSort(&clat, fst::OLabelCompare<CompactLatticeArc>());

// Wraps the rnnlm into FST. We re-create it for each lattice to prevent
// memory usage increasing with time.
RnnlmDeterministicFst rnnlm_fst(max_ngram_order, &rnnlm);

// Composes lattice with language model.
CompactLattice composed_clat;
ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat);

// Determinizes the composed lattice.
Lattice composed_lat;
ConvertLattice(composed_clat, &composed_lat);
Invert(&composed_lat);
CompactLattice determinized_clat;
DeterminizeLattice(composed_lat, &determinized_clat);
fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat);
if (determinized_clat.Start() == fst::kNoStateId) {
KALDI_WARN << "Empty lattice for utterance " << key
<< " (incompatible LM?)";
n_fail++;
} else {
compact_lattice_writer.Write(key, determinized_clat);
n_done++;
}
} else {
// Zero scale so nothing to do.
n_done++;
compact_lattice_writer.Write(key, clat);
}
}

KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail;
return (n_done != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
3 changes: 2 additions & 1 deletion src/lm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ include ../kaldi.mk

TESTFILES = lm-lib-test

OBJFILES = const-arpa-lm.o kaldi-lmtable.o kaldi-lm.o
OBJFILES = const-arpa-lm.o kaldi-lmtable.o kaldi-lm.o kaldi-rnnlm.o \
mikolov-rnnlm-lib.o

TESTOUTPUTS = composed.fst output.fst output1.fst output2.fst

Expand Down
140 changes: 140 additions & 0 deletions src/lm/kaldi-rnnlm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// lm/kaldi-rnnlm.cc

// Copyright 2015 Guoguo Chen

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include <utility>

#include "lm/kaldi-rnnlm.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"

namespace kaldi {

KaldiRnnlmWrapper::KaldiRnnlmWrapper(
const KaldiRnnlmWrapperOpts &opts,
const std::string &unk_prob_rspecifier,
const std::string &word_symbol_table_rxfilename,
const std::string &rnnlm_rxfilename) {
rnnlm_.setRnnLMFile(rnnlm_rxfilename);
rnnlm_.setRandSeed(1);
rnnlm_.setUnkSym(opts.unk_symbol);
rnnlm_.setUnkPenalty(unk_prob_rspecifier);
rnnlm_.restoreNet();

// Reads symbol table.
fst::SymbolTable *word_symbols = NULL;
if (!(word_symbols =
fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) {
KALDI_ERR << "Could not read symbol table from file "
<< word_symbol_table_rxfilename;
}
label_to_word_.resize(word_symbols->NumSymbols() + 1);
for (int32 i = 0; i < label_to_word_.size() - 1; ++i) {
label_to_word_[i] = word_symbols->Find(i);
if (label_to_word_[i] == "") {
KALDI_ERR << "Could not find word for integer " << i << "in the word "
<< "symbol table, mismatched symbol table or you have discoutinuous "
<< "integers in your symbol table?";
}
}
label_to_word_[label_to_word_.size() - 1] = opts.eos_symbol;
eos_ = label_to_word_.size() - 1;
}

BaseFloat KaldiRnnlmWrapper::GetLogProb(
int32 word, const std::vector<int32> &wseq,
const std::vector<BaseFloat> &context_in,
std::vector<BaseFloat> *context_out) {

std::vector<std::string> wseq_symbols(wseq.size());
for (int32 i = 0; i < wseq_symbols.size(); ++i) {
KALDI_ASSERT(wseq[i] < label_to_word_.size());
wseq_symbols[i] = label_to_word_[wseq[i]];
}

return rnnlm_.computeConditionalLogprob(label_to_word_[word], wseq_symbols,
context_in, context_out);
}

RnnlmDeterministicFst::RnnlmDeterministicFst(int32 max_ngram_order,
KaldiRnnlmWrapper *rnnlm) {
KALDI_ASSERT(rnnlm != NULL);
max_ngram_order_ = max_ngram_order;
rnnlm_ = rnnlm;

// Uses empty history for <s>.
std::vector<Label> bos;
std::vector<BaseFloat> bos_context(rnnlm->GetHiddenLayerSize(), 1.0f);
state_to_wseq_.push_back(bos);
state_to_context_.push_back(bos_context);
wseq_to_state_[bos] = 0;
start_state_ = 0;
}

fst::StdArc::Weight RnnlmDeterministicFst::Final(StateId s) {
// At this point, we should have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());

std::vector<Label> wseq = state_to_wseq_[s];
BaseFloat logprob = rnnlm_->GetLogProb(rnnlm_->GetEos(), wseq,
state_to_context_[s], NULL);
return Weight(-logprob);
}

bool RnnlmDeterministicFst::GetArc(StateId s, Label ilabel, fst::StdArc *oarc) {
// At this point, we should have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());

std::vector<Label> wseq = state_to_wseq_[s];
std::vector<BaseFloat> new_context(rnnlm_->GetHiddenLayerSize());
BaseFloat logprob = rnnlm_->GetLogProb(ilabel, wseq,
state_to_context_[s], &new_context);

wseq.push_back(ilabel);
if (max_ngram_order_ > 0) {
while (wseq.size() >= max_ngram_order_) {
// History state has at most <max_ngram_order_> - 1 words in the state.
wseq.erase(wseq.begin(), wseq.begin() + 1);
}
}

std::pair<const std::vector<Label>, StateId> wseq_state_pair(
wseq, static_cast<Label>(state_to_wseq_.size()));

// Attemps to insert the current <lseq_state_pair>. If the pair already exists
// then it returns false.
typedef MapType::iterator IterType;
std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);

// If the pair was just inserted, then also add it to <state_to_wseq_> and
// <state_to_context_>.
if (result.second == true) {
state_to_wseq_.push_back(wseq);
state_to_context_.push_back(new_context);
}

// Creates the arc.
oarc->ilabel = ilabel;
oarc->olabel = ilabel;
oarc->nextstate = result.first->second;
oarc->weight = Weight(-logprob);

return true;
}

} // namespace kaldi
Loading

0 comments on commit 6042649

Please sign in to comment.