Skip to content

Commit 160ed25

Browse files
committed
not finished yet
1 parent 6451f12 commit 160ed25

File tree

1 file changed

+22
-88
lines changed

1 file changed

+22
-88
lines changed

AdaBoost/template/decision_stump_adaboost.h

+22-88
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,9 @@ class DecisionStumpAdaBoost : public AdaBoost {
2626
void init_weak_classifier();
2727
// learn the weak classifier
2828
void weak_classifier_learn(DecisionStump & stump);
29+
// calculate the sumation of the weights
30+
void calculate_weight_sum(double &weight_sum, double &weight_label_sum, double &positive_weight_sum, double &negative_weight_sum);
2931

30-
// compute the outputs of the weak classifier with precalculated weights
31-
bool calculate_classifier_outputs(const double weight_label_sum_right,
32-
double &output_right,
33-
double &output_left,
34-
const double weight_sum = 0.0,
35-
const double weight_label_sum = 0.0,
36-
const double weight_sum_right = 0.0,
37-
const double positive_weight_sum_right = 0.0,
38-
const double negative_weight_sum_right = 0.0,
39-
const double positive_weight_sum = 0.0,
40-
const double negative_weight_sum = 0.0
41-
);
42-
43-
// compute the error with precalculated weights, and the outputs of the weak classifier
44-
double compute_error(const double positive_weight_sum_right,
45-
const double negative_weight_sum_right,
46-
const double positive_weight_sum_left,
47-
const double negative_weight_sum_left,
48-
const double output_right,
49-
const double output_left
50-
) const;
5132

5233
private:
5334
std::vector<DecisionStump> _weak_classifier_vec;
@@ -69,76 +50,29 @@ void DecisionStumpAdaBoost::weak_classifier_learn(DecisionStump &stump) {
6950
//
7051
}
7152

72-
bool DecisionStumpAdaBoost::calculate_classifier_outputs(const double weight_label_sum_right,
73-
double &output_right,
74-
double &output_left,
75-
const double weight_sum,
76-
const double weight_label_sum,
77-
const double weight_sum_right,
78-
const double positive_weight_sum_right,
79-
const double negative_weight_sum_right,
80-
const double positive_weight_sum,
81-
const double negative_weight_sum
82-
) {
83-
// different boosting types, require different parameters and return different values
84-
switch (this->get_boosting_type()) {
85-
case DISCRETE_TYPE:
86-
if (weight_label_sum_right > 0) {
87-
output_right = 1.0;
88-
output_left = -1.0;
89-
} else {
90-
output_left = 1.0;
91-
output_right = -1.0;
92-
}
93-
break;
94-
case REAL_TYPE:
95-
output_right = log((positive_weight_sum_right + epsilon) / (negative_weight_sum_right + epsilon)) / 2.0;
96-
output_left = log((positive_weight_sum - positive_weight_sum_right + epsilon) / (negative_weight_sum - negative_weight_sum_right + epsilon)) / 2.0;
97-
break;
98-
case GENTLE_TYPE:
99-
output_right = weight_label_sum_right / weight_sum_right;
100-
output_left = (weight_label_sum - weight_label_sum_right) / (weight_sum - weight_sum_right);
101-
break;
102-
default:
103-
return false;
53+
54+
void DesisionStumpAdaboost::calculate_weight_sum(double &weight_sum, double &weight_label_sum, double &positive_weight_sum, double &negative_weight_sum) {
55+
weight_sum = 0.0;
56+
weight_label_sum = 0.0;
57+
positive_weight_sum = 0.0;
58+
negative_weight_sum = 0.0;
59+
60+
size_t total_sample = this->get_total_sample();
61+
for (size_t i = 0; i < total_sample; ++i) {
62+
double w = this->get_weight_by_index(i);
63+
weight_sum += w;
64+
Record record = this->get_sorted_record_by_index(i);
65+
if (POSITIVE == record.label) {
66+
weight_label_sum += w;
67+
positive_weight_sum += w
68+
} else {
69+
weight_label_sum -= w;
70+
negative_weight_sum += w;
71+
}
10472
}
105-
return true;
10673
}
10774

108-
double DecisionStumpAdaBoost::compute_error(const double positive_weight_sum_right,
109-
const double negative_weight_sum_right,
110-
const double positive_weight_sum_left,
111-
const double negative_weight_sum_left,
112-
const double output_right,
113-
const double output_left
114-
) const {
115-
// the error
116-
double error = 0.0;
117-
switch (this->get_boosting_type()) {
118-
case DISCRETE_TYPE:
119-
error = positive_weight_sum_right * (1.0 - output_right) / 2.0;
120-
error += positive_weight_sum_left * (1.0 - output_left) / 2.0;
121-
error += negative_weight_sum_right * (1.0 + output_right) / 2.0;
122-
error += negative_weight_sum_left * (1.0 + output_left) / 2.0;
123-
break;
124-
case REAL_TYPE:
125-
error = positive_weight_sum_right * exp(-output_right);
126-
error += positive_weight_sum_left * exp(-output_left);
127-
error += negative_weight_sum_right * exp(output_right);
128-
error += negative_weight_sum_left * exp(output_left);
129-
break;
130-
case GENTLE_TYPE:
131-
error = positive_weight_sum_right * (1.0 - output_right) * (1.0 - output_right);
132-
error += positive_weight_sum_left * (1.0 - output_left) * (1.0 - output_left);
133-
error += negative_weight_sum_right * (1.0 + output_right) * (1.0 + output_right);
134-
error += negative_weight_sum_left * (1.0 + output_left) * (1.0 + output_left);
135-
break;
136-
default:
137-
return error;
138-
}
139-
return error;
140-
}
141-
75+
14276

14377

14478
#endif

0 commit comments

Comments
 (0)