@@ -26,28 +26,9 @@ class DecisionStumpAdaBoost : public AdaBoost {
26
26
void init_weak_classifier ();
27
27
// learn the weak classifier
28
28
void weak_classifier_learn (DecisionStump & stump);
29
+ // calculate the sumation of the weights
30
+ void calculate_weight_sum (double &weight_sum, double &weight_label_sum, double &positive_weight_sum, double &negative_weight_sum);
29
31
30
- // compute the outputs of the weak classifier with precalculated weights
31
- bool calculate_classifier_outputs (const double weight_label_sum_right,
32
- double &output_right,
33
- double &output_left,
34
- const double weight_sum = 0.0 ,
35
- const double weight_label_sum = 0.0 ,
36
- const double weight_sum_right = 0.0 ,
37
- const double positive_weight_sum_right = 0.0 ,
38
- const double negative_weight_sum_right = 0.0 ,
39
- const double positive_weight_sum = 0.0 ,
40
- const double negative_weight_sum = 0.0
41
- );
42
-
43
- // compute the error with precalculated weights, and the outputs of the weak classifier
44
- double compute_error (const double positive_weight_sum_right,
45
- const double negative_weight_sum_right,
46
- const double positive_weight_sum_left,
47
- const double negative_weight_sum_left,
48
- const double output_right,
49
- const double output_left
50
- ) const ;
51
32
52
33
private:
53
34
std::vector<DecisionStump> _weak_classifier_vec;
@@ -69,76 +50,29 @@ void DecisionStumpAdaBoost::weak_classifier_learn(DecisionStump &stump) {
69
50
//
70
51
}
71
52
72
- bool DecisionStumpAdaBoost::calculate_classifier_outputs (const double weight_label_sum_right,
73
- double &output_right,
74
- double &output_left,
75
- const double weight_sum,
76
- const double weight_label_sum,
77
- const double weight_sum_right,
78
- const double positive_weight_sum_right,
79
- const double negative_weight_sum_right,
80
- const double positive_weight_sum,
81
- const double negative_weight_sum
82
- ) {
83
- // different boosting types, require different parameters and return different values
84
- switch (this ->get_boosting_type ()) {
85
- case DISCRETE_TYPE:
86
- if (weight_label_sum_right > 0 ) {
87
- output_right = 1.0 ;
88
- output_left = -1.0 ;
89
- } else {
90
- output_left = 1.0 ;
91
- output_right = -1.0 ;
92
- }
93
- break ;
94
- case REAL_TYPE:
95
- output_right = log ((positive_weight_sum_right + epsilon) / (negative_weight_sum_right + epsilon)) / 2.0 ;
96
- output_left = log ((positive_weight_sum - positive_weight_sum_right + epsilon) / (negative_weight_sum - negative_weight_sum_right + epsilon)) / 2.0 ;
97
- break ;
98
- case GENTLE_TYPE:
99
- output_right = weight_label_sum_right / weight_sum_right;
100
- output_left = (weight_label_sum - weight_label_sum_right) / (weight_sum - weight_sum_right);
101
- break ;
102
- default :
103
- return false ;
53
+
54
+ void DesisionStumpAdaboost::calculate_weight_sum (double &weight_sum, double &weight_label_sum, double &positive_weight_sum, double &negative_weight_sum) {
55
+ weight_sum = 0.0 ;
56
+ weight_label_sum = 0.0 ;
57
+ positive_weight_sum = 0.0 ;
58
+ negative_weight_sum = 0.0 ;
59
+
60
+ size_t total_sample = this ->get_total_sample ();
61
+ for (size_t i = 0 ; i < total_sample; ++i) {
62
+ double w = this ->get_weight_by_index (i);
63
+ weight_sum += w;
64
+ Record record = this ->get_sorted_record_by_index (i);
65
+ if (POSITIVE == record.label ) {
66
+ weight_label_sum += w;
67
+ positive_weight_sum += w
68
+ } else {
69
+ weight_label_sum -= w;
70
+ negative_weight_sum += w;
71
+ }
104
72
}
105
- return true ;
106
73
}
107
74
108
- double DecisionStumpAdaBoost::compute_error (const double positive_weight_sum_right,
109
- const double negative_weight_sum_right,
110
- const double positive_weight_sum_left,
111
- const double negative_weight_sum_left,
112
- const double output_right,
113
- const double output_left
114
- ) const {
115
- // the error
116
- double error = 0.0 ;
117
- switch (this ->get_boosting_type ()) {
118
- case DISCRETE_TYPE:
119
- error = positive_weight_sum_right * (1.0 - output_right) / 2.0 ;
120
- error += positive_weight_sum_left * (1.0 - output_left) / 2.0 ;
121
- error += negative_weight_sum_right * (1.0 + output_right) / 2.0 ;
122
- error += negative_weight_sum_left * (1.0 + output_left) / 2.0 ;
123
- break ;
124
- case REAL_TYPE:
125
- error = positive_weight_sum_right * exp (-output_right);
126
- error += positive_weight_sum_left * exp (-output_left);
127
- error += negative_weight_sum_right * exp (output_right);
128
- error += negative_weight_sum_left * exp (output_left);
129
- break ;
130
- case GENTLE_TYPE:
131
- error = positive_weight_sum_right * (1.0 - output_right) * (1.0 - output_right);
132
- error += positive_weight_sum_left * (1.0 - output_left) * (1.0 - output_left);
133
- error += negative_weight_sum_right * (1.0 + output_right) * (1.0 + output_right);
134
- error += negative_weight_sum_left * (1.0 + output_left) * (1.0 + output_left);
135
- break ;
136
- default :
137
- return error;
138
- }
139
- return error;
140
- }
141
-
75
+
142
76
143
77
144
78
#endif
0 commit comments