Skip to content

Commit

Permalink
Support weight in starspace (facebookresearch#76)
Browse files Browse the repository at this point in the history
* local changes

* add support to weights

* [going to revert this]

* [new change: add a flag]

* save new flag

* Update README.md

* print new flat
  • Loading branch information
ledw authored Nov 10, 2017
1 parent e180b8c commit eec8249
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 90 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ use <a href="https://github.com/facebookresearch/Starspace/blob/master/examples/

The following arguments are optional:
-normalizeText whether to run basic text preprocess for input files [1]
-useWeight whether input file contains weights [0]
-verbose verbosity level [0]
-debug whether it's in debug mode [0]
-thread number of threads [10]
Expand Down
8 changes: 4 additions & 4 deletions src/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void InternDataHandler::convert(
}

void InternDataHandler::getWordExamples(
const vector<int32_t>& doc,
const vector<Base>& doc,
vector<ParseResults>& rslts) const {

rslts.clear();
Expand Down Expand Up @@ -206,7 +206,7 @@ void InternDataHandler::getNextKExamples(int K, vector<ParseResults>& c) {

// Randomly sample one example and randomly sample a label from this example
// The result is usually used as negative samples in training
void InternDataHandler::getRandomRHS(vector<int32_t>& results) const {
void InternDataHandler::getRandomRHS(vector<Base>& results) const {
assert(size_ > 0);
results.clear();
auto& ex = examples_[rand() % size_];
Expand All @@ -231,10 +231,10 @@ void InternDataHandler::save(std::ostream& out) {
out << "data size : " << size_ << endl;
for (auto& example : examples_) {
out << "lhs : ";
for (auto t : example.LHSTokens) {out << t << ' ';}
for (auto t : example.LHSTokens) {out << t.first << ':' << t.second << ' ';}
out << endl;
out << "rhs : ";
for (auto t : example.RHSTokens) {out << t << ' ';}
for (auto t : example.RHSTokens) {out << t.first << ':' << t.second << ' ';}
out << endl;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class InternDataHandler {

virtual void convert(const ParseResults& example, ParseResults& rslt) const;

virtual void getRandomRHS(std::vector<int32_t>& results) const;
virtual void getRandomRHS(std::vector<Base>& results) const;

virtual void save(std::ostream& out);

virtual void getWordExamples(int idx, std::vector<ParseResults>& rslt) const;

void getWordExamples(
const std::vector<int32_t>& doc,
const std::vector<Base>& doc,
std::vector<ParseResults>& rslt) const;

void addExample(const ParseResults& example);
Expand Down
10 changes: 5 additions & 5 deletions src/doc_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ void LayerDataHandler::loadFromFile(
}

void LayerDataHandler::insert(
vector<int32_t>& rslt,
const vector<int32_t>& ex,
vector<Base>& rslt,
const vector<Base>& ex,
float dropout) const {

if (dropout < 1e-8) {
Expand Down Expand Up @@ -156,7 +156,7 @@ void LayerDataHandler::convert(
}
}

void LayerDataHandler::getRandomRHS(vector<int32_t>& result) const {
void LayerDataHandler::getRandomRHS(vector<Base>& result) const {
assert(size_ > 0);
auto& ex = examples_[rand() % size_];
int r = rand() % ex.RHSFeatures.size();
Expand All @@ -183,11 +183,11 @@ void LayerDataHandler::save(ostream& out) {
for (auto example : examples_) {
out << "lhs: ";
for (auto t : example.LHSTokens) {
out << t << ' ';
out << t.first << ':' << t.second << ' ';
}
out << "\nrhs: ";
for (auto feat : example.RHSFeatures) {
for (auto r : feat) { cout << r << ' '; }
for (auto r : feat) { cout << r.first << ':' << r.second << ' '; }
out << "\t";
}
out << endl;
Expand Down
6 changes: 3 additions & 3 deletions src/doc_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ class LayerDataHandler : public InternDataHandler {
void loadFromFile(const std::string& file,
std::shared_ptr<DataParser> parser) override;

void getRandomRHS(std::vector<int32_t>& results) const override;
void getRandomRHS(std::vector<Base>& results) const override;

void save(std::ostream& out) override;

private:
void insert(
std::vector<int32_t>& rslt,
const std::vector<int32_t>& ex,
std::vector<Base>& rslt,
const std::vector<Base>& ex,
float dropout = 0.0) const;

};
Expand Down
20 changes: 15 additions & 5 deletions src/doc_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,30 @@ LayerDataParser::LayerDataParser(

bool LayerDataParser::parse(
string& s,
vector<int32_t>& feats,
vector<Base>& feats,
const string& sep) {

// split each part into tokens
vector<string> tokens;
boost::split(tokens, s, boost::is_any_of(string(sep)));

for (auto token : tokens) {
string t = token;
float weight = 1.0;
if (args_->useWeight) {
std::size_t pos = token.find(":");
if (pos != std::string::npos) {
t = token.substr(0, pos);
weight = atof(token.substr(pos + 1).c_str());
}
}

if (args_->normalizeText) {
normalize_text(token);
normalize_text(t);
}
int32_t wid = dict_->getId(token);
int32_t wid = dict_->getId(t);
if (wid != -1) {
feats.push_back(wid);
feats.push_back(make_pair(wid, weight));
}
}

Expand All @@ -64,7 +74,7 @@ bool LayerDataParser::parse(
start_idx = 1;
}
for (int i = start_idx; i < parts.size(); i++) {
vector<int32_t> feats;
vector<Base> feats;
if (parse(parts[i], feats)) {
rslt.RHSFeatures.push_back(feats);
}
Expand Down
2 changes: 1 addition & 1 deletion src/doc_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LayerDataParser : public DataParser {

bool parse(
std::string& line,
std::vector<int32_t>& rslt,
std::vector<Base>& rslt,
const std::string& sep=" ");

bool parse(
Expand Down
63 changes: 37 additions & 26 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,30 @@ Real norm2(Matrix<Real>::Row a) {
return std::max(std::numeric_limits<Real>::epsilon(), retval);
}

Matrix<Real> EmbedModel::projectRHS(std::vector<int32_t> ws) {
// consistent accessor methods for straight indices and index-weight pairs
int32_t index(int32_t idx) { return idx; }
int32_t index(std::pair<int32_t, Real> idxWeightPair) {
return idxWeightPair.first;
}

constexpr float weight(int32_t idx) { return 1.0; }
float weight(std::pair<int32_t, Real> idxWeightPair) {
return idxWeightPair.second;
}

Matrix<Real> EmbedModel::projectRHS(const std::vector<Base>& ws) {
Matrix<Real> retval;
projectRHS(ws, retval);
return retval;
}

Matrix<Real> EmbedModel::projectLHS(std::vector<int32_t> ws) {
Matrix<Real> EmbedModel::projectLHS(const std::vector<Base>& ws) {
Matrix<Real> retval;
projectLHS(ws, retval);
return retval;
}

void EmbedModel::projectLHS(std::vector<int32_t> ws, Matrix<Real>& retval) {
void EmbedModel::projectLHS(const std::vector<Base>& ws, Matrix<Real>& retval) {
LHSEmbeddings_->forward(ws, retval);
if (ws.size()) {
auto norm = (args_->similarity == "dot") ?
Expand All @@ -105,7 +116,7 @@ void EmbedModel::projectLHS(std::vector<int32_t> ws, Matrix<Real>& retval) {
}
}

void EmbedModel::projectRHS(std::vector<int32_t> ws, Matrix<Real>& retval) {
void EmbedModel::projectRHS(const std::vector<Base>& ws, Matrix<Real>& retval) {
RHSEmbeddings_->forward(ws, retval);
if (ws.size()) {
auto norm = (args_->similarity == "dot") ?
Expand Down Expand Up @@ -172,10 +183,10 @@ Real EmbedModel::train(shared_ptr<InternDataHandler> data,
continue;
}

if (args_->debug) {
auto printVec = [&](const vector<int32_t>& vec) {
if (amMaster && args_->debug) {
auto printVec = [&](const vector<Base>& vec) {
cout << "vec : ";
for (auto v : vec) {cout << v << ' ';}
for (auto v : vec) {cout << v.first << ':' << v.second << ' ';}
cout << endl;
};

Expand Down Expand Up @@ -304,8 +315,8 @@ void EmbedModel::normalize(Matrix<float>::Row row, double maxNorm) {
}

float EmbedModel::trainOne(shared_ptr<InternDataHandler> data,
const vector<int32_t>& items,
const vector<int32_t>& labels,
const vector<Base>& items,
const vector<Base>& labels,
size_t negSearchLimit,
Real rate0) {
if (items.size() == 0) return 0.0; // nothing to learn.
Expand Down Expand Up @@ -344,14 +355,14 @@ float EmbedModel::trainOne(shared_ptr<InternDataHandler> data,
// Select negative examples
Real loss = 0.0;
std::vector<Matrix<Real>> negs;
std::vector<std::vector<int32_t>> negLabelsBatch;
std::vector<std::vector<Base>> negLabelsBatch;
Matrix<Real> negMean;
negMean.matrix = zero_matrix<Real>(1, cols);

for (int i = 0; i < negSearchLimit &&
negs.size() < args_->maxNegSamples; i++) {

std::vector<int32_t> negLabels;
std::vector<Base> negLabels;
do {
data->getRandomRHS(negLabels);
} while (negLabels == labels);
Expand Down Expand Up @@ -407,8 +418,8 @@ float EmbedModel::trainOne(shared_ptr<InternDataHandler> data,
}

float EmbedModel::trainNLL(shared_ptr<InternDataHandler> data,
const vector<int32_t>& items,
const vector<int32_t>& labels,
const vector<Base>& items,
const vector<Base>& labels,
int32_t negSearchLimit,
Real rate0) {
if (items.size() == 0) return 0.0; // nothing to learn.
Expand All @@ -426,13 +437,13 @@ float EmbedModel::trainNLL(shared_ptr<InternDataHandler> data,
auto numClass = args_->negSearchLimit + 1;
std::vector<Real> prob(numClass);
std::vector<Matrix<Real>> negClassVec;
std::vector<std::vector<int32_t>> negLabelsBatch;
std::vector<std::vector<Base>> negLabelsBatch;

prob[0] = dot(lhs, rhsP);
Real max = prob[0];

for (int i = 1; i < numClass; i++) {
std::vector<int32_t> negLabels;
std::vector<Base> negLabels;
do {
data->getRandomRHS(negLabels);
} while (negLabels == labels);
Expand Down Expand Up @@ -491,9 +502,9 @@ float EmbedModel::trainNLL(shared_ptr<InternDataHandler> data,
}

void EmbedModel::backward(
const vector<int32_t>& items,
const vector<int32_t>& labels,
const vector<vector<int32_t>>& negLabels,
const vector<Base>& items,
const vector<Base>& labels,
const vector<vector<Base>>& negLabels,
Matrix<Real>& gradW,
Matrix<Real>& lhs,
Real rate_lhs,
Expand Down Expand Up @@ -535,21 +546,21 @@ void EmbedModel::backward(

// Update input items.
for (auto w : items) {
auto row = LHSEmbeddings_->row(w);
update(row, gradW, rate_lhs, n1, LHSUpdates_, w);
auto row = LHSEmbeddings_->row(index(w));
update(row, gradW, rate_lhs * weight(w), n1, LHSUpdates_, index(w));
}

// Update positive example.
for (auto label : labels) {
auto row = RHSEmbeddings_->row(label);
update(row, lhs, rate_rhsP, n2, RHSUpdates_, label);
for (auto la : labels) {
auto row = RHSEmbeddings_->row(index(la));
update(row, lhs, rate_rhsP * weight(la), n2, RHSUpdates_, index(la));
}

// Update negative example.
for (size_t i = 0; i < negLabels.size(); i++) {
for (auto label : negLabels[i]) {
auto row = RHSEmbeddings_->row(label);
update(row, lhs, rate_rhsN[i], n2, RHSUpdates_, label);
for (auto la : negLabels[i]) {
auto row = RHSEmbeddings_->row(index(la));
update(row, lhs, rate_rhsN[i] * weight(la), n2, RHSUpdates_, index(la));
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,20 @@ struct EmbedModel : public boost::noncopyable {
}

float trainOne(std::shared_ptr<InternDataHandler> data,
const std::vector<int32_t>& items,
const std::vector<int32_t>& labels,
const std::vector<Base>& items,
const std::vector<Base>& labels,
size_t maxNegSamples,
Real rate);

float trainNLL(std::shared_ptr<InternDataHandler> data,
const std::vector<int32_t>& items,
const std::vector<int32_t>& labels,
const std::vector<Base>& items,
const std::vector<Base>& labels,
int32_t negSearchLimit,
Real rate);

void backward(const std::vector<int32_t>& items,
const std::vector<int32_t>& labels,
const std::vector<std::vector<int32_t>>& negLabels,
void backward(const std::vector<Base>& items,
const std::vector<Base>& labels,
const std::vector<std::vector<Base>>& negLabels,
Matrix<Real>& gradW,
Matrix<Real>& lhs,
Real rate_lhs,
Expand All @@ -91,11 +91,11 @@ struct EmbedModel : public boost::noncopyable {
return kNN(RHSEmbeddings_, point, numSim);
}

Matrix<Real> projectRHS(std::vector<int32_t> ws);
Matrix<Real> projectLHS(std::vector<int32_t> ws);
Matrix<Real> projectRHS(const std::vector<Base>& ws);
Matrix<Real> projectLHS(const std::vector<Base>& ws);

void projectLHS(std::vector<int32_t> ws, Matrix<Real>& retval);
void projectRHS(std::vector<int32_t> ws, Matrix<Real>& retval);
void projectLHS(const std::vector<Base>& ws, Matrix<Real>& retval);
void projectRHS(const std::vector<Base>& ws, Matrix<Real>& retval);

void loadTsv(std::istream& in, const std::string sep = "\t ");
void loadTsv(const char* fname, const std::string sep = "\t ");
Expand Down
Loading

0 comments on commit eec8249

Please sign in to comment.