Skip to content

Commit

Permalink
Subsume tune_main into interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 20, 2016
1 parent 733ae5a commit 4122e98
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 69 deletions.
10 changes: 6 additions & 4 deletions lm/interpolate/Jamfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ obj tune_instances_test.o : tune_instances_test.cc ..//kenlm /top//boost_unit_te
run tune_instances_test.o tune_instances.o interp ..//kenlm ../common//common ../../util/stream//stream /top//boost_unit_test_framework : : tune_instances_data/toy0.1 ;
fakelib tuning : tune_instances.o tune_derivatives.cc tune_weights.cc interp ..//kenlm : <include>$(with-eigen) ;
unit-test tune_derivatives_test : tune_derivatives_test.cc tuning /top//boost_unit_test_framework : <include>$(with-eigen) ;
exe tune : tune_main.cc tuning /top//boost_program_options : <include>$(with-eigen) ;
explicit tune_instances.o tune_instances_test.o tune_instances_test tuning tune_derivatives_test tune ;
alias all_tuning : : [ check-target-builds tune_have_eigen_init_parallel.o "Eigen has Eigen::initParallel() required for log-linear tuning" : <source>tune_instances_test <source>tune <source>tune_derivatives_test ] ;

exe interpolate : interpolate_main.cc interp /top//boost_program_options ;
#Given weights, interpolation doesn't require Eigen. But it's pretty useless without that and this is also the main to tune weights.
exe interpolate : interpolate_main.cc tuning interp /top//boost_program_options : <include>$(with-eigen) ;
explicit tune_instances.o tune_instances_test.o tune_instances_test tuning tune_derivatives_test tune interpolate ;

alias all_eigen : : [ check-target-builds tune_have_eigen_init_parallel.o "Eigen has Eigen::initParallel() required for log-linear tuning" : <source>tune_instances_test <source>interpolate <source>tune_derivatives_test ] ;

exe streaming_example : ../builder//builder interp streaming_example_main.cc /top//boost_program_options ;

unit-test normalize_test : interp normalize_test.cc /top//boost_unit_test_framework ;
Expand Down
62 changes: 53 additions & 9 deletions lm/interpolate/interpolate_main.cc
Original file line number Diff line number Diff line change
@@ -1,43 +1,87 @@
#include "lm/common/model_buffer.hh"
#include "lm/common/size_option.hh"
#include "lm/interpolate/pipeline.hh"
#include "lm/interpolate/tune_instances.hh"
#include "lm/interpolate/tune_weights.hh"
#include "util/fixed_array.hh"
#include "util/usage.hh"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains.
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#include <Eigen/Core>
#pragma GCC diagnostic pop

#include <boost/program_options.hpp>

#include <iostream>
#include <vector>

int main(int argc, char *argv[]) {
lm::interpolate::Config config;
lm::interpolate::Config pipe_config;
lm::interpolate::InstancesConfig instances_config;
std::vector<std::string> input_models;
std::string tuning_file;

namespace po = boost::program_options;
po::options_description options("Log-linear interpolation options");
options.add_options()
("help,h", po::bool_switch(), "Show this help message")
("lambda,w", po::value<std::vector<float> >(&config.lambdas)->multitoken()->required(), "Interpolation weights")
("model,m", po::value<std::vector<std::string> >(&input_models)->multitoken()->required(), "Models to interpolate")
("temp_prefix,T", po::value<std::string>(&config.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", lm::SizeOption(config.sort.total_memory, util::GuessPhysicalMemory() ? "50%" : "1G"), "Sorting memory")
("sort_block", lm::SizeOption(config.sort.buffer_size, "64M"), "Block size");
("lambda,w", po::value<std::vector<float> >(&pipe_config.lambdas)->multitoken(), "Interpolation weights")
("tuning,t", po::value<std::string>(&tuning_file), "File to tune on: a text file with one sentence per line")
("just_tune", po::bool_switch(), "Tune and print weights then quit")
("temp_prefix,T", po::value<std::string>(&pipe_config.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", lm::SizeOption(pipe_config.sort.total_memory, util::GuessPhysicalMemory() ? "50%" : "1G"), "Sorting memory: this is a very rough guide")
("sort_block", lm::SizeOption(pipe_config.sort.buffer_size, "64M"), "Block size");
po::variables_map vm;

po::store(po::parse_command_line(argc, argv, options), vm);
if (argc == 1 || vm["help"].as<bool>()) {
std::cerr << "Interpolate multiple models\n\n" << options << std::endl;
return 1;
}
po::notify(vm);
instances_config.sort = pipe_config.sort;
instances_config.lazy_memory = instances_config.sort.total_memory;
instances_config.model_read_chain_mem = instances_config.sort.buffer_size;

if (pipe_config.lambdas.empty() && tuning_file.empty()) {
std::cerr << "Provide a tuning file with -t xor weights with -w." << std::endl;
return 1;
}
if (!pipe_config.lambdas.empty() && !tuning_file.empty()) {
std::cerr << "Provide weights xor a tuning file, not both." << std::endl;
return 1;
}

if (!tuning_file.empty()) {
// Tune weights
std::vector<StringPiece> model_names;
for (std::vector<std::string>::const_iterator i = input_models.begin(); i != input_models.end(); ++i) {
model_names.push_back(*i);
}
lm::interpolate::TuneWeights(util::OpenReadOrThrow(tuning_file.c_str()), model_names, instances_config, pipe_config.lambdas);

std::cerr << "Final weights:";
std::ostream &to = vm["just_tune"].as<bool>() ? std::cout : std::cerr;
for (std::vector<float>::const_iterator i = pipe_config.lambdas.begin(); i != pipe_config.lambdas.end(); ++i) {
to << ' ' << *i;
}
to << std::endl;
}
if (vm["just_tune"].as<bool>()) {
return 0;
}

if (config.lambdas.size() != input_models.size()) {
std::cerr << "Number of models " << input_models.size() << " should match the number of weights" << config.lambdas.size() << "." << std::endl;
if (pipe_config.lambdas.size() != input_models.size()) {
std::cerr << "Number of models " << input_models.size() << " should match the number of weights" << pipe_config.lambdas.size() << "." << std::endl;
return 1;
}

util::FixedArray<lm::ModelBuffer> models(input_models.size());
for (std::size_t i = 0; i < input_models.size(); ++i) {
models.push_back(input_models[i]);
}
lm::interpolate::Pipeline(models, config, 1);
lm::interpolate::Pipeline(models, pipe_config, 1);
return 0;
}
52 changes: 0 additions & 52 deletions lm/interpolate/tune_main.cc

This file was deleted.

5 changes: 3 additions & 2 deletions lm/interpolate/tune_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include <iostream>

namespace lm { namespace interpolate {
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, Vector &weights) {
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights_out) {
Instances instances(tune_file, model_names, config);
weights = Vector::Constant(model_names.size(), 1.0 / model_names.size());
Vector weights = Vector::Constant(model_names.size(), 1.0 / model_names.size());
Vector gradient;
Matrix hessian;
for (std::size_t iteration = 0; iteration < 10 /*TODO fancy stopping criteria */; ++iteration) {
Expand All @@ -28,5 +28,6 @@ void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, con
// TODO: 1.0 step size was too big and it kept getting unstable. More math.
weights -= 0.7 * hessian.inverse() * gradient;
}
weights_out.assign(weights.data(), weights.data() + weights.size());
}
}} // namespaces
3 changes: 1 addition & 2 deletions lm/interpolate/tune_weights.hh
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef LM_INTERPOLATE_TUNE_WEIGHTS_H
#define LM_INTERPOLATE_TUNE_WEIGHTS_H

#include "lm/interpolate/tune_matrix.hh"
#include "util/string_piece.hh"

#include <vector>
Expand All @@ -10,7 +9,7 @@ namespace lm { namespace interpolate {
class InstancesConfig;

// Run a tuning loop, producing weights as output.
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, Vector &weights);
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights);

}} // namespaces
#endif // LM_INTERPOLATE_TUNE_WEIGHTS_H

0 comments on commit 4122e98

Please sign in to comment.