Skip to content

Commit

Permalink
ADD: AUC Metric
Browse files Browse the repository at this point in the history
  • Loading branch information
xswang committed Nov 18, 2017
1 parent 4c9a8f5 commit 917620b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/loss/metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ REGISTER_METRIC("f1", F1Metric);
REGISTER_METRIC("mae", MAEMetric);
REGISTER_METRIC("mape", MAPEMetric);
REGISTER_METRIC("rmsd", RMSDMetric);
REGISTER_METRIC("auc", AUCMetric);

} // namespace xLearn
105 changes: 105 additions & 0 deletions src/loss/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,111 @@ class F1Metric : public Metric {
DISALLOW_COPY_AND_ASSIGN(F1Metric);
};

class AUCMetric : public Metric {
public:
struct Info {
Info() {
click_vec_.resize(MAX_BUCKET_SIZE, 0);
noclick_vec_.resize(MAX_BUCKET_SIZE, 0);
}
std::vector<int32_t> click_vec_;
std::vector<int32_t> noclick_vec_;
};

public:
AUCMetric()
: auc_(0.0) {
init();
}
~AUCMetric() { }

void init() {
glo_click_number_.resize(MAX_BUCKET_SIZE, 0);
glo_noclick_number_.resize(MAX_BUCKET_SIZE, 0);
}

static void auc_accum_thread(const std::vector<real_t>* Y,
const std::vector<real_t>* pred,
Info* info,
size_t start_idx,
size_t end_idx) {
CHECK_GE(end_idx, start_idx);
for (size_t i = start_idx; i < end_idx; ++i) {
real_t r_label = (*Y)[i] > 0 ? 1 : -1;
int32_t bkt_id = int32_t((*pred)[i] * MAX_BUCKET_SIZE);
if (r_label > 0) {
info->click_vec_[bkt_id] += 1;
} else {
info->noclick_vec_[bkt_id] += 1;
}
} // end for
} // end auc_accum_thread

void Accumulate(const std::vector<real_t>& Y,
const std::vector<real_t>& pred) {
CHECK_EQ(Y.size(), pred.size());
// multi-thread
Info single_info;
std::vector<Info> info(threadNumber_, single_info);
for (int i = 0; i < threadNumber_; ++i) {
size_t start_idx = getStart(pred.size(), threadNumber_, i);
size_t end_idx = getEnd(pred.size(), threadNumber_, i);
pool_->enqueue(std::bind(auc_accum_thread,
&Y,
&pred,
&info[i],
start_idx,
end_idx));
}
pool_->Sync(threadNumber_);
for (size_t i = 0; i < info.size(); ++i) {
for (int32_t j = 0; j < MAX_BUCKET_SIZE; ++j) {
glo_click_number_[j] += info[i].click_vec_[j];
glo_noclick_number_[j] += info[i].noclick_vec_[j];
} // end for
} // end for
auc_ = CalcAUC(glo_click_number_, glo_noclick_number_);
}

double CalcAUC(std::vector<int32_t> click_vec,
std::vector<int32_t> noclick_vec) {
CHECK_EQ(click_vec.size(), noclick_vec.size());
int32_t click_sum = 0;
int32_t noclick_sum= 0;
int32_t pre_click_sum = 0.0;
int32_t clicksum_dot_noclicksum = 0;
double auc = 0.0;
double auc_res = 0.0;
for (int32_t i = 0; i < MAX_BUCKET_SIZE; ++i) {
pre_click_sum = click_sum;
click_sum += glo_click_number_[i];
noclick_sum += glo_noclick_number_[i];
auc += (pre_click_sum + click_sum) * glo_noclick_number_[i] * 1.0 / 2;
}
clicksum_dot_noclicksum = click_sum * noclick_sum;
auc_res = auc / (clicksum_dot_noclicksum);
return 1.0 - auc_res;
}

inline void Reset() {
}

inline real_t GetMetric() {
return auc_;
}

inline std::string metric_type() {
return "AUC";
}

private:
double auc_;
const static int32_t MAX_BUCKET_SIZE = 2;
std::vector<int32_t> glo_noclick_number_;
std::vector<int32_t> glo_click_number_;
private:
DISALLOW_COPY_AND_ASSIGN(AUCMetric);
};

/*********************************************************
* For regression *
Expand Down

0 comments on commit 917620b

Please sign in to comment.