Skip to content

Commit

Permalink
dtree done
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhiqiang authored and liuzhiqiang committed Mar 15, 2016
1 parent 8de1b66 commit 4f6cbd8
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 20 deletions.
120 changes: 113 additions & 7 deletions gbdt/dtree.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,53 @@

#include "dtree.h"

struct _t_node {
struct _d_tree {
int n; /* instance num in this node */
int leaf; /* is leaf or not, 0 | 1 */
struct _t_node * child[2]; /* splited children nodes */
struct _d_tree * child[2]; /* splited children nodes */
int attr; /* if not leaf, split attr */
double attr_val; /* split val of this attr */
double sg; /* sum of 1-gradient */
double sh; /* sum of 2-gradient */
double wei; /* additive model value */
};

static void tree_grow(DTree * t, DTrain * ds, unsigned long * bit_map, int len, double * g, double * h, int n){
static void update_model(double w, unsigned long * bit_map, int len, double * F, int n){
int i, j, offs, ind;
unsigned long v = 0UL;
unsigned char uc = 0;
unsigned char luc = 0;
// Log Dict
int LogD[129] = {0};
LogD[1] = 0;
LogD[2] = 1;
LogD[4] = 2;
LogD[8] = 3;
LogD[16] = 4;
LogD[32] = 5;
LogD[64] = 6;
LogD[128] = 7;
for (i = 0; i < len; i++){
v = bit_map[i];
offs = 0;
while (v > 0UL){
uc = v & 0xff;
while (uc > 0){
luc = uc & (-uc);
ind = LogD[luc] + offs + (i << 6);
if (ind >= n){
return;
}
F[ind] += w;
uc -= luc;
}
v = v >> 8;
offs += 8;
}
}
}

static void tree_grow(DTree * t, DTD * ds, unsigned long * bit_map, int len, double * F, double * g, double * h, int n){
int offs = 0;
int row = -1;
int counter = 0;
Expand Down Expand Up @@ -70,6 +105,7 @@ static void tree_grow(DTree * t, DTrain * ds, unsigned long * bit_map, int len,
max_gain = gain;
b_attr = i;
attr_val = (val + last_value) / 2.0;

l_sg_b = l_sg;
l_sh_b = l_sh;
r_sg_b = r_sg;
Expand Down Expand Up @@ -127,19 +163,56 @@ static void tree_grow(DTree * t, DTrain * ds, unsigned long * bit_map, int len,
t->child[1] = rn;

// grow tree from ln and rn
tree_grow(ln, ds, l_bit, len, g, h, n);
tree_grow(rn, ds, r_bit, len, g, h, n);
tree_grow(ln, ds, l_bit, len, F, g, h, n);
tree_grow(rn, ds, r_bit, len, F, g, h, n);
}
else{
// t is a leaf node
t->leaf = 1;
t->wei = -1.0 * t->sg / t->sh;
// update model F in leaf node
// add leaf node's weight
update_model(t->wei, bit_map, len, F, n);
}
free(l_bit); l_bit = NULL;
free(r_bit); r_bit = NULL;
}

DTree * generate_dtree(DTrain * ds, double *F, double * g, double * h, int n){
static void scan_tree(DTD * ts, DTree * t, unsigned long * bit_map, int len, double * F, int n){
unsigned long * l_bit = NULL;
unsigned long * r_bit = NULL;
memset(l_bit, 0, len >> 3);
memmove(r_bit, bit_map, len >> 3);
int id, rowid, attr;
double val = 0.0, attr_val;
unsigned long v;
if (t->leaf == 0){
// if not leaf , down tree
l_bit = (unsigned long *)malloc(len >> 3);
r_bit = (unsigned long *)malloc(len >> 3);
for (int i = 0; i < ts->l[attr]; i++){
id = i + ts->cl[attr];
rowid = ts->vals[id].id;
val = ts->vals[id].val;
v = bit_map[v >> 6];
if ((v & (1UL << (rowid & 63))) > 0){
if (val >= attr_val){
l_bit[rowid >> 6] |= (1UL << (rowid & 63));
r_bit[rowid >> 6] ^= (1UL << (rowid & 63));
}
}
}
scan_tree(ts, t->child[0], l_bit, len, F, n);
free(l_bit); l_bit = NULL;
scan_tree(ts, t->child[1], r_bit, len, F, n);
free(r_bit); r_bit = NULL;
}
else{
update_model(t->wei, bit_map, len, F, n);
}
}

DTree * generate_dtree(DTD * ds, double *F, double * g, double * h, int n){
int i ;
// root of tree
DTree * t = (DTree*)malloc(sizeof(DTree));
Expand All @@ -165,7 +238,7 @@ DTree * generate_dtree(DTrain * ds, double *F, double * g, double * h, int n){
memset(bit_map, -1, len << 3);

// tree grow from root t
tree_grow(t, ds, bit_map, len, g, h, n);
tree_grow(t, ds, bit_map, len, F, g, h, n);
free(bit_map); bit_map = NULL;

if (t->leaf == 1){
Expand All @@ -175,3 +248,36 @@ DTree * generate_dtree(DTrain * ds, double *F, double * g, double * h, int n){

return t;
}

void free_dtree(DTree * t){
if (t){
if(t->child[0]){
free_dtree(t->child[0]);
t->child[0] = NULL;
}
if (t->child[1]){
free_dtree(t->child[1]);
t->child[1] = NULL;
}
free(t);
}
}

double * eval_tree(DTD * ts, DTree * t){
int n = ts->row;
int i, offs, len;
double * F = (double*)malloc(sizeof(double) * n);
memset(F, 0, sizeof(double) * n);

len = n >> 6;
if ((n & ((1UL << 6) - 1)) > 0){
len += 1;
}
unsigned long * bit_map = (unsigned long *)malloc(len << 3);
memset(bit_map, -1, len << 3);

scan_tree(ts, t, bit_map, len, F, n);

return F;
}

18 changes: 15 additions & 3 deletions gbdt/dtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,27 @@

#include "tdata.h"


typedef struct _t_node DTree;
typedef struct _d_tree DTree;

/* -----------------------------------------
* brief : generate tree by data ds
* ds : input data
* F : current model
* g : 1-gradient
* h : 2-gradient
* n : length of F ,g, h
* ----------------------------------------- */
DTree * generate_dtree(DTrain * ds, double *F, double * g, double * h, int n);
DTree * generate_dtree(DTD * ds, double *F, double * g, double * h, int n);

/* -------------------------
* brief : free tree space
* ------------------------- */
void free_dtree(DTree * t);

/* ------------------------------------
* brief : predict ts with tree t
* return : predict value(score)
* ------------------------------------ */
double * eval_tree(DTD * ts, DTree * t);

#endif //DTREE_H
12 changes: 2 additions & 10 deletions gbdt/tdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,9 @@ typedef struct {
int neg; /* -1|0 lables cnt */
int col; /* feature count of data */
int * l; /* row cnt of per feautre */
int * cl; /* cumulative row cnt .. */
DPair * vals; /* row id and feature value */
char (*id_map)[FKL]; /* feature id name mapping */
}DTrain;

typedef struct {
int row; /* row num of data */
int *l; /* fea cnt of per row */
DPair * vals; /* fea id and feature value */
double * y; /* labels of data */
int pos; /* +1 labels */
int neg; /* -1 lables */
}DTest;
}DTD;

#endif //TDATA_H

0 comments on commit 4f6cbd8

Please sign in to comment.