Skip to content

Commit

Permalink
Rearrange computation so that one loop over extensions could be done
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 18, 2016
1 parent 5db51fb commit fe31e8e
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions lm/interpolate/tune_derivatives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Accum Derivatives(Instances &in, const Vector &weights, Vector &gradient, Matrix
Matrix convolve;
Vector full_cross;

Matrix hessian_missing_Z_context;

// TODO make configurable memory size.
// TODO make use of this.
util::stream::Chain chain(util::stream::ChainConfig(in.ReadExtensionsEntrySize(), 2, 64 << 20));
Expand All @@ -38,48 +40,51 @@ Accum Derivatives(Instances &in, const Vector &weights, Vector &gradient, Matrix

// Compute \sum_{x: model does not backoff to unigram} p_I(x)
Accum sum_x_p_I = 0.0;
for (std::vector<WordIndex>::const_iterator x = n->extension_words.begin(); x != n->extension_words.end(); ++x) {
sum_x_p_I += interp_uni(*x);
}

// This should be divided by Z_context then added to the Hessian.
hessian_missing_Z_context = Matrix::Zero(weights.rows(), weights.rows());
// TODO move into loop over context.
weighted_extensions = (n->ln_extensions * weights).array().exp();
Accum Z_context = Z_epsilon * weighted_backoffs * (1.0 - sum_x_p_I) + weighted_extensions.sum();
sum_ln_Z_context += log(Z_context); // not used for rest of loop
// These are the correct values but need to be divided by Z_context.
full_cross = n->ln_extensions.transpose() * weighted_extensions;

// Adjust the first term of the Hessian to account for extension
for (std::size_t x = 0; x < n->extension_words.size(); ++x) {
WordIndex universal_x = n->extension_words[x];
sum_x_p_I += interp_uni(universal_x);
full_cross.noalias() -= interp_uni(*universal_x) * Z_epsilon * weighted_backoffs /* we'll divide by Z_context later to form B_I */ * in.LNUnigrams().row(*universal_x);
hessian_missing_Z_context.noalias() +=
// Replacement terms.
weighted_extensions(x) * n->ln_extensions.row(x).transpose() * n->ln_extensions.row(x)
// Presumed unigrams. TODO: individual terms with backoffs pulled out? Maybe faster?
- interp_uni(universal_x) * Z_epsilon * weighted_backoffs * (in.LNUnigrams().row(universal_x).transpose() + in.LNBackoffs(n)) * (in.LNUnigrams().row(universal_x) + in.LNBackoffs(n).transpose());
}

Accum Z_context = Z_epsilon * weighted_backoffs * (1.0 - sum_x_p_I) + weighted_extensions.sum();
sum_ln_Z_context += log(Z_context);
Accum B_I = Z_epsilon / Z_context * weighted_backoffs;
sum_B_I += B_I; // not used for rest of loop
sum_B_I += B_I;

// This is the gradient term for this instance except for -log p_i(w_n | w_1^{n-1}) which was accounted for as part of neg_correct_sum_.
// full_cross(i) is \sum_{all x} p_I(x | context) log p_i(x | context)
full_cross =
// Prior terms excluded dividing by Z_context because it wasn't known at the time.
full_cross /= Z_context;
full_cross +=
// Uncorrected term
B_I * (in.LNBackoffs(n) + unigram_cross)
// Correction term: add correct values
+ n->ln_extensions.transpose() * weighted_extensions / Z_context
// Subtract values that should not have been charged.
- sum_x_p_I * B_I * in.LNBackoffs(n);
for (std::vector<WordIndex>::const_iterator x = n->extension_words.begin(); x != n->extension_words.end(); ++x) {
full_cross.noalias() -= interp_uni(*x) * B_I * in.LNUnigrams().row(*x);
}

gradient += full_cross;

convolve = unigram_cross * in.LNBackoffs(n).transpose();
// There's one missing term here, which is independent of context and done at the end.
hessian.noalias() +=
// First term of Hessian, assuming all models back off to unigram.
B_I * (convolve + convolve.transpose() + in.LNBackoffs(n) * in.LNBackoffs(n).transpose())
// Error in the first term, correcting from unigram to full probabilities.
+ hessian_missing_Z_context / Z_context
// Second term of Hessian, with correct full probabilities.
- full_cross * full_cross.transpose();

// Adjust the first term of the Hessian to account for extension
for (std::size_t x = 0; x < n->extension_words.size(); ++x) {
WordIndex universal_x = n->extension_words[x];
hessian.noalias() +=
// Replacement terms.
weighted_extensions(x) / Z_context * n->ln_extensions.row(x).transpose() * n->ln_extensions.row(x)
// Presumed unigrams. TODO: individual terms with backoffs pulled out? Maybe faster?
- interp_uni(universal_x) * B_I * (in.LNUnigrams().row(universal_x).transpose() + in.LNBackoffs(n)) * (in.LNUnigrams().row(universal_x) + in.LNBackoffs(n).transpose());
}
}

for (Matrix::Index x = 0; x < interp_uni.rows(); ++x) {
Expand Down

0 comments on commit fe31e8e

Please sign in to comment.