Skip to content

Commit

Permalink
gbdt done
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhiqiang authored and liuzhiqiang committed Mar 16, 2016
1 parent 4f6cbd8 commit bee5f61
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 5 deletions.
5 changes: 1 addition & 4 deletions gbdt/dtree.c
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,8 @@ void free_dtree(DTree * t){
}
}

double * eval_tree(DTD * ts, DTree * t){
int n = ts->row;
double * eval_tree(DTD * ts, DTree * t, double F, int n){
int i, offs, len;
double * F = (double*)malloc(sizeof(double) * n);
memset(F, 0, sizeof(double) * n);

len = n >> 6;
if ((n & ((1UL << 6) - 1)) > 0){
Expand Down
2 changes: 1 addition & 1 deletion gbdt/dtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ void free_dtree(DTree * t);
* brief : predict ts with tree t
* return : predict value(score)
* ------------------------------------ */
double * eval_tree(DTD * ts, DTree * t);
double * eval_tree(DTD * ts, DTree * t, double * F, int n);

#endif //DTREE_H
133 changes: 133 additions & 0 deletions gbdt/gbdt.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,137 @@
* info : implementation for gbdt
* ======================================================== */

#include <stdlib.h>
#include <string.h>
#include "gbdt.h"

GBDT * gbdt_create(G g_fn, H h_fn, R r_fn, GBMP p){
GBDT * gbdt = (GBDT*)malloc(sizeof(GBDT));
if (!gbdt){
goto gb_failed;
}
memset(gbdt, 0, sizeof(GBDT));
gbdt->g_fn = g_fn;
gbdt->h_fn = h_fn;
gbdt->r_fn = r_fn;
gbdt->p = p;
gbdt->dts = (DTree **)malloc(sizeof(void *) * p.max_trees);
if (!gbdt->dts){
goto dts_failed;
}
memset(gbdt->dts, 0, sizeof(void *) * p.max_trees);
DTD *(*tds)[2] = load_data(p.train_input, p.test_input);
if (!tds){
goto ds_failed;
}
gbdt->train_ds = (*tds)[0];
gbdt->test_ds = (*tds)[1];
gbdt->f = (double*)malloc(sizeof(double) * gbdt->train_ds->row);
if (!gbdt->f){
goto train_y_failed;
}
memset(gbdt->f, 0, sizeof(double) * gbdt->train_ds->row);
gbdt->t = (double*)malloc(sizeof(double) * gbdt->test_ds->row);
if (!gbdt->t){
goto test_y_failed;
}
memset(gbdt->t, 0, sizeof(double) * gbdt->test_ds->row);

return gbdt;

test_y_failed:
free(gbdt->f);
gbdt->f = NULL;
train_y_failed:
free_data(gbdt->train_ds);
free_data(gbdt->test_ds);
gbdt->train_ds = NULL;
gbdt->test_ds = NULL;
ds_failed:
free(gbdt->dts);
gbdt->dts = NULL;
dts_failed:
free(gbdt);
gbdt = NULL;
gb_failed:
return NULL;
}

int gbdt_train(GBDT * gbdt){
int i, j, n = gbdt->train_ds->row, l = gbdt->test_ds->row;
double * f = (double*)malloc(sizeof(double) * n);
double * t = (double*)malloc(sizeof(double) * l);
double * g = (double*)malloc(sizeof(double) * n);
double * h = (double*)malloc(sizeof(double) * n);
memset(f, 0, sizeof(double) * n);
memset(t, 0, sizeof(double) * l);
memset(g, 0, sizeof(double) * n);
memset(h, 0, sizeof(double) * n);
gbdt->tree_size = 0;

for (i = 0; i < gbdt->p.max_trees; i++) {
gbdt->g_fn(gbdt->f, gbdt->train_ds->y, g, n);
gbdt->h_fn(gbdt->f, gbdt->train_ds->y, h, n);

DTree * tt = generate_dtree(gbdt->train_ds, f, g, h, n);
if (tt){
gbdt->dts[i] = tt;
gbdt->tree_size += 1;
eval_tree(gbdt->test_ds, tt, t, l);
for (j = 0; j < n; j++){
gbdt->f[j] += f[j] * gbdt->p.rate;
}
for (j = 0; j < l; j++){
gbdt->t[j] += t[j] * gbdt->p.rate;
}
memset(f, 0, sizeof(double) * n);
memset(t, 0, sizeof(double) * l);
gbdt->r_fn(gbdt);
}
else{
break;
}
}
free(f); f = NULL;
free(t); t = NULL;
free(g); g = NULL;
free(h); h = NULL;

return 0;
}

void gbdt_save (GBDT * gbdt, int n);

void gbdt_free (GBDT * gbdt){
int i;
if (gbdt){
if (gbdt->train_ds){
free_data(gbdt->train_ds);
gbdt->train_ds = NULL;
}
if (gbdt->test_ds){
free(gbdt->test_ds);
gbdt->test_ds = NULL;
}
if (gbdt->f){
free(gbdt->f);
gbdt->f = NULL;
}
if (gbdt->t){
free(gbdt->t);
gbdt->t = NULL;
}
if (gbdt->dts){
for (i = 0; i < gbdt->tree_size; i++){
if (gbdt->dts[i]){
free_dtree(gbdt->dts[i]);
gbdt->dts[i] = NULL;
}
}
free(gbdt->dts);
gbdt->dts = NULL;
}
free(gbdt);
}
}

29 changes: 29 additions & 0 deletions gbdt/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define _GBDT_H

#include "dtree.h"
#include "tdata.h"


/* --------------------------------------
Expand All @@ -23,6 +24,34 @@
* -------------------------------------- */
typedef void(*G)(double * f, double * y, double * g, int n);
typedef void(*H)(double * f, double * y, double * h, int n);
typedef void(*R)(void * m);

typedef struct _gbdt_param {
double rate;
int max_depth;
int max_trees;
char * train_input;
char * test_input;
char * out_dir;
} GBMP;

typedef struct _gbdt {
DTD * train_ds;
DTD * test_ds;
double * f;
double * t;
int tree_size;
DTree ** dts;
GBMP p;
G g_fn;
H h_fn;
R r_fn;
} GBDT;

GBDT * gbdt_create(G g_fn, H h_fn, R f_fn, GBMP p);
int gbdt_train(GBDT * gbdt);
void gbdt_save (GBDT * gbdt, int n);
void gbdt_free (GBDT * gbdt);

#endif //GBDT_H

4 changes: 4 additions & 0 deletions gbdt/tdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ typedef struct {
char (*id_map)[FKL]; /* feature id name mapping */
}DTD;

DTD *(*load_data(char * train_input, char * test_input))[2];

void free_data(DTD *ts);

#endif //TDATA_H

0 comments on commit bee5f61

Please sign in to comment.