Skip to content

Commit

Permalink
[YDF] Move loss options definition
Browse files Browse the repository at this point in the history
Old: learner/gradient_boosted_trees/gradient_boosted_trees.proto
New: model/gradient_boosted_trees/gradient_boosted_trees.proto
PiperOrigin-RevId: 676396458
  • Loading branch information
rstz authored and copybara-github committed Sep 19, 2024
1 parent 0d66aba commit 7dafafd
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 63 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Changelog under `yggdrasil_decision_forests/port/python/CHANGELOG.md`.
but can still be used.
- Allow configuring the truncation of NDCG losses.

### Misc

- Loss options are now defined
model/gradient_boosted_trees/gradient_boosted_trees.proto (previously
learner/gradient_boosted_trees/gradient_boosted_trees.proto)

## 1.10.0 - 2024-08-21

### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,66 +177,9 @@ message GradientBoostedTreesTrainingConfig {
optional int32 early_stopping_initial_iteration = 37 [default = 10];

oneof loss_options {
LambdaMartNdcg lambda_mart_ndcg = 12;
XeNdcg xe_ndcg = 26;
BinaryFocalLossOptions binary_focal_loss_options = 36;
}

message LambdaMartNdcg {
// If false, the gradient is computed using NDCG i.e. normalized-DCG. If
// false, the gradient is computed using DCG.
optional bool gradient_use_non_normalized_dcg = 1 [default = false];

// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the cross-entropy NDCG loss must be
// configured separately.
optional int32 ndcg_truncation = 2 [default = 5];
}

message XeNdcg {
enum Gamma {
// For the time being, defaults to UNIFORM.
AUTO = 0;
// Gammas are sampled from a uniform distribution on [0, 1].
UNIFORM = 1;
// Gammas are set to 1 across the board. This is more appropriate for
// click datasets with a large number of documents per query.
ONE = 2;
}
optional Gamma gamma = 1 [default = UNIFORM];

// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the classic NDCG loss must be configured
// separately.
optional int32 ndcg_truncation = 2 [default = 5];
}

message BinaryFocalLossOptions {
// Exponent of the misprediction multiplier in focal loss.
// Corresponds to the gamma parameter in
// https://arxiv.org/pdf/1708.02002.pdf
optional float misprediction_exponent = 1 [default = 2.0];

// A hypertuning coefficient to multiply the loss and its gradient(s) in
// case of a positive sample.
// Loss and gradient on positive samples will be multiplied by
// positive_sample_coefficient, on negative samples will be multiplied
// by (1 - positive_sample_coefficient)
// Corresponds to the 'alpha' parameter in
// https://arxiv.org/pdf/1708.02002.pdf
optional float positive_sample_coefficient = 2 [default = 0.5];
LossConfiguration.LambdaMartNdcg lambda_mart_ndcg = 12;
LossConfiguration.XeNdcg xe_ndcg = 26;
LossConfiguration.BinaryFocalLossOptions binary_focal_loss_options = 36;
}

// L2 regularization on the tree predictions i.e. on the value of the leaf.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/learner/decision_tree:training",
"//yggdrasil_decision_forests/learner/gradient_boosted_trees:gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/utils:compatibility",
"//yggdrasil_decision_forests/utils:concurrency",
"//yggdrasil_decision_forests/utils:random",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/utils/compatibility.h"
#include "yggdrasil_decision_forests/utils/concurrency.h"
#include "yggdrasil_decision_forests/utils/random.h"
Expand Down Expand Up @@ -104,11 +105,11 @@ absl::Status CrossEntropyNDCGLoss::UpdateGradients(
params.resize(group_size);

switch (gbt_config_.xe_ndcg().gamma()) {
case proto::GradientBoostedTreesTrainingConfig::XeNdcg::ONE:
case proto::LossConfiguration::XeNdcg::ONE:
std::fill(params.begin(), params.end(), 1.f);
break;
case proto::GradientBoostedTreesTrainingConfig::XeNdcg::AUTO:
case proto::GradientBoostedTreesTrainingConfig::XeNdcg::UNIFORM:
case proto::LossConfiguration::XeNdcg::AUTO:
case proto::LossConfiguration::XeNdcg::UNIFORM:
for (int item_idx = 0; item_idx < group_size; item_idx++) {
params[item_idx] = distribution(*random);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,62 @@ extend model.proto.SerializedModel {
optional GradientBoostedTreesSerializedModel
gradient_boosted_trees_serialized_model = 1001;
}

message LossConfiguration {
message LambdaMartNdcg {
// If false, the gradient is computed using NDCG i.e. normalized-DCG. If
// false, the gradient is computed using DCG.
optional bool gradient_use_non_normalized_dcg = 1 [default = false];

// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the cross-entropy NDCG loss must be
// configured separately.
optional int32 ndcg_truncation = 2 [default = 5];
}

message XeNdcg {
enum Gamma {
// For the time being, defaults to UNIFORM.
AUTO = 0;
// Gammas are sampled from a uniform distribution on [0, 1].
UNIFORM = 1;
// Gammas are set to 1 across the board. This is more appropriate for
// click datasets with a large number of documents per query.
ONE = 2;
}
optional Gamma gamma = 1 [default = UNIFORM];

// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the classic NDCG loss must be configured
// separately.
optional int32 ndcg_truncation = 2 [default = 5];
}

message BinaryFocalLossOptions {
// Exponent of the misprediction multiplier in focal loss.
// Corresponds to the gamma parameter in
// https://arxiv.org/pdf/1708.02002.pdf
optional float misprediction_exponent = 1 [default = 2.0];

// A hypertuning coefficient to multiply the loss and its gradient(s) in
// case of a positive sample.
// Loss and gradient on positive samples will be multiplied by
// positive_sample_coefficient, on negative samples will be multiplied
// by (1 - positive_sample_coefficient)
// Corresponds to the 'alpha' parameter in
// https://arxiv.org/pdf/1708.02002.pdf
optional float positive_sample_coefficient = 2 [default = 0.5];
}
}

0 comments on commit 7dafafd

Please sign in to comment.